summaryrefslogtreecommitdiffstats
path: root/Lib/test
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/buffer_tests.py8
-rw-r--r--Lib/test/crashers/README4
-rw-r--r--Lib/test/crashers/borrowed_ref_1.py29
-rw-r--r--Lib/test/crashers/borrowed_ref_2.py38
-rw-r--r--Lib/test/crashers/compiler_recursion.py5
-rw-r--r--Lib/test/crashers/loosing_mro_ref.py35
-rw-r--r--Lib/test/crashers/nasty_eq_vs_dict.py47
-rw-r--r--Lib/test/crashers/recursion_limit_too_high.py16
-rw-r--r--Lib/test/datetimetester.py128
-rw-r--r--Lib/test/decimaltestdata/extra.decTest13
-rw-r--r--Lib/test/dh512.pem9
-rw-r--r--Lib/test/exception_hierarchy.txt21
-rw-r--r--Lib/test/fork_wait.py10
-rw-r--r--Lib/test/future_test1.py (renamed from Lib/test/test_future1.py)0
-rw-r--r--Lib/test/future_test2.py (renamed from Lib/test/test_future2.py)0
-rw-r--r--Lib/test/json_tests/test_dump.py19
-rw-r--r--Lib/test/json_tests/test_scanstring.py11
-rw-r--r--Lib/test/keycert.passwd.pem33
-rw-r--r--Lib/test/list_tests.py41
-rw-r--r--Lib/test/lock_tests.py16
-rw-r--r--Lib/test/mailcap.txt39
-rw-r--r--Lib/test/math_testcases.txt114
-rw-r--r--Lib/test/memory_watchdog.py28
-rw-r--r--Lib/test/mock_socket.py3
-rw-r--r--Lib/test/multibytecodec_support.py (renamed from Lib/test/test_multibytecodec_support.py)26
-rw-r--r--Lib/test/namespace_pkgs/both_portions/foo/one.py1
-rw-r--r--Lib/test/namespace_pkgs/both_portions/foo/two.py1
-rw-r--r--Lib/test/namespace_pkgs/missing_directory.zipbin0 -> 515 bytes
-rw-r--r--Lib/test/namespace_pkgs/module_and_namespace_package/a_test.py1
-rw-r--r--Lib/test/namespace_pkgs/module_and_namespace_package/a_test/empty0
-rw-r--r--Lib/test/namespace_pkgs/nested_portion1.zipbin0 -> 556 bytes
-rw-r--r--Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/__init__.py0
-rw-r--r--Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/one.py1
-rw-r--r--Lib/test/namespace_pkgs/portion1/foo/one.py1
-rw-r--r--Lib/test/namespace_pkgs/portion2/foo/two.py1
-rw-r--r--Lib/test/namespace_pkgs/project1/parent/child/one.py1
-rw-r--r--Lib/test/namespace_pkgs/project2/parent/child/two.py1
-rw-r--r--Lib/test/namespace_pkgs/project3/parent/child/three.py1
-rw-r--r--Lib/test/namespace_pkgs/top_level_portion1.zipbin0 -> 332 bytes
-rw-r--r--Lib/test/pickletester.py150
-rwxr-xr-xLib/test/regrtest.py281
-rw-r--r--Lib/test/reperf.py4
-rw-r--r--Lib/test/script_helper.py7
-rw-r--r--Lib/test/seq_tests.py7
-rw-r--r--Lib/test/sortperf.py4
-rw-r--r--Lib/test/ssl_key.passwd.pem18
-rw-r--r--Lib/test/ssl_servers.py16
-rw-r--r--Lib/test/string_tests.py87
-rw-r--r--Lib/test/support.py370
-rw-r--r--Lib/test/test__locale.py40
-rw-r--r--Lib/test/test_abc.py196
-rw-r--r--Lib/test/test_abstract_numbers.py2
-rw-r--r--Lib/test/test_aifc.py13
-rw-r--r--Lib/test/test_argparse.py173
-rwxr-xr-xLib/test/test_array.py114
-rw-r--r--Lib/test/test_ast.py450
-rw-r--r--Lib/test/test_asyncore.py202
-rw-r--r--Lib/test/test_base64.py165
-rw-r--r--Lib/test/test_bigmem.py223
-rw-r--r--Lib/test/test_binascii.py34
-rw-r--r--Lib/test/test_bisect.py112
-rw-r--r--Lib/test/test_bool.py10
-rw-r--r--Lib/test/test_buffer.py4291
-rw-r--r--Lib/test/test_bufio.py11
-rw-r--r--Lib/test/test_builtin.py159
-rw-r--r--Lib/test/test_bytes.py244
-rw-r--r--Lib/test/test_bz2.py614
-rw-r--r--Lib/test/test_calendar.py223
-rw-r--r--Lib/test/test_capi.py116
-rw-r--r--Lib/test/test_cgi.py22
-rw-r--r--Lib/test/test_cgitb.py70
-rw-r--r--Lib/test/test_cmd.py2
-rw-r--r--Lib/test/test_cmd_line.py16
-rw-r--r--Lib/test/test_cmd_line_script.py111
-rw-r--r--Lib/test/test_code.py2
-rw-r--r--Lib/test/test_code_module.py72
-rw-r--r--Lib/test/test_codeccallbacks.py161
-rw-r--r--Lib/test/test_codecencodings_cn.py39
-rw-r--r--Lib/test/test_codecencodings_hk.py10
-rw-r--r--Lib/test/test_codecencodings_iso2022.py14
-rw-r--r--Lib/test/test_codecencodings_jp.py118
-rw-r--r--Lib/test/test_codecencodings_kr.py39
-rw-r--r--Lib/test/test_codecencodings_tw.py10
-rw-r--r--Lib/test/test_codecmaps_cn.py8
-rw-r--r--Lib/test/test_codecmaps_hk.py4
-rw-r--r--Lib/test/test_codecmaps_jp.py12
-rw-r--r--Lib/test/test_codecmaps_kr.py8
-rw-r--r--Lib/test/test_codecmaps_tw.py9
-rw-r--r--Lib/test/test_codecs.py464
-rw-r--r--Lib/test/test_coding.py4
-rw-r--r--Lib/test/test_collections.py96
-rw-r--r--Lib/test/test_compile.py67
-rw-r--r--Lib/test/test_concurrent_futures.py62
-rw-r--r--Lib/test/test_configparser.py (renamed from Lib/test/test_cfgparser.py)89
-rw-r--r--Lib/test/test_contextlib.py225
-rw-r--r--Lib/test/test_copy.py187
-rw-r--r--Lib/test/test_cprofile.py23
-rw-r--r--Lib/test/test_crashers.py38
-rw-r--r--Lib/test/test_crypt.py28
-rw-r--r--Lib/test/test_csv.py11
-rw-r--r--Lib/test/test_curses.py50
-rw-r--r--Lib/test/test_dbm.py4
-rw-r--r--Lib/test/test_decimal.py3768
-rw-r--r--Lib/test/test_deque.py15
-rw-r--r--Lib/test/test_descr.py204
-rw-r--r--Lib/test/test_descrtut.py3
-rw-r--r--Lib/test/test_devpoll.py94
-rw-r--r--Lib/test/test_dict.py120
-rw-r--r--Lib/test/test_dictcomps.py5
-rw-r--r--Lib/test/test_dis.py111
-rw-r--r--Lib/test/test_doctest.py446
-rw-r--r--Lib/test/test_dummy_thread.py4
-rw-r--r--Lib/test/test_email.py14
-rw-r--r--Lib/test/test_email/__init__.py150
-rw-r--r--Lib/test/test_email/__main__.py3
-rw-r--r--Lib/test/test_email/data/PyBanner048.gifbin0 -> 954 bytes
-rw-r--r--Lib/test/test_email/data/audiotest.aubin0 -> 28144 bytes
-rw-r--r--Lib/test/test_email/data/msg_01.txt19
-rw-r--r--Lib/test/test_email/data/msg_02.txt135
-rw-r--r--Lib/test/test_email/data/msg_03.txt16
-rw-r--r--Lib/test/test_email/data/msg_04.txt37
-rw-r--r--Lib/test/test_email/data/msg_05.txt28
-rw-r--r--Lib/test/test_email/data/msg_06.txt33
-rw-r--r--Lib/test/test_email/data/msg_07.txt83
-rw-r--r--Lib/test/test_email/data/msg_08.txt24
-rw-r--r--Lib/test/test_email/data/msg_09.txt24
-rw-r--r--Lib/test/test_email/data/msg_10.txt39
-rw-r--r--Lib/test/test_email/data/msg_11.txt7
-rw-r--r--Lib/test/test_email/data/msg_12.txt36
-rw-r--r--Lib/test/test_email/data/msg_12a.txt38
-rw-r--r--Lib/test/test_email/data/msg_13.txt94
-rw-r--r--Lib/test/test_email/data/msg_14.txt23
-rw-r--r--Lib/test/test_email/data/msg_15.txt52
-rw-r--r--Lib/test/test_email/data/msg_16.txt123
-rw-r--r--Lib/test/test_email/data/msg_17.txt12
-rw-r--r--Lib/test/test_email/data/msg_18.txt6
-rw-r--r--Lib/test/test_email/data/msg_19.txt43
-rw-r--r--Lib/test/test_email/data/msg_20.txt22
-rw-r--r--Lib/test/test_email/data/msg_21.txt20
-rw-r--r--Lib/test/test_email/data/msg_22.txt46
-rw-r--r--Lib/test/test_email/data/msg_23.txt8
-rw-r--r--Lib/test/test_email/data/msg_24.txt10
-rw-r--r--Lib/test/test_email/data/msg_25.txt117
-rw-r--r--Lib/test/test_email/data/msg_26.txt46
-rw-r--r--Lib/test/test_email/data/msg_27.txt15
-rw-r--r--Lib/test/test_email/data/msg_28.txt25
-rw-r--r--Lib/test/test_email/data/msg_29.txt22
-rw-r--r--Lib/test/test_email/data/msg_30.txt23
-rw-r--r--Lib/test/test_email/data/msg_31.txt15
-rw-r--r--Lib/test/test_email/data/msg_32.txt14
-rw-r--r--Lib/test/test_email/data/msg_33.txt29
-rw-r--r--Lib/test/test_email/data/msg_34.txt19
-rw-r--r--Lib/test/test_email/data/msg_35.txt4
-rw-r--r--Lib/test/test_email/data/msg_36.txt40
-rw-r--r--Lib/test/test_email/data/msg_37.txt22
-rw-r--r--Lib/test/test_email/data/msg_38.txt101
-rw-r--r--Lib/test/test_email/data/msg_39.txt83
-rw-r--r--Lib/test/test_email/data/msg_40.txt10
-rw-r--r--Lib/test/test_email/data/msg_41.txt8
-rw-r--r--Lib/test/test_email/data/msg_42.txt20
-rw-r--r--Lib/test/test_email/data/msg_43.txt217
-rw-r--r--Lib/test/test_email/data/msg_44.txt33
-rw-r--r--Lib/test/test_email/data/msg_45.txt33
-rw-r--r--Lib/test/test_email/data/msg_46.txt23
-rw-r--r--Lib/test/test_email/test__encoded_words.py187
-rw-r--r--Lib/test/test_email/test__header_value_parser.py2552
-rw-r--r--Lib/test/test_email/test_asian_codecs.py82
-rw-r--r--Lib/test/test_email/test_defect_handling.py320
-rw-r--r--Lib/test/test_email/test_email.py5009
-rw-r--r--Lib/test/test_email/test_generator.py199
-rw-r--r--Lib/test/test_email/test_headerregistry.py1515
-rw-r--r--Lib/test/test_email/test_inversion.py45
-rw-r--r--Lib/test/test_email/test_message.py18
-rw-r--r--Lib/test/test_email/test_parser.py36
-rw-r--r--Lib/test/test_email/test_pickleable.py74
-rw-r--r--Lib/test/test_email/test_policy.py322
-rw-r--r--Lib/test/test_email/test_utils.py136
-rw-r--r--Lib/test/test_email/torture_test.py136
-rw-r--r--Lib/test/test_enumerate.py30
-rw-r--r--Lib/test/test_epoll.py3
-rw-r--r--Lib/test/test_exceptions.py124
-rw-r--r--Lib/test/test_extcall.py87
-rw-r--r--Lib/test/test_faulthandler.py593
-rw-r--r--Lib/test/test_file.py24
-rw-r--r--Lib/test/test_fileinput.py630
-rw-r--r--Lib/test/test_fileio.py5
-rw-r--r--Lib/test/test_float.py21
-rw-r--r--Lib/test/test_format.py46
-rw-r--r--Lib/test/test_fractions.py12
-rw-r--r--Lib/test/test_frozen.py26
-rw-r--r--Lib/test/test_ftplib.py207
-rw-r--r--Lib/test/test_funcattrs.py55
-rw-r--r--Lib/test/test_functools.py93
-rw-r--r--Lib/test/test_future.py12
-rw-r--r--Lib/test/test_gc.py143
-rw-r--r--Lib/test/test_gdb.py143
-rw-r--r--Lib/test/test_generators.py42
-rw-r--r--Lib/test/test_genericpath.py58
-rw-r--r--Lib/test/test_genexps.py8
-rw-r--r--Lib/test/test_getargs2.py114
-rw-r--r--Lib/test/test_glob.py6
-rw-r--r--Lib/test/test_grammar.py123
-rw-r--r--Lib/test/test_gzip.py136
-rw-r--r--Lib/test/test_hash.py24
-rw-r--r--Lib/test/test_hashlib.py4
-rw-r--r--Lib/test/test_heapq.py35
-rw-r--r--Lib/test/test_hmac.py123
-rw-r--r--Lib/test/test_htmlparser.py13
-rw-r--r--Lib/test/test_http_cookiejar.py11
-rw-r--r--Lib/test/test_http_cookies.py9
-rw-r--r--Lib/test/test_httplib.py184
-rw-r--r--Lib/test/test_httpservers.py73
-rw-r--r--Lib/test/test_imaplib.py89
-rw-r--r--Lib/test/test_imp.py172
-rw-r--r--Lib/test/test_import.py528
-rw-r--r--Lib/test/test_importhooks.py13
-rw-r--r--Lib/test/test_importlib.py5
-rw-r--r--Lib/test/test_importlib/__init__.py33
-rw-r--r--Lib/test/test_importlib/__main__.py20
-rw-r--r--Lib/test/test_importlib/abc.py99
-rw-r--r--Lib/test/test_importlib/builtin/__init__.py12
-rw-r--r--Lib/test/test_importlib/builtin/test_finder.py55
-rw-r--r--Lib/test/test_importlib/builtin/test_loader.py105
-rw-r--r--Lib/test/test_importlib/builtin/util.py7
-rw-r--r--Lib/test/test_importlib/extension/__init__.py13
-rw-r--r--Lib/test/test_importlib/extension/test_case_sensitivity.py50
-rw-r--r--Lib/test/test_importlib/extension/test_finder.py46
-rw-r--r--Lib/test/test_importlib/extension/test_loader.py79
-rw-r--r--Lib/test/test_importlib/extension/test_path_hook.py32
-rw-r--r--Lib/test/test_importlib/extension/util.py20
-rw-r--r--Lib/test/test_importlib/frozen/__init__.py13
-rw-r--r--Lib/test/test_importlib/frozen/test_finder.py47
-rw-r--r--Lib/test/test_importlib/frozen/test_loader.py121
-rw-r--r--Lib/test/test_importlib/import_/__init__.py13
-rw-r--r--Lib/test/test_importlib/import_/test___package__.py119
-rw-r--r--Lib/test/test_importlib/import_/test_api.py67
-rw-r--r--Lib/test/test_importlib/import_/test_caching.py88
-rw-r--r--Lib/test/test_importlib/import_/test_fromlist.py127
-rw-r--r--Lib/test/test_importlib/import_/test_meta_path.py115
-rw-r--r--Lib/test/test_importlib/import_/test_packages.py112
-rw-r--r--Lib/test/test_importlib/import_/test_path.py120
-rw-r--r--Lib/test/test_importlib/import_/test_relative_imports.py217
-rw-r--r--Lib/test/test_importlib/import_/util.py28
-rw-r--r--Lib/test/test_importlib/regrtest.py17
-rw-r--r--Lib/test/test_importlib/source/__init__.py13
-rw-r--r--Lib/test/test_importlib/source/test_abc_loader.py906
-rw-r--r--Lib/test/test_importlib/source/test_case_sensitivity.py70
-rw-r--r--Lib/test/test_importlib/source/test_file_loader.py484
-rw-r--r--Lib/test/test_importlib/source/test_finder.py191
-rw-r--r--Lib/test/test_importlib/source/test_path_hook.py32
-rw-r--r--Lib/test/test_importlib/source/test_source_encoding.py123
-rw-r--r--Lib/test/test_importlib/source/util.py97
-rw-r--r--Lib/test/test_importlib/test_abc.py103
-rw-r--r--Lib/test/test_importlib/test_api.py202
-rw-r--r--Lib/test/test_importlib/test_locks.py129
-rw-r--r--Lib/test/test_importlib/test_util.py208
-rw-r--r--Lib/test/test_importlib/util.py140
-rw-r--r--Lib/test/test_inspect.py1143
-rw-r--r--Lib/test/test_int.py39
-rw-r--r--Lib/test/test_io.py146
-rw-r--r--Lib/test/test_ipaddress.py1649
-rw-r--r--Lib/test/test_isinstance.py2
-rw-r--r--Lib/test/test_iter.py43
-rw-r--r--Lib/test/test_itertools.py415
-rw-r--r--Lib/test/test_keywordonlyarg.py2
-rw-r--r--Lib/test/test_lib2to3.py4
-rw-r--r--Lib/test/test_list.py28
-rw-r--r--Lib/test/test_locale.py4
-rw-r--r--Lib/test/test_logging.py1807
-rw-r--r--Lib/test/test_long.py80
-rw-r--r--Lib/test/test_lzma.py1531
-rw-r--r--Lib/test/test_macpath.py8
-rw-r--r--Lib/test/test_mailbox.py56
-rw-r--r--Lib/test/test_mailcap.py221
-rw-r--r--Lib/test/test_marshal.py20
-rw-r--r--Lib/test/test_math.py43
-rw-r--r--Lib/test/test_memoryio.py2
-rw-r--r--Lib/test/test_memoryview.py93
-rw-r--r--Lib/test/test_metaclass.py16
-rw-r--r--Lib/test/test_minidom.py77
-rw-r--r--Lib/test/test_mmap.py32
-rw-r--r--Lib/test/test_module.py99
-rw-r--r--Lib/test/test_modulefinder.py63
-rw-r--r--Lib/test/test_multibytecodec.py16
-rw-r--r--Lib/test/test_multiprocessing.py1053
-rw-r--r--Lib/test/test_mutants.py291
-rw-r--r--Lib/test/test_namespace_pkgs.py294
-rw-r--r--Lib/test/test_nntplib.py27
-rw-r--r--Lib/test/test_ntpath.py13
-rw-r--r--Lib/test/test_numeric_tower.py2
-rw-r--r--Lib/test/test_optparse.py20
-rw-r--r--Lib/test/test_os.py849
-rw-r--r--Lib/test/test_ossaudiodev.py16
-rw-r--r--Lib/test/test_osx_env.py3
-rw-r--r--Lib/test/test_parser.py11
-rw-r--r--Lib/test/test_pdb.py3
-rw-r--r--Lib/test/test_peepholer.py63
-rw-r--r--Lib/test/test_pep277.py66
-rw-r--r--Lib/test/test_pep292.py33
-rw-r--r--Lib/test/test_pep3131.py7
-rw-r--r--Lib/test/test_pep3151.py211
-rw-r--r--Lib/test/test_pep380.py965
-rw-r--r--Lib/test/test_pickle.py28
-rw-r--r--Lib/test/test_pipes.py15
-rw-r--r--Lib/test/test_pkg.py33
-rw-r--r--Lib/test/test_pkgimport.py4
-rw-r--r--Lib/test/test_pkgutil.py102
-rw-r--r--Lib/test/test_platform.py51
-rw-r--r--Lib/test/test_plistlib.py6
-rw-r--r--Lib/test/test_poplib.py15
-rw-r--r--Lib/test/test_posix.py637
-rw-r--r--Lib/test/test_posixpath.py20
-rw-r--r--Lib/test/test_pprint.py4
-rw-r--r--Lib/test/test_print.py26
-rw-r--r--Lib/test/test_property.py23
-rw-r--r--Lib/test/test_pty.py19
-rw-r--r--Lib/test/test_pulldom.py347
-rw-r--r--Lib/test/test_pydoc.py21
-rw-r--r--Lib/test/test_raise.py39
-rw-r--r--Lib/test/test_random.py6
-rw-r--r--Lib/test/test_range.py105
-rw-r--r--Lib/test/test_re.py141
-rw-r--r--Lib/test/test_reprlib.py82
-rw-r--r--Lib/test/test_richcmp.py1
-rw-r--r--Lib/test/test_runpy.py58
-rw-r--r--Lib/test/test_sax.py4
-rw-r--r--Lib/test/test_sched.py135
-rw-r--r--Lib/test/test_scope.py19
-rw-r--r--Lib/test/test_select.py22
-rw-r--r--Lib/test/test_set.py21
-rw-r--r--Lib/test/test_shelve.py6
-rw-r--r--Lib/test/test_shlex.py21
-rw-r--r--Lib/test/test_shutil.py978
-rw-r--r--Lib/test/test_signal.py478
-rw-r--r--Lib/test/test_site.py5
-rw-r--r--Lib/test/test_smtpd.py368
-rw-r--r--Lib/test/test_smtplib.py80
-rw-r--r--Lib/test/test_smtpnet.py42
-rw-r--r--Lib/test/test_socket.py2803
-rw-r--r--Lib/test/test_socketserver.py2
-rw-r--r--Lib/test/test_ssl.py449
-rw-r--r--Lib/test/test_startfile.py2
-rw-r--r--Lib/test/test_stat.py66
-rw-r--r--Lib/test/test_string.py70
-rw-r--r--Lib/test/test_strlit.py32
-rw-r--r--Lib/test/test_struct.py71
-rw-r--r--Lib/test/test_structseq.py5
-rw-r--r--Lib/test/test_subprocess.py271
-rw-r--r--Lib/test/test_sundry.py1
-rw-r--r--Lib/test/test_super.py49
-rw-r--r--Lib/test/test_support.py199
-rw-r--r--Lib/test/test_sys.py179
-rw-r--r--Lib/test/test_sys_settrace.py10
-rw-r--r--Lib/test/test_sysconfig.py119
-rw-r--r--Lib/test/test_tarfile.py160
-rw-r--r--Lib/test/test_telnetlib.py1
-rw-r--r--Lib/test/test_tempfile.py277
-rw-r--r--Lib/test/test_textwrap.py148
-rw-r--r--Lib/test/test_threaded_import.py58
-rw-r--r--Lib/test/test_threading.py21
-rw-r--r--Lib/test/test_threadsignals.py6
-rw-r--r--Lib/test/test_time.py486
-rw-r--r--Lib/test/test_tokenize.py185
-rw-r--r--Lib/test/test_tools.py37
-rw-r--r--Lib/test/test_trace.py63
-rw-r--r--Lib/test/test_traceback.py15
-rw-r--r--Lib/test/test_tuple.py29
-rw-r--r--Lib/test/test_types.py583
-rw-r--r--Lib/test/test_ucn.py90
-rw-r--r--Lib/test/test_unicode.py579
-rw-r--r--Lib/test/test_unicode_file.py17
-rw-r--r--Lib/test/test_unicodedata.py16
-rw-r--r--Lib/test/test_urllib.py78
-rw-r--r--Lib/test/test_urllib2.py129
-rw-r--r--Lib/test/test_urllib2_localnet.py7
-rw-r--r--Lib/test/test_urllib2net.py13
-rw-r--r--Lib/test/test_urllibnet.py34
-rw-r--r--Lib/test/test_userlist.py6
-rwxr-xr-xLib/test/test_userstring.py11
-rw-r--r--Lib/test/test_uuid.py4
-rw-r--r--Lib/test/test_venv.py203
-rw-r--r--Lib/test/test_wait3.py7
-rw-r--r--Lib/test/test_warnings.py66
-rw-r--r--Lib/test/test_webbrowser.py192
-rw-r--r--Lib/test/test_winsound.py8
-rw-r--r--Lib/test/test_wsgiref.py10
-rw-r--r--Lib/test/test_xml_etree.py1341
-rw-r--r--Lib/test/test_xml_etree_c.py104
-rw-r--r--Lib/test/test_xmlrpc.py119
-rw-r--r--Lib/test/test_xmlrpc_net.py4
-rw-r--r--Lib/test/test_zipfile.py373
-rw-r--r--Lib/test/test_zipfile64.py20
-rw-r--r--Lib/test/test_zipimport.py24
-rw-r--r--Lib/test/test_zipimport_support.py5
-rw-r--r--Lib/test/test_zlib.py73
-rw-r--r--Lib/test/testbz2_bigmem.bz2bin3023 -> 0 bytes
-rw-r--r--Lib/test/threaded_import_hangers.py13
-rw-r--r--Lib/test/tokenize_tests.txt8
398 files changed, 56847 insertions, 5707 deletions
diff --git a/Lib/test/buffer_tests.py b/Lib/test/buffer_tests.py
index 6d20f7d..cf54c28 100644
--- a/Lib/test/buffer_tests.py
+++ b/Lib/test/buffer_tests.py
@@ -200,7 +200,13 @@ class MixinBytesBufferCommonTests(object):
self.marshal(b'abc\ndef\r\nghi\n\r').splitlines())
self.assertEqual([b'', b'abc', b'def', b'ghi', b''],
self.marshal(b'\nabc\ndef\r\nghi\n\r').splitlines())
+ self.assertEqual([b'', b'abc', b'def', b'ghi', b''],
+ self.marshal(b'\nabc\ndef\r\nghi\n\r').splitlines(False))
+ self.assertEqual([b'\n', b'abc\n', b'def\r\n', b'ghi\n', b'\r'],
+ self.marshal(b'\nabc\ndef\r\nghi\n\r').splitlines(True))
+ self.assertEqual([b'', b'abc', b'def', b'ghi', b''],
+ self.marshal(b'\nabc\ndef\r\nghi\n\r').splitlines(keepends=False))
self.assertEqual([b'\n', b'abc\n', b'def\r\n', b'ghi\n', b'\r'],
- self.marshal(b'\nabc\ndef\r\nghi\n\r').splitlines(1))
+ self.marshal(b'\nabc\ndef\r\nghi\n\r').splitlines(keepends=True))
self.assertRaises(TypeError, self.marshal(b'abc').splitlines, 42, 42)
diff --git a/Lib/test/crashers/README b/Lib/test/crashers/README
index 2a73e1b..0259a06 100644
--- a/Lib/test/crashers/README
+++ b/Lib/test/crashers/README
@@ -14,3 +14,7 @@ note if the cause is system or environment dependent and what the variables are.
Once the crash is fixed, the test case should be moved into an appropriate test
(even if it was originally from the test suite). This ensures the regression
doesn't happen again. And if it does, it should be easier to track down.
+
+Also see Lib/test_crashers.py which exercises the crashers in this directory.
+In particular, make sure to add any new infinite loop crashers to the black
+list so it doesn't try to run them.
diff --git a/Lib/test/crashers/borrowed_ref_1.py b/Lib/test/crashers/borrowed_ref_1.py
deleted file mode 100644
index b82f464..0000000
--- a/Lib/test/crashers/borrowed_ref_1.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""
-_PyType_Lookup() returns a borrowed reference.
-This attacks the call in dictobject.c.
-"""
-
-class A(object):
- pass
-
-class B(object):
- def __del__(self):
- print('hi')
- del D.__missing__
-
-class D(dict):
- class __missing__:
- def __init__(self, *args):
- pass
-
-
-d = D()
-a = A()
-a.cycle = a
-a.other = B()
-del a
-
-prev = None
-while 1:
- d[5]
- prev = (prev,)
diff --git a/Lib/test/crashers/borrowed_ref_2.py b/Lib/test/crashers/borrowed_ref_2.py
deleted file mode 100644
index 6e403eb..0000000
--- a/Lib/test/crashers/borrowed_ref_2.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""
-_PyType_Lookup() returns a borrowed reference.
-This attacks PyObject_GenericSetAttr().
-
-NB. on my machine this crashes in 2.5 debug but not release.
-"""
-
-class A(object):
- pass
-
-class B(object):
- def __del__(self):
- print("hi")
- del C.d
-
-class D(object):
- def __set__(self, obj, value):
- self.hello = 42
-
-class C(object):
- d = D()
-
- def g():
- pass
-
-
-c = C()
-a = A()
-a.cycle = a
-a.other = B()
-
-lst = [None] * 1000000
-i = 0
-del a
-while 1:
- c.d = 42 # segfaults in PyMethod_New(__func__=D.__set__, __self__=d)
- lst[i] = c.g # consume the free list of instancemethod objects
- i += 1
diff --git a/Lib/test/crashers/compiler_recursion.py b/Lib/test/crashers/compiler_recursion.py
deleted file mode 100644
index 4954bdd..0000000
--- a/Lib/test/crashers/compiler_recursion.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
-The compiler (>= 2.5) recurses happily.
-"""
-
-compile('()'*9**5, '?', 'exec')
diff --git a/Lib/test/crashers/loosing_mro_ref.py b/Lib/test/crashers/loosing_mro_ref.py
deleted file mode 100644
index b3bcd32..0000000
--- a/Lib/test/crashers/loosing_mro_ref.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""
-There is a way to put keys of any type in a type's dictionary.
-I think this allows various kinds of crashes, but so far I have only
-found a convoluted attack of _PyType_Lookup(), which uses the mro of the
-type without holding a strong reference to it. Probably works with
-super.__getattribute__() too, which uses the same kind of code.
-"""
-
-class MyKey(object):
- def __hash__(self):
- return hash('mykey')
-
- def __eq__(self, other):
- # the following line decrefs the previous X.__mro__
- X.__bases__ = (Base2,)
- # trash all tuples of length 3, to make sure that the items of
- # the previous X.__mro__ are really garbage
- z = []
- for i in range(1000):
- z.append((i, None, None))
- return 0
-
-
-class Base(object):
- mykey = 'from Base'
-
-class Base2(object):
- mykey = 'from Base2'
-
-# you can't add a non-string key to X.__dict__, but it can be
-# there from the beginning :-)
-X = type('X', (Base,), {MyKey(): 5})
-
-print(X.mykey)
-# I get a segfault, or a slightly wrong assertion error in a debug build.
diff --git a/Lib/test/crashers/nasty_eq_vs_dict.py b/Lib/test/crashers/nasty_eq_vs_dict.py
deleted file mode 100644
index 85f7caf..0000000
--- a/Lib/test/crashers/nasty_eq_vs_dict.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# from http://mail.python.org/pipermail/python-dev/2001-June/015239.html
-
-# if you keep changing a dictionary while looking up a key, you can
-# provoke an infinite recursion in C
-
-# At the time neither Tim nor Michael could be bothered to think of a
-# way to fix it.
-
-class Yuck:
- def __init__(self):
- self.i = 0
-
- def make_dangerous(self):
- self.i = 1
-
- def __hash__(self):
- # direct to slot 4 in table of size 8; slot 12 when size 16
- return 4 + 8
-
- def __eq__(self, other):
- if self.i == 0:
- # leave dict alone
- pass
- elif self.i == 1:
- # fiddle to 16 slots
- self.__fill_dict(6)
- self.i = 2
- else:
- # fiddle to 8 slots
- self.__fill_dict(4)
- self.i = 1
-
- return 1
-
- def __fill_dict(self, n):
- self.i = 0
- dict.clear()
- for i in range(n):
- dict[i] = i
- dict[self] = "OK!"
-
-y = Yuck()
-dict = {y: "OK!"}
-
-z = Yuck()
-y.make_dangerous()
-print(dict[z])
diff --git a/Lib/test/crashers/recursion_limit_too_high.py b/Lib/test/crashers/recursion_limit_too_high.py
deleted file mode 100644
index ec64936..0000000
--- a/Lib/test/crashers/recursion_limit_too_high.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# The following example may crash or not depending on the platform.
-# E.g. on 32-bit Intel Linux in a "standard" configuration it seems to
-# crash on Python 2.5 (but not 2.4 nor 2.3). On Windows the import
-# eventually fails to find the module, possibly because we run out of
-# file handles.
-
-# The point of this example is to show that sys.setrecursionlimit() is a
-# hack, and not a robust solution. This example simply exercises a path
-# where it takes many C-level recursions, consuming a lot of stack
-# space, for each Python-level recursion. So 1000 times this amount of
-# stack space may be too much for standard platforms already.
-
-import sys
-if 'recursion_limit_too_high' in sys.modules:
- del sys.modules['recursion_limit_too_high']
-import recursion_limit_too_high
diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py
index bb18630..931ef6f 100644
--- a/Lib/test/datetimetester.py
+++ b/Lib/test/datetimetester.py
@@ -979,7 +979,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
# exempt such platforms (provided they return reasonable
# results!).
for insane in -1e200, 1e200:
- self.assertRaises(ValueError, self.theclass.fromtimestamp,
+ self.assertRaises(OverflowError, self.theclass.fromtimestamp,
insane)
def test_today(self):
@@ -1291,12 +1291,18 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
self.assertTrue(self.theclass.min)
self.assertTrue(self.theclass.max)
- def test_strftime_out_of_range(self):
- # For nasty technical reasons, we can't handle years before 1000.
- cls = self.theclass
- self.assertEqual(cls(1000, 1, 1).strftime("%Y"), "1000")
- for y in 1, 49, 51, 99, 100, 999:
- self.assertRaises(ValueError, cls(y, 1, 1).strftime, "%Y")
+ def test_strftime_y2k(self):
+ for y in (1, 49, 70, 99, 100, 999, 1000, 1970):
+ d = self.theclass(y, 1, 1)
+ # Issue 13305: For years < 1000, the value is not always
+ # padded to 4 digits across platforms. The C standard
+ # assumes year >= 1900, so it does not specify the number
+ # of digits.
+ if d.strftime("%Y") != '%04d' % y:
+ # Year 42 returns '42', not padded
+ self.assertEqual(d.strftime("%Y"), '%d' % y)
+ # '0042' is obtained anyway
+ self.assertEqual(d.strftime("%4Y"), '%04d' % y)
def test_replace(self):
cls = self.theclass
@@ -1731,13 +1737,74 @@ class TestDateTime(TestDate):
got = self.theclass.utcfromtimestamp(ts)
self.verify_field_equality(expected, got)
+ # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in
+ # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0).
+ @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0')
+ def test_timestamp_naive(self):
+ t = self.theclass(1970, 1, 1)
+ self.assertEqual(t.timestamp(), 18000.0)
+ t = self.theclass(1970, 1, 1, 1, 2, 3, 4)
+ self.assertEqual(t.timestamp(),
+ 18000.0 + 3600 + 2*60 + 3 + 4*1e-6)
+ # Missing hour may produce platform-dependent result
+ t = self.theclass(2012, 3, 11, 2, 30)
+ self.assertIn(self.theclass.fromtimestamp(t.timestamp()),
+ [t - timedelta(hours=1), t + timedelta(hours=1)])
+ # Ambiguous hour defaults to DST
+ t = self.theclass(2012, 11, 4, 1, 30)
+ self.assertEqual(self.theclass.fromtimestamp(t.timestamp()), t)
+
+ # Timestamp may raise an overflow error on some platforms
+ for t in [self.theclass(1,1,1), self.theclass(9999,12,12)]:
+ try:
+ s = t.timestamp()
+ except OverflowError:
+ pass
+ else:
+ self.assertEqual(self.theclass.fromtimestamp(s), t)
+
+ def test_timestamp_aware(self):
+ t = self.theclass(1970, 1, 1, tzinfo=timezone.utc)
+ self.assertEqual(t.timestamp(), 0.0)
+ t = self.theclass(1970, 1, 1, 1, 2, 3, 4, tzinfo=timezone.utc)
+ self.assertEqual(t.timestamp(),
+ 3600 + 2*60 + 3 + 4*1e-6)
+ t = self.theclass(1970, 1, 1, 1, 2, 3, 4,
+ tzinfo=timezone(timedelta(hours=-5), 'EST'))
+ self.assertEqual(t.timestamp(),
+ 18000 + 3600 + 2*60 + 3 + 4*1e-6)
def test_microsecond_rounding(self):
- # Test whether fromtimestamp "rounds up" floats that are less
- # than 1/2 microsecond smaller than an integer.
for fts in [self.theclass.fromtimestamp,
self.theclass.utcfromtimestamp]:
- self.assertEqual(fts(0.9999999), fts(1))
- self.assertEqual(fts(0.99999949).microsecond, 999999)
+ zero = fts(0)
+ self.assertEqual(zero.second, 0)
+ self.assertEqual(zero.microsecond, 0)
+ try:
+ minus_one = fts(-1e-6)
+ except OSError:
+ # localtime(-1) and gmtime(-1) is not supported on Windows
+ pass
+ else:
+ self.assertEqual(minus_one.second, 59)
+ self.assertEqual(minus_one.microsecond, 999999)
+
+ t = fts(-1e-8)
+ self.assertEqual(t, minus_one)
+ t = fts(-9e-7)
+ self.assertEqual(t, minus_one)
+ t = fts(-1e-7)
+ self.assertEqual(t, minus_one)
+
+ t = fts(1e-7)
+ self.assertEqual(t, zero)
+ t = fts(9e-7)
+ self.assertEqual(t, zero)
+ t = fts(0.99999949)
+ self.assertEqual(t.second, 0)
+ self.assertEqual(t.microsecond, 999999)
+ t = fts(0.9999999)
+ self.assertEqual(t.second, 0)
+ self.assertEqual(t.microsecond, 999999)
def test_insane_fromtimestamp(self):
# It's possible that some platform maps time_t to double,
@@ -1745,7 +1812,7 @@ class TestDateTime(TestDate):
# exempt such platforms (provided they return reasonable
# results!).
for insane in -1e200, 1e200:
- self.assertRaises(ValueError, self.theclass.fromtimestamp,
+ self.assertRaises(OverflowError, self.theclass.fromtimestamp,
insane)
def test_insane_utcfromtimestamp(self):
@@ -1754,7 +1821,7 @@ class TestDateTime(TestDate):
# exempt such platforms (provided they return reasonable
# results!).
for insane in -1e200, 1e200:
- self.assertRaises(ValueError, self.theclass.utcfromtimestamp,
+ self.assertRaises(OverflowError, self.theclass.utcfromtimestamp,
insane)
@unittest.skipIf(sys.platform == "win32", "Windows doesn't accept negative timestamps")
def test_negative_float_fromtimestamp(self):
@@ -1907,7 +1974,7 @@ class TestDateTime(TestDate):
# simply can't be applied to a naive object.
dt = self.theclass.now()
f = FixedOffset(44, "")
- self.assertRaises(TypeError, dt.astimezone) # not enough args
+ self.assertRaises(ValueError, dt.astimezone) # naive
self.assertRaises(TypeError, dt.astimezone, f, f) # too many args
self.assertRaises(TypeError, dt.astimezone, dt) # arg wrong type
self.assertRaises(ValueError, dt.astimezone, f) # naive
@@ -2479,7 +2546,7 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
self.assertEqual(t1, t2)
self.assertEqual(t1, t3)
self.assertEqual(t2, t3)
- self.assertRaises(TypeError, lambda: t4 == t5) # mixed tz-aware & naive
+ self.assertNotEqual(t4, t5) # mixed tz-aware & naive
self.assertRaises(TypeError, lambda: t4 < t5) # mixed tz-aware & naive
self.assertRaises(TypeError, lambda: t5 < t4) # mixed tz-aware & naive
@@ -2631,7 +2698,7 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
t2 = t2.replace(tzinfo=FixedOffset(None, ""))
self.assertEqual(t1, t2)
t2 = t2.replace(tzinfo=FixedOffset(0, ""))
- self.assertRaises(TypeError, lambda: t1 == t2)
+ self.assertNotEqual(t1, t2)
# In time w/ identical tzinfo objects, utcoffset is ignored.
class Varies(tzinfo):
@@ -2736,16 +2803,16 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
microsecond=1)
self.assertTrue(t1 > t2)
- # Make t2 naive and it should fail.
+ # Make t2 naive and it should differ.
t2 = self.theclass.min
- self.assertRaises(TypeError, lambda: t1 == t2)
+ self.assertNotEqual(t1, t2)
self.assertEqual(t2, t2)
# It's also naive if it has tzinfo but tzinfo.utcoffset() is None.
class Naive(tzinfo):
def utcoffset(self, dt): return None
t2 = self.theclass(5, 6, 7, tzinfo=Naive())
- self.assertRaises(TypeError, lambda: t1 == t2)
+ self.assertNotEqual(t1, t2)
self.assertEqual(t2, t2)
# OTOH, it's OK to compare two of these mixing the two ways of being
@@ -3188,8 +3255,6 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
self.assertTrue(dt.tzinfo is f44m)
# Replacing with degenerate tzinfo raises an exception.
self.assertRaises(ValueError, dt.astimezone, fnone)
- # Ditto with None tz.
- self.assertRaises(TypeError, dt.astimezone, None)
# Replacing with same tzinfo makes no change.
x = dt.astimezone(dt.tzinfo)
self.assertTrue(x.tzinfo is f44m)
@@ -3209,6 +3274,25 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
self.assertTrue(got.tzinfo is expected.tzinfo)
self.assertEqual(got, expected)
+ @support.run_with_tz('UTC')
+ def test_astimezone_default_utc(self):
+ dt = self.theclass.now(timezone.utc)
+ self.assertEqual(dt.astimezone(None), dt)
+ self.assertEqual(dt.astimezone(), dt)
+
+ # Note that offset in TZ variable has the opposite sign to that
+ # produced by %z directive.
+ @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0')
+ def test_astimezone_default_eastern(self):
+ dt = self.theclass(2012, 11, 4, 6, 30, tzinfo=timezone.utc)
+ local = dt.astimezone()
+ self.assertEqual(dt, local)
+ self.assertEqual(local.strftime("%z %Z"), "-0500 EST")
+ dt = self.theclass(2012, 11, 4, 5, 30, tzinfo=timezone.utc)
+ local = dt.astimezone()
+ self.assertEqual(dt, local)
+ self.assertEqual(local.strftime("%z %Z"), "-0400 EDT")
+
def test_aware_subtract(self):
cls = self.theclass
@@ -3262,7 +3346,7 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
t2 = t2.replace(tzinfo=FixedOffset(None, ""))
self.assertEqual(t1, t2)
t2 = t2.replace(tzinfo=FixedOffset(0, ""))
- self.assertRaises(TypeError, lambda: t1 == t2)
+ self.assertNotEqual(t1, t2)
# In datetime w/ identical tzinfo objects, utcoffset is ignored.
class Varies(tzinfo):
diff --git a/Lib/test/decimaltestdata/extra.decTest b/Lib/test/decimaltestdata/extra.decTest
index fe8b77a..b630d8e 100644
--- a/Lib/test/decimaltestdata/extra.decTest
+++ b/Lib/test/decimaltestdata/extra.decTest
@@ -222,12 +222,25 @@ extr1700 power 10 1e-999999999 -> 1.000000000000000 Inexact Rounded
extr1701 power 100.0 -557.71e-742888888 -> 1.000000000000000 Inexact Rounded
extr1702 power 10 1e-100 -> 1.000000000000000 Inexact Rounded
+-- Another one (see issue #12080). Thanks again to Stefan Krah.
+extr1703 power 4 -1.2e-999999999 -> 1.000000000000000 Inexact Rounded
+
-- A couple of interesting exact cases for power. Note that the specification
-- requires these to be reported as Inexact.
extr1710 power 1e375 56e-3 -> 1.000000000000000E+21 Inexact Rounded
extr1711 power 10000 0.75 -> 1000.000000000000 Inexact Rounded
extr1712 power 1e-24 0.875 -> 1.000000000000000E-21 Inexact Rounded
+-- Some more exact cases, exercising power with negative second argument.
+extr1720 power 400 -0.5 -> 0.05000000000000000 Inexact Rounded
+extr1721 power 4096 -0.75 -> 0.001953125000000000 Inexact Rounded
+extr1722 power 625e4 -0.25 -> 0.02000000000000000 Inexact Rounded
+
+-- Nonexact cases, to exercise some of the early exit conditions from
+-- _power_exact.
+extr1730 power 2048 -0.75 -> 0.003284751622084822 Inexact Rounded
+
+
-- Tests for the is_* boolean operations
precision: 9
maxExponent: 999
diff --git a/Lib/test/dh512.pem b/Lib/test/dh512.pem
new file mode 100644
index 0000000..200d16c
--- /dev/null
+++ b/Lib/test/dh512.pem
@@ -0,0 +1,9 @@
+-----BEGIN DH PARAMETERS-----
+MEYCQQD1Kv884bEpQBgRjXyEpwpy1obEAxnIByl6ypUM2Zafq9AKUJsCRtMIPWak
+XUGfnHy9iUsiGSa6q6Jew1XpKgVfAgEC
+-----END DH PARAMETERS-----
+
+These are the 512 bit DH parameters from "Assigned Number for SKIP Protocols"
+(http://www.skip-vpn.org/spec/numbers.html).
+See there for how they were generated.
+Note that g is not a generator, but this is not a problem since p is a safe prime.
diff --git a/Lib/test/exception_hierarchy.txt b/Lib/test/exception_hierarchy.txt
index 5037b33..1c1f69f 100644
--- a/Lib/test/exception_hierarchy.txt
+++ b/Lib/test/exception_hierarchy.txt
@@ -11,11 +11,6 @@ BaseException
+-- AssertionError
+-- AttributeError
+-- BufferError
- +-- EnvironmentError
- | +-- IOError
- | +-- OSError
- | +-- WindowsError (Windows)
- | +-- VMSError (VMS)
+-- EOFError
+-- ImportError
+-- LookupError
@@ -24,6 +19,22 @@ BaseException
+-- MemoryError
+-- NameError
| +-- UnboundLocalError
+ +-- OSError
+ | +-- BlockingIOError
+ | +-- ChildProcessError
+ | +-- ConnectionError
+ | | +-- BrokenPipeError
+ | | +-- ConnectionAbortedError
+ | | +-- ConnectionRefusedError
+ | | +-- ConnectionResetError
+ | +-- FileExistsError
+ | +-- FileNotFoundError
+ | +-- InterruptedError
+ | +-- IsADirectoryError
+ | +-- NotADirectoryError
+ | +-- PermissionError
+ | +-- ProcessLookupError
+ | +-- TimeoutError
+-- ReferenceError
+-- RuntimeError
| +-- NotImplementedError
diff --git a/Lib/test/fork_wait.py b/Lib/test/fork_wait.py
index 1caab1c..88527df 100644
--- a/Lib/test/fork_wait.py
+++ b/Lib/test/fork_wait.py
@@ -43,6 +43,7 @@ class ForkWait(unittest.TestCase):
self.assertEqual(spid, cpid)
self.assertEqual(status, 0, "cause = %d, exit = %d" % (status&0xff, status>>8))
+ @support.reap_threads
def test_wait(self):
for i in range(NUM_THREADS):
_thread.start_new(self.f, (i,))
@@ -69,7 +70,8 @@ class ForkWait(unittest.TestCase):
os._exit(n)
else:
# Parent
- self.wait_impl(cpid)
- # Tell threads to die
- self.stop = 1
- time.sleep(2*SHORTSLEEP) # Wait for threads to die
+ try:
+ self.wait_impl(cpid)
+ finally:
+ # Tell threads to die
+ self.stop = 1
diff --git a/Lib/test/test_future1.py b/Lib/test/future_test1.py
index 297c2e0..297c2e0 100644
--- a/Lib/test/test_future1.py
+++ b/Lib/test/future_test1.py
diff --git a/Lib/test/test_future2.py b/Lib/test/future_test2.py
index 3d7fc86..3d7fc86 100644
--- a/Lib/test/test_future2.py
+++ b/Lib/test/future_test2.py
diff --git a/Lib/test/json_tests/test_dump.py b/Lib/test/json_tests/test_dump.py
index fee972e..237ee35 100644
--- a/Lib/test/json_tests/test_dump.py
+++ b/Lib/test/json_tests/test_dump.py
@@ -1,6 +1,7 @@
from io import StringIO
from test.json_tests import PyTest, CTest
+from test.support import bigmemtest, _1G
class TestDump:
def test_dump(self):
@@ -29,4 +30,20 @@ class TestDump:
class TestPyDump(TestDump, PyTest): pass
-class TestCDump(TestDump, CTest): pass
+
+class TestCDump(TestDump, CTest):
+
+ # The size requirement here is hopefully over-estimated (actual
+ # memory consumption depending on implementation details, and also
+ # system memory management, since this may allocate a lot of
+ # small objects).
+
+ @bigmemtest(size=_1G, memuse=1)
+ def test_large_list(self, size):
+ N = int(30 * 1024 * 1024 * (size / _1G))
+ l = [1] * N
+ encoded = self.dumps(l)
+ self.assertEqual(len(encoded), N * 3)
+ self.assertEqual(encoded[:1], "[")
+ self.assertEqual(encoded[-2:], "1]")
+ self.assertEqual(encoded[1:-2], "1, " * (N - 1))
diff --git a/Lib/test/json_tests/test_scanstring.py b/Lib/test/json_tests/test_scanstring.py
index f82cdee..426c8dd 100644
--- a/Lib/test/json_tests/test_scanstring.py
+++ b/Lib/test/json_tests/test_scanstring.py
@@ -9,14 +9,9 @@ class TestScanstring:
scanstring('"z\\ud834\\udd20x"', 1, True),
('z\U0001d120x', 16))
- if sys.maxunicode == 65535:
- self.assertEqual(
- scanstring('"z\U0001d120x"', 1, True),
- ('z\U0001d120x', 6))
- else:
- self.assertEqual(
- scanstring('"z\U0001d120x"', 1, True),
- ('z\U0001d120x', 5))
+ self.assertEqual(
+ scanstring('"z\U0001d120x"', 1, True),
+ ('z\U0001d120x', 5))
self.assertEqual(
scanstring('"\\u007b"', 1, True),
diff --git a/Lib/test/keycert.passwd.pem b/Lib/test/keycert.passwd.pem
new file mode 100644
index 0000000..e905748
--- /dev/null
+++ b/Lib/test/keycert.passwd.pem
@@ -0,0 +1,33 @@
+-----BEGIN RSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: DES-EDE3-CBC,1A8D9D2A02EC698A
+
+kJYbfZ8L0sfe9Oty3gw0aloNnY5E8fegRfQLZlNoxTl6jNt0nIwI8kDJ36CZgR9c
+u3FDJm/KqrfUoz8vW+qEnWhSG7QPX2wWGPHd4K94Yz/FgrRzZ0DoK7XxXq9gOtVA
+AVGQhnz32p+6WhfGsCr9ArXEwRZrTk/FvzEPaU5fHcoSkrNVAGX8IpSVkSDwEDQr
+Gv17+cfk99UV1OCza6yKHoFkTtrC+PZU71LomBabivS2Oc4B9hYuSR2hF01wTHP+
+YlWNagZOOVtNz4oKK9x9eNQpmfQXQvPPTfusexKIbKfZrMvJoxcm1gfcZ0H/wK6P
+6wmXSG35qMOOztCZNtperjs1wzEBXznyK8QmLcAJBjkfarABJX9vBEzZV0OUKhy+
+noORFwHTllphbmydLhu6ehLUZMHPhzAS5UN7srtpSN81eerDMy0RMUAwA7/PofX1
+94Me85Q8jP0PC9ETdsJcPqLzAPETEYu0ELewKRcrdyWi+tlLFrpE5KT/s5ecbl9l
+7B61U4Kfd1PIXc/siINhU3A3bYK+845YyUArUOnKf1kEox7p1RpD7yFqVT04lRTo
+cibNKATBusXSuBrp2G6GNuhWEOSafWCKJQAzgCYIp6ZTV2khhMUGppc/2H3CF6cO
+zX0KtlPVZC7hLkB6HT8SxYUwF1zqWY7+/XPPdc37MeEZ87Q3UuZwqORLY+Z0hpgt
+L5JXBCoklZhCAaN2GqwFLXtGiRSRFGY7xXIhbDTlE65Wv1WGGgDLMKGE1gOz3yAo
+2jjG1+yAHJUdE69XTFHSqSkvaloA1W03LdMXZ9VuQJ/ySXCie6ABAQ==
+-----END RSA PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV
+BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u
+IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw
+MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH
+Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k
+YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
+gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7
+6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt
+pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw
+FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd
+BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G
+lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1
+CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX
+-----END CERTIFICATE-----
diff --git a/Lib/test/list_tests.py b/Lib/test/list_tests.py
index be054ea..42e118b 100644
--- a/Lib/test/list_tests.py
+++ b/Lib/test/list_tests.py
@@ -418,6 +418,47 @@ class CommonTest(seq_tests.CommonTest):
self.assertRaises(TypeError, u.reverse, 42)
+ def test_clear(self):
+ u = self.type2test([2, 3, 4])
+ u.clear()
+ self.assertEqual(u, [])
+
+ u = self.type2test([])
+ u.clear()
+ self.assertEqual(u, [])
+
+ u = self.type2test([])
+ u.append(1)
+ u.clear()
+ u.append(2)
+ self.assertEqual(u, [2])
+
+ self.assertRaises(TypeError, u.clear, None)
+
+ def test_copy(self):
+ u = self.type2test([1, 2, 3])
+ v = u.copy()
+ self.assertEqual(v, [1, 2, 3])
+
+ u = self.type2test([])
+ v = u.copy()
+ self.assertEqual(v, [])
+
+ # test that it's indeed a copy and not a reference
+ u = self.type2test(['a', 'b'])
+ v = u.copy()
+ v.append('i')
+ self.assertEqual(u, ['a', 'b'])
+ self.assertEqual(v, u + ['i'])
+
+ # test that it's a shallow, not a deep copy
+ u = self.type2test([1, 2, [3, 4], 5])
+ v = u.copy()
+ self.assertEqual(u, v)
+ self.assertIs(v[3], u[3])
+
+ self.assertRaises(TypeError, u.copy, None)
+
def test_sort(self):
u = self.type2test([1, 0])
u.sort()
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py
index 094cc7a..bfbf44e 100644
--- a/Lib/test/lock_tests.py
+++ b/Lib/test/lock_tests.py
@@ -4,7 +4,7 @@ Various tests for synchronization primitives.
import sys
import time
-from _thread import start_new_thread, get_ident, TIMEOUT_MAX
+from _thread import start_new_thread, TIMEOUT_MAX
import threading
import unittest
@@ -31,7 +31,7 @@ class Bunch(object):
self.finished = []
self._can_exit = not wait_before_exit
def task():
- tid = get_ident()
+ tid = threading.get_ident()
self.started.append(tid)
try:
f()
@@ -255,6 +255,18 @@ class RLockTests(BaseLockTests):
lock.release()
self.assertRaises(RuntimeError, lock.release)
+ def test_release_save_unacquired(self):
+ # Cannot _release_save an unacquired lock
+ lock = self.locktype()
+ self.assertRaises(RuntimeError, lock._release_save)
+ lock.acquire()
+ lock.acquire()
+ lock.release()
+ lock.acquire()
+ lock.release()
+ lock.release()
+ self.assertRaises(RuntimeError, lock._release_save)
+
def test_different_thread(self):
# Cannot release from a different thread
lock = self.locktype()
diff --git a/Lib/test/mailcap.txt b/Lib/test/mailcap.txt
new file mode 100644
index 0000000..f61135d
--- /dev/null
+++ b/Lib/test/mailcap.txt
@@ -0,0 +1,39 @@
+# Mailcap file for test_mailcap; based on RFC 1524
+# Referred to by test_mailcap.py
+
+#
+# This is a comment.
+#
+
+application/frame; showframe %s; print="cat %s | lp"
+application/postscript; ps-to-terminal %s;\
+ needsterminal
+application/postscript; ps-to-terminal %s; \
+ compose=idraw %s
+application/x-dvi; xdvi %s
+application/x-movie; movieplayer %s; compose=moviemaker %s; \
+ description="Movie"; \
+ x11-bitmap="/usr/lib/Zmail/bitmaps/movie.xbm"
+application/*; echo "This is \"%t\" but \
+ is 50 \% Greek to me" \; cat %s; copiousoutput
+
+audio/basic; showaudio %s; compose=audiocompose %s; edit=audiocompose %s;\
+description="An audio fragment"
+audio/* ; /usr/local/bin/showaudio %t
+
+image/rgb; display %s
+#image/gif; display %s
+image/x-xwindowdump; display %s
+
+# The continuation char shouldn't \
+# make a difference in a comment.
+
+message/external-body; showexternal %s %{access-type} %{name} %{site} \
+ %{directory} %{mode} %{server}; needsterminal; composetyped = extcompose %s; \
+ description="A reference to data stored in an external location"
+
+text/richtext; shownonascii iso-8859-8 -e richtext -p %s; test=test "`echo \
+ %{charset} | tr '[A-Z]' '[a-z]'`" = iso-8859-8; copiousoutput
+
+video/mpeg; mpeg_play %s
+video/*; animate %s
diff --git a/Lib/test/math_testcases.txt b/Lib/test/math_testcases.txt
index 5e24335..9585188 100644
--- a/Lib/test/math_testcases.txt
+++ b/Lib/test/math_testcases.txt
@@ -517,3 +517,117 @@ expm10306 expm1 1.79e308 -> inf overflow
-- weaker version of expm10302
expm10307 expm1 709.5 -> 1.3549863193146328e+308
+
+-------------------------
+-- log2: log to base 2 --
+-------------------------
+
+-- special values
+log20000 log2 0.0 -> -inf divide-by-zero
+log20001 log2 -0.0 -> -inf divide-by-zero
+log20002 log2 inf -> inf
+log20003 log2 -inf -> nan invalid
+log20004 log2 nan -> nan
+
+-- exact value at 1.0
+log20010 log2 1.0 -> 0.0
+
+-- negatives
+log20020 log2 -5e-324 -> nan invalid
+log20021 log2 -1.0 -> nan invalid
+log20022 log2 -1.7e-308 -> nan invalid
+
+-- exact values at powers of 2
+log20100 log2 2.0 -> 1.0
+log20101 log2 4.0 -> 2.0
+log20102 log2 8.0 -> 3.0
+log20103 log2 16.0 -> 4.0
+log20104 log2 32.0 -> 5.0
+log20105 log2 64.0 -> 6.0
+log20106 log2 128.0 -> 7.0
+log20107 log2 256.0 -> 8.0
+log20108 log2 512.0 -> 9.0
+log20109 log2 1024.0 -> 10.0
+log20110 log2 2048.0 -> 11.0
+
+log20200 log2 0.5 -> -1.0
+log20201 log2 0.25 -> -2.0
+log20202 log2 0.125 -> -3.0
+log20203 log2 0.0625 -> -4.0
+
+-- values close to 1.0
+log20300 log2 1.0000000000000002 -> 3.2034265038149171e-16
+log20301 log2 1.0000000001 -> 1.4426951601859516e-10
+log20302 log2 1.00001 -> 1.4426878274712997e-5
+
+log20310 log2 0.9999999999999999 -> -1.6017132519074588e-16
+log20311 log2 0.9999999999 -> -1.4426951603302210e-10
+log20312 log2 0.99999 -> -1.4427022544056922e-5
+
+-- tiny values
+log20400 log2 5e-324 -> -1074.0
+log20401 log2 1e-323 -> -1073.0
+log20402 log2 1.5e-323 -> -1072.4150374992789
+log20403 log2 2e-323 -> -1072.0
+
+log20410 log2 1e-308 -> -1023.1538532253076
+log20411 log2 2.2250738585072014e-308 -> -1022.0
+log20412 log2 4.4501477170144028e-308 -> -1021.0
+log20413 log2 1e-307 -> -1019.8319251304202
+
+-- huge values
+log20500 log2 1.7976931348623157e+308 -> 1024.0
+log20501 log2 1.7e+308 -> 1023.9193879716706
+log20502 log2 8.9884656743115795e+307 -> 1023.0
+
+-- selection of random values
+log20600 log2 -7.2174324841039838e+289 -> nan invalid
+log20601 log2 -2.861319734089617e+265 -> nan invalid
+log20602 log2 -4.3507646894008962e+257 -> nan invalid
+log20603 log2 -6.6717265307520224e+234 -> nan invalid
+log20604 log2 -3.9118023786619294e+229 -> nan invalid
+log20605 log2 -1.5478221302505161e+206 -> nan invalid
+log20606 log2 -1.4380485131364602e+200 -> nan invalid
+log20607 log2 -3.7235198730382645e+185 -> nan invalid
+log20608 log2 -1.0472242235095724e+184 -> nan invalid
+log20609 log2 -5.0141781956163884e+160 -> nan invalid
+log20610 log2 -2.1157958031160324e+124 -> nan invalid
+log20611 log2 -7.9677558612567718e+90 -> nan invalid
+log20612 log2 -5.5553906194063732e+45 -> nan invalid
+log20613 log2 -16573900952607.953 -> nan invalid
+log20614 log2 -37198371019.888618 -> nan invalid
+log20615 log2 -6.0727115121422674e-32 -> nan invalid
+log20616 log2 -2.5406841656526057e-38 -> nan invalid
+log20617 log2 -4.9056766703267657e-43 -> nan invalid
+log20618 log2 -2.1646786075228305e-71 -> nan invalid
+log20619 log2 -2.470826790488573e-78 -> nan invalid
+log20620 log2 -3.8661709303489064e-165 -> nan invalid
+log20621 log2 -1.0516496976649986e-182 -> nan invalid
+log20622 log2 -1.5935458614317996e-255 -> nan invalid
+log20623 log2 -2.8750977267336654e-293 -> nan invalid
+log20624 log2 -7.6079466794732585e-296 -> nan invalid
+log20625 log2 3.2073253539988545e-307 -> -1018.1505544209213
+log20626 log2 1.674937885472249e-244 -> -809.80634755783126
+log20627 log2 1.0911259044931283e-214 -> -710.76679472274213
+log20628 log2 2.0275372624809709e-154 -> -510.55719818383272
+log20629 log2 7.3926087369631841e-115 -> -379.13564735312292
+log20630 log2 1.3480198206342423e-86 -> -285.25497445094436
+log20631 log2 8.9927384655719947e-83 -> -272.55127136401637
+log20632 log2 3.1452398713597487e-60 -> -197.66251564496875
+log20633 log2 7.0706573215457351e-55 -> -179.88420087782217
+log20634 log2 3.1258285390731669e-49 -> -161.13023800505653
+log20635 log2 8.2253046627829942e-41 -> -133.15898277355879
+log20636 log2 7.8691367397519897e+49 -> 165.75068202732419
+log20637 log2 2.9920561983925013e+64 -> 214.18453534573757
+log20638 log2 4.7827254553946841e+77 -> 258.04629628445673
+log20639 log2 3.1903566496481868e+105 -> 350.47616767491166
+log20640 log2 5.6195082449502419e+113 -> 377.86831861008250
+log20641 log2 9.9625658250651047e+125 -> 418.55752921228753
+log20642 log2 2.7358945220961532e+145 -> 483.13158636923413
+log20643 log2 2.785842387926931e+174 -> 579.49360214860280
+log20644 log2 2.4169172507252751e+193 -> 642.40529039289652
+log20645 log2 3.1689091206395632e+205 -> 682.65924573798395
+log20646 log2 2.535995592365391e+208 -> 692.30359597460460
+log20647 log2 6.2011236566089916e+233 -> 776.64177576730913
+log20648 log2 2.1843274820677632e+253 -> 841.57499717289647
+log20649 log2 8.7493931063474791e+297 -> 989.74182713073981
diff --git a/Lib/test/memory_watchdog.py b/Lib/test/memory_watchdog.py
new file mode 100644
index 0000000..88cca8d
--- /dev/null
+++ b/Lib/test/memory_watchdog.py
@@ -0,0 +1,28 @@
+"""Memory watchdog: periodically read the memory usage of the main test process
+and print it out, until terminated."""
+# stdin should refer to the process' /proc/<PID>/statm: we don't pass the
+# process' PID to avoid a race condition in case of - unlikely - PID recycling.
+# If the process crashes, reading from the /proc entry will fail with ESRCH.
+
+
+import os
+import sys
+import time
+
+
+try:
+ page_size = os.sysconf('SC_PAGESIZE')
+except (ValueError, AttributeError):
+ try:
+ page_size = os.sysconf('SC_PAGE_SIZE')
+ except (ValueError, AttributeError):
+ page_size = 4096
+
+while True:
+ sys.stdin.seek(0)
+ statm = sys.stdin.read()
+ data = int(statm.split()[5])
+ sys.stdout.write(" ... process data size: {data:.1f}G\n"
+ .format(data=data * page_size / (1024 ** 3)))
+ sys.stdout.flush()
+ time.sleep(1)
diff --git a/Lib/test/mock_socket.py b/Lib/test/mock_socket.py
index 8036932..d09e78c 100644
--- a/Lib/test/mock_socket.py
+++ b/Lib/test/mock_socket.py
@@ -106,7 +106,8 @@ def socket(family=None, type=None, proto=None):
return MockSocket()
-def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT):
+def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None):
try:
int_port = int(address[1])
except ValueError:
diff --git a/Lib/test/test_multibytecodec_support.py b/Lib/test/multibytecodec_support.py
index ef63b69..26bac7b 100644
--- a/Lib/test/test_multibytecodec_support.py
+++ b/Lib/test/multibytecodec_support.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
-# test_multibytecodec_support.py
+# multibytecodec_support.py
# Common Unittest Routines for CJK codecs
#
@@ -108,12 +108,19 @@ class TestBase:
self.assertEqual(self.encode(sin,
"test.xmlcharnamereplace")[0], sout)
+ def test_callback_returns_bytes(self):
+ def myreplace(exc):
+ return (b"1234", exc.end)
+ codecs.register_error("test.cjktest", myreplace)
+ enc = self.encode("abc" + self.unmappedunicode + "def", "test.cjktest")[0]
+ self.assertEqual(enc, b"abc1234def")
+
def test_callback_wrong_objects(self):
def myreplace(exc):
return (ret, exc.end)
codecs.register_error("test.cjktest", myreplace)
- for ret in ([1, 2, 3], [], None, object(), b'string', b''):
+ for ret in ([1, 2, 3], [], None, object()):
self.assertRaises(TypeError, self.encode, self.unmappedunicode,
'test.cjktest')
@@ -264,21 +271,6 @@ class TestBase:
self.assertEqual(ostream.getvalue(), self.tstring[0])
-if len('\U00012345') == 2: # ucs2 build
- _unichr = chr
- def chr(v):
- if v >= 0x10000:
- return _unichr(0xd800 + ((v - 0x10000) >> 10)) + \
- _unichr(0xdc00 + ((v - 0x10000) & 0x3ff))
- else:
- return _unichr(v)
- _ord = ord
- def ord(c):
- if len(c) == 2:
- return 0x10000 + ((_ord(c[0]) - 0xd800) << 10) + \
- (ord(c[1]) - 0xdc00)
- else:
- return _ord(c)
class TestBase_Mapping(unittest.TestCase):
pass_enctest = []
diff --git a/Lib/test/namespace_pkgs/both_portions/foo/one.py b/Lib/test/namespace_pkgs/both_portions/foo/one.py
new file mode 100644
index 0000000..3080f6f
--- /dev/null
+++ b/Lib/test/namespace_pkgs/both_portions/foo/one.py
@@ -0,0 +1 @@
+attr = 'both_portions foo one'
diff --git a/Lib/test/namespace_pkgs/both_portions/foo/two.py b/Lib/test/namespace_pkgs/both_portions/foo/two.py
new file mode 100644
index 0000000..4131d3d
--- /dev/null
+++ b/Lib/test/namespace_pkgs/both_portions/foo/two.py
@@ -0,0 +1 @@
+attr = 'both_portions foo two'
diff --git a/Lib/test/namespace_pkgs/missing_directory.zip b/Lib/test/namespace_pkgs/missing_directory.zip
new file mode 100644
index 0000000..836a910
--- /dev/null
+++ b/Lib/test/namespace_pkgs/missing_directory.zip
Binary files differ
diff --git a/Lib/test/namespace_pkgs/module_and_namespace_package/a_test.py b/Lib/test/namespace_pkgs/module_and_namespace_package/a_test.py
new file mode 100644
index 0000000..43cbedb
--- /dev/null
+++ b/Lib/test/namespace_pkgs/module_and_namespace_package/a_test.py
@@ -0,0 +1 @@
+attr = 'in module'
diff --git a/Lib/test/namespace_pkgs/module_and_namespace_package/a_test/empty b/Lib/test/namespace_pkgs/module_and_namespace_package/a_test/empty
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/Lib/test/namespace_pkgs/module_and_namespace_package/a_test/empty
diff --git a/Lib/test/namespace_pkgs/nested_portion1.zip b/Lib/test/namespace_pkgs/nested_portion1.zip
new file mode 100644
index 0000000..8d22406
--- /dev/null
+++ b/Lib/test/namespace_pkgs/nested_portion1.zip
Binary files differ
diff --git a/Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/__init__.py b/Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/__init__.py
diff --git a/Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/one.py b/Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/one.py
new file mode 100644
index 0000000..d8f5c83
--- /dev/null
+++ b/Lib/test/namespace_pkgs/not_a_namespace_pkg/foo/one.py
@@ -0,0 +1 @@
+attr = 'portion1 foo one'
diff --git a/Lib/test/namespace_pkgs/portion1/foo/one.py b/Lib/test/namespace_pkgs/portion1/foo/one.py
new file mode 100644
index 0000000..d8f5c83
--- /dev/null
+++ b/Lib/test/namespace_pkgs/portion1/foo/one.py
@@ -0,0 +1 @@
+attr = 'portion1 foo one'
diff --git a/Lib/test/namespace_pkgs/portion2/foo/two.py b/Lib/test/namespace_pkgs/portion2/foo/two.py
new file mode 100644
index 0000000..d092e1e
--- /dev/null
+++ b/Lib/test/namespace_pkgs/portion2/foo/two.py
@@ -0,0 +1 @@
+attr = 'portion2 foo two'
diff --git a/Lib/test/namespace_pkgs/project1/parent/child/one.py b/Lib/test/namespace_pkgs/project1/parent/child/one.py
new file mode 100644
index 0000000..2776fcd
--- /dev/null
+++ b/Lib/test/namespace_pkgs/project1/parent/child/one.py
@@ -0,0 +1 @@
+attr = 'parent child one'
diff --git a/Lib/test/namespace_pkgs/project2/parent/child/two.py b/Lib/test/namespace_pkgs/project2/parent/child/two.py
new file mode 100644
index 0000000..8b037bc
--- /dev/null
+++ b/Lib/test/namespace_pkgs/project2/parent/child/two.py
@@ -0,0 +1 @@
+attr = 'parent child two'
diff --git a/Lib/test/namespace_pkgs/project3/parent/child/three.py b/Lib/test/namespace_pkgs/project3/parent/child/three.py
new file mode 100644
index 0000000..f8abfe1
--- /dev/null
+++ b/Lib/test/namespace_pkgs/project3/parent/child/three.py
@@ -0,0 +1 @@
+attr = 'parent child three'
diff --git a/Lib/test/namespace_pkgs/top_level_portion1.zip b/Lib/test/namespace_pkgs/top_level_portion1.zip
new file mode 100644
index 0000000..3b866c9
--- /dev/null
+++ b/Lib/test/namespace_pkgs/top_level_portion1.zip
Binary files differ
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index 4d491b0..fb04830 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -4,10 +4,11 @@ import pickle
import pickletools
import sys
import copyreg
+import weakref
from http.cookies import SimpleCookie
from test.support import (
- TestFailed, TESTFN, run_with_locale,
+ TestFailed, TESTFN, run_with_locale, no_tracing,
_2G, _4G, bigmemtest,
)
@@ -18,7 +19,7 @@ from pickle import bytes_types
# kind of outer loop.
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
-character_size = 4 if sys.maxunicode > 0xFFFF else 2
+ascii_char_size = 1
# Return True if opcode code appears in the pickle, else False.
@@ -747,6 +748,18 @@ class AbstractPickleTests(unittest.TestCase):
u = self.loads(s)
self.assertEqual(t, u)
+ def test_ellipsis(self):
+ for proto in protocols:
+ s = self.dumps(..., proto)
+ u = self.loads(s)
+ self.assertEqual(..., u)
+
+ def test_notimplemented(self):
+ for proto in protocols:
+ s = self.dumps(NotImplemented, proto)
+ u = self.loads(s)
+ self.assertEqual(NotImplemented, u)
+
# Tests for protocol 2
def test_proto(self):
@@ -880,6 +893,25 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(B(x), B(y), detail)
self.assertEqual(x.__dict__, y.__dict__, detail)
+ def test_newobj_proxies(self):
+ # NEWOBJ should use the __class__ rather than the raw type
+ classes = myclasses[:]
+ # Cannot create weakproxies to these classes
+ for c in (MyInt, MyTuple):
+ classes.remove(c)
+ for proto in protocols:
+ for C in classes:
+ B = C.__base__
+ x = C(C.sample)
+ x.foo = 42
+ p = weakref.proxy(x)
+ s = self.dumps(p, proto)
+ y = self.loads(s)
+ self.assertEqual(type(y), type(x)) # rather than type(p)
+ detail = (proto, C, B, x, y, type(y))
+ self.assertEqual(B(x), B(y), detail)
+ self.assertEqual(x.__dict__, y.__dict__, detail)
+
# Register a type with copyreg, with extension code extcode. Pickle
# an object of that type. Check that the resulting pickle uses opcode
# (EXT[124]) under proto 2, and not in proto 1.
@@ -1040,13 +1072,13 @@ class AbstractPickleTests(unittest.TestCase):
y = self.loads(s)
self.assertEqual(y._reduce_called, 1)
+ @no_tracing
def test_bad_getattr(self):
x = BadGetattr()
for proto in 0, 1:
self.assertRaises(RuntimeError, self.dumps, x, proto)
# protocol 2 don't raise a RuntimeError.
d = self.dumps(x, 2)
- self.assertRaises(RuntimeError, self.loads, d)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
@@ -1136,6 +1168,15 @@ class AbstractPickleTests(unittest.TestCase):
empty = self.loads(b'\x80\x03U\x00q\x00.', encoding='koi8-r')
self.assertEqual(empty, '')
+ def test_int_pickling_efficiency(self):
+ # Test compacity of int representation (see issue #12744)
+ for proto in protocols:
+ sizes = [len(self.dumps(2**n, proto)) for n in range(70)]
+ # the size function is monotonic
+ self.assertEqual(sorted(sizes), sizes)
+ if proto >= 2:
+ self.assertLessEqual(sizes[-1], 14)
+
def check_negative_32b_binXXX(self, dumped):
if sys.maxsize > 2**32:
self.skipTest("test is only meaningful on 32-bit builds")
@@ -1217,7 +1258,7 @@ class BigmemPickleTests(unittest.TestCase):
# All protocols use 1-byte per printable ASCII character; we add another
# byte because the encoded form has to be copied into the internal buffer.
- @bigmemtest(size=_2G, memuse=2 + character_size, dry_run=False)
+ @bigmemtest(size=_2G, memuse=2 + ascii_char_size, dry_run=False)
def test_huge_str_32b(self, size):
data = "abcd" * (size // 4)
try:
@@ -1234,7 +1275,7 @@ class BigmemPickleTests(unittest.TestCase):
# BINUNICODE (protocols 1, 2 and 3) cannot carry more than
# 2**32 - 1 bytes of utf-8 encoded unicode.
- @bigmemtest(size=_4G, memuse=1 + character_size, dry_run=False)
+ @bigmemtest(size=_4G, memuse=1 + ascii_char_size, dry_run=False)
def test_huge_str_64b(self, size):
data = "a" * size
try:
@@ -1578,6 +1619,105 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
self.assertEqual(unpickler.load(), data)
+# Tests for dispatch_table attribute
+
+REDUCE_A = 'reduce_A'
+
+class AAA(object):
+ def __reduce__(self):
+ return str, (REDUCE_A,)
+
+class BBB(object):
+ pass
+
+class AbstractDispatchTableTests(unittest.TestCase):
+
+ def test_default_dispatch_table(self):
+ # No dispatch_table attribute by default
+ f = io.BytesIO()
+ p = self.pickler_class(f, 0)
+ with self.assertRaises(AttributeError):
+ p.dispatch_table
+ self.assertFalse(hasattr(p, 'dispatch_table'))
+
+ def test_class_dispatch_table(self):
+ # A dispatch_table attribute can be specified class-wide
+ dt = self.get_dispatch_table()
+
+ class MyPickler(self.pickler_class):
+ dispatch_table = dt
+
+ def dumps(obj, protocol=None):
+ f = io.BytesIO()
+ p = MyPickler(f, protocol)
+ self.assertEqual(p.dispatch_table, dt)
+ p.dump(obj)
+ return f.getvalue()
+
+ self._test_dispatch_table(dumps, dt)
+
+ def test_instance_dispatch_table(self):
+ # A dispatch_table attribute can also be specified instance-wide
+ dt = self.get_dispatch_table()
+
+ def dumps(obj, protocol=None):
+ f = io.BytesIO()
+ p = self.pickler_class(f, protocol)
+ p.dispatch_table = dt
+ self.assertEqual(p.dispatch_table, dt)
+ p.dump(obj)
+ return f.getvalue()
+
+ self._test_dispatch_table(dumps, dt)
+
+ def _test_dispatch_table(self, dumps, dispatch_table):
+ def custom_load_dump(obj):
+ return pickle.loads(dumps(obj, 0))
+
+ def default_load_dump(obj):
+ return pickle.loads(pickle.dumps(obj, 0))
+
+ # pickling complex numbers using protocol 0 relies on copyreg
+ # so check pickling a complex number still works
+ z = 1 + 2j
+ self.assertEqual(custom_load_dump(z), z)
+ self.assertEqual(default_load_dump(z), z)
+
+ # modify pickling of complex
+ REDUCE_1 = 'reduce_1'
+ def reduce_1(obj):
+ return str, (REDUCE_1,)
+ dispatch_table[complex] = reduce_1
+ self.assertEqual(custom_load_dump(z), REDUCE_1)
+ self.assertEqual(default_load_dump(z), z)
+
+ # check picklability of AAA and BBB
+ a = AAA()
+ b = BBB()
+ self.assertEqual(custom_load_dump(a), REDUCE_A)
+ self.assertIsInstance(custom_load_dump(b), BBB)
+ self.assertEqual(default_load_dump(a), REDUCE_A)
+ self.assertIsInstance(default_load_dump(b), BBB)
+
+ # modify pickling of BBB
+ dispatch_table[BBB] = reduce_1
+ self.assertEqual(custom_load_dump(a), REDUCE_A)
+ self.assertEqual(custom_load_dump(b), REDUCE_1)
+ self.assertEqual(default_load_dump(a), REDUCE_A)
+ self.assertIsInstance(default_load_dump(b), BBB)
+
+ # revert pickling of BBB and modify pickling of AAA
+ REDUCE_2 = 'reduce_2'
+ def reduce_2(obj):
+ return str, (REDUCE_2,)
+ dispatch_table[AAA] = reduce_2
+ del dispatch_table[BBB]
+ self.assertEqual(custom_load_dump(a), REDUCE_2)
+ self.assertIsInstance(custom_load_dump(b), BBB)
+ self.assertEqual(default_load_dump(a), REDUCE_A)
+ self.assertIsInstance(default_load_dump(b), BBB)
+
+
if __name__ == "__main__":
# Print some stuff that can be used to rewrite DATA{0,1,2}
from pickletools import dis
diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py
index 84beb8d..636282e 100755
--- a/Lib/test/regrtest.py
+++ b/Lib/test/regrtest.py
@@ -20,6 +20,11 @@ python -E -Wd -m test [options] [test_name1 ...]
Options:
-h/--help -- print this text and exit
+--timeout TIMEOUT
+ -- dump the traceback and exit if a test takes more
+ than TIMEOUT seconds; disabled if TIMEOUT is negative
+ or equals to zero
+--wait -- wait for user input, e.g., allow a debugger to be attached
Verbosity
@@ -44,6 +49,9 @@ Selecting tests
-- specify which special resource intensive tests to run
-M/--memlimit LIMIT
-- run very large memory-consuming tests
+ --testdir DIR
+ -- execute test files in the specified directory (instead
+ of the Python stdlib test suite)
Special runs
@@ -125,6 +133,8 @@ resources to test. Currently only the following are defined:
all - Enable all special resources.
+ none - Disable all special resources (this is the default).
+
audio - Tests that use the audio device. (There are known
cases of broken audio drivers that can crash Python or
even the Linux kernel.)
@@ -155,8 +165,11 @@ example, to run all the tests except for the gui tests, give the
option '-uall,-gui'.
"""
+# We import importlib *ASAP* in order to test #15386
+import importlib
+
import builtins
-import errno
+import faulthandler
import getopt
import io
import json
@@ -166,6 +179,7 @@ import platform
import random
import re
import shutil
+import signal
import sys
import sysconfig
import tempfile
@@ -225,6 +239,7 @@ ENV_CHANGED = -1
SKIPPED = -2
RESOURCE_DENIED = -3
INTERRUPTED = -4
+CHILD_ERROR = -5 # error in a child process
from test import support
@@ -268,6 +283,18 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
on the command line.
"""
+ # Display the Python traceback on fatal errors (e.g. segfault)
+ faulthandler.enable(all_threads=True)
+
+ # Display the Python traceback on SIGALRM or SIGUSR1 signal
+ signals = []
+ if hasattr(signal, 'SIGALRM'):
+ signals.append(signal.SIGALRM)
+ if hasattr(signal, 'SIGUSR1'):
+ signals.append(signal.SIGUSR1)
+ for signum in signals:
+ faulthandler.register(signum, chain=True)
+
replace_stdout()
support.record_original_stdout(sys.stdout)
@@ -278,7 +305,8 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
'use=', 'threshold=', 'coverdir=', 'nocoverdir',
'runleaks', 'huntrleaks=', 'memlimit=', 'randseed=',
'multiprocess=', 'coverage', 'slaveargs=', 'forever', 'debug',
- 'start=', 'nowindows', 'header', 'failfast', 'match='])
+ 'start=', 'nowindows', 'header', 'testdir=', 'timeout=', 'wait',
+ 'failfast', 'match='])
except getopt.error as msg:
usage(msg)
@@ -289,6 +317,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
use_resources = []
debug = False
start = None
+ timeout = None
for o, a in opts:
if o in ('-h', '--help'):
print(__doc__)
@@ -332,7 +361,9 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
elif o in ('-T', '--coverage'):
trace = True
elif o in ('-D', '--coverdir'):
- coverdir = os.path.join(os.getcwd(), a)
+ # CWD is replaced with a temporary dir before calling main(), so we
+ # need join it with the saved CWD so it goes where the user expects.
+ coverdir = os.path.join(support.SAVEDCWD, a)
elif o in ('-N', '--nocoverdir'):
coverdir = None
elif o in ('-R', '--huntrleaks'):
@@ -350,9 +381,9 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
huntrleaks[1] = int(huntrleaks[1])
if len(huntrleaks) == 2 or not huntrleaks[2]:
huntrleaks[2:] = ["reflog.txt"]
- # Avoid false positives due to the character cache in
- # stringobject.c filling slowly with random data
- warm_char_cache()
+ # Avoid false positives due to various caches
+ # filling slowly with random data:
+ warm_caches()
elif o in ('-M', '--memlimit'):
support.set_memlimit(a)
elif o in ('-u', '--use'):
@@ -361,6 +392,9 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
if r == 'all':
use_resources[:] = RESOURCE_NAMES
continue
+ if r == 'none':
+ del use_resources[:]
+ continue
remove = False
if r[0] == '-':
remove = True
@@ -391,18 +425,45 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
forever = True
elif o in ('-j', '--multiprocess'):
use_mp = int(a)
+ if use_mp <= 0:
+ try:
+ import multiprocessing
+ # Use all cores + extras for tests that like to sleep
+ use_mp = 2 + multiprocessing.cpu_count()
+ except (ImportError, NotImplementedError):
+ use_mp = 3
+ if use_mp == 1:
+ use_mp = None
elif o == '--header':
header = True
elif o == '--slaveargs':
args, kwargs = json.loads(a)
try:
result = runtest(*args, **kwargs)
+ except KeyboardInterrupt:
+ result = INTERRUPTED, ''
except BaseException as e:
- result = INTERRUPTED, e.__class__.__name__
+ traceback.print_exc()
+ result = CHILD_ERROR, str(e)
sys.stdout.flush()
print() # Force a newline (just in case)
print(json.dumps(result))
sys.exit(0)
+ elif o == '--testdir':
+ # CWD is replaced with a temporary dir before calling main(), so we
+ # join it with the saved CWD so it ends up where the user expects.
+ testdir = os.path.join(support.SAVEDCWD, a)
+ elif o == '--timeout':
+ if hasattr(faulthandler, 'dump_tracebacks_later'):
+ timeout = float(a)
+ if timeout <= 0:
+ timeout = None
+ else:
+ print("Warning: The timeout option requires "
+ "faulthandler.dump_tracebacks_later")
+ timeout = None
+ elif o == '--wait':
+ input("Press any key to continue...")
else:
print(("No handler for option {}. Please report this as a bug "
"at http://bugs.python.org.").format(o), file=sys.stderr)
@@ -481,7 +542,13 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
print("== ", os.getcwd())
print("Testing with flags:", sys.flags)
- alltests = findtests(testdir, stdtests, nottests)
+ # if testdir is set, then we are not running the python tests suite, so
+ # don't add default tests to be executed or skipped (pass empty values)
+ if testdir:
+ alltests = findtests(testdir, list(), set())
+ else:
+ alltests = findtests(testdir, stdtests, nottests)
+
selected = tests or args or alltests
if single:
selected = selected[:1]
@@ -501,7 +568,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
random.shuffle(selected)
if trace:
import trace, tempfile
- tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix,
+ tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,
tempfile.gettempdir()],
trace=False, count=True)
@@ -566,7 +633,8 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
(test, verbose, quiet),
dict(huntrleaks=huntrleaks, use_resources=use_resources,
debug=debug, output_on_failure=verbose3,
- failfast=failfast, match_tests=match_tests)
+ timeout=timeout, failfast=failfast,
+ match_tests=match_tests)
)
# -E is needed by some tests, e.g. test_import
# Running the child from the same working directory ensures
@@ -578,10 +646,15 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
close_fds=(os.name != 'nt'),
cwd=support.SAVEDCWD)
stdout, stderr = popen.communicate()
+ retcode = popen.wait()
# Strip last refcount output line if it exists, since it
# comes from the shutdown of the interpreter in the subcommand.
stderr = debug_output_pat.sub("", stderr)
stdout, _, result = stdout.strip().rpartition("\n")
+ if retcode != 0:
+ result = (CHILD_ERROR, "Exit code %s" % retcode)
+ output.put((test, stdout.rstrip(), stderr.rstrip(), result))
+ return
if not result:
output.put((None, None, None, None))
return
@@ -614,8 +687,9 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
sys.stdout.flush()
sys.stderr.flush()
if result[0] == INTERRUPTED:
- assert result[1] == 'KeyboardInterrupt'
- raise KeyboardInterrupt # What else?
+ raise KeyboardInterrupt
+ if result[0] == CHILD_ERROR:
+ raise Exception("Child error on {}: {}".format(test, result[1]))
test_index += 1
except KeyboardInterrupt:
interrupted = True
@@ -632,13 +706,14 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
if trace:
# If we're tracing code coverage, then we don't exit with status
# if on a false return value from main.
- tracer.runctx('runtest(test, verbose, quiet)',
+ tracer.runctx('runtest(test, verbose, quiet, timeout=timeout)',
globals=globals(), locals=vars())
else:
try:
result = runtest(test, verbose, quiet, huntrleaks, debug,
output_on_failure=verbose3,
- failfast=failfast, match_tests=match_tests)
+ timeout=timeout, failfast=failfast,
+ match_tests=match_tests)
accumulate_result(test, result)
except KeyboardInterrupt:
interrupted = True
@@ -709,7 +784,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
sys.stdout.flush()
try:
verbose = True
- ok = runtest(test, True, quiet, huntrleaks, debug)
+ ok = runtest(test, True, quiet, huntrleaks, debug, timeout=timeout)
except KeyboardInterrupt:
# print a newline separate from the ^C
print()
@@ -734,6 +809,8 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
sys.exit(len(bad) > 0 or interrupted)
+# small set of tests to determine if we have a basically functioning interpreter
+# (i.e. if any of these fail, then anything else is likely to follow)
STDTESTS = [
'test_grammar',
'test_opcodes',
@@ -744,12 +821,11 @@ STDTESTS = [
'test_unittest',
'test_doctest',
'test_doctest2',
+ 'test_support'
]
-NOTTESTS = {
- 'test_future1',
- 'test_future2',
-}
+# set of tests that we don't want to be executed when using regrtest
+NOTTESTS = set()
def findtests(testdir=None, stdtests=STDTESTS, nottests=NOTTESTS):
"""Return a list of all applicable test modules."""
@@ -758,9 +834,9 @@ def findtests(testdir=None, stdtests=STDTESTS, nottests=NOTTESTS):
tests = []
others = set(stdtests) | nottests
for name in names:
- modname, ext = os.path.splitext(name)
- if modname[:5] == "test_" and ext == ".py" and modname not in others:
- tests.append(modname)
+ mod, ext = os.path.splitext(name)
+ if mod[:5] == "test_" and ext in (".py", "") and mod not in others:
+ tests.append(mod)
return stdtests + sorted(tests)
# We do not use a generator so multiple threads can call next().
@@ -785,17 +861,14 @@ class MultiprocessTests(object):
def replace_stdout():
"""Set stdout encoder error handler to backslashreplace (as stderr error
handler) to avoid UnicodeEncodeError when printing a traceback"""
- if os.name == "nt":
- # Replace sys.stdout breaks the stdout newlines on Windows: issue #8533
- return
-
import atexit
stdout = sys.stdout
sys.stdout = open(stdout.fileno(), 'w',
encoding=stdout.encoding,
errors="backslashreplace",
- closefd=False)
+ closefd=False,
+ newline='\n')
def restore_stdout():
sys.stdout.close()
@@ -804,7 +877,8 @@ def replace_stdout():
def runtest(test, verbose, quiet,
huntrleaks=False, debug=False, use_resources=None,
- output_on_failure=False, failfast=False, match_tests=None):
+ output_on_failure=False, failfast=False, match_tests=None,
+ timeout=None):
"""Run a single test.
test -- the name of the test
@@ -814,6 +888,8 @@ def runtest(test, verbose, quiet,
huntrleaks -- run multiple times to test for leaks; requires a debug
build; a triple corresponding to -R's three arguments
output_on_failure -- if true, display test output on failure
+ timeout -- dump the traceback and exit if a test takes more than
+ timeout seconds
Returns one of the test result constants:
INTERRUPTED KeyboardInterrupt when run under -j
@@ -826,6 +902,9 @@ def runtest(test, verbose, quiet,
if use_resources is not None:
support.use_resources = use_resources
+ use_timeout = (timeout is not None)
+ if use_timeout:
+ faulthandler.dump_tracebacks_later(timeout, exit=True)
try:
support.match_tests = match_tests
if failfast:
@@ -864,6 +943,8 @@ def runtest(test, verbose, quiet,
display_failure=not verbose)
return result
finally:
+ if use_timeout:
+ faulthandler.cancel_dump_tracebacks_later()
cleanup_test_droppings(test, verbose)
runtest.stringio = None
@@ -909,10 +990,10 @@ class saved_test_environment:
resources = ('sys.argv', 'cwd', 'sys.stdin', 'sys.stdout', 'sys.stderr',
'os.environ', 'sys.path', 'sys.path_hooks', '__import__',
'warnings.filters', 'asyncore.socket_map',
- 'logging._handlers', 'logging._handlerList',
- 'shutil.archive_formats', 'shutil.unpack_formats',
+ 'logging._handlers', 'logging._handlerList', 'sys.gettrace',
'sys.warnoptions', 'threading._dangling',
'multiprocessing.process._dangling',
+ 'sysconfig._CONFIG_VARS', 'sysconfig._INSTALL_SCHEMES',
'support.TESTFN',
)
@@ -961,6 +1042,11 @@ class saved_test_environment:
sys.path_hooks = saved_hooks[1]
sys.path_hooks[:] = saved_hooks[2]
+ def get_sys_gettrace(self):
+ return sys.gettrace()
+ def restore_sys_gettrace(self, trace_fxn):
+ sys.settrace(trace_fxn)
+
def get___import__(self):
return builtins.__import__
def restore___import__(self, import_):
@@ -1044,6 +1130,24 @@ class saved_test_environment:
multiprocessing.process._dangling.clear()
multiprocessing.process._dangling.update(saved)
+ def get_sysconfig__CONFIG_VARS(self):
+ # make sure the dict is initialized
+ sysconfig.get_config_var('prefix')
+ return (id(sysconfig._CONFIG_VARS), sysconfig._CONFIG_VARS,
+ dict(sysconfig._CONFIG_VARS))
+ def restore_sysconfig__CONFIG_VARS(self, saved):
+ sysconfig._CONFIG_VARS = saved[1]
+ sysconfig._CONFIG_VARS.clear()
+ sysconfig._CONFIG_VARS.update(saved[2])
+
+ def get_sysconfig__INSTALL_SCHEMES(self):
+ return (id(sysconfig._INSTALL_SCHEMES), sysconfig._INSTALL_SCHEMES,
+ sysconfig._INSTALL_SCHEMES.copy())
+ def restore_sysconfig__INSTALL_SCHEMES(self, saved):
+ sysconfig._INSTALL_SCHEMES = saved[1]
+ sysconfig._INSTALL_SCHEMES.clear()
+ sysconfig._INSTALL_SCHEMES.update(saved[2])
+
def get_support_TESTFN(self):
if os.path.isfile(support.TESTFN):
result = 'f'
@@ -1108,14 +1212,15 @@ def runtest_inner(test, verbose, quiet,
start_time = time.time()
the_package = __import__(abstest, globals(), locals(), [])
the_module = getattr(the_package, test)
- # Old tests run to completion simply as a side-effect of
- # being imported. For tests based on unittest or doctest,
- # explicitly invoke their test_main() function (if it exists).
- indirect_test = getattr(the_module, "test_main", None)
- if indirect_test is not None:
- indirect_test()
+ # If the test has a test_main, that will run the appropriate
+ # tests. If not, use normal unittest test loading.
+ test_runner = getattr(the_module, "test_main", None)
+ if test_runner is None:
+ tests = unittest.TestLoader().loadTestsFromModule(the_module)
+ test_runner = lambda: support.run_unittest(tests)
+ test_runner()
if huntrleaks:
- refleak = dash_R(the_module, test, indirect_test,
+ refleak = dash_R(the_module, test, test_runner,
huntrleaks)
test_time = time.time() - start_time
except support.ResourceDenied as msg:
@@ -1198,7 +1303,8 @@ def dash_R(the_module, test, indirect_test, huntrleaks):
False if the test didn't leak references; True if we detected refleaks.
"""
# This code is hackish and inelegant, but it seems to do the job.
- import copyreg, _abcoll
+ import copyreg
+ import collections.abc
if not hasattr(sys, 'gettotalrefcount'):
raise Exception("Tracking reference leaks requires a debug build "
@@ -1215,7 +1321,7 @@ def dash_R(the_module, test, indirect_test, huntrleaks):
else:
zdc = zipimport._zip_directory_cache.copy()
abcs = {}
- for abc in [getattr(_abcoll, a) for a in _abcoll.__all__]:
+ for abc in [getattr(collections.abc, a) for a in collections.abc.__all__]:
if not isabstract(abc):
continue
for obj in abc.__subclasses__() + [abc]:
@@ -1261,7 +1367,7 @@ def dash_R_cleanup(fs, ps, pic, zdc, abcs):
import gc, copyreg
import _strptime, linecache
import urllib.parse, urllib.request, mimetypes, doctest
- import struct, filecmp, _abcoll
+ import struct, filecmp, collections.abc
from distutils.dir_util import _path_created
from weakref import WeakSet
@@ -1288,7 +1394,7 @@ def dash_R_cleanup(fs, ps, pic, zdc, abcs):
sys._clear_type_cache()
# Clear ABC registries, restoring previously saved ABC registries.
- for abc in [getattr(_abcoll, a) for a in _abcoll.__all__]:
+ for abc in [getattr(collections.abc, a) for a in collections.abc.__all__]:
if not isabstract(abc):
continue
for obj in abc.__subclasses__() + [abc]:
@@ -1324,10 +1430,15 @@ def dash_R_cleanup(fs, ps, pic, zdc, abcs):
# Collect cyclic trash.
gc.collect()
-def warm_char_cache():
+def warm_caches():
+ # char cache
s = bytes(range(256))
for i in range(256):
s[i:i+1]
+ # unicode cache
+ x = [chr(i) for i in range(256)]
+ # int cache
+ x = list(range(-5, 257))
def findtestdir(path=None):
return path or os.path.dirname(__file__) or os.curdir
@@ -1374,13 +1485,14 @@ def printlist(x, width=70, indent=4):
# Tests that are expected to be skipped everywhere except on one platform
# are also handled separately.
-_expectations = {
- 'win32':
+_expectations = (
+ ('win32',
"""
test__locale
test_crypt
test_curses
test_dbm
+ test_devpoll
test_fcntl
test_fork1
test_epoll
@@ -1403,15 +1515,16 @@ _expectations = {
test_threadsignals
test_wait3
test_wait4
- """,
- 'linux2':
+ """),
+ ('linux',
"""
test_curses
+ test_devpoll
test_largefile
test_kqueue
test_ossaudiodev
- """,
- 'unixware7':
+ """),
+ ('unixware',
"""
test_epoll
test_largefile
@@ -1421,8 +1534,8 @@ _expectations = {
test_pyexpat
test_sax
test_sundry
- """,
- 'openunix8':
+ """),
+ ('openunix',
"""
test_epoll
test_largefile
@@ -1432,8 +1545,8 @@ _expectations = {
test_pyexpat
test_sax
test_sundry
- """,
- 'sco_sv3':
+ """),
+ ('sco_sv',
"""
test_asynchat
test_fork1
@@ -1452,11 +1565,12 @@ _expectations = {
test_threaded_import
test_threadedtempfile
test_threading
- """,
- 'darwin':
+ """),
+ ('darwin',
"""
test__locale
test_curses
+ test_devpoll
test_epoll
test_dbm_gnu
test_gdb
@@ -1465,8 +1579,8 @@ _expectations = {
test_minidom
test_ossaudiodev
test_poll
- """,
- 'sunos5':
+ """),
+ ('sunos',
"""
test_curses
test_dbm
@@ -1477,8 +1591,8 @@ _expectations = {
test_openpty
test_zipfile
test_zlib
- """,
- 'hp-ux11':
+ """),
+ ('hp-ux',
"""
test_curses
test_epoll
@@ -1493,11 +1607,12 @@ _expectations = {
test_sax
test_zipfile
test_zlib
- """,
- 'cygwin':
+ """),
+ ('cygwin',
"""
test_curses
test_dbm
+ test_devpoll
test_epoll
test_ioctl
test_kqueue
@@ -1505,8 +1620,8 @@ _expectations = {
test_locale
test_ossaudiodev
test_socketserver
- """,
- 'os2emx':
+ """),
+ ('os2emx',
"""
test_audioop
test_curses
@@ -1519,9 +1634,10 @@ _expectations = {
test_pty
test_resource
test_signal
- """,
- 'freebsd4':
+ """),
+ ('freebsd',
"""
+ test_devpoll
test_epoll
test_dbm_gnu
test_locale
@@ -1536,8 +1652,8 @@ _expectations = {
test_timeout
test_urllibnet
test_multiprocessing
- """,
- 'aix5':
+ """),
+ ('aix',
"""
test_bz2
test_epoll
@@ -1551,10 +1667,11 @@ _expectations = {
test_ttk_textonly
test_zipimport
test_zlib
- """,
- 'openbsd3':
+ """),
+ ('openbsd',
"""
test_ctypes
+ test_devpoll
test_epoll
test_dbm_gnu
test_locale
@@ -1566,11 +1683,12 @@ _expectations = {
test_ttk_guionly
test_ttk_textonly
test_multiprocessing
- """,
- 'netbsd3':
+ """),
+ ('netbsd',
"""
test_ctypes
test_curses
+ test_devpoll
test_epoll
test_dbm_gnu
test_locale
@@ -1581,12 +1699,8 @@ _expectations = {
test_ttk_guionly
test_ttk_textonly
test_multiprocessing
- """,
-}
-_expectations['freebsd5'] = _expectations['freebsd4']
-_expectations['freebsd6'] = _expectations['freebsd4']
-_expectations['freebsd7'] = _expectations['freebsd4']
-_expectations['freebsd8'] = _expectations['freebsd4']
+ """),
+)
class _ExpectedSkips:
def __init__(self):
@@ -1594,9 +1708,13 @@ class _ExpectedSkips:
from test import test_timeout
self.valid = False
- if sys.platform in _expectations:
- s = _expectations[sys.platform]
- self.expected = set(s.split())
+ expected = None
+ for item in _expectations:
+ if sys.platform.startswith(item[0]):
+ expected = item[1]
+ break
+ if expected is not None:
+ self.expected = set(expected.split())
# These are broken tests, for now skipped on every platform.
# XXX Fix these!
@@ -1656,9 +1774,8 @@ def _make_temp_dir_for_build(TEMPDIR):
TEMPDIR = os.path.abspath(TEMPDIR)
try:
os.mkdir(TEMPDIR)
- except OSError as e:
- if e.errno != errno.EEXIST:
- raise
+ except FileExistsError:
+ pass
# Define a writable temp dir that will be used as cwd while running
# the tests. The name of the dir includes the pid to allow parallel
diff --git a/Lib/test/reperf.py b/Lib/test/reperf.py
index 7c68234..e93bacd 100644
--- a/Lib/test/reperf.py
+++ b/Lib/test/reperf.py
@@ -9,13 +9,13 @@ def main():
timefunc(10, p.findall, s)
def timefunc(n, func, *args, **kw):
- t0 = time.clock()
+ t0 = time.perf_counter()
try:
for i in range(n):
result = func(*args, **kw)
return result
finally:
- t1 = time.clock()
+ t1 = time.perf_counter()
if n > 1:
print(n, "times", end=' ')
print(func.__name__, "%.3f" % (t1-t0), "CPU seconds")
diff --git a/Lib/test/script_helper.py b/Lib/test/script_helper.py
index ba446cd..b09f4bf 100644
--- a/Lib/test/script_helper.py
+++ b/Lib/test/script_helper.py
@@ -1,6 +1,7 @@
# Common utility functions used by various script execution tests
# e.g. test_cmd_line, test_cmd_line_script and test_runpy
+import importlib
import sys
import os
import os.path
@@ -59,11 +60,12 @@ def assert_python_failure(*args, **env_vars):
"""
return _assert_python(False, *args, **env_vars)
-def spawn_python(*args):
+def spawn_python(*args, **kw):
cmd_line = [sys.executable, '-E']
cmd_line.extend(args)
return subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ **kw)
def kill_python(p):
p.stdin.close()
@@ -92,6 +94,7 @@ def make_script(script_dir, script_basename, source):
script_file = open(script_name, 'w', encoding='utf-8')
script_file.write(source)
script_file.close()
+ importlib.invalidate_caches()
return script_name
def make_zip_script(zip_dir, zip_basename, script_name, name_in_zip=None):
diff --git a/Lib/test/seq_tests.py b/Lib/test/seq_tests.py
index f655c29..f185a82 100644
--- a/Lib/test/seq_tests.py
+++ b/Lib/test/seq_tests.py
@@ -4,6 +4,7 @@ Tests common to tuple, list and UserList.UserList
import unittest
import sys
+import pickle
# Various iterables
# This is used for checking the constructor (here and in test_deque.py)
@@ -388,3 +389,9 @@ class CommonTest(unittest.TestCase):
self.assertEqual(a.index(0, -4*sys.maxsize, 4*sys.maxsize), 2)
self.assertRaises(ValueError, a.index, 0, 4*sys.maxsize,-4*sys.maxsize)
self.assertRaises(ValueError, a.index, 2, 0, -10)
+
+ def test_pickle(self):
+ lst = self.type2test([4, 5, 6, 7])
+ lst2 = pickle.loads(pickle.dumps(lst))
+ self.assertEqual(lst2, lst)
+ self.assertNotEqual(id(lst2), id(lst))
diff --git a/Lib/test/sortperf.py b/Lib/test/sortperf.py
index 0ce88de..af7c0b4 100644
--- a/Lib/test/sortperf.py
+++ b/Lib/test/sortperf.py
@@ -57,9 +57,9 @@ def flush():
sys.stdout.flush()
def doit(L):
- t0 = time.clock()
+ t0 = time.perf_counter()
L.sort()
- t1 = time.clock()
+ t1 = time.perf_counter()
print("%6.2f" % (t1-t0), end=' ')
flush()
diff --git a/Lib/test/ssl_key.passwd.pem b/Lib/test/ssl_key.passwd.pem
new file mode 100644
index 0000000..2524672
--- /dev/null
+++ b/Lib/test/ssl_key.passwd.pem
@@ -0,0 +1,18 @@
+-----BEGIN RSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: DES-EDE3-CBC,1A8D9D2A02EC698A
+
+kJYbfZ8L0sfe9Oty3gw0aloNnY5E8fegRfQLZlNoxTl6jNt0nIwI8kDJ36CZgR9c
+u3FDJm/KqrfUoz8vW+qEnWhSG7QPX2wWGPHd4K94Yz/FgrRzZ0DoK7XxXq9gOtVA
+AVGQhnz32p+6WhfGsCr9ArXEwRZrTk/FvzEPaU5fHcoSkrNVAGX8IpSVkSDwEDQr
+Gv17+cfk99UV1OCza6yKHoFkTtrC+PZU71LomBabivS2Oc4B9hYuSR2hF01wTHP+
+YlWNagZOOVtNz4oKK9x9eNQpmfQXQvPPTfusexKIbKfZrMvJoxcm1gfcZ0H/wK6P
+6wmXSG35qMOOztCZNtperjs1wzEBXznyK8QmLcAJBjkfarABJX9vBEzZV0OUKhy+
+noORFwHTllphbmydLhu6ehLUZMHPhzAS5UN7srtpSN81eerDMy0RMUAwA7/PofX1
+94Me85Q8jP0PC9ETdsJcPqLzAPETEYu0ELewKRcrdyWi+tlLFrpE5KT/s5ecbl9l
+7B61U4Kfd1PIXc/siINhU3A3bYK+845YyUArUOnKf1kEox7p1RpD7yFqVT04lRTo
+cibNKATBusXSuBrp2G6GNuhWEOSafWCKJQAzgCYIp6ZTV2khhMUGppc/2H3CF6cO
+zX0KtlPVZC7hLkB6HT8SxYUwF1zqWY7+/XPPdc37MeEZ87Q3UuZwqORLY+Z0hpgt
+L5JXBCoklZhCAaN2GqwFLXtGiRSRFGY7xXIhbDTlE65Wv1WGGgDLMKGE1gOz3yAo
+2jjG1+yAHJUdE69XTFHSqSkvaloA1W03LdMXZ9VuQJ/ySXCie6ABAQ==
+-----END RSA PRIVATE KEY-----
diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py
index 77c0542..8686153 100644
--- a/Lib/test/ssl_servers.py
+++ b/Lib/test/ssl_servers.py
@@ -94,7 +94,12 @@ class StatsRequestHandler(BaseHTTPRequestHandler):
"""Serve a GET request."""
sock = self.rfile.raw._sock
context = sock.context
- body = pprint.pformat(context.session_stats())
+ stats = {
+ 'session_cache': context.session_stats(),
+ 'cipher': sock.cipher(),
+ 'compression': sock.compression(),
+ }
+ body = pprint.pformat(stats)
body = body.encode('utf-8')
self.send_response(200)
self.send_header("Content-type", "text/plain; charset=utf-8")
@@ -172,6 +177,11 @@ if __name__ == "__main__":
action='store_false', help='be less verbose')
parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
action='store_true', help='always return stats page')
+ parser.add_argument('--curve-name', dest='curve_name', type=str,
+ action='store',
+ help='curve name for EC-based Diffie-Hellman')
+ parser.add_argument('--dh', dest='dh_file', type=str, action='store',
+ help='PEM file containing DH parameters')
args = parser.parse_args()
support.verbose = args.verbose
@@ -182,6 +192,10 @@ if __name__ == "__main__":
handler_class.root = os.getcwd()
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
context.load_cert_chain(CERTFILE)
+ if args.curve_name:
+ context.set_ecdh_curve(args.curve_name)
+ if args.dh_file:
+ context.load_dh_params(args.dh_file)
server = HTTPSServer(("", args.port), handler_class, context)
if args.verbose:
diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py
index 6792179..385b039 100644
--- a/Lib/test/string_tests.py
+++ b/Lib/test/string_tests.py
@@ -20,7 +20,7 @@ class BadSeq2(Sequence):
def __init__(self): self.seq = ['a', 'b', 'c']
def __len__(self): return 8
-class BaseTest(unittest.TestCase):
+class BaseTest:
# These tests are for buffers of values (bytes) and not
# specific to character interpretation, used for bytes objects
# and various string implementations
@@ -29,6 +29,11 @@ class BaseTest(unittest.TestCase):
# Change in subclasses to change the behaviour of fixtesttype()
type2test = None
+ # Whether the "contained items" of the container are integers in
+ # range(0, 256) (i.e. bytes, bytearray) or strings of length 1
+ # (str)
+ contains_bytes = False
+
# All tests pass their arguments to the testing methods
# as str objects. fixtesttype() can be used to propagate
# these arguments to the appropriate type
@@ -48,11 +53,12 @@ class BaseTest(unittest.TestCase):
return obj
# check that obj.method(*args) returns result
- def checkequal(self, result, obj, methodname, *args):
+ def checkequal(self, result, obj, methodname, *args, **kwargs):
result = self.fixtype(result)
obj = self.fixtype(obj)
args = self.fixtype(args)
- realresult = getattr(obj, methodname)(*args)
+ kwargs = {k: self.fixtype(v) for k,v in kwargs.items()}
+ realresult = getattr(obj, methodname)(*args, **kwargs)
self.assertEqual(
result,
realresult
@@ -117,7 +123,11 @@ class BaseTest(unittest.TestCase):
self.checkequal(0, '', 'count', 'xx', sys.maxsize, 0)
self.checkraises(TypeError, 'hello', 'count')
- self.checkraises(TypeError, 'hello', 'count', 42)
+
+ if self.contains_bytes:
+ self.checkequal(0, 'hello', 'count', 42)
+ else:
+ self.checkraises(TypeError, 'hello', 'count', 42)
# For a variety of combinations,
# verify that str.count() matches an equivalent function
@@ -163,7 +173,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'find')
- self.checkraises(TypeError, 'hello', 'find', 42)
+
+ if self.contains_bytes:
+ self.checkequal(-1, 'hello', 'find', 42)
+ else:
+ self.checkraises(TypeError, 'hello', 'find', 42)
self.checkequal(0, '', 'find', '')
self.checkequal(-1, '', 'find', '', 1, 1)
@@ -217,7 +231,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'rfind', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'rfind')
- self.checkraises(TypeError, 'hello', 'rfind', 42)
+
+ if self.contains_bytes:
+ self.checkequal(-1, 'hello', 'rfind', 42)
+ else:
+ self.checkraises(TypeError, 'hello', 'rfind', 42)
# For a variety of combinations,
# verify that str.rfind() matches __contains__
@@ -245,6 +263,9 @@ class BaseTest(unittest.TestCase):
# issue 7458
self.checkequal(-1, 'ab', 'rfind', 'xxx', sys.maxsize + 1, 0)
+ # issue #15534
+ self.checkequal(0, '<......\u043c...', "rfind", "<")
+
def test_index(self):
self.checkequal(0, 'abcdefghiabc', 'index', '')
self.checkequal(3, 'abcdefghiabc', 'index', 'def')
@@ -264,7 +285,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'index')
- self.checkraises(TypeError, 'hello', 'index', 42)
+
+ if self.contains_bytes:
+ self.checkraises(ValueError, 'hello', 'index', 42)
+ else:
+ self.checkraises(TypeError, 'hello', 'index', 42)
def test_rindex(self):
self.checkequal(12, 'abcdefghiabc', 'rindex', '')
@@ -286,7 +311,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'rindex', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'rindex')
- self.checkraises(TypeError, 'hello', 'rindex', 42)
+
+ if self.contains_bytes:
+ self.checkraises(ValueError, 'hello', 'rindex', 42)
+ else:
+ self.checkraises(TypeError, 'hello', 'rindex', 42)
def test_lower(self):
self.checkequal('hello', 'HeLLo', 'lower')
@@ -364,6 +393,17 @@ class BaseTest(unittest.TestCase):
self.checkequal(['a']*18 + ['aBLAHa'], ('aBLAH'*20)[:-4],
'split', 'BLAH', 18)
+ # with keyword args
+ self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', sep='|')
+ self.checkequal(['a', 'b|c|d'],
+ 'a|b|c|d', 'split', '|', maxsplit=1)
+ self.checkequal(['a', 'b|c|d'],
+ 'a|b|c|d', 'split', sep='|', maxsplit=1)
+ self.checkequal(['a', 'b|c|d'],
+ 'a|b|c|d', 'split', maxsplit=1, sep='|')
+ self.checkequal(['a', 'b c d'],
+ 'a b c d', 'split', maxsplit=1)
+
# argument type
self.checkraises(TypeError, 'hello', 'split', 42, 42, 42)
@@ -421,6 +461,17 @@ class BaseTest(unittest.TestCase):
self.checkequal(['aBLAHa'] + ['a']*18, ('aBLAH'*20)[:-4],
'rsplit', 'BLAH', 18)
+ # with keyword args
+ self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', sep='|')
+ self.checkequal(['a|b|c', 'd'],
+ 'a|b|c|d', 'rsplit', '|', maxsplit=1)
+ self.checkequal(['a|b|c', 'd'],
+ 'a|b|c|d', 'rsplit', sep='|', maxsplit=1)
+ self.checkequal(['a|b|c', 'd'],
+ 'a|b|c|d', 'rsplit', maxsplit=1, sep='|')
+ self.checkequal(['a b c', 'd'],
+ 'a b c d', 'rsplit', maxsplit=1)
+
# argument type
self.checkraises(TypeError, 'hello', 'rsplit', 42, 42, 42)
@@ -550,6 +601,8 @@ class BaseTest(unittest.TestCase):
EQ("ReyKKjavik", "Reykjavik", "replace", "k", "KK", 1)
EQ("Reykjavik", "Reykjavik", "replace", "k", "KK", 0)
EQ("A----B----C----", "A.B.C.", "replace", ".", "----")
+ # issue #15534
+ EQ('...\u043c......&lt;', '...\u043c......<', "replace", "<", "&lt;")
EQ("Reykjavik", "Reykjavik", "replace", "q", "KK")
@@ -644,7 +697,7 @@ class CommonTest(BaseTest):
# check that titlecased chars are lowered correctly
# \u1ffc is the titlecased char
- self.checkequal('\u1ffc\u1ff3\u1ff3\u1ff3',
+ self.checkequal('\u03a9\u0399\u1ff3\u1ff3\u1ff3',
'\u1ff3\u1ff3\u1ffc\u1ffc', 'capitalize')
# check with cased non-letter chars
self.checkequal('\u24c5\u24e8\u24e3\u24d7\u24de\u24dd',
@@ -909,7 +962,14 @@ class MixinStrUnicodeUserStringTest:
self.checkequal(['abc', 'def', 'ghi'], "abc\ndef\r\nghi\n", 'splitlines')
self.checkequal(['abc', 'def', 'ghi', ''], "abc\ndef\r\nghi\n\r", 'splitlines')
self.checkequal(['', 'abc', 'def', 'ghi', ''], "\nabc\ndef\r\nghi\n\r", 'splitlines')
- self.checkequal(['\n', 'abc\n', 'def\r\n', 'ghi\n', '\r'], "\nabc\ndef\r\nghi\n\r", 'splitlines', 1)
+ self.checkequal(['', 'abc', 'def', 'ghi', ''],
+ "\nabc\ndef\r\nghi\n\r", 'splitlines', False)
+ self.checkequal(['\n', 'abc\n', 'def\r\n', 'ghi\n', '\r'],
+ "\nabc\ndef\r\nghi\n\r", 'splitlines', True)
+ self.checkequal(['', 'abc', 'def', 'ghi', ''], "\nabc\ndef\r\nghi\n\r",
+ 'splitlines', keepends=False)
+ self.checkequal(['\n', 'abc\n', 'def\r\n', 'ghi\n', '\r'],
+ "\nabc\ndef\r\nghi\n\r", 'splitlines', keepends=True)
self.checkraises(TypeError, 'abc', 'splitlines', 42, 42)
@@ -1143,6 +1203,10 @@ class MixinStrUnicodeUserStringTest:
self.checkraises(TypeError, '%10.*f', '__mod__', ('foo', 42.))
self.checkraises(ValueError, '%10', '__mod__', (42,))
+ # Outrageously large width or precision should raise ValueError.
+ self.checkraises(ValueError, '%%%df' % (2**64), '__mod__', (3.2))
+ self.checkraises(ValueError, '%%.%df' % (2**64), '__mod__', (3.2))
+
self.checkraises(OverflowError, '%*s', '__mod__',
(_testcapi.PY_SSIZE_T_MAX + 1, ''))
self.checkraises(OverflowError, '%.*f', '__mod__',
@@ -1271,6 +1335,9 @@ class MixinStrUnicodeUserStringTest:
self.assertRaisesRegex(TypeError, r'^endswith\(', s.endswith,
x, None, None, None)
+ # issue #15534
+ self.checkequal(10, "...\u043c......<", "find", "<")
+
class MixinStrUnicodeTest:
# Additional tests that only work with str and unicode.
diff --git a/Lib/test/support.py b/Lib/test/support.py
index 66ddf4b..d89e172 100644
--- a/Lib/test/support.py
+++ b/Lib/test/support.py
@@ -15,7 +15,7 @@ import shutil
import warnings
import unittest
import importlib
-import collections
+import collections.abc
import re
import subprocess
import imp
@@ -37,26 +37,41 @@ try:
except ImportError:
multiprocessing = None
+try:
+ import zlib
+except ImportError:
+ zlib = None
+
+try:
+ import bz2
+except ImportError:
+ bz2 = None
+
+try:
+ import lzma
+except ImportError:
+ lzma = None
__all__ = [
- "Error", "TestFailed", "ResourceDenied", "import_module",
- "verbose", "use_resources", "max_memuse", "record_original_stdout",
+ "Error", "TestFailed", "ResourceDenied", "import_module", "verbose",
+ "use_resources", "max_memuse", "record_original_stdout",
"get_original_stdout", "unload", "unlink", "rmtree", "forget",
- "is_resource_enabled", "requires", "requires_mac_ver",
- "find_unused_port", "bind_port",
- "fcmp", "is_jython", "TESTFN", "HOST", "FUZZ", "SAVEDCWD", "temp_cwd",
- "findfile", "sortdict", "check_syntax_error", "open_urlresource",
- "check_warnings", "CleanImport", "EnvironmentVarGuard",
- "TransientResource", "captured_output", "captured_stdout",
- "captured_stdin", "captured_stderr",
- "time_out", "socket_peer_reset", "ioerror_peer_reset",
- "run_with_locale", 'temp_umask', "transient_internet",
- "set_memlimit", "bigmemtest", "bigaddrspacetest", "BasicTestRunner",
- "run_unittest", "run_doctest", "threading_setup", "threading_cleanup",
- "reap_children", "cpython_only", "check_impl_detail", "get_attribute",
- "swap_item", "swap_attr", "requires_IEEE_754",
+ "is_resource_enabled", "requires", "requires_freebsd_version",
+ "requires_linux_version", "requires_mac_ver", "find_unused_port",
+ "bind_port", "IPV6_ENABLED", "is_jython", "TESTFN", "HOST", "SAVEDCWD",
+ "temp_cwd", "findfile", "create_empty_file", "sortdict",
+ "check_syntax_error", "open_urlresource", "check_warnings", "CleanImport",
+ "EnvironmentVarGuard", "TransientResource", "captured_stdout",
+ "captured_stdin", "captured_stderr", "time_out", "socket_peer_reset",
+ "ioerror_peer_reset", "run_with_locale", 'temp_umask',
+ "transient_internet", "set_memlimit", "bigmemtest", "bigaddrspacetest",
+ "BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
+ "threading_cleanup", "reap_children", "cpython_only", "check_impl_detail",
+ "get_attribute", "swap_item", "swap_attr", "requires_IEEE_754",
"TestHandler", "Matcher", "can_symlink", "skip_unless_symlink",
- "import_fresh_module", "failfast", "run_with_tz"
+ "skip_unless_xattr", "import_fresh_module", "requires_zlib",
+ "PIPE_MAX_SIZE", "failfast", "anticipate_failure", "run_with_tz",
+ "requires_bz2", "requires_lzma"
]
class Error(Exception):
@@ -127,6 +142,17 @@ def _save_and_block_module(name, orig_modules):
return saved
+def anticipate_failure(condition):
+ """Decorator to mark a test that is known to be broken in some cases
+
+ Any use of this decorator should have a comment identifying the
+ associated tracker issue.
+ """
+ if condition:
+ return unittest.expectedFailure
+ return lambda f: f
+
+
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
"""Imports and returns a module, deliberately bypassing the sys.modules cache
and importing a fresh copy of the module. Once the import is complete,
@@ -170,8 +196,7 @@ def get_attribute(obj, name):
try:
attribute = getattr(obj, name)
except AttributeError:
- raise unittest.SkipTest("module %s has no attribute %s" % (
- obj.__name__, name))
+ raise unittest.SkipTest("object %r has no attribute %r" % (obj, name))
else:
return attribute
@@ -276,8 +301,7 @@ def rmtree(path):
try:
_rmtree(path)
except OSError as error:
- # Unix returns ENOENT, Windows returns ESRCH.
- if error.errno not in (errno.ENOENT, errno.ESRCH):
+ if error.errno != errno.ENOENT:
raise
def make_legacy_pyc(source):
@@ -362,9 +386,52 @@ def requires(resource, msg=None):
return
if not is_resource_enabled(resource):
if msg is None:
- msg = "Use of the `%s' resource not enabled" % resource
+ msg = "Use of the %r resource not enabled" % resource
raise ResourceDenied(msg)
+def _requires_unix_version(sysname, min_version):
+ """Decorator raising SkipTest if the OS is `sysname` and the version is less
+ than `min_version`.
+
+ For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if
+ the FreeBSD version is less than 7.2.
+ """
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kw):
+ if platform.system() == sysname:
+ version_txt = platform.release().split('-', 1)[0]
+ try:
+ version = tuple(map(int, version_txt.split('.')))
+ except ValueError:
+ pass
+ else:
+ if version < min_version:
+ min_version_txt = '.'.join(map(str, min_version))
+ raise unittest.SkipTest(
+ "%s version %s or higher required, not %s"
+ % (sysname, min_version_txt, version_txt))
+ return wrapper
+ return decorator
+
+def requires_freebsd_version(*min_version):
+ """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version is
+ less than `min_version`.
+
+ For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD
+ version is less than 7.2.
+ """
+ return _requires_unix_version('FreeBSD', min_version)
+
+def requires_linux_version(*min_version):
+ """Decorator raising SkipTest if the OS is Linux and the Linux version is
+ less than `min_version`.
+
+ For example, @requires_linux_version(2, 6, 32) raises SkipTest if the Linux
+ version is less than 2.6.32.
+ """
+ return _requires_unix_version('Linux', min_version)
+
def requires_mac_ver(*min_version):
"""Decorator raising SkipTest if the OS is Mac OS X and the OS X
version if less than min_version.
@@ -392,6 +459,7 @@ def requires_mac_ver(*min_version):
return wrapper
return decorator
+
HOST = 'localhost'
def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
@@ -487,29 +555,41 @@ def bind_port(sock, host=HOST):
port = sock.getsockname()[1]
return port
-FUZZ = 1e-6
-
-def fcmp(x, y): # fuzzy comparison function
- if isinstance(x, float) or isinstance(y, float):
+def _is_ipv6_enabled():
+ """Check whether IPv6 is enabled on this host."""
+ if socket.has_ipv6:
+ sock = None
try:
- fuzz = (abs(x) + abs(y)) * FUZZ
- if abs(x-y) <= fuzz:
- return 0
- except:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind(('::1', 0))
+ return True
+ except (socket.error, socket.gaierror):
pass
- elif type(x) == type(y) and isinstance(x, (tuple, list)):
- for i in range(min(len(x), len(y))):
- outcome = fcmp(x[i], y[i])
- if outcome != 0:
- return outcome
- return (len(x) > len(y)) - (len(x) < len(y))
- return (x > y) - (x < y)
+ finally:
+ if sock:
+ sock.close()
+ return False
+
+IPV6_ENABLED = _is_ipv6_enabled()
+
+
+# A constant likely larger than the underlying OS pipe buffer size.
+# Windows limit seems to be around 512B, and most Unix kernels have a 64K pipe
+# buffer size: take 1M to be sure.
+PIPE_MAX_SIZE = 1024 * 1024
+
# decorator for skipping tests on non-IEEE 754 platforms
requires_IEEE_754 = unittest.skipUnless(
float.__getformat__("double").startswith("IEEE"),
"test requires IEEE 754 doubles")
+requires_zlib = unittest.skipUnless(zlib, 'requires zlib')
+
+requires_bz2 = unittest.skipUnless(bz2, 'requires bz2')
+
+requires_lzma = unittest.skipUnless(lzma, 'requires lzma')
+
is_jython = sys.platform.startswith('java')
# Filename used for testing
@@ -688,14 +768,15 @@ def temp_cwd(name='tempcwd', quiet=False, path=None):
rmtree(name)
-@contextlib.contextmanager
-def temp_umask(umask):
- """Context manager that temporarily sets the process umask."""
- oldmask = os.umask(umask)
- try:
- yield
- finally:
- os.umask(oldmask)
+if hasattr(os, "umask"):
+ @contextlib.contextmanager
+ def temp_umask(umask):
+ """Context manager that temporarily sets the process umask."""
+ oldmask = os.umask(umask)
+ try:
+ yield
+ finally:
+ os.umask(oldmask)
def findfile(file, here=__file__, subdir=None):
@@ -713,6 +794,11 @@ def findfile(file, here=__file__, subdir=None):
if os.path.exists(fn): return fn
return file
+def create_empty_file(filename):
+ """Create an empty file. If the file already exists, truncate it."""
+ fd = os.open(filename, os.O_WRONLY | os.O_CREAT | os.O_TRUNC)
+ os.close(fd)
+
def sortdict(dict):
"Like repr(dict), but in sorted order."
items = sorted(dict.items())
@@ -777,7 +863,7 @@ def open_urlresource(url, *args, **kw):
f = check_valid_file(fn)
if f is not None:
return f
- raise TestFailed('invalid resource "%s"' % fn)
+ raise TestFailed('invalid resource %r' % fn)
class WarningsRecorder(object):
@@ -898,7 +984,7 @@ class CleanImport(object):
sys.modules.update(self.original_modules)
-class EnvironmentVarGuard(collections.MutableMapping):
+class EnvironmentVarGuard(collections.abc.MutableMapping):
"""Class to help protect the environment variable properly. Can be used as
a context manager."""
@@ -1029,7 +1115,7 @@ def transient_internet(resource_name, *, timeout=30.0, errnos=()):
('WSANO_DATA', 11004),
]
- denied = ResourceDenied("Resource '%s' is not available" % resource_name)
+ denied = ResourceDenied("Resource %r is not available" % resource_name)
captured_errnos = errnos
gai_errnos = []
if not captured_errnos:
@@ -1118,6 +1204,16 @@ def gc_collect():
gc.collect()
gc.collect()
+@contextlib.contextmanager
+def disable_gc():
+ have_gc = gc.isenabled()
+ gc.disable()
+ try:
+ yield
+ finally:
+ if have_gc:
+ gc.enable()
+
def python_is_optimized():
"""Find if Python was built with optimizations."""
@@ -1126,19 +1222,21 @@ def python_is_optimized():
for opt in cflags.split():
if opt.startswith('-O'):
final_opt = opt
- return final_opt and final_opt != '-O0'
+ return final_opt != '' and final_opt != '-O0'
-_header = '2P'
+_header = 'nP'
+_align = '0n'
if hasattr(sys, "gettotalrefcount"):
_header = '2P' + _header
-_vheader = _header + 'P'
+ _align = '0P'
+_vheader = _header + 'n'
def calcobjsize(fmt):
- return struct.calcsize(_header + fmt + '0P')
+ return struct.calcsize(_header + fmt + _align)
def calcvobjsize(fmt):
- return struct.calcsize(_vheader + fmt + '0P')
+ return struct.calcsize(_vheader + fmt + _align)
_TPFLAGS_HAVE_GC = 1<<14
@@ -1212,7 +1310,7 @@ def run_with_tz(tz):
try:
return func(*args, **kwds)
finally:
- if orig_tz == None:
+ if orig_tz is None:
del os.environ['TZ']
else:
os.environ['TZ'] = orig_tz
@@ -1257,41 +1355,35 @@ def set_memlimit(limit):
raise ValueError('Memory limit %r too low to be useful' % (limit,))
max_memuse = memlimit
-def _memory_watchdog(start_evt, finish_evt, period=10.0):
- """A function which periodically watches the process' memory consumption
+class _MemoryWatchdog:
+ """An object which periodically watches the process' memory consumption
and prints it out.
"""
- # XXX: because of the GIL, and because the very long operations tested
- # in most bigmem tests are uninterruptible, the loop below gets woken up
- # much less often than expected.
- # The polling code should be rewritten in raw C, without holding the GIL,
- # and push results onto an anonymous pipe.
- try:
- page_size = os.sysconf('SC_PAGESIZE')
- except (ValueError, AttributeError):
+
+ def __init__(self):
+ self.procfile = '/proc/{pid}/statm'.format(pid=os.getpid())
+ self.started = False
+
+ def start(self):
try:
- page_size = os.sysconf('SC_PAGE_SIZE')
- except (ValueError, AttributeError):
- page_size = 4096
- procfile = '/proc/{pid}/statm'.format(pid=os.getpid())
- try:
- f = open(procfile, 'rb')
- except IOError as e:
- warnings.warn('/proc not available for stats: {}'.format(e),
- RuntimeWarning)
- sys.stderr.flush()
- return
- with f:
- start_evt.set()
- old_data = -1
- while not finish_evt.wait(period):
- f.seek(0)
- statm = f.read().decode('ascii')
- data = int(statm.split()[5])
- if data != old_data:
- old_data = data
- print(" ... process data size: {data:.1f}G"
- .format(data=data * page_size / (1024 ** 3)))
+ f = open(self.procfile, 'r')
+ except OSError as e:
+ warnings.warn('/proc not available for stats: {}'.format(e),
+ RuntimeWarning)
+ sys.stderr.flush()
+ return
+
+ watchdog_script = findfile("memory_watchdog.py")
+ self.mem_watchdog = subprocess.Popen([sys.executable, watchdog_script],
+ stdin=f, stderr=subprocess.DEVNULL)
+ f.close()
+ self.started = True
+
+ def stop(self):
+ if self.started:
+ self.mem_watchdog.terminate()
+ self.mem_watchdog.wait()
+
def bigmemtest(size, memuse, dry_run=True):
"""Decorator for bigmem tests.
@@ -1318,27 +1410,20 @@ def bigmemtest(size, memuse, dry_run=True):
"not enough memory: %.1fG minimum needed"
% (size * memuse / (1024 ** 3)))
- if real_max_memuse and verbose and threading:
+ if real_max_memuse and verbose:
print()
print(" ... expected peak memory use: {peak:.1f}G"
.format(peak=size * memuse / (1024 ** 3)))
- sys.stdout.flush()
- start_evt = threading.Event()
- finish_evt = threading.Event()
- t = threading.Thread(target=_memory_watchdog,
- args=(start_evt, finish_evt, 0.5))
- t.daemon = True
- t.start()
- start_evt.set()
+ watchdog = _MemoryWatchdog()
+ watchdog.start()
else:
- t = None
+ watchdog = None
try:
return f(self, maxsize)
finally:
- if t:
- finish_evt.set()
- t.join()
+ if watchdog:
+ watchdog.stop()
wrapper.size = size
wrapper.memuse = memuse
@@ -1420,6 +1505,33 @@ def check_impl_detail(**guards):
return guards.get(platform.python_implementation().lower(), default)
+def no_tracing(func):
+ """Decorator to temporarily turn off tracing for the duration of a test."""
+ if not hasattr(sys, 'gettrace'):
+ return func
+ else:
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ original_trace = sys.gettrace()
+ try:
+ sys.settrace(None)
+ return func(*args, **kwargs)
+ finally:
+ sys.settrace(original_trace)
+ return wrapper
+
+
+def refcount_test(test):
+ """Decorator for tests which involve reference counting.
+
+ To start, the decorator does not run the test if is not run by CPython.
+ After that, any trace function is unset during the test to prevent
+ unexpected refcounts caused by the trace function.
+
+ """
+ return no_tracing(cpython_only(test))
+
+
def _filter_suite(suite, pred):
"""Recursively filter test cases in a suite based on a predicate."""
newtests = []
@@ -1432,7 +1544,6 @@ def _filter_suite(suite, pred):
newtests.append(test)
suite._tests = newtests
-
def _run_suite(suite):
"""Run tests from a unittest.TestSuite-derived class."""
if verbose:
@@ -1491,7 +1602,7 @@ requires_docstrings = unittest.skipUnless(HAVE_DOCSTRINGS,
#=======================================================================
# doctest driver.
-def run_doctest(module, verbosity=None):
+def run_doctest(module, verbosity=None, optionflags=0):
"""Run doctest on the given module. Return (#failures, #tests).
If optional argument verbosity is not specified (or is None), pass
@@ -1506,7 +1617,7 @@ def run_doctest(module, verbosity=None):
else:
verbosity = None
- f, t = doctest.testmod(module, verbose=verbosity)
+ f, t = doctest.testmod(module, verbose=verbosity, optionflags=optionflags)
if f:
raise TestFailed("%d of %d doctests failed" % (f, t))
if verbose:
@@ -1664,28 +1775,13 @@ def strip_python_stderr(stderr):
This will typically be run on the result of the communicate() method
of a subprocess.Popen object.
"""
- stderr = re.sub(br"\[\d+ refs\]\r?\n?$", b"", stderr).strip()
+ stderr = re.sub(br"\[\d+ refs\]\r?\n?", b"", stderr).strip()
return stderr
def args_from_interpreter_flags():
"""Return a list of command-line arguments reproducing the current
- settings in sys.flags."""
- flag_opt_map = {
- 'bytes_warning': 'b',
- 'dont_write_bytecode': 'B',
- 'hash_randomization': 'R',
- 'ignore_environment': 'E',
- 'no_user_site': 's',
- 'no_site': 'S',
- 'optimize': 'O',
- 'verbose': 'v',
- }
- args = []
- for flag, opt in flag_opt_map.items():
- v = getattr(sys.flags, flag)
- if v > 0:
- args.append('-' + opt * v)
- return args
+ settings in sys.flags and sys.warnoptions."""
+ return subprocess._args_from_interpreter_flags()
#============================================================
# Support for assertions about logging.
@@ -1775,6 +1871,40 @@ def skip_unless_symlink(test):
msg = "Requires functional symlink implementation"
return test if ok else unittest.skip(msg)(test)
+_can_xattr = None
+def can_xattr():
+ global _can_xattr
+ if _can_xattr is not None:
+ return _can_xattr
+ if not hasattr(os, "setxattr"):
+ can = False
+ else:
+ tmp_fp, tmp_name = tempfile.mkstemp()
+ try:
+ with open(TESTFN, "wb") as fp:
+ try:
+ # TESTFN & tempfile may use different file systems with
+ # different capabilities
+ os.setxattr(tmp_fp, b"user.test", b"")
+ os.setxattr(fp.fileno(), b"user.test", b"")
+ # Kernels < 2.6.39 don't respect setxattr flags.
+ kernel_version = platform.release()
+ m = re.match("2.6.(\d{1,2})", kernel_version)
+ can = m is None or int(m.group(1)) >= 39
+ except OSError:
+ can = False
+ finally:
+ unlink(TESTFN)
+ unlink(tmp_name)
+ _can_xattr = can
+ return can
+
+def skip_unless_xattr(test):
+ """Skip decorator for tests that require functional extended attributes"""
+ ok = can_xattr()
+ msg = "no non-broken extended attribute support"
+ return test if ok else unittest.skip(msg)(test)
+
def patch(test_instance, object_to_patch, attr_name, new_value):
"""Override 'object_to_patch'.'attr_name' with 'new_value'.
diff --git a/Lib/test/test__locale.py b/Lib/test/test__locale.py
index 3fadb57..4231f37 100644
--- a/Lib/test/test__locale.py
+++ b/Lib/test/test__locale.py
@@ -1,23 +1,25 @@
-from test.support import run_unittest
from _locale import (setlocale, LC_ALL, LC_CTYPE, LC_NUMERIC, localeconv, Error)
try:
from _locale import (RADIXCHAR, THOUSEP, nl_langinfo)
except ImportError:
nl_langinfo = None
-import unittest
+import codecs
+import locale
import sys
+import unittest
from platform import uname
+from test.support import run_unittest
-if uname()[0] == "Darwin":
- maj, min, mic = [int(part) for part in uname()[2].split(".")]
+if uname().system == "Darwin":
+ maj, min, mic = [int(part) for part in uname().release.split(".")]
if (maj, min, mic) < (8, 0, 0):
raise unittest.SkipTest("locale support broken for OS X < 10.4")
candidate_locales = ['es_UY', 'fr_FR', 'fi_FI', 'es_CO', 'pt_PT', 'it_IT',
'et_EE', 'es_PY', 'no_NO', 'nl_NL', 'lv_LV', 'el_GR', 'be_BY', 'fr_BE',
'ro_RO', 'ru_UA', 'ru_RU', 'es_VE', 'ca_ES', 'se_NO', 'es_EC', 'id_ID',
- 'ka_GE', 'es_CL', 'hu_HU', 'wa_BE', 'lt_LT', 'sl_SI', 'hr_HR', 'es_AR',
+ 'ka_GE', 'es_CL', 'wa_BE', 'hu_HU', 'lt_LT', 'sl_SI', 'hr_HR', 'es_AR',
'es_ES', 'oc_FR', 'gl_ES', 'bg_BG', 'is_IS', 'mk_MK', 'de_AT', 'pt_BR',
'da_DK', 'nn_NO', 'cs_CZ', 'de_LU', 'es_BO', 'sq_AL', 'sk_SK', 'fr_CH',
'de_DE', 'sr_YU', 'br_FR', 'nl_BE', 'sv_FI', 'pl_PL', 'fr_CA', 'fo_FO',
@@ -25,6 +27,31 @@ candidate_locales = ['es_UY', 'fr_FR', 'fi_FI', 'es_CO', 'pt_PT', 'it_IT',
'eu_ES', 'vi_VN', 'af_ZA', 'nb_NO', 'en_DK', 'tg_TJ', 'en_US',
'es_ES.ISO8859-1', 'fr_FR.ISO8859-15', 'ru_RU.KOI8-R', 'ko_KR.eucKR']
+# Issue #13441: Skip some locales (e.g. cs_CZ and hu_HU) on Solaris to
+# workaround a mbstowcs() bug. For example, on Solaris, the hu_HU locale uses
+# the locale encoding ISO-8859-2, the thousauds separator is b'\xA0' and it is
+# decoded as U+30000020 (an invalid character) by mbstowcs().
+if sys.platform == 'sunos5':
+ old_locale = locale.setlocale(locale.LC_ALL)
+ try:
+ locales = []
+ for loc in candidate_locales:
+ try:
+ locale.setlocale(locale.LC_ALL, loc)
+ except Error:
+ continue
+ encoding = locale.getpreferredencoding(False)
+ try:
+ localeconv()
+ except Exception as err:
+ print("WARNING: Skip locale %s (encoding %s): [%s] %s"
+ % (loc, encoding, type(err), err))
+ else:
+ locales.append(loc)
+ candidate_locales = locales
+ finally:
+ locale.setlocale(locale.LC_ALL, old_locale)
+
# Workaround for MSVC6(debug) crash bug
if "MSC v.1200" in sys.version:
def accept(loc):
@@ -86,9 +113,10 @@ class _LocaleTests(unittest.TestCase):
setlocale(LC_CTYPE, loc)
except Error:
continue
+ formatting = localeconv()
for lc in ("decimal_point",
"thousands_sep"):
- self.numeric_tester('localeconv', localeconv()[lc], lc, loc)
+ self.numeric_tester('localeconv', formatting[lc], lc, loc)
@unittest.skipUnless(nl_langinfo, "nl_langinfo is not available")
def test_lc_numeric_basic(self):
diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py
index d86f97c..653c957 100644
--- a/Lib/test/test_abc.py
+++ b/Lib/test/test_abc.py
@@ -10,14 +10,7 @@ import abc
from inspect import isabstract
-class TestABC(unittest.TestCase):
-
- def test_abstractmethod_basics(self):
- @abc.abstractmethod
- def foo(self): pass
- self.assertTrue(foo.__isabstractmethod__)
- def bar(self): pass
- self.assertFalse(hasattr(bar, "__isabstractmethod__"))
+class TestLegacyAPI(unittest.TestCase):
def test_abstractproperty_basics(self):
@abc.abstractproperty
@@ -29,10 +22,12 @@ class TestABC(unittest.TestCase):
class C(metaclass=abc.ABCMeta):
@abc.abstractproperty
def foo(self): return 3
+ self.assertRaises(TypeError, C)
class D(C):
@property
def foo(self): return super().foo
self.assertEqual(D().foo, 3)
+ self.assertFalse(getattr(D.foo, "__isabstractmethod__", False))
def test_abstractclassmethod_basics(self):
@abc.abstractclassmethod
@@ -40,7 +35,7 @@ class TestABC(unittest.TestCase):
self.assertTrue(foo.__isabstractmethod__)
@classmethod
def bar(cls): pass
- self.assertFalse(hasattr(bar, "__isabstractmethod__"))
+ self.assertFalse(getattr(bar, "__isabstractmethod__", False))
class C(metaclass=abc.ABCMeta):
@abc.abstractclassmethod
@@ -58,7 +53,7 @@ class TestABC(unittest.TestCase):
self.assertTrue(foo.__isabstractmethod__)
@staticmethod
def bar(): pass
- self.assertFalse(hasattr(bar, "__isabstractmethod__"))
+ self.assertFalse(getattr(bar, "__isabstractmethod__", False))
class C(metaclass=abc.ABCMeta):
@abc.abstractstaticmethod
@@ -98,6 +93,163 @@ class TestABC(unittest.TestCase):
self.assertRaises(TypeError, F) # because bar is abstract now
self.assertTrue(isabstract(F))
+
+class TestABC(unittest.TestCase):
+
+ def test_abstractmethod_basics(self):
+ @abc.abstractmethod
+ def foo(self): pass
+ self.assertTrue(foo.__isabstractmethod__)
+ def bar(self): pass
+ self.assertFalse(hasattr(bar, "__isabstractmethod__"))
+
+ def test_abstractproperty_basics(self):
+ @property
+ @abc.abstractmethod
+ def foo(self): pass
+ self.assertTrue(foo.__isabstractmethod__)
+ def bar(self): pass
+ self.assertFalse(getattr(bar, "__isabstractmethod__", False))
+
+ class C(metaclass=abc.ABCMeta):
+ @property
+ @abc.abstractmethod
+ def foo(self): return 3
+ self.assertRaises(TypeError, C)
+ class D(C):
+ @C.foo.getter
+ def foo(self): return super().foo
+ self.assertEqual(D().foo, 3)
+
+ def test_abstractclassmethod_basics(self):
+ @classmethod
+ @abc.abstractmethod
+ def foo(cls): pass
+ self.assertTrue(foo.__isabstractmethod__)
+ @classmethod
+ def bar(cls): pass
+ self.assertFalse(getattr(bar, "__isabstractmethod__", False))
+
+ class C(metaclass=abc.ABCMeta):
+ @classmethod
+ @abc.abstractmethod
+ def foo(cls): return cls.__name__
+ self.assertRaises(TypeError, C)
+ class D(C):
+ @classmethod
+ def foo(cls): return super().foo()
+ self.assertEqual(D.foo(), 'D')
+ self.assertEqual(D().foo(), 'D')
+
+ def test_abstractstaticmethod_basics(self):
+ @staticmethod
+ @abc.abstractmethod
+ def foo(): pass
+ self.assertTrue(foo.__isabstractmethod__)
+ @staticmethod
+ def bar(): pass
+ self.assertFalse(getattr(bar, "__isabstractmethod__", False))
+
+ class C(metaclass=abc.ABCMeta):
+ @staticmethod
+ @abc.abstractmethod
+ def foo(): return 3
+ self.assertRaises(TypeError, C)
+ class D(C):
+ @staticmethod
+ def foo(): return 4
+ self.assertEqual(D.foo(), 4)
+ self.assertEqual(D().foo(), 4)
+
+ def test_abstractmethod_integration(self):
+ for abstractthing in [abc.abstractmethod, abc.abstractproperty,
+ abc.abstractclassmethod,
+ abc.abstractstaticmethod]:
+ class C(metaclass=abc.ABCMeta):
+ @abstractthing
+ def foo(self): pass # abstract
+ def bar(self): pass # concrete
+ self.assertEqual(C.__abstractmethods__, {"foo"})
+ self.assertRaises(TypeError, C) # because foo is abstract
+ self.assertTrue(isabstract(C))
+ class D(C):
+ def bar(self): pass # concrete override of concrete
+ self.assertEqual(D.__abstractmethods__, {"foo"})
+ self.assertRaises(TypeError, D) # because foo is still abstract
+ self.assertTrue(isabstract(D))
+ class E(D):
+ def foo(self): pass
+ self.assertEqual(E.__abstractmethods__, set())
+ E() # now foo is concrete, too
+ self.assertFalse(isabstract(E))
+ class F(E):
+ @abstractthing
+ def bar(self): pass # abstract override of concrete
+ self.assertEqual(F.__abstractmethods__, {"bar"})
+ self.assertRaises(TypeError, F) # because bar is abstract now
+ self.assertTrue(isabstract(F))
+
+ def test_descriptors_with_abstractmethod(self):
+ class C(metaclass=abc.ABCMeta):
+ @property
+ @abc.abstractmethod
+ def foo(self): return 3
+ @foo.setter
+ @abc.abstractmethod
+ def foo(self, val): pass
+ self.assertRaises(TypeError, C)
+ class D(C):
+ @C.foo.getter
+ def foo(self): return super().foo
+ self.assertRaises(TypeError, D)
+ class E(D):
+ @D.foo.setter
+ def foo(self, val): pass
+ self.assertEqual(E().foo, 3)
+ # check that the property's __isabstractmethod__ descriptor does the
+ # right thing when presented with a value that fails truth testing:
+ class NotBool(object):
+ def __nonzero__(self):
+ raise ValueError()
+ __len__ = __nonzero__
+ with self.assertRaises(ValueError):
+ class F(C):
+ def bar(self):
+ pass
+ bar.__isabstractmethod__ = NotBool()
+ foo = property(bar)
+
+
+ def test_customdescriptors_with_abstractmethod(self):
+ class Descriptor:
+ def __init__(self, fget, fset=None):
+ self._fget = fget
+ self._fset = fset
+ def getter(self, callable):
+ return Descriptor(callable, self._fget)
+ def setter(self, callable):
+ return Descriptor(self._fget, callable)
+ @property
+ def __isabstractmethod__(self):
+ return (getattr(self._fget, '__isabstractmethod__', False)
+ or getattr(self._fset, '__isabstractmethod__', False))
+ class C(metaclass=abc.ABCMeta):
+ @Descriptor
+ @abc.abstractmethod
+ def foo(self): return 3
+ @foo.setter
+ @abc.abstractmethod
+ def foo(self, val): pass
+ self.assertRaises(TypeError, C)
+ class D(C):
+ @C.foo.getter
+ def foo(self): return super().foo
+ self.assertRaises(TypeError, D)
+ class E(D):
+ @D.foo.setter
+ def foo(self, val): pass
+ self.assertFalse(E.foo.__isabstractmethod__)
+
def test_metaclass_abc(self):
# Metaclasses can be ABCs, too.
class A(metaclass=abc.ABCMeta):
@@ -121,11 +273,32 @@ class TestABC(unittest.TestCase):
self.assertFalse(issubclass(B, (A,)))
self.assertNotIsInstance(b, A)
self.assertNotIsInstance(b, (A,))
- A.register(B)
+ B1 = A.register(B)
+ self.assertTrue(issubclass(B, A))
+ self.assertTrue(issubclass(B, (A,)))
+ self.assertIsInstance(b, A)
+ self.assertIsInstance(b, (A,))
+ self.assertIs(B1, B)
+ class C(B):
+ pass
+ c = C()
+ self.assertTrue(issubclass(C, A))
+ self.assertTrue(issubclass(C, (A,)))
+ self.assertIsInstance(c, A)
+ self.assertIsInstance(c, (A,))
+
+ def test_register_as_class_deco(self):
+ class A(metaclass=abc.ABCMeta):
+ pass
+ @A.register
+ class B(object):
+ pass
+ b = B()
self.assertTrue(issubclass(B, A))
self.assertTrue(issubclass(B, (A,)))
self.assertIsInstance(b, A)
self.assertIsInstance(b, (A,))
+ @A.register
class C(B):
pass
c = C()
@@ -133,6 +306,7 @@ class TestABC(unittest.TestCase):
self.assertTrue(issubclass(C, (A,)))
self.assertIsInstance(c, A)
self.assertIsInstance(c, (A,))
+ self.assertIs(C, A.register(C))
def test_isinstance_invalidation(self):
class A(metaclass=abc.ABCMeta):
diff --git a/Lib/test/test_abstract_numbers.py b/Lib/test/test_abstract_numbers.py
index 2a396cd..253e6f0 100644
--- a/Lib/test/test_abstract_numbers.py
+++ b/Lib/test/test_abstract_numbers.py
@@ -14,6 +14,7 @@ class TestNumbers(unittest.TestCase):
self.assertEqual(7, int(7).real)
self.assertEqual(0, int(7).imag)
self.assertEqual(7, int(7).conjugate())
+ self.assertEqual(-7, int(-7).conjugate())
self.assertEqual(7, int(7).numerator)
self.assertEqual(1, int(7).denominator)
@@ -24,6 +25,7 @@ class TestNumbers(unittest.TestCase):
self.assertEqual(7.3, float(7.3).real)
self.assertEqual(0, float(7.3).imag)
self.assertEqual(7.3, float(7.3).conjugate())
+ self.assertEqual(-7.3, float(-7.3).conjugate())
def test_complex(self):
self.assertFalse(issubclass(complex, Real))
diff --git a/Lib/test/test_aifc.py b/Lib/test/test_aifc.py
index 0b19af6..9c0e7b9 100644
--- a/Lib/test/test_aifc.py
+++ b/Lib/test/test_aifc.py
@@ -1,4 +1,4 @@
-from test.support import findfile, run_unittest, TESTFN, captured_stdout, unlink
+from test.support import findfile, run_unittest, TESTFN, unlink
import unittest
import os
import io
@@ -214,11 +214,8 @@ class AIFCLowLevelTest(unittest.TestCase):
b += b'COMM' + struct.pack('>LhlhhLL', 18, 0, 0, 0, 0, 0, 0)
b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
b += b'MARK' + struct.pack('>LhB', 3, 1, 1)
- with captured_stdout() as s:
+ with self.assertWarns(UserWarning):
f = aifc.open(io.BytesIO(b))
- self.assertEqual(
- s.getvalue(),
- 'Warning: MARK chunk contains only 0 markers instead of 1\n')
self.assertEqual(f.getmarkers(), None)
def test_read_comm_kludge_compname_even(self):
@@ -226,9 +223,8 @@ class AIFCLowLevelTest(unittest.TestCase):
b += b'COMM' + struct.pack('>LhlhhLL', 18, 0, 0, 0, 0, 0, 0)
b += b'NONE' + struct.pack('B', 4) + b'even' + b'\x00'
b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
- with captured_stdout() as s:
+ with self.assertWarns(UserWarning):
f = aifc.open(io.BytesIO(b))
- self.assertEqual(s.getvalue(), 'Warning: bad COMM chunk size\n')
self.assertEqual(f.getcompname(), b'even')
def test_read_comm_kludge_compname_odd(self):
@@ -236,9 +232,8 @@ class AIFCLowLevelTest(unittest.TestCase):
b += b'COMM' + struct.pack('>LhlhhLL', 18, 0, 0, 0, 0, 0, 0)
b += b'NONE' + struct.pack('B', 3) + b'odd'
b += b'SSND' + struct.pack('>L', 8) + b'\x00' * 8
- with captured_stdout() as s:
+ with self.assertWarns(UserWarning):
f = aifc.open(io.BytesIO(b))
- self.assertEqual(s.getvalue(), 'Warning: bad COMM chunk size\n')
self.assertEqual(f.getcompname(), b'odd')
def test_write_params_raises(self):
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index 22c26cc..c06c940 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -1323,20 +1323,21 @@ class TestParserDefaultSuppress(ParserTestCase):
class TestParserDefault42(ParserTestCase):
"""Test actions with a parser-level default of 42"""
- parser_signature = Sig(argument_default=42, version='1.0')
+ parser_signature = Sig(argument_default=42)
argument_signatures = [
+ Sig('--version', action='version', version='1.0'),
Sig('foo', nargs='?'),
Sig('bar', nargs='*'),
Sig('--baz', action='store_true'),
]
failures = ['-x']
successes = [
- ('', NS(foo=42, bar=42, baz=42)),
- ('a', NS(foo='a', bar=42, baz=42)),
- ('a b', NS(foo='a', bar=['b'], baz=42)),
- ('--baz', NS(foo=42, bar=42, baz=True)),
- ('a --baz', NS(foo='a', bar=42, baz=True)),
- ('--baz a b', NS(foo='a', bar=['b'], baz=True)),
+ ('', NS(foo=42, bar=42, baz=42, version=42)),
+ ('a', NS(foo='a', bar=42, baz=42, version=42)),
+ ('a b', NS(foo='a', bar=['b'], baz=42, version=42)),
+ ('--baz', NS(foo=42, bar=42, baz=True, version=42)),
+ ('a --baz', NS(foo='a', bar=42, baz=True, version=42)),
+ ('--baz a b', NS(foo='a', bar=['b'], baz=True, version=42)),
]
@@ -2927,10 +2928,9 @@ class TestHelpFormattingMetaclass(type):
parser_text = sfile.getvalue()
self._test(tester, parser_text)
- # add tests for {format,print}_{usage,help,version}
+ # add tests for {format,print}_{usage,help}
for func_suffix, std_name in [('usage', 'stdout'),
- ('help', 'stdout'),
- ('version', 'stderr')]:
+ ('help', 'stdout')]:
AddTests(cls, func_suffix, std_name)
bases = TestCase,
@@ -2941,8 +2941,9 @@ class TestHelpBiggerOptionals(HelpTestCase):
"""Make sure that argument help aligns when options are longer"""
parser_signature = Sig(prog='PROG', description='DESCRIPTION',
- epilog='EPILOG', version='0.1')
+ epilog='EPILOG')
argument_signatures = [
+ Sig('-v', '--version', action='version', version='0.1'),
Sig('-x', action='store_true', help='X HELP'),
Sig('--y', help='Y HELP'),
Sig('foo', help='FOO HELP'),
@@ -2977,8 +2978,9 @@ class TestHelpBiggerOptionalGroups(HelpTestCase):
"""Make sure that argument help aligns when options are longer"""
parser_signature = Sig(prog='PROG', description='DESCRIPTION',
- epilog='EPILOG', version='0.1')
+ epilog='EPILOG')
argument_signatures = [
+ Sig('-v', '--version', action='version', version='0.1'),
Sig('-x', action='store_true', help='X HELP'),
Sig('--y', help='Y HELP'),
Sig('foo', help='FOO HELP'),
@@ -3145,9 +3147,9 @@ HHAAHHH
class TestHelpWrappingLongNames(HelpTestCase):
"""Make sure that text after long names starts on the next line"""
- parser_signature = Sig(usage='USAGE', description= 'D D' * 30,
- version='V V'*30)
+ parser_signature = Sig(usage='USAGE', description= 'D D' * 30)
argument_signatures = [
+ Sig('-v', '--version', action='version', version='V V' * 30),
Sig('-x', metavar='X' * 25, help='XH XH' * 20),
Sig('y', metavar='y' * 25, help='YH YH' * 20),
]
@@ -3750,8 +3752,9 @@ class TestHelpNoHelpOptional(HelpTestCase):
class TestHelpVersionOptional(HelpTestCase):
"""Test that the --version argument can be suppressed help messages"""
- parser_signature = Sig(prog='PROG', version='1.0')
+ parser_signature = Sig(prog='PROG')
argument_signatures = [
+ Sig('-v', '--version', action='version', version='1.0'),
Sig('--foo', help='foo help'),
Sig('spam', help='spam help'),
]
@@ -3984,8 +3987,8 @@ class TestHelpVersionAction(HelpTestCase):
class TestHelpSubparsersOrdering(HelpTestCase):
"""Test ordering of subcommands in help matches the code"""
parser_signature = Sig(prog='PROG',
- description='display some subcommands',
- version='0.1')
+ description='display some subcommands')
+ argument_signatures = [Sig('-v', '--version', action='version', version='0.1')]
subparsers_signatures = [Sig(name=name)
for name in ('a', 'b', 'c', 'd', 'e')]
@@ -4013,8 +4016,8 @@ class TestHelpSubparsersOrdering(HelpTestCase):
class TestHelpSubparsersWithHelpOrdering(HelpTestCase):
"""Test ordering of subcommands in help matches the code"""
parser_signature = Sig(prog='PROG',
- description='display some subcommands',
- version='0.1')
+ description='display some subcommands')
+ argument_signatures = [Sig('-v', '--version', action='version', version='0.1')]
subcommand_data = (('a', 'a subcommand help'),
('b', 'b subcommand help'),
@@ -4052,6 +4055,37 @@ class TestHelpSubparsersWithHelpOrdering(HelpTestCase):
'''
+
+class TestHelpMetavarTypeFormatter(HelpTestCase):
+ """"""
+
+ def custom_type(string):
+ return string
+
+ parser_signature = Sig(prog='PROG', description='description',
+ formatter_class=argparse.MetavarTypeHelpFormatter)
+ argument_signatures = [Sig('a', type=int),
+ Sig('-b', type=custom_type),
+ Sig('-c', type=float, metavar='SOME FLOAT')]
+ argument_group_signatures = []
+ usage = '''\
+ usage: PROG [-h] [-b custom_type] [-c SOME FLOAT] int
+ '''
+ help = usage + '''\
+
+ description
+
+ positional arguments:
+ int
+
+ optional arguments:
+ -h, --help show this help message and exit
+ -b custom_type
+ -c SOME FLOAT
+ '''
+ version = ''
+
+
# =====================================
# Optional/Positional constructor tests
# =====================================
@@ -4280,32 +4314,28 @@ class TestOptionalsHelpVersionActions(TestCase):
parser.format_help(),
self._get_error(parser.parse_args, args_str.split()).stdout)
- def assertPrintVersionExit(self, parser, args_str):
- self.assertEqual(
- parser.format_version(),
- self._get_error(parser.parse_args, args_str.split()).stderr)
-
def assertArgumentParserError(self, parser, *args):
self.assertRaises(ArgumentParserError, parser.parse_args, args)
def test_version(self):
- parser = ErrorRaisingArgumentParser(version='1.0')
+ parser = ErrorRaisingArgumentParser()
+ parser.add_argument('-v', '--version', action='version', version='1.0')
self.assertPrintHelpExit(parser, '-h')
self.assertPrintHelpExit(parser, '--help')
- self.assertPrintVersionExit(parser, '-v')
- self.assertPrintVersionExit(parser, '--version')
+ self.assertRaises(AttributeError, getattr, parser, 'format_version')
def test_version_format(self):
- parser = ErrorRaisingArgumentParser(prog='PPP', version='%(prog)s 3.5')
+ parser = ErrorRaisingArgumentParser(prog='PPP')
+ parser.add_argument('-v', '--version', action='version', version='%(prog)s 3.5')
msg = self._get_error(parser.parse_args, ['-v']).stderr
self.assertEqual('PPP 3.5\n', msg)
def test_version_no_help(self):
- parser = ErrorRaisingArgumentParser(add_help=False, version='1.0')
+ parser = ErrorRaisingArgumentParser(add_help=False)
+ parser.add_argument('-v', '--version', action='version', version='1.0')
self.assertArgumentParserError(parser, '-h')
self.assertArgumentParserError(parser, '--help')
- self.assertPrintVersionExit(parser, '-v')
- self.assertPrintVersionExit(parser, '--version')
+ self.assertRaises(AttributeError, getattr, parser, 'format_version')
def test_version_action(self):
parser = ErrorRaisingArgumentParser(prog='XXX')
@@ -4325,12 +4355,13 @@ class TestOptionalsHelpVersionActions(TestCase):
parser.add_argument('-x', action='help')
parser.add_argument('-y', action='version')
self.assertPrintHelpExit(parser, '-x')
- self.assertPrintVersionExit(parser, '-y')
self.assertArgumentParserError(parser, '-v')
self.assertArgumentParserError(parser, '--version')
+ self.assertRaises(AttributeError, getattr, parser, 'format_version')
def test_help_version_extra_arguments(self):
- parser = ErrorRaisingArgumentParser(version='1.0')
+ parser = ErrorRaisingArgumentParser()
+ parser.add_argument('--version', action='version', version='1.0')
parser.add_argument('-x', action='store_true')
parser.add_argument('y')
@@ -4342,8 +4373,7 @@ class TestOptionalsHelpVersionActions(TestCase):
format = '%s %%s %s' % (prefix, suffix)
self.assertPrintHelpExit(parser, format % '-h')
self.assertPrintHelpExit(parser, format % '--help')
- self.assertPrintVersionExit(parser, format % '-v')
- self.assertPrintVersionExit(parser, format % '--version')
+ self.assertRaises(AttributeError, getattr, parser, 'format_version')
# ======================
@@ -4398,7 +4428,7 @@ class TestStrings(TestCase):
parser = argparse.ArgumentParser(prog='PROG')
string = (
"ArgumentParser(prog='PROG', usage=None, description=None, "
- "version=None, formatter_class=%r, conflict_handler='error', "
+ "formatter_class=%r, conflict_handler='error', "
"add_help=True)" % argparse.HelpFormatter)
self.assertStringEqual(parser, string)
@@ -4442,7 +4472,7 @@ class TestEncoding(TestCase):
def _test_module_encoding(self, path):
path, _ = os.path.splitext(path)
path += ".py"
- with codecs.open(path, 'r', 'utf8') as f:
+ with codecs.open(path, 'r', 'utf-8') as f:
f.read()
def test_argparse_module_encoding(self):
@@ -4484,6 +4514,67 @@ class TestArgumentTypeError(TestCase):
else:
self.fail()
+# =========================
+# MessageContentError tests
+# =========================
+
+class TestMessageContentError(TestCase):
+
+ def test_missing_argument_name_in_message(self):
+ parser = ErrorRaisingArgumentParser(prog='PROG', usage='')
+ parser.add_argument('req_pos', type=str)
+ parser.add_argument('-req_opt', type=int, required=True)
+ parser.add_argument('need_one', type=str, nargs='+')
+
+ with self.assertRaises(ArgumentParserError) as cm:
+ parser.parse_args([])
+ msg = str(cm.exception)
+ self.assertRegex(msg, 'req_pos')
+ self.assertRegex(msg, 'req_opt')
+ self.assertRegex(msg, 'need_one')
+ with self.assertRaises(ArgumentParserError) as cm:
+ parser.parse_args(['myXargument'])
+ msg = str(cm.exception)
+ self.assertNotIn(msg, 'req_pos')
+ self.assertRegex(msg, 'req_opt')
+ self.assertRegex(msg, 'need_one')
+ with self.assertRaises(ArgumentParserError) as cm:
+ parser.parse_args(['myXargument', '-req_opt=1'])
+ msg = str(cm.exception)
+ self.assertNotIn(msg, 'req_pos')
+ self.assertNotIn(msg, 'req_opt')
+ self.assertRegex(msg, 'need_one')
+
+ def test_optional_optional_not_in_message(self):
+ parser = ErrorRaisingArgumentParser(prog='PROG', usage='')
+ parser.add_argument('req_pos', type=str)
+ parser.add_argument('--req_opt', type=int, required=True)
+ parser.add_argument('--opt_opt', type=bool, nargs='?',
+ default=True)
+ with self.assertRaises(ArgumentParserError) as cm:
+ parser.parse_args([])
+ msg = str(cm.exception)
+ self.assertRegex(msg, 'req_pos')
+ self.assertRegex(msg, 'req_opt')
+ self.assertNotIn(msg, 'opt_opt')
+ with self.assertRaises(ArgumentParserError) as cm:
+ parser.parse_args(['--req_opt=1'])
+ msg = str(cm.exception)
+ self.assertRegex(msg, 'req_pos')
+ self.assertNotIn(msg, 'req_opt')
+ self.assertNotIn(msg, 'opt_opt')
+
+ def test_optional_positional_not_in_message(self):
+ parser = ErrorRaisingArgumentParser(prog='PROG', usage='')
+ parser.add_argument('req_pos')
+ parser.add_argument('optional_positional', nargs='?', default='eggs')
+ with self.assertRaises(ArgumentParserError) as cm:
+ parser.parse_args([])
+ msg = str(cm.exception)
+ self.assertRegex(msg, 'req_pos')
+ self.assertNotIn(msg, 'optional_positional')
+
+
# ================================================
# Check that the type function is called only once
# ================================================
@@ -4782,13 +4873,7 @@ class TestImportStar(TestCase):
self.assertEqual(sorted(items), sorted(argparse.__all__))
def test_main():
- # silence warnings about version argument - these are expected
- with support.check_warnings(
- ('The "version" argument to ArgumentParser is deprecated.',
- DeprecationWarning),
- ('The (format|print)_version method is deprecated',
- DeprecationWarning)):
- support.run_unittest(__name__)
+ support.run_unittest(__name__)
# Remove global references to avoid looking like we have refleaks.
RFile.seen = {}
WFile.seen = set()
diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py
index e26e9ad..a532a9f 100755
--- a/Lib/test/test_array.py
+++ b/Lib/test/test_array.py
@@ -16,6 +16,13 @@ import warnings
import array
from array import _array_reconstructor as array_reconstructor
+try:
+ # Try to determine availability of long long independently
+ # of the array module under test
+ struct.calcsize('@q')
+ have_long_long = True
+except struct.error:
+ have_long_long = False
class ArraySubclass(array.array):
pass
@@ -24,8 +31,9 @@ class ArraySubclassWithKwargs(array.array):
def __init__(self, typecode, newarg=None):
array.array.__init__(self)
-tests = [] # list to accumulate all tests
typecodes = "ubBhHiIlLfd"
+if have_long_long:
+ typecodes += 'qQ'
class BadConstructorTest(unittest.TestCase):
@@ -35,7 +43,6 @@ class BadConstructorTest(unittest.TestCase):
self.assertRaises(TypeError, array.array, 'xx')
self.assertRaises(ValueError, array.array, 'x')
-tests.append(BadConstructorTest)
# Machine format codes.
#
@@ -165,10 +172,7 @@ class ArrayReconstructorTest(unittest.TestCase):
msg="{0!r} != {1!r}; testcase={2!r}".format(a, b, testcase))
-tests.append(ArrayReconstructorTest)
-
-
-class BaseTest(unittest.TestCase):
+class BaseTest:
# Required class attributes (provided by subclasses
# typecode: the typecode to test
# example: an initializer usable in the constructor for this type
@@ -209,10 +213,14 @@ class BaseTest(unittest.TestCase):
self.assertEqual(bi[1], len(a))
def test_byteswap(self):
- a = array.array(self.typecode, self.example)
+ if self.typecode == 'u':
+ example = '\U00100100'
+ else:
+ example = self.example
+ a = array.array(self.typecode, example)
self.assertRaises(TypeError, a.byteswap, 42)
if a.itemsize in (1, 2, 4, 8):
- b = array.array(self.typecode, self.example)
+ b = array.array(self.typecode, example)
b.byteswap()
if a.itemsize==1:
self.assertEqual(a, b)
@@ -272,6 +280,20 @@ class BaseTest(unittest.TestCase):
self.assertEqual(a.x, b.x)
self.assertEqual(type(a), type(b))
+ def test_iterator_pickle(self):
+ data = array.array(self.typecode, self.example)
+ orgit = iter(data)
+ d = pickle.dumps(orgit)
+ it = pickle.loads(d)
+ self.assertEqual(type(orgit), type(it))
+ self.assertEqual(list(it), list(data))
+
+ if len(data):
+ it = pickle.loads(d)
+ next(it)
+ d = pickle.dumps(it)
+ self.assertEqual(list(it), list(data)[1:])
+
def test_insert(self):
a = array.array(self.typecode, self.example)
a.insert(0, self.example[0])
@@ -991,14 +1013,14 @@ class BaseTest(unittest.TestCase):
@support.cpython_only
def test_sizeof_with_buffer(self):
a = array.array(self.typecode, self.example)
- basesize = support.calcvobjsize('4Pi')
+ basesize = support.calcvobjsize('Pn2Pi')
buffer_size = a.buffer_info()[1] * a.itemsize
support.check_sizeof(self, a, basesize + buffer_size)
@support.cpython_only
def test_sizeof_without_buffer(self):
a = array.array(self.typecode)
- basesize = support.calcvobjsize('4Pi')
+ basesize = support.calcvobjsize('Pn2Pi')
support.check_sizeof(self, a, basesize)
@@ -1009,7 +1031,7 @@ class StringTest(BaseTest):
a = array.array(self.typecode, self.example)
self.assertRaises(TypeError, a.__setitem__, 0, self.example[:2])
-class UnicodeTest(StringTest):
+class UnicodeTest(StringTest, unittest.TestCase):
typecode = 'u'
example = '\x01\u263a\x00\ufeff'
smallerexample = '\x01\u263a\x00\ufefe'
@@ -1018,6 +1040,16 @@ class UnicodeTest(StringTest):
minitemsize = 2
def test_unicode(self):
+ try:
+ import ctypes
+ sizeof_wchar = ctypes.sizeof(ctypes.c_wchar)
+ except ImportError:
+ import sys
+ if sys.platform == 'win32':
+ sizeof_wchar = 2
+ else:
+ sizeof_wchar = 4
+
self.assertRaises(TypeError, array.array, 'b', 'foo')
a = array.array('u', '\xa0\xc2\u1234')
@@ -1027,6 +1059,7 @@ class UnicodeTest(StringTest):
a.fromunicode('\x11abc\xff\u1234')
s = a.tounicode()
self.assertEqual(s, '\xa0\xc2\u1234 \x11abc\xff\u1234')
+ self.assertEqual(a.itemsize, sizeof_wchar)
s = '\x00="\'a\\b\x80\xff\u0000\u0001\u1234'
a = array.array('u', s)
@@ -1036,8 +1069,6 @@ class UnicodeTest(StringTest):
self.assertRaises(TypeError, a.fromunicode)
-tests.append(UnicodeTest)
-
class NumberTest(BaseTest):
def test_extslice(self):
@@ -1178,45 +1209,47 @@ class UnsignedNumberTest(NumberTest):
)
-class ByteTest(SignedNumberTest):
+class ByteTest(SignedNumberTest, unittest.TestCase):
typecode = 'b'
minitemsize = 1
-tests.append(ByteTest)
-class UnsignedByteTest(UnsignedNumberTest):
+class UnsignedByteTest(UnsignedNumberTest, unittest.TestCase):
typecode = 'B'
minitemsize = 1
-tests.append(UnsignedByteTest)
-class ShortTest(SignedNumberTest):
+class ShortTest(SignedNumberTest, unittest.TestCase):
typecode = 'h'
minitemsize = 2
-tests.append(ShortTest)
-class UnsignedShortTest(UnsignedNumberTest):
+class UnsignedShortTest(UnsignedNumberTest, unittest.TestCase):
typecode = 'H'
minitemsize = 2
-tests.append(UnsignedShortTest)
-class IntTest(SignedNumberTest):
+class IntTest(SignedNumberTest, unittest.TestCase):
typecode = 'i'
minitemsize = 2
-tests.append(IntTest)
-class UnsignedIntTest(UnsignedNumberTest):
+class UnsignedIntTest(UnsignedNumberTest, unittest.TestCase):
typecode = 'I'
minitemsize = 2
-tests.append(UnsignedIntTest)
-class LongTest(SignedNumberTest):
+class LongTest(SignedNumberTest, unittest.TestCase):
typecode = 'l'
minitemsize = 4
-tests.append(LongTest)
-class UnsignedLongTest(UnsignedNumberTest):
+class UnsignedLongTest(UnsignedNumberTest, unittest.TestCase):
typecode = 'L'
minitemsize = 4
-tests.append(UnsignedLongTest)
+
+@unittest.skipIf(not have_long_long, 'need long long support')
+class LongLongTest(SignedNumberTest, unittest.TestCase):
+ typecode = 'q'
+ minitemsize = 8
+
+@unittest.skipIf(not have_long_long, 'need long long support')
+class UnsignedLongLongTest(UnsignedNumberTest, unittest.TestCase):
+ typecode = 'Q'
+ minitemsize = 8
class FPTest(NumberTest):
example = [-42.0, 0, 42, 1e5, -1e10]
@@ -1243,12 +1276,11 @@ class FPTest(NumberTest):
b.byteswap()
self.assertEqual(a, b)
-class FloatTest(FPTest):
+class FloatTest(FPTest, unittest.TestCase):
typecode = 'f'
minitemsize = 4
-tests.append(FloatTest)
-class DoubleTest(FPTest):
+class DoubleTest(FPTest, unittest.TestCase):
typecode = 'd'
minitemsize = 8
@@ -1269,22 +1301,6 @@ class DoubleTest(FPTest):
else:
self.fail("Array of size > maxsize created - MemoryError expected")
-tests.append(DoubleTest)
-
-def test_main(verbose=None):
- import sys
-
- support.run_unittest(*tests)
-
- # verify reference counting
- if verbose and hasattr(sys, "gettotalrefcount"):
- import gc
- counts = [None] * 5
- for i in range(len(counts)):
- support.run_unittest(*tests)
- gc.collect()
- counts[i] = sys.gettotalrefcount()
- print(counts)
if __name__ == "__main__":
- test_main(verbose=True)
+ unittest.main()
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 2887092..dc24126 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -1,6 +1,10 @@
-import sys, unittest
-from test import support
+import os
+import sys
+import unittest
import ast
+import weakref
+
+from test import support
def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)):
@@ -52,6 +56,9 @@ exec_tests = [
"while v:pass",
# If
"if v:pass",
+ # With
+ "with x as y: pass",
+ "with x as y, z as q: pass",
# Raise
"raise Exception('string')",
# TryExcept
@@ -191,6 +198,9 @@ class AST_Tests(unittest.TestCase):
def test_AST_objects(self):
x = ast.AST()
self.assertEqual(x._fields, ())
+ x.foobar = 42
+ self.assertEqual(x.foobar, 42)
+ self.assertEqual(x.__dict__["foobar"], 42)
with self.assertRaises(AttributeError):
x.vararg
@@ -199,6 +209,17 @@ class AST_Tests(unittest.TestCase):
# "_ast.AST constructor takes 0 positional arguments"
ast.AST(2)
+ def test_AST_garbage_collection(self):
+ class X:
+ pass
+ a = ast.AST()
+ a.x = X()
+ a.x.a = a
+ ref = weakref.ref(a.x)
+ del a
+ support.gc_collect()
+ self.assertIsNone(ref())
+
def test_snippets(self):
for input, output, kind in ((exec_tests, exec_results, "exec"),
(single_tests, single_results, "single"),
@@ -378,6 +399,14 @@ class AST_Tests(unittest.TestCase):
compile(m, "<test>", "exec")
self.assertIn("string must be of type str", str(cm.exception))
+ def test_empty_yield_from(self):
+ # Issue 16546: yield from value is not optional.
+ empty_yield_from = ast.parse("def f():\n yield from g()")
+ empty_yield_from.body[0].body[0].value.value = None
+ with self.assertRaises(ValueError) as cm:
+ compile(empty_yield_from, "<test>", "exec")
+ self.assertIn("field value is required", str(cm.exception))
+
class ASTHelpers_Test(unittest.TestCase):
@@ -390,7 +419,9 @@ class ASTHelpers_Test(unittest.TestCase):
try:
1/0
except Exception:
- self.assertRaises(SyntaxError, ast.parse, r"'\U'")
+ with self.assertRaises(SyntaxError) as e:
+ ast.literal_eval(r"'\U'")
+ self.assertIsNotNone(e.exception.__context__)
def test_dump(self):
node = ast.parse('spam(eggs, "and cheese")')
@@ -504,8 +535,413 @@ class ASTHelpers_Test(unittest.TestCase):
self.assertIn("invalid integer value: None", str(cm.exception))
+class ASTValidatorTests(unittest.TestCase):
+
+ def mod(self, mod, msg=None, mode="exec", *, exc=ValueError):
+ mod.lineno = mod.col_offset = 0
+ ast.fix_missing_locations(mod)
+ with self.assertRaises(exc) as cm:
+ compile(mod, "<test>", mode)
+ if msg is not None:
+ self.assertIn(msg, str(cm.exception))
+
+ def expr(self, node, msg=None, *, exc=ValueError):
+ mod = ast.Module([ast.Expr(node)])
+ self.mod(mod, msg, exc=exc)
+
+ def stmt(self, stmt, msg=None):
+ mod = ast.Module([stmt])
+ self.mod(mod, msg)
+
+ def test_module(self):
+ m = ast.Interactive([ast.Expr(ast.Name("x", ast.Store()))])
+ self.mod(m, "must have Load context", "single")
+ m = ast.Expression(ast.Name("x", ast.Store()))
+ self.mod(m, "must have Load context", "eval")
+
+ def _check_arguments(self, fac, check):
+ def arguments(args=None, vararg=None, varargannotation=None,
+ kwonlyargs=None, kwarg=None, kwargannotation=None,
+ defaults=None, kw_defaults=None):
+ if args is None:
+ args = []
+ if kwonlyargs is None:
+ kwonlyargs = []
+ if defaults is None:
+ defaults = []
+ if kw_defaults is None:
+ kw_defaults = []
+ args = ast.arguments(args, vararg, varargannotation, kwonlyargs,
+ kwarg, kwargannotation, defaults, kw_defaults)
+ return fac(args)
+ args = [ast.arg("x", ast.Name("x", ast.Store()))]
+ check(arguments(args=args), "must have Load context")
+ check(arguments(varargannotation=ast.Num(3)),
+ "varargannotation but no vararg")
+ check(arguments(varargannotation=ast.Name("x", ast.Store()), vararg="x"),
+ "must have Load context")
+ check(arguments(kwonlyargs=args), "must have Load context")
+ check(arguments(kwargannotation=ast.Num(42)),
+ "kwargannotation but no kwarg")
+ check(arguments(kwargannotation=ast.Name("x", ast.Store()),
+ kwarg="x"), "must have Load context")
+ check(arguments(defaults=[ast.Num(3)]),
+ "more positional defaults than args")
+ check(arguments(kw_defaults=[ast.Num(4)]),
+ "length of kwonlyargs is not the same as kw_defaults")
+ args = [ast.arg("x", ast.Name("x", ast.Load()))]
+ check(arguments(args=args, defaults=[ast.Name("x", ast.Store())]),
+ "must have Load context")
+ args = [ast.arg("a", ast.Name("x", ast.Load())),
+ ast.arg("b", ast.Name("y", ast.Load()))]
+ check(arguments(kwonlyargs=args,
+ kw_defaults=[None, ast.Name("x", ast.Store())]),
+ "must have Load context")
+
+ def test_funcdef(self):
+ a = ast.arguments([], None, None, [], None, None, [], [])
+ f = ast.FunctionDef("x", a, [], [], None)
+ self.stmt(f, "empty body on FunctionDef")
+ f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())],
+ None)
+ self.stmt(f, "must have Load context")
+ f = ast.FunctionDef("x", a, [ast.Pass()], [],
+ ast.Name("x", ast.Store()))
+ self.stmt(f, "must have Load context")
+ def fac(args):
+ return ast.FunctionDef("x", args, [ast.Pass()], [], None)
+ self._check_arguments(fac, self.stmt)
+
+ def test_classdef(self):
+ def cls(bases=None, keywords=None, starargs=None, kwargs=None,
+ body=None, decorator_list=None):
+ if bases is None:
+ bases = []
+ if keywords is None:
+ keywords = []
+ if body is None:
+ body = [ast.Pass()]
+ if decorator_list is None:
+ decorator_list = []
+ return ast.ClassDef("myclass", bases, keywords, starargs,
+ kwargs, body, decorator_list)
+ self.stmt(cls(bases=[ast.Name("x", ast.Store())]),
+ "must have Load context")
+ self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store()))]),
+ "must have Load context")
+ self.stmt(cls(starargs=ast.Name("x", ast.Store())),
+ "must have Load context")
+ self.stmt(cls(kwargs=ast.Name("x", ast.Store())),
+ "must have Load context")
+ self.stmt(cls(body=[]), "empty body on ClassDef")
+ self.stmt(cls(body=[None]), "None disallowed")
+ self.stmt(cls(decorator_list=[ast.Name("x", ast.Store())]),
+ "must have Load context")
+
+ def test_delete(self):
+ self.stmt(ast.Delete([]), "empty targets on Delete")
+ self.stmt(ast.Delete([None]), "None disallowed")
+ self.stmt(ast.Delete([ast.Name("x", ast.Load())]),
+ "must have Del context")
+
+ def test_assign(self):
+ self.stmt(ast.Assign([], ast.Num(3)), "empty targets on Assign")
+ self.stmt(ast.Assign([None], ast.Num(3)), "None disallowed")
+ self.stmt(ast.Assign([ast.Name("x", ast.Load())], ast.Num(3)),
+ "must have Store context")
+ self.stmt(ast.Assign([ast.Name("x", ast.Store())],
+ ast.Name("y", ast.Store())),
+ "must have Load context")
+
+ def test_augassign(self):
+ aug = ast.AugAssign(ast.Name("x", ast.Load()), ast.Add(),
+ ast.Name("y", ast.Load()))
+ self.stmt(aug, "must have Store context")
+ aug = ast.AugAssign(ast.Name("x", ast.Store()), ast.Add(),
+ ast.Name("y", ast.Store()))
+ self.stmt(aug, "must have Load context")
+
+ def test_for(self):
+ x = ast.Name("x", ast.Store())
+ y = ast.Name("y", ast.Load())
+ p = ast.Pass()
+ self.stmt(ast.For(x, y, [], []), "empty body on For")
+ self.stmt(ast.For(ast.Name("x", ast.Load()), y, [p], []),
+ "must have Store context")
+ self.stmt(ast.For(x, ast.Name("y", ast.Store()), [p], []),
+ "must have Load context")
+ e = ast.Expr(ast.Name("x", ast.Store()))
+ self.stmt(ast.For(x, y, [e], []), "must have Load context")
+ self.stmt(ast.For(x, y, [p], [e]), "must have Load context")
+
+ def test_while(self):
+ self.stmt(ast.While(ast.Num(3), [], []), "empty body on While")
+ self.stmt(ast.While(ast.Name("x", ast.Store()), [ast.Pass()], []),
+ "must have Load context")
+ self.stmt(ast.While(ast.Num(3), [ast.Pass()],
+ [ast.Expr(ast.Name("x", ast.Store()))]),
+ "must have Load context")
+
+ def test_if(self):
+ self.stmt(ast.If(ast.Num(3), [], []), "empty body on If")
+ i = ast.If(ast.Name("x", ast.Store()), [ast.Pass()], [])
+ self.stmt(i, "must have Load context")
+ i = ast.If(ast.Num(3), [ast.Expr(ast.Name("x", ast.Store()))], [])
+ self.stmt(i, "must have Load context")
+ i = ast.If(ast.Num(3), [ast.Pass()],
+ [ast.Expr(ast.Name("x", ast.Store()))])
+ self.stmt(i, "must have Load context")
+
+ def test_with(self):
+ p = ast.Pass()
+ self.stmt(ast.With([], [p]), "empty items on With")
+ i = ast.withitem(ast.Num(3), None)
+ self.stmt(ast.With([i], []), "empty body on With")
+ i = ast.withitem(ast.Name("x", ast.Store()), None)
+ self.stmt(ast.With([i], [p]), "must have Load context")
+ i = ast.withitem(ast.Num(3), ast.Name("x", ast.Load()))
+ self.stmt(ast.With([i], [p]), "must have Store context")
+
+ def test_raise(self):
+ r = ast.Raise(None, ast.Num(3))
+ self.stmt(r, "Raise with cause but no exception")
+ r = ast.Raise(ast.Name("x", ast.Store()), None)
+ self.stmt(r, "must have Load context")
+ r = ast.Raise(ast.Num(4), ast.Name("x", ast.Store()))
+ self.stmt(r, "must have Load context")
+
+ def test_try(self):
+ p = ast.Pass()
+ t = ast.Try([], [], [], [p])
+ self.stmt(t, "empty body on Try")
+ t = ast.Try([ast.Expr(ast.Name("x", ast.Store()))], [], [], [p])
+ self.stmt(t, "must have Load context")
+ t = ast.Try([p], [], [], [])
+ self.stmt(t, "Try has neither except handlers nor finalbody")
+ t = ast.Try([p], [], [p], [p])
+ self.stmt(t, "Try has orelse but no except handlers")
+ t = ast.Try([p], [ast.ExceptHandler(None, "x", [])], [], [])
+ self.stmt(t, "empty body on ExceptHandler")
+ e = [ast.ExceptHandler(ast.Name("x", ast.Store()), "y", [p])]
+ self.stmt(ast.Try([p], e, [], []), "must have Load context")
+ e = [ast.ExceptHandler(None, "x", [p])]
+ t = ast.Try([p], e, [ast.Expr(ast.Name("x", ast.Store()))], [p])
+ self.stmt(t, "must have Load context")
+ t = ast.Try([p], e, [p], [ast.Expr(ast.Name("x", ast.Store()))])
+ self.stmt(t, "must have Load context")
+
+ def test_assert(self):
+ self.stmt(ast.Assert(ast.Name("x", ast.Store()), None),
+ "must have Load context")
+ assrt = ast.Assert(ast.Name("x", ast.Load()),
+ ast.Name("y", ast.Store()))
+ self.stmt(assrt, "must have Load context")
+
+ def test_import(self):
+ self.stmt(ast.Import([]), "empty names on Import")
+
+ def test_importfrom(self):
+ imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
+ self.stmt(imp, "level less than -1")
+ self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
+
+ def test_global(self):
+ self.stmt(ast.Global([]), "empty names on Global")
+
+ def test_nonlocal(self):
+ self.stmt(ast.Nonlocal([]), "empty names on Nonlocal")
+
+ def test_expr(self):
+ e = ast.Expr(ast.Name("x", ast.Store()))
+ self.stmt(e, "must have Load context")
+
+ def test_boolop(self):
+ b = ast.BoolOp(ast.And(), [])
+ self.expr(b, "less than 2 values")
+ b = ast.BoolOp(ast.And(), [ast.Num(3)])
+ self.expr(b, "less than 2 values")
+ b = ast.BoolOp(ast.And(), [ast.Num(4), None])
+ self.expr(b, "None disallowed")
+ b = ast.BoolOp(ast.And(), [ast.Num(4), ast.Name("x", ast.Store())])
+ self.expr(b, "must have Load context")
+
+ def test_unaryop(self):
+ u = ast.UnaryOp(ast.Not(), ast.Name("x", ast.Store()))
+ self.expr(u, "must have Load context")
+
+ def test_lambda(self):
+ a = ast.arguments([], None, None, [], None, None, [], [])
+ self.expr(ast.Lambda(a, ast.Name("x", ast.Store())),
+ "must have Load context")
+ def fac(args):
+ return ast.Lambda(args, ast.Name("x", ast.Load()))
+ self._check_arguments(fac, self.expr)
+
+ def test_ifexp(self):
+ l = ast.Name("x", ast.Load())
+ s = ast.Name("y", ast.Store())
+ for args in (s, l, l), (l, s, l), (l, l, s):
+ self.expr(ast.IfExp(*args), "must have Load context")
+
+ def test_dict(self):
+ d = ast.Dict([], [ast.Name("x", ast.Load())])
+ self.expr(d, "same number of keys as values")
+ d = ast.Dict([None], [ast.Name("x", ast.Load())])
+ self.expr(d, "None disallowed")
+ d = ast.Dict([ast.Name("x", ast.Load())], [None])
+ self.expr(d, "None disallowed")
+
+ def test_set(self):
+ self.expr(ast.Set([None]), "None disallowed")
+ s = ast.Set([ast.Name("x", ast.Store())])
+ self.expr(s, "must have Load context")
+
+ def _check_comprehension(self, fac):
+ self.expr(fac([]), "comprehension with no generators")
+ g = ast.comprehension(ast.Name("x", ast.Load()),
+ ast.Name("x", ast.Load()), [])
+ self.expr(fac([g]), "must have Store context")
+ g = ast.comprehension(ast.Name("x", ast.Store()),
+ ast.Name("x", ast.Store()), [])
+ self.expr(fac([g]), "must have Load context")
+ x = ast.Name("x", ast.Store())
+ y = ast.Name("y", ast.Load())
+ g = ast.comprehension(x, y, [None])
+ self.expr(fac([g]), "None disallowed")
+ g = ast.comprehension(x, y, [ast.Name("x", ast.Store())])
+ self.expr(fac([g]), "must have Load context")
+
+ def _simple_comp(self, fac):
+ g = ast.comprehension(ast.Name("x", ast.Store()),
+ ast.Name("x", ast.Load()), [])
+ self.expr(fac(ast.Name("x", ast.Store()), [g]),
+ "must have Load context")
+ def wrap(gens):
+ return fac(ast.Name("x", ast.Store()), gens)
+ self._check_comprehension(wrap)
+
+ def test_listcomp(self):
+ self._simple_comp(ast.ListComp)
+
+ def test_setcomp(self):
+ self._simple_comp(ast.SetComp)
+
+ def test_generatorexp(self):
+ self._simple_comp(ast.GeneratorExp)
+
+ def test_dictcomp(self):
+ g = ast.comprehension(ast.Name("y", ast.Store()),
+ ast.Name("p", ast.Load()), [])
+ c = ast.DictComp(ast.Name("x", ast.Store()),
+ ast.Name("y", ast.Load()), [g])
+ self.expr(c, "must have Load context")
+ c = ast.DictComp(ast.Name("x", ast.Load()),
+ ast.Name("y", ast.Store()), [g])
+ self.expr(c, "must have Load context")
+ def factory(comps):
+ k = ast.Name("x", ast.Load())
+ v = ast.Name("y", ast.Load())
+ return ast.DictComp(k, v, comps)
+ self._check_comprehension(factory)
+
+ def test_yield(self):
+ self.expr(ast.Yield(ast.Name("x", ast.Store())), "must have Load")
+ self.expr(ast.YieldFrom(ast.Name("x", ast.Store())), "must have Load")
+
+ def test_compare(self):
+ left = ast.Name("x", ast.Load())
+ comp = ast.Compare(left, [ast.In()], [])
+ self.expr(comp, "no comparators")
+ comp = ast.Compare(left, [ast.In()], [ast.Num(4), ast.Num(5)])
+ self.expr(comp, "different number of comparators and operands")
+ comp = ast.Compare(ast.Num("blah"), [ast.In()], [left])
+ self.expr(comp, "non-numeric", exc=TypeError)
+ comp = ast.Compare(left, [ast.In()], [ast.Num("blah")])
+ self.expr(comp, "non-numeric", exc=TypeError)
+
+ def test_call(self):
+ func = ast.Name("x", ast.Load())
+ args = [ast.Name("y", ast.Load())]
+ keywords = [ast.keyword("w", ast.Name("z", ast.Load()))]
+ stararg = ast.Name("p", ast.Load())
+ kwarg = ast.Name("q", ast.Load())
+ call = ast.Call(ast.Name("x", ast.Store()), args, keywords, stararg,
+ kwarg)
+ self.expr(call, "must have Load context")
+ call = ast.Call(func, [None], keywords, stararg, kwarg)
+ self.expr(call, "None disallowed")
+ bad_keywords = [ast.keyword("w", ast.Name("z", ast.Store()))]
+ call = ast.Call(func, args, bad_keywords, stararg, kwarg)
+ self.expr(call, "must have Load context")
+ call = ast.Call(func, args, keywords, ast.Name("z", ast.Store()), kwarg)
+ self.expr(call, "must have Load context")
+ call = ast.Call(func, args, keywords, stararg,
+ ast.Name("w", ast.Store()))
+ self.expr(call, "must have Load context")
+
+ def test_num(self):
+ class subint(int):
+ pass
+ class subfloat(float):
+ pass
+ class subcomplex(complex):
+ pass
+ for obj in "0", "hello", subint(), subfloat(), subcomplex():
+ self.expr(ast.Num(obj), "non-numeric", exc=TypeError)
+
+ def test_attribute(self):
+ attr = ast.Attribute(ast.Name("x", ast.Store()), "y", ast.Load())
+ self.expr(attr, "must have Load context")
+
+ def test_subscript(self):
+ sub = ast.Subscript(ast.Name("x", ast.Store()), ast.Index(ast.Num(3)),
+ ast.Load())
+ self.expr(sub, "must have Load context")
+ x = ast.Name("x", ast.Load())
+ sub = ast.Subscript(x, ast.Index(ast.Name("y", ast.Store())),
+ ast.Load())
+ self.expr(sub, "must have Load context")
+ s = ast.Name("x", ast.Store())
+ for args in (s, None, None), (None, s, None), (None, None, s):
+ sl = ast.Slice(*args)
+ self.expr(ast.Subscript(x, sl, ast.Load()),
+ "must have Load context")
+ sl = ast.ExtSlice([])
+ self.expr(ast.Subscript(x, sl, ast.Load()), "empty dims on ExtSlice")
+ sl = ast.ExtSlice([ast.Index(s)])
+ self.expr(ast.Subscript(x, sl, ast.Load()), "must have Load context")
+
+ def test_starred(self):
+ left = ast.List([ast.Starred(ast.Name("x", ast.Load()), ast.Store())],
+ ast.Store())
+ assign = ast.Assign([left], ast.Num(4))
+ self.stmt(assign, "must have Store context")
+
+ def _sequence(self, fac):
+ self.expr(fac([None], ast.Load()), "None disallowed")
+ self.expr(fac([ast.Name("x", ast.Store())], ast.Load()),
+ "must have Load context")
+
+ def test_list(self):
+ self._sequence(ast.List)
+
+ def test_tuple(self):
+ self._sequence(ast.Tuple)
+
+ def test_stdlib_validates(self):
+ stdlib = os.path.dirname(ast.__file__)
+ tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
+ tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
+ for module in tests:
+ fn = os.path.join(stdlib, module)
+ with open(fn, "r", encoding="utf-8") as fp:
+ source = fp.read()
+ mod = ast.parse(source)
+ compile(mod, fn, "exec")
+
+
def test_main():
- support.run_unittest(AST_Tests, ASTHelpers_Test)
+ support.run_unittest(AST_Tests, ASTHelpers_Test, ASTValidatorTests)
def main():
if __name__ != '__main__':
@@ -539,9 +975,11 @@ exec_results = [
('Module', [('For', (1, 0), ('Name', (1, 4), 'v', ('Store',)), ('Name', (1, 9), 'v', ('Load',)), [('Pass', (1, 11))], [])]),
('Module', [('While', (1, 0), ('Name', (1, 6), 'v', ('Load',)), [('Pass', (1, 8))], [])]),
('Module', [('If', (1, 0), ('Name', (1, 3), 'v', ('Load',)), [('Pass', (1, 5))], [])]),
+('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',)))], [('Pass', (1, 13))])]),
+('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',))), ('withitem', ('Name', (1, 13), 'z', ('Load',)), ('Name', (1, 18), 'q', ('Store',)))], [('Pass', (1, 21))])]),
('Module', [('Raise', (1, 0), ('Call', (1, 6), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], [], None, None), None)]),
-('Module', [('TryExcept', (1, 0), [('Pass', (2, 2))], [('ExceptHandler', (3, 0), ('Name', (3, 7), 'Exception', ('Load',)), None, [('Pass', (4, 2))])], [])]),
-('Module', [('TryFinally', (1, 0), [('Pass', (2, 2))], [('Pass', (4, 2))])]),
+('Module', [('Try', (1, 0), [('Pass', (2, 2))], [('ExceptHandler', (3, 0), ('Name', (3, 7), 'Exception', ('Load',)), None, [('Pass', (4, 2))])], [], [])]),
+('Module', [('Try', (1, 0), [('Pass', (2, 2))], [], [], [('Pass', (4, 2))])]),
('Module', [('Assert', (1, 0), ('Name', (1, 7), 'v', ('Load',)), None)]),
('Module', [('Import', (1, 0), [('alias', 'sys', None)])]),
('Module', [('ImportFrom', (1, 0), 'sys', [('alias', 'v', None)], 0)]),
diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py
index 8989a63..878b26c 100644
--- a/Lib/test/test_asyncore.py
+++ b/Lib/test/test_asyncore.py
@@ -21,6 +21,8 @@ except ImportError:
HOST = support.HOST
+HAS_UNIX_SOCKETS = hasattr(socket, 'AF_UNIX')
+
class dummysocket:
def __init__(self):
self.closed = False
@@ -72,15 +74,16 @@ def capture_server(evt, buf, serv):
pass
else:
n = 200
- while n > 0:
- r, w, e = select.select([conn], [], [])
+ start = time.time()
+ while n > 0 and time.time() - start < 3.0:
+ r, w, e = select.select([conn], [], [], 0.1)
if r:
+ n -= 1
data = conn.recv(10)
# keep everything except for the newline terminator
buf.write(data.replace(b'\n', b''))
if b'\n' in data:
break
- n -= 1
time.sleep(0.01)
conn.close()
@@ -88,6 +91,13 @@ def capture_server(evt, buf, serv):
serv.close()
evt.set()
+def bind_af_aware(sock, addr):
+ """Helper function to bind a socket according to its family."""
+ if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX:
+ # Make sure the path doesn't exist.
+ unlink(addr)
+ sock.bind(addr)
+
class HelperFunctionTests(unittest.TestCase):
def test_readwriteexc(self):
@@ -353,7 +363,7 @@ class DispatcherWithSendTests(unittest.TestCase):
@support.reap_threads
def test_send(self):
evt = threading.Event()
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock = socket.socket()
sock.settimeout(3)
port = support.bind_port(sock)
@@ -368,7 +378,7 @@ class DispatcherWithSendTests(unittest.TestCase):
data = b"Suppose there isn't a 16-ton weight?"
d = dispatcherwithsend_noread()
- d.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ d.create_socket()
d.connect((HOST, port))
# give time for socket to connect
@@ -468,22 +478,22 @@ class BaseTestHandler(asyncore.dispatcher):
raise
-class TCPServer(asyncore.dispatcher):
+class BaseServer(asyncore.dispatcher):
"""A server which listens on an address and dispatches the
connection to a handler.
"""
- def __init__(self, handler=BaseTestHandler, host=HOST, port=0):
+ def __init__(self, family, addr, handler=BaseTestHandler):
asyncore.dispatcher.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.create_socket(family)
self.set_reuse_addr()
- self.bind((host, port))
+ bind_af_aware(self.socket, addr)
self.listen(5)
self.handler = handler
@property
def address(self):
- return self.socket.getsockname()[:2]
+ return self.socket.getsockname()
def handle_accepted(self, sock, addr):
self.handler(sock)
@@ -494,16 +504,16 @@ class TCPServer(asyncore.dispatcher):
class BaseClient(BaseTestHandler):
- def __init__(self, address):
+ def __init__(self, family, address):
BaseTestHandler.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.create_socket(family)
self.connect(address)
def handle_connect(self):
pass
-class BaseTestAPI(unittest.TestCase):
+class BaseTestAPI:
def tearDown(self):
asyncore.close_all()
@@ -526,8 +536,8 @@ class BaseTestAPI(unittest.TestCase):
def handle_connect(self):
self.flag = True
- server = TCPServer()
- client = TestClient(server.address)
+ server = BaseServer(self.family, self.addr)
+ client = TestClient(self.family, server.address)
self.loop_waiting_for_flag(client)
def test_handle_accept(self):
@@ -535,18 +545,18 @@ class BaseTestAPI(unittest.TestCase):
class TestListener(BaseTestHandler):
- def __init__(self):
+ def __init__(self, family, addr):
BaseTestHandler.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.bind((HOST, 0))
+ self.create_socket(family)
+ bind_af_aware(self.socket, addr)
self.listen(5)
- self.address = self.socket.getsockname()[:2]
+ self.address = self.socket.getsockname()
def handle_accept(self):
self.flag = True
- server = TestListener()
- client = BaseClient(server.address)
+ server = TestListener(self.family, self.addr)
+ client = BaseClient(self.family, server.address)
self.loop_waiting_for_flag(server)
def test_handle_accepted(self):
@@ -554,12 +564,12 @@ class BaseTestAPI(unittest.TestCase):
class TestListener(BaseTestHandler):
- def __init__(self):
+ def __init__(self, family, addr):
BaseTestHandler.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.bind((HOST, 0))
+ self.create_socket(family)
+ bind_af_aware(self.socket, addr)
self.listen(5)
- self.address = self.socket.getsockname()[:2]
+ self.address = self.socket.getsockname()
def handle_accept(self):
asyncore.dispatcher.handle_accept(self)
@@ -568,8 +578,8 @@ class BaseTestAPI(unittest.TestCase):
sock.close()
self.flag = True
- server = TestListener()
- client = BaseClient(server.address)
+ server = TestListener(self.family, self.addr)
+ client = BaseClient(self.family, server.address)
self.loop_waiting_for_flag(server)
@@ -585,8 +595,8 @@ class BaseTestAPI(unittest.TestCase):
BaseTestHandler.__init__(self, conn)
self.send(b'x' * 1024)
- server = TCPServer(TestHandler)
- client = TestClient(server.address)
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
self.loop_waiting_for_flag(client)
def test_handle_write(self):
@@ -596,8 +606,8 @@ class BaseTestAPI(unittest.TestCase):
def handle_write(self):
self.flag = True
- server = TCPServer()
- client = TestClient(server.address)
+ server = BaseServer(self.family, self.addr)
+ client = TestClient(self.family, server.address)
self.loop_waiting_for_flag(client)
def test_handle_close(self):
@@ -620,8 +630,40 @@ class BaseTestAPI(unittest.TestCase):
BaseTestHandler.__init__(self, conn)
self.close()
- server = TCPServer(TestHandler)
- client = TestClient(server.address)
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_handle_close_after_conn_broken(self):
+ # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and
+ # #11265).
+
+ data = b'\0' * 128
+
+ class TestClient(BaseClient):
+
+ def handle_write(self):
+ self.send(data)
+
+ def handle_close(self):
+ self.flag = True
+ self.close()
+
+ def handle_expt(self):
+ self.flag = True
+ self.close()
+
+ class TestHandler(BaseTestHandler):
+
+ def handle_read(self):
+ self.recv(len(data))
+ self.close()
+
+ def writable(self):
+ return False
+
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
self.loop_waiting_for_flag(client)
@unittest.skipIf(sys.platform.startswith("sunos"),
@@ -630,9 +672,12 @@ class BaseTestAPI(unittest.TestCase):
# Make sure handle_expt is called on OOB data received.
# Note: this might fail on some platforms as OOB data is
# tenuously supported and rarely used.
+ if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
+ self.skipTest("Not applicable to AF_UNIX sockets.")
class TestClient(BaseClient):
def handle_expt(self):
+ self.socket.recv(1024, socket.MSG_OOB)
self.flag = True
class TestHandler(BaseTestHandler):
@@ -640,8 +685,8 @@ class BaseTestAPI(unittest.TestCase):
BaseTestHandler.__init__(self, conn)
self.socket.send(bytes(chr(244), 'latin-1'), socket.MSG_OOB)
- server = TCPServer(TestHandler)
- client = TestClient(server.address)
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
self.loop_waiting_for_flag(client)
def test_handle_error(self):
@@ -658,13 +703,13 @@ class BaseTestAPI(unittest.TestCase):
else:
raise Exception("exception not raised")
- server = TCPServer()
- client = TestClient(server.address)
+ server = BaseServer(self.family, self.addr)
+ client = TestClient(self.family, server.address)
self.loop_waiting_for_flag(client)
def test_connection_attributes(self):
- server = TCPServer()
- client = BaseClient(server.address)
+ server = BaseServer(self.family, self.addr)
+ client = BaseClient(self.family, server.address)
# we start disconnected
self.assertFalse(server.connected)
@@ -694,25 +739,29 @@ class BaseTestAPI(unittest.TestCase):
def test_create_socket(self):
s = asyncore.dispatcher()
- s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.assertEqual(s.socket.family, socket.AF_INET)
+ s.create_socket(self.family)
+ self.assertEqual(s.socket.family, self.family)
SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
self.assertEqual(s.socket.type, socket.SOCK_STREAM | SOCK_NONBLOCK)
def test_bind(self):
+ if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
+ self.skipTest("Not applicable to AF_UNIX sockets.")
s1 = asyncore.dispatcher()
- s1.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- s1.bind((HOST, 0))
+ s1.create_socket(self.family)
+ s1.bind(self.addr)
s1.listen(5)
port = s1.socket.getsockname()[1]
s2 = asyncore.dispatcher()
- s2.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ s2.create_socket(self.family)
# EADDRINUSE indicates the socket was correctly bound
- self.assertRaises(socket.error, s2.bind, (HOST, port))
+ self.assertRaises(socket.error, s2.bind, (self.addr[0], port))
def test_set_reuse_addr(self):
- sock = socket.socket()
+ if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
+ self.skipTest("Not applicable to AF_UNIX sockets.")
+ sock = socket.socket(self.family)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error:
@@ -720,11 +769,11 @@ class BaseTestAPI(unittest.TestCase):
else:
# if SO_REUSEADDR succeeded for sock we expect asyncore
# to do the same
- s = asyncore.dispatcher(socket.socket())
+ s = asyncore.dispatcher(socket.socket(self.family))
self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR))
s.socket.close()
- s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.create_socket(self.family)
s.set_reuse_addr()
self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR))
@@ -735,13 +784,14 @@ class BaseTestAPI(unittest.TestCase):
@support.reap_threads
def test_quick_connect(self):
# see: http://bugs.python.org/issue10340
- server = TCPServer()
- t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
- t.start()
- self.addCleanup(t.join)
-
- for x in range(20):
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if self.family in (socket.AF_INET, getattr(socket, "AF_INET6", object())):
+ server = BaseServer(self.family, self.addr)
+ t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1,
+ count=500))
+ t.start()
+ self.addCleanup(t.join)
+
+ s = socket.socket(self.family, socket.SOCK_STREAM)
s.settimeout(.2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
@@ -752,19 +802,45 @@ class BaseTestAPI(unittest.TestCase):
finally:
s.close()
-class TestAPI_UseSelect(BaseTestAPI):
+class TestAPI_UseIPv4Sockets(BaseTestAPI):
+ family = socket.AF_INET
+ addr = (HOST, 0)
+
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 support required')
+class TestAPI_UseIPv6Sockets(BaseTestAPI):
+ family = socket.AF_INET6
+ addr = ('::1', 0)
+
+@unittest.skipUnless(HAS_UNIX_SOCKETS, 'Unix sockets required')
+class TestAPI_UseUnixSockets(BaseTestAPI):
+ if HAS_UNIX_SOCKETS:
+ family = socket.AF_UNIX
+ addr = support.TESTFN
+
+ def tearDown(self):
+ unlink(self.addr)
+ BaseTestAPI.tearDown(self)
+
+class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase):
use_poll = False
@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
-class TestAPI_UsePoll(BaseTestAPI):
+class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase):
use_poll = True
+class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase):
+ use_poll = False
-def test_main():
- tests = [HelperFunctionTests, DispatcherTests, DispatcherWithSendTests,
- DispatcherWithSendTests_UsePoll, TestAPI_UseSelect,
- TestAPI_UsePoll, FileWrapperTest]
- run_unittest(*tests)
+@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
+class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase):
+ use_poll = True
+
+class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase):
+ use_poll = False
+
+@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
+class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase):
+ use_poll = True
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py
index ca94504..2569476 100644
--- a/Lib/test/test_base64.py
+++ b/Lib/test/test_base64.py
@@ -103,44 +103,53 @@ class BaseXYTestCase(unittest.TestCase):
def test_b64decode(self):
eq = self.assertEqual
- eq(base64.b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org")
- eq(base64.b64decode(b'AA=='), b'\x00')
- eq(base64.b64decode(b"YQ=="), b"a")
- eq(base64.b64decode(b"YWI="), b"ab")
- eq(base64.b64decode(b"YWJj"), b"abc")
- eq(base64.b64decode(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
- b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
- b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="),
- b"abcdefghijklmnopqrstuvwxyz"
- b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- b"0123456789!@#0^&*();:<>,. []{}")
- eq(base64.b64decode(b''), b'')
+
+ tests = {b"d3d3LnB5dGhvbi5vcmc=": b"www.python.org",
+ b'AA==': b'\x00',
+ b"YQ==": b"a",
+ b"YWI=": b"ab",
+ b"YWJj": b"abc",
+ b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
+ b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
+ b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==":
+
+ b"abcdefghijklmnopqrstuvwxyz"
+ b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ b"0123456789!@#0^&*();:<>,. []{}",
+ b'': b'',
+ }
+ for data, res in tests.items():
+ eq(base64.b64decode(data), res)
+ eq(base64.b64decode(data.decode('ascii')), res)
+
# Test with arbitrary alternative characters
- eq(base64.b64decode(b'01a*b$cd', altchars=b'*$'), b'\xd3V\xbeo\xf7\x1d')
- # Check if passing a str object raises an error
- self.assertRaises(TypeError, base64.b64decode, "")
- self.assertRaises(TypeError, base64.b64decode, b"", altchars="")
+ tests_altchars = {(b'01a*b$cd', b'*$'): b'\xd3V\xbeo\xf7\x1d',
+ }
+ for (data, altchars), res in tests_altchars.items():
+ data_str = data.decode('ascii')
+ altchars_str = altchars.decode('ascii')
+
+ eq(base64.b64decode(data, altchars=altchars), res)
+ eq(base64.b64decode(data_str, altchars=altchars), res)
+ eq(base64.b64decode(data, altchars=altchars_str), res)
+ eq(base64.b64decode(data_str, altchars=altchars_str), res)
+
# Test standard alphabet
- eq(base64.standard_b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org")
- eq(base64.standard_b64decode(b"YQ=="), b"a")
- eq(base64.standard_b64decode(b"YWI="), b"ab")
- eq(base64.standard_b64decode(b"YWJj"), b"abc")
- eq(base64.standard_b64decode(b""), b"")
- eq(base64.standard_b64decode(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
- b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
- b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="),
- b"abcdefghijklmnopqrstuvwxyz"
- b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- b"0123456789!@#0^&*();:<>,. []{}")
- # Check if passing a str object raises an error
- self.assertRaises(TypeError, base64.standard_b64decode, "")
- self.assertRaises(TypeError, base64.standard_b64decode, b"", altchars="")
+ for data, res in tests.items():
+ eq(base64.standard_b64decode(data), res)
+ eq(base64.standard_b64decode(data.decode('ascii')), res)
+
# Test with 'URL safe' alternative characters
- eq(base64.urlsafe_b64decode(b'01a-b_cd'), b'\xd3V\xbeo\xf7\x1d')
- self.assertRaises(TypeError, base64.urlsafe_b64decode, "")
+ tests_urlsafe = {b'01a-b_cd': b'\xd3V\xbeo\xf7\x1d',
+ b'': b'',
+ }
+ for data, res in tests_urlsafe.items():
+ eq(base64.urlsafe_b64decode(data), res)
+ eq(base64.urlsafe_b64decode(data.decode('ascii')), res)
def test_b64decode_padding_error(self):
self.assertRaises(binascii.Error, base64.b64decode, b'abc')
+ self.assertRaises(binascii.Error, base64.b64decode, 'abc')
def test_b64decode_invalid_chars(self):
# issue 1466065: Test some invalid characters.
@@ -155,8 +164,11 @@ class BaseXYTestCase(unittest.TestCase):
(b'YWJj\nYWI=', b'abcab'))
for bstr, res in tests:
self.assertEqual(base64.b64decode(bstr), res)
+ self.assertEqual(base64.b64decode(bstr.decode('ascii')), res)
with self.assertRaises(binascii.Error):
base64.b64decode(bstr, validate=True)
+ with self.assertRaises(binascii.Error):
+ base64.b64decode(bstr.decode('ascii'), validate=True)
def test_b32encode(self):
eq = self.assertEqual
@@ -171,40 +183,63 @@ class BaseXYTestCase(unittest.TestCase):
def test_b32decode(self):
eq = self.assertEqual
- eq(base64.b32decode(b''), b'')
- eq(base64.b32decode(b'AA======'), b'\x00')
- eq(base64.b32decode(b'ME======'), b'a')
- eq(base64.b32decode(b'MFRA===='), b'ab')
- eq(base64.b32decode(b'MFRGG==='), b'abc')
- eq(base64.b32decode(b'MFRGGZA='), b'abcd')
- eq(base64.b32decode(b'MFRGGZDF'), b'abcde')
- self.assertRaises(TypeError, base64.b32decode, "")
+ tests = {b'': b'',
+ b'AA======': b'\x00',
+ b'ME======': b'a',
+ b'MFRA====': b'ab',
+ b'MFRGG===': b'abc',
+ b'MFRGGZA=': b'abcd',
+ b'MFRGGZDF': b'abcde',
+ }
+ for data, res in tests.items():
+ eq(base64.b32decode(data), res)
+ eq(base64.b32decode(data.decode('ascii')), res)
def test_b32decode_casefold(self):
eq = self.assertEqual
- eq(base64.b32decode(b'', True), b'')
- eq(base64.b32decode(b'ME======', True), b'a')
- eq(base64.b32decode(b'MFRA====', True), b'ab')
- eq(base64.b32decode(b'MFRGG===', True), b'abc')
- eq(base64.b32decode(b'MFRGGZA=', True), b'abcd')
- eq(base64.b32decode(b'MFRGGZDF', True), b'abcde')
- # Lower cases
- eq(base64.b32decode(b'me======', True), b'a')
- eq(base64.b32decode(b'mfra====', True), b'ab')
- eq(base64.b32decode(b'mfrgg===', True), b'abc')
- eq(base64.b32decode(b'mfrggza=', True), b'abcd')
- eq(base64.b32decode(b'mfrggzdf', True), b'abcde')
- # Expected exceptions
+ tests = {b'': b'',
+ b'ME======': b'a',
+ b'MFRA====': b'ab',
+ b'MFRGG===': b'abc',
+ b'MFRGGZA=': b'abcd',
+ b'MFRGGZDF': b'abcde',
+ # Lower cases
+ b'me======': b'a',
+ b'mfra====': b'ab',
+ b'mfrgg===': b'abc',
+ b'mfrggza=': b'abcd',
+ b'mfrggzdf': b'abcde',
+ }
+
+ for data, res in tests.items():
+ eq(base64.b32decode(data, True), res)
+ eq(base64.b32decode(data.decode('ascii'), True), res)
+
self.assertRaises(TypeError, base64.b32decode, b'me======')
+ self.assertRaises(TypeError, base64.b32decode, 'me======')
+
# Mapping zero and one
eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe')
- eq(base64.b32decode(b'M1023456', map01=b'L'), b'b\xdd\xad\xf3\xbe')
- eq(base64.b32decode(b'M1023456', map01=b'I'), b'b\x1d\xad\xf3\xbe')
- self.assertRaises(TypeError, base64.b32decode, b"", map01="")
+ eq(base64.b32decode('MLO23456'), b'b\xdd\xad\xf3\xbe')
+
+ map_tests = {(b'M1023456', b'L'): b'b\xdd\xad\xf3\xbe',
+ (b'M1023456', b'I'): b'b\x1d\xad\xf3\xbe',
+ }
+ for (data, map01), res in map_tests.items():
+ data_str = data.decode('ascii')
+ map01_str = map01.decode('ascii')
+
+ eq(base64.b32decode(data, map01=map01), res)
+ eq(base64.b32decode(data_str, map01=map01), res)
+ eq(base64.b32decode(data, map01=map01_str), res)
+ eq(base64.b32decode(data_str, map01=map01_str), res)
def test_b32decode_error(self):
- self.assertRaises(binascii.Error, base64.b32decode, b'abc')
- self.assertRaises(binascii.Error, base64.b32decode, b'ABCDEF==')
+ for data in [b'abc', b'ABCDEF==']:
+ with self.assertRaises(binascii.Error):
+ base64.b32decode(data)
+ with self.assertRaises(binascii.Error):
+ base64.b32decode(data.decode('ascii'))
def test_b16encode(self):
eq = self.assertEqual
@@ -215,12 +250,24 @@ class BaseXYTestCase(unittest.TestCase):
def test_b16decode(self):
eq = self.assertEqual
eq(base64.b16decode(b'0102ABCDEF'), b'\x01\x02\xab\xcd\xef')
+ eq(base64.b16decode('0102ABCDEF'), b'\x01\x02\xab\xcd\xef')
eq(base64.b16decode(b'00'), b'\x00')
+ eq(base64.b16decode('00'), b'\x00')
# Lower case is not allowed without a flag
self.assertRaises(binascii.Error, base64.b16decode, b'0102abcdef')
+ self.assertRaises(binascii.Error, base64.b16decode, '0102abcdef')
# Case fold
eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef')
- self.assertRaises(TypeError, base64.b16decode, "")
+ eq(base64.b16decode('0102abcdef', True), b'\x01\x02\xab\xcd\xef')
+
+ def test_decode_nonascii_str(self):
+ decode_funcs = (base64.b64decode,
+ base64.standard_b64decode,
+ base64.urlsafe_b64decode,
+ base64.b32decode,
+ base64.b16decode)
+ for f in decode_funcs:
+ self.assertRaises(ValueError, f, 'with non-ascii \xcb')
def test_ErrorHeritage(self):
self.assertTrue(issubclass(binascii.Error, ValueError))
diff --git a/Lib/test/test_bigmem.py b/Lib/test/test_bigmem.py
index f3c6ebb..0e54595 100644
--- a/Lib/test/test_bigmem.py
+++ b/Lib/test/test_bigmem.py
@@ -1,3 +1,13 @@
+"""Bigmem tests - tests for the 32-bit boundary in containers.
+
+These tests try to exercise the 32-bit boundary that is sometimes, if
+rarely, exceeded in practice, but almost never tested. They are really only
+meaningful on 64-bit builds on machines with a *lot* of memory, but the
+tests are always run, usually with very low memory limits to make sure the
+tests themselves don't suffer from bitrot. To run them for real, pass a
+high memory limit to regrtest, with the -M option.
+"""
+
from test import support
from test.support import bigmemtest, _1G, _2G, _4G
@@ -6,20 +16,35 @@ import operator
import sys
import functools
+# These tests all use one of the bigmemtest decorators to indicate how much
+# memory they use and how much memory they need to be even meaningful. The
+# decorators take two arguments: a 'memuse' indicator declaring
+# (approximate) bytes per size-unit the test will use (at peak usage), and a
+# 'minsize' indicator declaring a minimum *useful* size. A test that
+# allocates a bytestring to test various operations near the end will have a
+# minsize of at least 2Gb (or it wouldn't reach the 32-bit limit, so the
+# test wouldn't be very useful) and a memuse of 1 (one byte per size-unit,
+# if it allocates only one big string at a time.)
+#
+# When run with a memory limit set, both decorators skip tests that need
+# more memory than available to be meaningful. The precisionbigmemtest will
+# always pass minsize as size, even if there is much more memory available.
+# The bigmemtest decorator will scale size upward to fill available memory.
+#
# Bigmem testing houserules:
#
# - Try not to allocate too many large objects. It's okay to rely on
-# refcounting semantics, but don't forget that 's = create_largestring()'
+# refcounting semantics, and don't forget that 's = create_largestring()'
# doesn't release the old 's' (if it exists) until well after its new
# value has been created. Use 'del s' before the create_largestring call.
#
-# - Do *not* compare large objects using assertEqual or similar. It's a
-# lengthy operation and the errormessage will be utterly useless due to
-# its size. To make sure whether a result has the right contents, better
-# to use the strip or count methods, or compare meaningful slices.
+# - Do *not* compare large objects using assertEqual, assertIn or similar.
+# It's a lengthy operation and the errormessage will be utterly useless
+# due to its size. To make sure whether a result has the right contents,
+# better to use the strip or count methods, or compare meaningful slices.
#
# - Don't forget to test for large indices, offsets and results and such,
-# in addition to large sizes.
+# in addition to large sizes. Anything that probes the 32-bit boundary.
#
# - When repeating an object (say, a substring, or a small list) to create
# a large object, make the subobject of a length that is not a power of
@@ -37,13 +62,14 @@ import functools
# fail as well. I do not know whether it is due to memory fragmentation
# issues, or other specifics of the platform malloc() routine.
-character_size = 4 if sys.maxunicode > 0xFFFF else 2
+ascii_char_size = 1
+ucs2_char_size = 2
+ucs4_char_size = 4
class BaseStrTest:
- @bigmemtest(size=_2G, memuse=2)
- def test_capitalize(self, size):
+ def _test_capitalize(self, size):
_ = self.from_latin1
SUBSTR = self.from_latin1(' abc def ghi')
s = _('-') * size + SUBSTR
@@ -92,7 +118,7 @@ class BaseStrTest:
_ = self.from_latin1
s = _('-') * size
tabsize = 8
- self.assertEqual(s.expandtabs(), s)
+ self.assertTrue(s.expandtabs() == s)
del s
slen, remainder = divmod(size, tabsize)
s = _(' \t') * slen
@@ -347,7 +373,7 @@ class BaseStrTest:
# suffer for the list size. (Otherwise, it'd cost another 48 times
# size in bytes!) Nevertheless, a list of size takes
# 8*size bytes.
- @bigmemtest(size=_2G + 5, memuse=10)
+ @bigmemtest(size=_2G + 5, memuse=2 * ascii_char_size + 8)
def test_split_large(self, size):
_ = self.from_latin1
s = _(' a') * size + _(' ')
@@ -366,9 +392,9 @@ class BaseStrTest:
# take up an inordinate amount of memory
chunksize = int(size ** 0.5 + 2) // 2
SUBSTR = _(' ') * chunksize + _('\n') + _(' ') * chunksize + _('\r\n')
- s = SUBSTR * chunksize
+ s = SUBSTR * (chunksize * 2)
l = s.splitlines()
- self.assertEqual(len(l), chunksize * 2)
+ self.assertEqual(len(l), chunksize * 4)
expected = _(' ') * chunksize
for item in l:
self.assertEqual(item, expected)
@@ -394,8 +420,7 @@ class BaseStrTest:
self.assertEqual(len(s), size)
self.assertEqual(s.strip(), SUBSTR.strip())
- @bigmemtest(size=_2G, memuse=2)
- def test_swapcase(self, size):
+ def _test_swapcase(self, size):
_ = self.from_latin1
SUBSTR = _("aBcDeFG12.'\xa9\x00")
sublen = len(SUBSTR)
@@ -406,8 +431,7 @@ class BaseStrTest:
self.assertEqual(s[:sublen * 3], SUBSTR.swapcase() * 3)
self.assertEqual(s[-sublen * 3:], SUBSTR.swapcase() * 3)
- @bigmemtest(size=_2G, memuse=2)
- def test_title(self, size):
+ def _test_title(self, size):
_ = self.from_latin1
SUBSTR = _('SpaaHAaaAaham')
s = SUBSTR * (size // len(SUBSTR) + 2)
@@ -419,14 +443,7 @@ class BaseStrTest:
def test_translate(self, size):
_ = self.from_latin1
SUBSTR = _('aZz.z.Aaz.')
- if isinstance(SUBSTR, str):
- trans = {
- ord(_('.')): _('-'),
- ord(_('a')): _('!'),
- ord(_('Z')): _('$'),
- }
- else:
- trans = bytes.maketrans(b'.aZ', b'-!$')
+ trans = bytes.maketrans(b'.aZ', b'-!$')
sublen = len(SUBSTR)
repeats = size // sublen + 2
s = SUBSTR * repeats
@@ -519,19 +536,19 @@ class BaseStrTest:
edge = _('-') * (size // 2)
s = _('').join([edge, SUBSTR, edge])
del edge
- self.assertIn(SUBSTR, s)
- self.assertNotIn(SUBSTR * 2, s)
- self.assertIn(_('-'), s)
- self.assertNotIn(_('a'), s)
+ self.assertTrue(SUBSTR in s)
+ self.assertFalse(SUBSTR * 2 in s)
+ self.assertTrue(_('-') in s)
+ self.assertFalse(_('a') in s)
s += _('a')
- self.assertIn(_('a'), s)
+ self.assertTrue(_('a') in s)
@bigmemtest(size=_2G + 10, memuse=2)
def test_compare(self, size):
_ = self.from_latin1
s1 = _('-') * size
s2 = _('-') * size
- self.assertEqual(s1, s2)
+ self.assertTrue(s1 == s2)
del s2
s2 = s1 + _('a')
self.assertFalse(s1 == s2)
@@ -552,7 +569,7 @@ class BaseStrTest:
h1 = hash(s)
del s
s = _('\x00') * (size + 1)
- self.assertFalse(h1 == hash(s))
+ self.assertNotEqual(h1, hash(s))
class StrTest(unittest.TestCase, BaseStrTest):
@@ -563,7 +580,6 @@ class StrTest(unittest.TestCase, BaseStrTest):
def basic_encode_test(self, size, enc, c='.', expectedsize=None):
if expectedsize is None:
expectedsize = size
-
try:
s = c * size
self.assertEqual(len(s.encode(enc)), expectedsize)
@@ -582,48 +598,64 @@ class StrTest(unittest.TestCase, BaseStrTest):
memuse = meth.memuse
except AttributeError:
continue
- meth.memuse = character_size * memuse
+ meth.memuse = ascii_char_size * memuse
self._adjusted[name] = memuse
def tearDown(self):
for name, memuse in self._adjusted.items():
getattr(type(self), name).memuse = memuse
- # the utf8 encoder preallocates big time (4x the number of characters)
- @bigmemtest(size=_2G + 2, memuse=character_size + 4)
+ @bigmemtest(size=_2G, memuse=ucs4_char_size * 3)
+ def test_capitalize(self, size):
+ self._test_capitalize(size)
+
+ @bigmemtest(size=_2G, memuse=ucs4_char_size * 3)
+ def test_title(self, size):
+ self._test_title(size)
+
+ @bigmemtest(size=_2G, memuse=ucs4_char_size * 3)
+ def test_swapcase(self, size):
+ self._test_swapcase(size)
+
+ # Many codecs convert to the legacy representation first, explaining
+ # why we add 'ucs4_char_size' to the 'memuse' below.
+
+ @bigmemtest(size=_2G + 2, memuse=ascii_char_size + 1)
def test_encode(self, size):
return self.basic_encode_test(size, 'utf-8')
- @bigmemtest(size=_4G // 6 + 2, memuse=character_size + 1)
+ @bigmemtest(size=_4G // 6 + 2, memuse=ascii_char_size + ucs4_char_size + 1)
def test_encode_raw_unicode_escape(self, size):
try:
return self.basic_encode_test(size, 'raw_unicode_escape')
except MemoryError:
pass # acceptable on 32-bit
- @bigmemtest(size=_4G // 5 + 70, memuse=character_size + 1)
+ @bigmemtest(size=_4G // 5 + 70, memuse=ascii_char_size + ucs4_char_size + 1)
def test_encode_utf7(self, size):
try:
return self.basic_encode_test(size, 'utf7')
except MemoryError:
pass # acceptable on 32-bit
- @bigmemtest(size=_4G // 4 + 5, memuse=character_size + 4)
+ @bigmemtest(size=_4G // 4 + 5, memuse=ascii_char_size + ucs4_char_size + 4)
def test_encode_utf32(self, size):
try:
- return self.basic_encode_test(size, 'utf32', expectedsize=4*size+4)
+ return self.basic_encode_test(size, 'utf32', expectedsize=4 * size + 4)
except MemoryError:
pass # acceptable on 32-bit
- @bigmemtest(size=_2G - 1, memuse=character_size + 1)
+ @bigmemtest(size=_2G - 1, memuse=ascii_char_size + 1)
def test_encode_ascii(self, size):
return self.basic_encode_test(size, 'ascii', c='A')
- @bigmemtest(size=_2G + 10, memuse=character_size * 2)
+ # str % (...) uses a Py_UCS4 intermediate representation
+
+ @bigmemtest(size=_2G + 10, memuse=ascii_char_size * 2 + ucs4_char_size)
def test_format(self, size):
s = '-' * size
sf = '%s' % (s,)
- self.assertEqual(s, sf)
+ self.assertTrue(s == sf)
del sf
sf = '..%s..' % (s,)
self.assertEqual(len(sf), len(s) + 4)
@@ -640,7 +672,7 @@ class StrTest(unittest.TestCase, BaseStrTest):
self.assertEqual(s.count('.'), 3)
self.assertEqual(s.count('-'), size * 2)
- @bigmemtest(size=_2G + 10, memuse=character_size * 2)
+ @bigmemtest(size=_2G + 10, memuse=ascii_char_size * 2)
def test_repr_small(self, size):
s = '-' * size
s = repr(s)
@@ -661,7 +693,7 @@ class StrTest(unittest.TestCase, BaseStrTest):
self.assertEqual(s.count('\\'), size)
self.assertEqual(s.count('0'), size * 2)
- @bigmemtest(size=_2G + 10, memuse=character_size * 5)
+ @bigmemtest(size=_2G + 10, memuse=ascii_char_size * 5)
def test_repr_large(self, size):
s = '\x00' * size
s = repr(s)
@@ -671,7 +703,13 @@ class StrTest(unittest.TestCase, BaseStrTest):
self.assertEqual(s.count('\\'), size)
self.assertEqual(s.count('0'), size * 2)
- @bigmemtest(size=_2G // 5 + 1, memuse=character_size * 7)
+ # ascii() calls encode('ascii', 'backslashreplace'), which itself
+ # creates a temporary Py_UNICODE representation in addition to the
+ # original (Py_UCS2) one
+ # There's also some overallocation when resizing the ascii() result
+ # that isn't taken into account here.
+ @bigmemtest(size=_2G // 5 + 1, memuse=ucs2_char_size +
+ ucs4_char_size + ascii_char_size * 6)
def test_unicode_repr(self, size):
# Use an assigned, but not printable code point.
# It is in the range of the low surrogates \uDC00-\uDFFF.
@@ -686,9 +724,7 @@ class StrTest(unittest.TestCase, BaseStrTest):
finally:
r = s = None
- # The character takes 4 bytes even in UCS-2 builds because it will
- # be decomposed into surrogates.
- @bigmemtest(size=_2G // 5 + 1, memuse=4 + character_size * 9)
+ @bigmemtest(size=_2G // 5 + 1, memuse=ucs4_char_size * 2 + ascii_char_size * 10)
def test_unicode_repr_wide(self, size):
char = "\U0001DCBA"
s = char * size
@@ -701,39 +737,76 @@ class StrTest(unittest.TestCase, BaseStrTest):
finally:
r = s = None
- @bigmemtest(size=_4G // 5, memuse=character_size * (6 + 1))
- def _test_unicode_repr_overflow(self, size):
- # XXX not sure what this test is about
- char = "\uDCBA"
- s = char * size
- try:
- r = repr(s)
- self.assertTrue(s == eval(r))
- finally:
- r = s = None
+ # The original test_translate is overriden here, so as to get the
+ # correct size estimate: str.translate() uses an intermediate Py_UCS4
+ # representation.
+
+ @bigmemtest(size=_2G, memuse=ascii_char_size * 2 + ucs4_char_size)
+ def test_translate(self, size):
+ _ = self.from_latin1
+ SUBSTR = _('aZz.z.Aaz.')
+ trans = {
+ ord(_('.')): _('-'),
+ ord(_('a')): _('!'),
+ ord(_('Z')): _('$'),
+ }
+ sublen = len(SUBSTR)
+ repeats = size // sublen + 2
+ s = SUBSTR * repeats
+ s = s.translate(trans)
+ self.assertEqual(len(s), repeats * sublen)
+ self.assertEqual(s[:sublen], SUBSTR.translate(trans))
+ self.assertEqual(s[-sublen:], SUBSTR.translate(trans))
+ self.assertEqual(s.count(_('.')), 0)
+ self.assertEqual(s.count(_('!')), repeats * 2)
+ self.assertEqual(s.count(_('z')), repeats * 3)
class BytesTest(unittest.TestCase, BaseStrTest):
def from_latin1(self, s):
- return s.encode("latin1")
+ return s.encode("latin-1")
- @bigmemtest(size=_2G + 2, memuse=1 + character_size)
+ @bigmemtest(size=_2G + 2, memuse=1 + ascii_char_size)
def test_decode(self, size):
s = self.from_latin1('.') * size
self.assertEqual(len(s.decode('utf-8')), size)
+ @bigmemtest(size=_2G, memuse=2)
+ def test_capitalize(self, size):
+ self._test_capitalize(size)
+
+ @bigmemtest(size=_2G, memuse=2)
+ def test_title(self, size):
+ self._test_title(size)
+
+ @bigmemtest(size=_2G, memuse=2)
+ def test_swapcase(self, size):
+ self._test_swapcase(size)
+
class BytearrayTest(unittest.TestCase, BaseStrTest):
def from_latin1(self, s):
- return bytearray(s.encode("latin1"))
+ return bytearray(s.encode("latin-1"))
- @bigmemtest(size=_2G + 2, memuse=1 + character_size)
+ @bigmemtest(size=_2G + 2, memuse=1 + ascii_char_size)
def test_decode(self, size):
s = self.from_latin1('.') * size
self.assertEqual(len(s.decode('utf-8')), size)
+ @bigmemtest(size=_2G, memuse=2)
+ def test_capitalize(self, size):
+ self._test_capitalize(size)
+
+ @bigmemtest(size=_2G, memuse=2)
+ def test_title(self, size):
+ self._test_title(size)
+
+ @bigmemtest(size=_2G, memuse=2)
+ def test_swapcase(self, size):
+ self._test_swapcase(size)
+
test_hash = None
test_split_large = None
@@ -752,7 +825,7 @@ class TupleTest(unittest.TestCase):
def test_compare(self, size):
t1 = ('',) * size
t2 = ('',) * size
- self.assertEqual(t1, t2)
+ self.assertTrue(t1 == t2)
del t2
t2 = ('',) * (size + 1)
self.assertFalse(t1 == t2)
@@ -783,9 +856,9 @@ class TupleTest(unittest.TestCase):
def test_contains(self, size):
t = (1, 2, 3, 4, 5) * size
self.assertEqual(len(t), size * 5)
- self.assertIn(5, t)
- self.assertNotIn((1, 2, 3, 4, 5), t)
- self.assertNotIn(0, t)
+ self.assertTrue(5 in t)
+ self.assertFalse((1, 2, 3, 4, 5) in t)
+ self.assertFalse(0 in t)
@bigmemtest(size=_2G + 10, memuse=8)
def test_hash(self, size):
@@ -869,11 +942,11 @@ class TupleTest(unittest.TestCase):
self.assertEqual(s[-5:], '0, 0)')
self.assertEqual(s.count('0'), size)
- @bigmemtest(size=_2G // 3 + 2, memuse=8 + 3 * character_size)
+ @bigmemtest(size=_2G // 3 + 2, memuse=8 + 3 * ascii_char_size)
def test_repr_small(self, size):
return self.basic_test_repr(size)
- @bigmemtest(size=_2G + 2, memuse=8 + 3 * character_size)
+ @bigmemtest(size=_2G + 2, memuse=8 + 3 * ascii_char_size)
def test_repr_large(self, size):
return self.basic_test_repr(size)
@@ -888,7 +961,7 @@ class ListTest(unittest.TestCase):
def test_compare(self, size):
l1 = [''] * size
l2 = [''] * size
- self.assertEqual(l1, l2)
+ self.assertTrue(l1 == l2)
del l2
l2 = [''] * (size + 1)
self.assertFalse(l1 == l2)
@@ -934,9 +1007,9 @@ class ListTest(unittest.TestCase):
def test_contains(self, size):
l = [1, 2, 3, 4, 5] * size
self.assertEqual(len(l), size * 5)
- self.assertIn(5, l)
- self.assertNotIn([1, 2, 3, 4, 5], l)
- self.assertNotIn(0, l)
+ self.assertTrue(5 in l)
+ self.assertFalse([1, 2, 3, 4, 5] in l)
+ self.assertFalse(0 in l)
@bigmemtest(size=_2G + 10, memuse=8)
def test_hash(self, size):
@@ -1044,11 +1117,11 @@ class ListTest(unittest.TestCase):
self.assertEqual(s[-5:], '0, 0]')
self.assertEqual(s.count('0'), size)
- @bigmemtest(size=_2G // 3 + 2, memuse=8 + 3 * character_size)
+ @bigmemtest(size=_2G // 3 + 2, memuse=8 + 3 * ascii_char_size)
def test_repr_small(self, size):
return self.basic_test_repr(size)
- @bigmemtest(size=_2G + 2, memuse=8 + 3 * character_size)
+ @bigmemtest(size=_2G + 2, memuse=8 + 3 * ascii_char_size)
def test_repr_large(self, size):
return self.basic_test_repr(size)
diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py
index 1e9e888..04d8f9d 100644
--- a/Lib/test/test_binascii.py
+++ b/Lib/test/test_binascii.py
@@ -208,9 +208,9 @@ class BinASCIITest(unittest.TestCase):
except Exception as err:
self.fail("{}({!r}) raises {!r}".format(func, empty, err))
- def test_unicode_strings(self):
- # Unicode strings are not accepted.
- for func in all_functions:
+ def test_unicode_b2a(self):
+ # Unicode strings are not accepted by b2a_* functions.
+ for func in set(all_functions) - set(a2b_functions) | {'rledecode_hqx'}:
try:
self.assertRaises(TypeError, getattr(binascii, func), "test")
except Exception as err:
@@ -218,6 +218,34 @@ class BinASCIITest(unittest.TestCase):
# crc_hqx needs 2 arguments
self.assertRaises(TypeError, binascii.crc_hqx, "test", 0)
+ def test_unicode_a2b(self):
+ # Unicode strings are accepted by a2b_* functions.
+ MAX_ALL = 45
+ raw = self.rawdata[:MAX_ALL]
+ for fa, fb in zip(a2b_functions, b2a_functions):
+ if fa == 'rledecode_hqx':
+ # Takes non-ASCII data
+ continue
+ a2b = getattr(binascii, fa)
+ b2a = getattr(binascii, fb)
+ try:
+ a = b2a(self.type2test(raw))
+ binary_res = a2b(a)
+ a = a.decode('ascii')
+ res = a2b(a)
+ except Exception as err:
+ self.fail("{}/{} conversion raises {!r}".format(fb, fa, err))
+ if fb == 'b2a_hqx':
+ # b2a_hqx returns a tuple
+ res, _ = res
+ binary_res, _ = binary_res
+ self.assertEqual(res, raw, "{}/{} conversion: "
+ "{!r} != {!r}".format(fb, fa, res, raw))
+ self.assertEqual(res, binary_res)
+ self.assertIsInstance(res, bytes)
+ # non-ASCII string
+ self.assertRaises(ValueError, a2b, "\x80")
+
class ArrayBinASCIITest(BinASCIITest):
def type2test(self, s):
diff --git a/Lib/test/test_bisect.py b/Lib/test/test_bisect.py
index f95ed63..7b9bd19 100644
--- a/Lib/test/test_bisect.py
+++ b/Lib/test/test_bisect.py
@@ -3,25 +3,8 @@ import unittest
from test import support
from collections import UserList
-# We do a bit of trickery here to be able to test both the C implementation
-# and the Python implementation of the module.
-
-# Make it impossible to import the C implementation anymore.
-sys.modules['_bisect'] = 0
-# We must also handle the case that bisect was imported before.
-if 'bisect' in sys.modules:
- del sys.modules['bisect']
-
-# Now we can import the module and get the pure Python implementation.
-import bisect as py_bisect
-
-# Restore everything to normal.
-del sys.modules['_bisect']
-del sys.modules['bisect']
-
-# This is now the module with the C implementation.
-import bisect as c_bisect
-
+py_bisect = support.import_fresh_module('bisect', blocked=['_bisect'])
+c_bisect = support.import_fresh_module('bisect', fresh=['_bisect'])
class Range(object):
"""A trivial range()-like object without any integer width limitations."""
@@ -45,9 +28,7 @@ class Range(object):
self.last_insert = idx, item
-class TestBisect(unittest.TestCase):
- module = None
-
+class TestBisect:
def setUp(self):
self.precomputedCases = [
(self.module.bisect_right, [], 1, 0),
@@ -218,17 +199,15 @@ class TestBisect(unittest.TestCase):
self.module.insort(a=data, x=25, lo=1, hi=3)
self.assertEqual(data, [10, 20, 25, 25, 25, 30, 40, 50])
-class TestBisectPython(TestBisect):
+class TestBisectPython(TestBisect, unittest.TestCase):
module = py_bisect
-class TestBisectC(TestBisect):
+class TestBisectC(TestBisect, unittest.TestCase):
module = c_bisect
#==============================================================================
-class TestInsort(unittest.TestCase):
- module = None
-
+class TestInsort:
def test_vsBuiltinSort(self, n=500):
from random import choice
for insorted in (list(), UserList()):
@@ -255,15 +234,14 @@ class TestInsort(unittest.TestCase):
self.module.insort_right(lst, 5)
self.assertEqual([5, 10], lst.data)
-class TestInsortPython(TestInsort):
+class TestInsortPython(TestInsort, unittest.TestCase):
module = py_bisect
-class TestInsortC(TestInsort):
+class TestInsortC(TestInsort, unittest.TestCase):
module = c_bisect
#==============================================================================
-
class LenOnly:
"Dummy sequence class defining __len__ but not __getitem__."
def __len__(self):
@@ -284,9 +262,7 @@ class CmpErr:
__eq__ = __lt__
__ne__ = __lt__
-class TestErrorHandling(unittest.TestCase):
- module = None
-
+class TestErrorHandling:
def test_non_sequence(self):
for f in (self.module.bisect_left, self.module.bisect_right,
self.module.insort_left, self.module.insort_right):
@@ -313,58 +289,40 @@ class TestErrorHandling(unittest.TestCase):
self.module.insort_left, self.module.insort_right):
self.assertRaises(TypeError, f, 10)
-class TestErrorHandlingPython(TestErrorHandling):
+class TestErrorHandlingPython(TestErrorHandling, unittest.TestCase):
module = py_bisect
-class TestErrorHandlingC(TestErrorHandling):
+class TestErrorHandlingC(TestErrorHandling, unittest.TestCase):
module = c_bisect
#==============================================================================
-libreftest = """
-Example from the Library Reference: Doc/library/bisect.rst
-
-The bisect() function is generally useful for categorizing numeric data.
-This example uses bisect() to look up a letter grade for an exam total
-(say) based on a set of ordered numeric breakpoints: 85 and up is an `A',
-75..84 is a `B', etc.
-
- >>> grades = "FEDCBA"
- >>> breakpoints = [30, 44, 66, 75, 85]
- >>> from bisect import bisect
- >>> def grade(total):
- ... return grades[bisect(breakpoints, total)]
- ...
- >>> grade(66)
- 'C'
- >>> list(map(grade, [33, 99, 77, 44, 12, 88]))
- ['E', 'A', 'B', 'D', 'F', 'A']
+class TestDocExample:
+ def test_grades(self):
+ def grade(score, breakpoints=[60, 70, 80, 90], grades='FDCBA'):
+ i = self.module.bisect(breakpoints, score)
+ return grades[i]
+
+ result = [grade(score) for score in [33, 99, 77, 70, 89, 90, 100]]
+ self.assertEqual(result, ['F', 'A', 'C', 'C', 'B', 'A', 'A'])
+
+ def test_colors(self):
+ data = [('red', 5), ('blue', 1), ('yellow', 8), ('black', 0)]
+ data.sort(key=lambda r: r[1])
+ keys = [r[1] for r in data]
+ bisect_left = self.module.bisect_left
+ self.assertEqual(data[bisect_left(keys, 0)], ('black', 0))
+ self.assertEqual(data[bisect_left(keys, 1)], ('blue', 1))
+ self.assertEqual(data[bisect_left(keys, 5)], ('red', 5))
+ self.assertEqual(data[bisect_left(keys, 8)], ('yellow', 8))
+
+class TestDocExamplePython(TestDocExample, unittest.TestCase):
+ module = py_bisect
-"""
+class TestDocExampleC(TestDocExample, unittest.TestCase):
+ module = c_bisect
#------------------------------------------------------------------------------
-__test__ = {'libreftest' : libreftest}
-
-def test_main(verbose=None):
- from test import test_bisect
-
- test_classes = [TestBisectPython, TestBisectC,
- TestInsortPython, TestInsortC,
- TestErrorHandlingPython, TestErrorHandlingC]
-
- support.run_unittest(*test_classes)
- support.run_doctest(test_bisect, verbose)
-
- # verify reference counting
- if verbose and hasattr(sys, "gettotalrefcount"):
- import gc
- counts = [None] * 5
- for i in range(len(counts)):
- support.run_unittest(*test_classes)
- gc.collect()
- counts[i] = sys.gettotalrefcount()
- print(counts)
-
if __name__ == "__main__":
- test_main(verbose=True)
+ unittest.main()
diff --git a/Lib/test/test_bool.py b/Lib/test/test_bool.py
index b296870..4bab28b 100644
--- a/Lib/test/test_bool.py
+++ b/Lib/test/test_bool.py
@@ -330,6 +330,16 @@ class BoolTest(unittest.TestCase):
except (Exception) as e_len:
self.assertEqual(str(e_bool), str(e_len))
+ def test_real_and_imag(self):
+ self.assertEqual(True.real, 1)
+ self.assertEqual(True.imag, 0)
+ self.assertIs(type(True.real), int)
+ self.assertIs(type(True.imag), int)
+ self.assertEqual(False.real, 0)
+ self.assertEqual(False.imag, 0)
+ self.assertIs(type(False.real), int)
+ self.assertIs(type(False.imag), int)
+
def test_main():
support.run_unittest(BoolTest)
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py
new file mode 100644
index 0000000..747e2a2
--- /dev/null
+++ b/Lib/test/test_buffer.py
@@ -0,0 +1,4291 @@
+#
+# The ndarray object from _testbuffer.c is a complete implementation of
+# a PEP-3118 buffer provider. It is independent from NumPy's ndarray
+# and the tests don't require NumPy.
+#
+# If NumPy is present, some tests check both ndarray implementations
+# against each other.
+#
+# Most ndarray tests also check that memoryview(ndarray) behaves in
+# the same way as the original. Thus, a substantial part of the
+# memoryview tests is now in this module.
+#
+
+import unittest
+from test import support
+from itertools import permutations, product
+from random import randrange, sample, choice
+from sysconfig import get_config_var
+import warnings
+import sys, array, io
+from decimal import Decimal
+from fractions import Fraction
+
+try:
+ from _testbuffer import *
+except ImportError:
+ ndarray = None
+
+try:
+ import struct
+except ImportError:
+ struct = None
+
+try:
+ import ctypes
+except ImportError:
+ ctypes = None
+
+try:
+ with warnings.catch_warnings():
+ from numpy import ndarray as numpy_array
+except ImportError:
+ numpy_array = None
+
+
+SHORT_TEST = True
+
+
+# ======================================================================
+# Random lists by format specifier
+# ======================================================================
+
+# Native format chars and their ranges.
+NATIVE = {
+ '?':0, 'c':0, 'b':0, 'B':0,
+ 'h':0, 'H':0, 'i':0, 'I':0,
+ 'l':0, 'L':0, 'n':0, 'N':0,
+ 'f':0, 'd':0, 'P':0
+}
+
+# NumPy does not have 'n' or 'N':
+if numpy_array:
+ del NATIVE['n']
+ del NATIVE['N']
+
+if struct:
+ try:
+ # Add "qQ" if present in native mode.
+ struct.pack('Q', 2**64-1)
+ NATIVE['q'] = 0
+ NATIVE['Q'] = 0
+ except struct.error:
+ pass
+
+# Standard format chars and their ranges.
+STANDARD = {
+ '?':(0, 2), 'c':(0, 1<<8),
+ 'b':(-(1<<7), 1<<7), 'B':(0, 1<<8),
+ 'h':(-(1<<15), 1<<15), 'H':(0, 1<<16),
+ 'i':(-(1<<31), 1<<31), 'I':(0, 1<<32),
+ 'l':(-(1<<31), 1<<31), 'L':(0, 1<<32),
+ 'q':(-(1<<63), 1<<63), 'Q':(0, 1<<64),
+ 'f':(-(1<<63), 1<<63), 'd':(-(1<<1023), 1<<1023)
+}
+
+def native_type_range(fmt):
+ """Return range of a native type."""
+ if fmt == 'c':
+ lh = (0, 256)
+ elif fmt == '?':
+ lh = (0, 2)
+ elif fmt == 'f':
+ lh = (-(1<<63), 1<<63)
+ elif fmt == 'd':
+ lh = (-(1<<1023), 1<<1023)
+ else:
+ for exp in (128, 127, 64, 63, 32, 31, 16, 15, 8, 7):
+ try:
+ struct.pack(fmt, (1<<exp)-1)
+ break
+ except struct.error:
+ pass
+ lh = (-(1<<exp), 1<<exp) if exp & 1 else (0, 1<<exp)
+ return lh
+
+fmtdict = {
+ '':NATIVE,
+ '@':NATIVE,
+ '<':STANDARD,
+ '>':STANDARD,
+ '=':STANDARD,
+ '!':STANDARD
+}
+
+if struct:
+ for fmt in fmtdict['@']:
+ fmtdict['@'][fmt] = native_type_range(fmt)
+
+MEMORYVIEW = NATIVE.copy()
+ARRAY = NATIVE.copy()
+for k in NATIVE:
+ if not k in "bBhHiIlLfd":
+ del ARRAY[k]
+
+BYTEFMT = NATIVE.copy()
+for k in NATIVE:
+ if not k in "Bbc":
+ del BYTEFMT[k]
+
+fmtdict['m'] = MEMORYVIEW
+fmtdict['@m'] = MEMORYVIEW
+fmtdict['a'] = ARRAY
+fmtdict['b'] = BYTEFMT
+fmtdict['@b'] = BYTEFMT
+
+# Capabilities of the test objects:
+MODE = 0
+MULT = 1
+cap = { # format chars # multiplier
+ 'ndarray': (['', '@', '<', '>', '=', '!'], ['', '1', '2', '3']),
+ 'array': (['a'], ['']),
+ 'numpy': ([''], ['']),
+ 'memoryview': (['@m', 'm'], ['']),
+ 'bytefmt': (['@b', 'b'], ['']),
+}
+
+def randrange_fmt(mode, char, obj):
+ """Return random item for a type specified by a mode and a single
+ format character."""
+ x = randrange(*fmtdict[mode][char])
+ if char == 'c':
+ x = bytes(chr(x), 'latin1')
+ if char == '?':
+ x = bool(x)
+ if char == 'f' or char == 'd':
+ x = struct.pack(char, x)
+ x = struct.unpack(char, x)[0]
+ if obj == 'numpy' and x == b'\x00':
+ # http://projects.scipy.org/numpy/ticket/1925
+ x = b'\x01'
+ return x
+
+def gen_item(fmt, obj):
+ """Return single random item."""
+ mode, chars = fmt.split('#')
+ x = []
+ for c in chars:
+ x.append(randrange_fmt(mode, c, obj))
+ return x[0] if len(x) == 1 else tuple(x)
+
+def gen_items(n, fmt, obj):
+ """Return a list of random items (or a scalar)."""
+ if n == 0:
+ return gen_item(fmt, obj)
+ lst = [0] * n
+ for i in range(n):
+ lst[i] = gen_item(fmt, obj)
+ return lst
+
+def struct_items(n, obj):
+ mode = choice(cap[obj][MODE])
+ xfmt = mode + '#'
+ fmt = mode.strip('amb')
+ nmemb = randrange(2, 10) # number of struct members
+ for _ in range(nmemb):
+ char = choice(tuple(fmtdict[mode]))
+ multiplier = choice(cap[obj][MULT])
+ xfmt += (char * int(multiplier if multiplier else 1))
+ fmt += (multiplier + char)
+ items = gen_items(n, xfmt, obj)
+ item = gen_item(xfmt, obj)
+ return fmt, items, item
+
+def randitems(n, obj='ndarray', mode=None, char=None):
+ """Return random format, items, item."""
+ if mode is None:
+ mode = choice(cap[obj][MODE])
+ if char is None:
+ char = choice(tuple(fmtdict[mode]))
+ multiplier = choice(cap[obj][MULT])
+ fmt = mode + '#' + char * int(multiplier if multiplier else 1)
+ items = gen_items(n, fmt, obj)
+ item = gen_item(fmt, obj)
+ fmt = mode.strip('amb') + multiplier + char
+ return fmt, items, item
+
+def iter_mode(n, obj='ndarray'):
+ """Iterate through supported mode/char combinations."""
+ for mode in cap[obj][MODE]:
+ for char in fmtdict[mode]:
+ yield randitems(n, obj, mode, char)
+
+def iter_format(nitems, testobj='ndarray'):
+ """Yield (format, items, item) for all possible modes and format
+ characters plus one random compound format string."""
+ for t in iter_mode(nitems, testobj):
+ yield t
+ if testobj != 'ndarray':
+ raise StopIteration
+ yield struct_items(nitems, testobj)
+
+
+def is_byte_format(fmt):
+ return 'c' in fmt or 'b' in fmt or 'B' in fmt
+
+def is_memoryview_format(fmt):
+ """format suitable for memoryview"""
+ x = len(fmt)
+ return ((x == 1 or (x == 2 and fmt[0] == '@')) and
+ fmt[x-1] in MEMORYVIEW)
+
+NON_BYTE_FORMAT = [c for c in fmtdict['@'] if not is_byte_format(c)]
+
+
+# ======================================================================
+# Multi-dimensional tolist(), slicing and slice assignments
+# ======================================================================
+
+def atomp(lst):
+ """Tuple items (representing structs) are regarded as atoms."""
+ return not isinstance(lst, list)
+
+def listp(lst):
+ return isinstance(lst, list)
+
+def prod(lst):
+ """Product of list elements."""
+ if len(lst) == 0:
+ return 0
+ x = lst[0]
+ for v in lst[1:]:
+ x *= v
+ return x
+
+def strides_from_shape(ndim, shape, itemsize, layout):
+ """Calculate strides of a contiguous array. Layout is 'C' or
+ 'F' (Fortran)."""
+ if ndim == 0:
+ return ()
+ if layout == 'C':
+ strides = list(shape[1:]) + [itemsize]
+ for i in range(ndim-2, -1, -1):
+ strides[i] *= strides[i+1]
+ else:
+ strides = [itemsize] + list(shape[:-1])
+ for i in range(1, ndim):
+ strides[i] *= strides[i-1]
+ return strides
+
+def _ca(items, s):
+ """Convert flat item list to the nested list representation of a
+ multidimensional C array with shape 's'."""
+ if atomp(items):
+ return items
+ if len(s) == 0:
+ return items[0]
+ lst = [0] * s[0]
+ stride = len(items) // s[0] if s[0] else 0
+ for i in range(s[0]):
+ start = i*stride
+ lst[i] = _ca(items[start:start+stride], s[1:])
+ return lst
+
+def _fa(items, s):
+ """Convert flat item list to the nested list representation of a
+ multidimensional Fortran array with shape 's'."""
+ if atomp(items):
+ return items
+ if len(s) == 0:
+ return items[0]
+ lst = [0] * s[0]
+ stride = s[0]
+ for i in range(s[0]):
+ lst[i] = _fa(items[i::stride], s[1:])
+ return lst
+
+def carray(items, shape):
+ if listp(items) and not 0 in shape and prod(shape) != len(items):
+ raise ValueError("prod(shape) != len(items)")
+ return _ca(items, shape)
+
+def farray(items, shape):
+ if listp(items) and not 0 in shape and prod(shape) != len(items):
+ raise ValueError("prod(shape) != len(items)")
+ return _fa(items, shape)
+
+def indices(shape):
+ """Generate all possible tuples of indices."""
+ iterables = [range(v) for v in shape]
+ return product(*iterables)
+
+def getindex(ndim, ind, strides):
+ """Convert multi-dimensional index to the position in the flat list."""
+ ret = 0
+ for i in range(ndim):
+ ret += strides[i] * ind[i]
+ return ret
+
+def transpose(src, shape):
+ """Transpose flat item list that is regarded as a multi-dimensional
+ matrix defined by shape: dest...[k][j][i] = src[i][j][k]... """
+ if not shape:
+ return src
+ ndim = len(shape)
+ sstrides = strides_from_shape(ndim, shape, 1, 'C')
+ dstrides = strides_from_shape(ndim, shape[::-1], 1, 'C')
+ dest = [0] * len(src)
+ for ind in indices(shape):
+ fr = getindex(ndim, ind, sstrides)
+ to = getindex(ndim, ind[::-1], dstrides)
+ dest[to] = src[fr]
+ return dest
+
+def _flatten(lst):
+ """flatten list"""
+ if lst == []:
+ return lst
+ if atomp(lst):
+ return [lst]
+ return _flatten(lst[0]) + _flatten(lst[1:])
+
+def flatten(lst):
+ """flatten list or return scalar"""
+ if atomp(lst): # scalar
+ return lst
+ return _flatten(lst)
+
+def slice_shape(lst, slices):
+ """Get the shape of lst after slicing: slices is a list of slice
+ objects."""
+ if atomp(lst):
+ return []
+ return [len(lst[slices[0]])] + slice_shape(lst[0], slices[1:])
+
+def multislice(lst, slices):
+ """Multi-dimensional slicing: slices is a list of slice objects."""
+ if atomp(lst):
+ return lst
+ return [multislice(sublst, slices[1:]) for sublst in lst[slices[0]]]
+
+def m_assign(llst, rlst, lslices, rslices):
+ """Multi-dimensional slice assignment: llst and rlst are the operands,
+ lslices and rslices are lists of slice objects. llst and rlst must
+ have the same structure.
+
+ For a two-dimensional example, this is not implemented in Python:
+
+ llst[0:3:2, 0:3:2] = rlst[1:3:1, 1:3:1]
+
+ Instead we write:
+
+ lslices = [slice(0,3,2), slice(0,3,2)]
+ rslices = [slice(1,3,1), slice(1,3,1)]
+ multislice_assign(llst, rlst, lslices, rslices)
+ """
+ if atomp(rlst):
+ return rlst
+ rlst = [m_assign(l, r, lslices[1:], rslices[1:])
+ for l, r in zip(llst[lslices[0]], rlst[rslices[0]])]
+ llst[lslices[0]] = rlst
+ return llst
+
+def cmp_structure(llst, rlst, lslices, rslices):
+ """Compare the structure of llst[lslices] and rlst[rslices]."""
+ lshape = slice_shape(llst, lslices)
+ rshape = slice_shape(rlst, rslices)
+ if (len(lshape) != len(rshape)):
+ return -1
+ for i in range(len(lshape)):
+ if lshape[i] != rshape[i]:
+ return -1
+ if lshape[i] == 0:
+ return 0
+ return 0
+
+def multislice_assign(llst, rlst, lslices, rslices):
+ """Return llst after assigning: llst[lslices] = rlst[rslices]"""
+ if cmp_structure(llst, rlst, lslices, rslices) < 0:
+ raise ValueError("lvalue and rvalue have different structures")
+ return m_assign(llst, rlst, lslices, rslices)
+
+
+# ======================================================================
+# Random structures
+# ======================================================================
+
+#
+# PEP-3118 is very permissive with respect to the contents of a
+# Py_buffer. In particular:
+#
+# - shape can be zero
+# - strides can be any integer, including zero
+# - offset can point to any location in the underlying
+# memory block, provided that it is a multiple of
+# itemsize.
+#
+# The functions in this section test and verify random structures
+# in full generality. A structure is valid iff it fits in the
+# underlying memory block.
+#
+# The structure 't' (short for 'tuple') is fully defined by:
+#
+# t = (memlen, itemsize, ndim, shape, strides, offset)
+#
+
+def verify_structure(memlen, itemsize, ndim, shape, strides, offset):
+ """Verify that the parameters represent a valid array within
+ the bounds of the allocated memory:
+ char *mem: start of the physical memory block
+ memlen: length of the physical memory block
+ offset: (char *)buf - mem
+ """
+ if offset % itemsize:
+ return False
+ if offset < 0 or offset+itemsize > memlen:
+ return False
+ if any(v % itemsize for v in strides):
+ return False
+
+ if ndim <= 0:
+ return ndim == 0 and not shape and not strides
+ if 0 in shape:
+ return True
+
+ imin = sum(strides[j]*(shape[j]-1) for j in range(ndim)
+ if strides[j] <= 0)
+ imax = sum(strides[j]*(shape[j]-1) for j in range(ndim)
+ if strides[j] > 0)
+
+ return 0 <= offset+imin and offset+imax+itemsize <= memlen
+
+def get_item(lst, indices):
+ for i in indices:
+ lst = lst[i]
+ return lst
+
+def memory_index(indices, t):
+ """Location of an item in the underlying memory."""
+ memlen, itemsize, ndim, shape, strides, offset = t
+ p = offset
+ for i in range(ndim):
+ p += strides[i]*indices[i]
+ return p
+
+def is_overlapping(t):
+ """The structure 't' is overlapping if at least one memory location
+ is visited twice while iterating through all possible tuples of
+ indices."""
+ memlen, itemsize, ndim, shape, strides, offset = t
+ visited = 1<<memlen
+ for ind in indices(shape):
+ i = memory_index(ind, t)
+ bit = 1<<i
+ if visited & bit:
+ return True
+ visited |= bit
+ return False
+
+def rand_structure(itemsize, valid, maxdim=5, maxshape=16, shape=()):
+ """Return random structure:
+ (memlen, itemsize, ndim, shape, strides, offset)
+ If 'valid' is true, the returned structure is valid, otherwise invalid.
+ If 'shape' is given, use that instead of creating a random shape.
+ """
+ if not shape:
+ ndim = randrange(maxdim+1)
+ if (ndim == 0):
+ if valid:
+ return itemsize, itemsize, ndim, (), (), 0
+ else:
+ nitems = randrange(1, 16+1)
+ memlen = nitems * itemsize
+ offset = -itemsize if randrange(2) == 0 else memlen
+ return memlen, itemsize, ndim, (), (), offset
+
+ minshape = 2
+ n = randrange(100)
+ if n >= 95 and valid:
+ minshape = 0
+ elif n >= 90:
+ minshape = 1
+ shape = [0] * ndim
+
+ for i in range(ndim):
+ shape[i] = randrange(minshape, maxshape+1)
+ else:
+ ndim = len(shape)
+
+ maxstride = 5
+ n = randrange(100)
+ zero_stride = True if n >= 95 and n & 1 else False
+
+ strides = [0] * ndim
+ strides[ndim-1] = itemsize * randrange(-maxstride, maxstride+1)
+ if not zero_stride and strides[ndim-1] == 0:
+ strides[ndim-1] = itemsize
+
+ for i in range(ndim-2, -1, -1):
+ maxstride *= shape[i+1] if shape[i+1] else 1
+ if zero_stride:
+ strides[i] = itemsize * randrange(-maxstride, maxstride+1)
+ else:
+ strides[i] = ((1,-1)[randrange(2)] *
+ itemsize * randrange(1, maxstride+1))
+
+ imin = imax = 0
+ if not 0 in shape:
+ imin = sum(strides[j]*(shape[j]-1) for j in range(ndim)
+ if strides[j] <= 0)
+ imax = sum(strides[j]*(shape[j]-1) for j in range(ndim)
+ if strides[j] > 0)
+
+ nitems = imax - imin
+ if valid:
+ offset = -imin * itemsize
+ memlen = offset + (imax+1) * itemsize
+ else:
+ memlen = (-imin + imax) * itemsize
+ offset = -imin-itemsize if randrange(2) == 0 else memlen
+ return memlen, itemsize, ndim, shape, strides, offset
+
+def randslice_from_slicelen(slicelen, listlen):
+ """Create a random slice of len slicelen that fits into listlen."""
+ maxstart = listlen - slicelen
+ start = randrange(maxstart+1)
+ maxstep = (listlen - start) // slicelen if slicelen else 1
+ step = randrange(1, maxstep+1)
+ stop = start + slicelen * step
+ s = slice(start, stop, step)
+ _, _, _, control = slice_indices(s, listlen)
+ if control != slicelen:
+ raise RuntimeError
+ return s
+
+def randslice_from_shape(ndim, shape):
+ """Create two sets of slices for an array x with shape 'shape'
+ such that shapeof(x[lslices]) == shapeof(x[rslices])."""
+ lslices = [0] * ndim
+ rslices = [0] * ndim
+ for n in range(ndim):
+ l = shape[n]
+ slicelen = randrange(1, l+1) if l > 0 else 0
+ lslices[n] = randslice_from_slicelen(slicelen, l)
+ rslices[n] = randslice_from_slicelen(slicelen, l)
+ return tuple(lslices), tuple(rslices)
+
+def rand_aligned_slices(maxdim=5, maxshape=16):
+ """Create (lshape, rshape, tuple(lslices), tuple(rslices)) such that
+ shapeof(x[lslices]) == shapeof(y[rslices]), where x is an array
+ with shape 'lshape' and y is an array with shape 'rshape'."""
+ ndim = randrange(1, maxdim+1)
+ minshape = 2
+ n = randrange(100)
+ if n >= 95:
+ minshape = 0
+ elif n >= 90:
+ minshape = 1
+ all_random = True if randrange(100) >= 80 else False
+ lshape = [0]*ndim; rshape = [0]*ndim
+ lslices = [0]*ndim; rslices = [0]*ndim
+
+ for n in range(ndim):
+ small = randrange(minshape, maxshape+1)
+ big = randrange(minshape, maxshape+1)
+ if big < small:
+ big, small = small, big
+
+ # Create a slice that fits the smaller value.
+ if all_random:
+ start = randrange(-small, small+1)
+ stop = randrange(-small, small+1)
+ step = (1,-1)[randrange(2)] * randrange(1, small+2)
+ s_small = slice(start, stop, step)
+ _, _, _, slicelen = slice_indices(s_small, small)
+ else:
+ slicelen = randrange(1, small+1) if small > 0 else 0
+ s_small = randslice_from_slicelen(slicelen, small)
+
+ # Create a slice of the same length for the bigger value.
+ s_big = randslice_from_slicelen(slicelen, big)
+ if randrange(2) == 0:
+ rshape[n], lshape[n] = big, small
+ rslices[n], lslices[n] = s_big, s_small
+ else:
+ rshape[n], lshape[n] = small, big
+ rslices[n], lslices[n] = s_small, s_big
+
+ return lshape, rshape, tuple(lslices), tuple(rslices)
+
+def randitems_from_structure(fmt, t):
+ """Return a list of random items for structure 't' with format
+ 'fmtchar'."""
+ memlen, itemsize, _, _, _, _ = t
+ return gen_items(memlen//itemsize, '#'+fmt, 'numpy')
+
+def ndarray_from_structure(items, fmt, t, flags=0):
+ """Return ndarray from the tuple returned by rand_structure()"""
+ memlen, itemsize, ndim, shape, strides, offset = t
+ return ndarray(items, shape=shape, strides=strides, format=fmt,
+ offset=offset, flags=ND_WRITABLE|flags)
+
+def numpy_array_from_structure(items, fmt, t):
+ """Return numpy_array from the tuple returned by rand_structure()"""
+ memlen, itemsize, ndim, shape, strides, offset = t
+ buf = bytearray(memlen)
+ for j, v in enumerate(items):
+ struct.pack_into(fmt, buf, j*itemsize, v)
+ return numpy_array(buffer=buf, shape=shape, strides=strides,
+ dtype=fmt, offset=offset)
+
+
+# ======================================================================
+# memoryview casts
+# ======================================================================
+
+def cast_items(exporter, fmt, itemsize, shape=None):
+ """Interpret the raw memory of 'exporter' as a list of items with
+ size 'itemsize'. If shape=None, the new structure is assumed to
+ be 1-D with n * itemsize = bytelen. If shape is given, the usual
+ constraint for contiguous arrays prod(shape) * itemsize = bytelen
+ applies. On success, return (items, shape). If the constraints
+ cannot be met, return (None, None). If a chunk of bytes is interpreted
+ as NaN as a result of float conversion, return ('nan', None)."""
+ bytelen = exporter.nbytes
+ if shape:
+ if prod(shape) * itemsize != bytelen:
+ return None, shape
+ elif shape == []:
+ if exporter.ndim == 0 or itemsize != bytelen:
+ return None, shape
+ else:
+ n, r = divmod(bytelen, itemsize)
+ shape = [n]
+ if r != 0:
+ return None, shape
+
+ mem = exporter.tobytes()
+ byteitems = [mem[i:i+itemsize] for i in range(0, len(mem), itemsize)]
+
+ items = []
+ for v in byteitems:
+ item = struct.unpack(fmt, v)[0]
+ if item != item:
+ return 'nan', shape
+ items.append(item)
+
+ return (items, shape) if shape != [] else (items[0], shape)
+
+def gencastshapes():
+ """Generate shapes to test casting."""
+ for n in range(32):
+ yield [n]
+ ndim = randrange(4, 6)
+ minshape = 1 if randrange(100) > 80 else 2
+ yield [randrange(minshape, 5) for _ in range(ndim)]
+ ndim = randrange(2, 4)
+ minshape = 1 if randrange(100) > 80 else 2
+ yield [randrange(minshape, 5) for _ in range(ndim)]
+
+
+# ======================================================================
+# Actual tests
+# ======================================================================
+
+def genslices(n):
+ """Generate all possible slices for a single dimension."""
+ return product(range(-n, n+1), range(-n, n+1), range(-n, n+1))
+
+def genslices_ndim(ndim, shape):
+ """Generate all possible slice tuples for 'shape'."""
+ iterables = [genslices(shape[n]) for n in range(ndim)]
+ return product(*iterables)
+
+def rslice(n, allow_empty=False):
+ """Generate random slice for a single dimension of length n.
+ If zero=True, the slices may be empty, otherwise they will
+ be non-empty."""
+ minlen = 0 if allow_empty or n == 0 else 1
+ slicelen = randrange(minlen, n+1)
+ return randslice_from_slicelen(slicelen, n)
+
+def rslices(n, allow_empty=False):
+ """Generate random slices for a single dimension."""
+ for _ in range(5):
+ yield rslice(n, allow_empty)
+
+def rslices_ndim(ndim, shape, iterations=5):
+ """Generate random slice tuples for 'shape'."""
+ # non-empty slices
+ for _ in range(iterations):
+ yield tuple(rslice(shape[n]) for n in range(ndim))
+ # possibly empty slices
+ for _ in range(iterations):
+ yield tuple(rslice(shape[n], allow_empty=True) for n in range(ndim))
+ # invalid slices
+ yield tuple(slice(0,1,0) for _ in range(ndim))
+
+def rpermutation(iterable, r=None):
+ pool = tuple(iterable)
+ r = len(pool) if r is None else r
+ yield tuple(sample(pool, r))
+
+def ndarray_print(nd):
+ """Print ndarray for debugging."""
+ try:
+ x = nd.tolist()
+ except (TypeError, NotImplementedError):
+ x = nd.tobytes()
+ if isinstance(nd, ndarray):
+ offset = nd.offset
+ flags = nd.flags
+ else:
+ offset = 'unknown'
+ flags = 'unknown'
+ print("ndarray(%s, shape=%s, strides=%s, suboffsets=%s, offset=%s, "
+ "format='%s', itemsize=%s, flags=%s)" %
+ (x, nd.shape, nd.strides, nd.suboffsets, offset,
+ nd.format, nd.itemsize, flags))
+ sys.stdout.flush()
+
+
+ITERATIONS = 100
+MAXDIM = 5
+MAXSHAPE = 10
+
+if SHORT_TEST:
+ ITERATIONS = 10
+ MAXDIM = 3
+ MAXSHAPE = 4
+ genslices = rslices
+ genslices_ndim = rslices_ndim
+ permutations = rpermutation
+
+
+@unittest.skipUnless(struct, 'struct module required for this test.')
+@unittest.skipUnless(ndarray, 'ndarray object required for this test')
+class TestBufferProtocol(unittest.TestCase):
+
+ def setUp(self):
+ # The suboffsets tests need sizeof(void *).
+ self.sizeof_void_p = get_sizeof_void_p()
+
+ def verify(self, result, obj=-1,
+ itemsize={1}, fmt=-1, readonly={1},
+ ndim={1}, shape=-1, strides=-1,
+ lst=-1, sliced=False, cast=False):
+ # Verify buffer contents against expected values. Default values
+ # are deliberately initialized to invalid types.
+ if shape:
+ expected_len = prod(shape)*itemsize
+ else:
+ if not fmt: # array has been implicitly cast to unsigned bytes
+ expected_len = len(lst)
+ else: # ndim = 0
+ expected_len = itemsize
+
+ # Reconstruct suboffsets from strides. Support for slicing
+ # could be added, but is currently only needed for test_getbuf().
+ suboffsets = ()
+ if result.suboffsets:
+ self.assertGreater(ndim, 0)
+
+ suboffset0 = 0
+ for n in range(1, ndim):
+ if shape[n] == 0:
+ break
+ if strides[n] <= 0:
+ suboffset0 += -strides[n] * (shape[n]-1)
+
+ suboffsets = [suboffset0] + [-1 for v in range(ndim-1)]
+
+ # Not correct if slicing has occurred in the first dimension.
+ stride0 = self.sizeof_void_p
+ if strides[0] < 0:
+ stride0 = -stride0
+ strides = [stride0] + list(strides[1:])
+
+ self.assertIs(result.obj, obj)
+ self.assertEqual(result.nbytes, expected_len)
+ self.assertEqual(result.itemsize, itemsize)
+ self.assertEqual(result.format, fmt)
+ self.assertEqual(result.readonly, readonly)
+ self.assertEqual(result.ndim, ndim)
+ self.assertEqual(result.shape, tuple(shape))
+ if not (sliced and suboffsets):
+ self.assertEqual(result.strides, tuple(strides))
+ self.assertEqual(result.suboffsets, tuple(suboffsets))
+
+ if isinstance(result, ndarray) or is_memoryview_format(fmt):
+ rep = result.tolist() if fmt else result.tobytes()
+ self.assertEqual(rep, lst)
+
+ if not fmt: # array has been cast to unsigned bytes,
+ return # the remaining tests won't work.
+
+ # PyBuffer_GetPointer() is the definition how to access an item.
+ # If PyBuffer_GetPointer(indices) is correct for all possible
+ # combinations of indices, the buffer is correct.
+ #
+ # Also test tobytes() against the flattened 'lst', with all items
+ # packed to bytes.
+ if not cast: # casts chop up 'lst' in different ways
+ b = bytearray()
+ buf_err = None
+ for ind in indices(shape):
+ try:
+ item1 = get_pointer(result, ind)
+ item2 = get_item(lst, ind)
+ if isinstance(item2, tuple):
+ x = struct.pack(fmt, *item2)
+ else:
+ x = struct.pack(fmt, item2)
+ b.extend(x)
+ except BufferError:
+ buf_err = True # re-exporter does not provide full buffer
+ break
+ self.assertEqual(item1, item2)
+
+ if not buf_err:
+ # test tobytes()
+ self.assertEqual(result.tobytes(), b)
+
+ # lst := expected multi-dimensional logical representation
+ # flatten(lst) := elements in C-order
+ ff = fmt if fmt else 'B'
+ flattened = flatten(lst)
+
+ # Rules for 'A': if the array is already contiguous, return
+ # the array unaltered. Otherwise, return a contiguous 'C'
+ # representation.
+ for order in ['C', 'F', 'A']:
+ expected = result
+ if order == 'F':
+ if not is_contiguous(result, 'A') or \
+ is_contiguous(result, 'C'):
+ # For constructing the ndarray, convert the
+ # flattened logical representation to Fortran order.
+ trans = transpose(flattened, shape)
+ expected = ndarray(trans, shape=shape, format=ff,
+ flags=ND_FORTRAN)
+ else: # 'C', 'A'
+ if not is_contiguous(result, 'A') or \
+ is_contiguous(result, 'F') and order == 'C':
+ # The flattened list is already in C-order.
+ expected = ndarray(flattened, shape=shape, format=ff)
+
+ contig = get_contiguous(result, PyBUF_READ, order)
+ self.assertEqual(contig.tobytes(), b)
+ self.assertTrue(cmp_contig(contig, expected))
+
+ if ndim == 0:
+ continue
+
+ nmemb = len(flattened)
+ ro = 0 if readonly else ND_WRITABLE
+
+ ### See comment in test_py_buffer_to_contiguous for an
+ ### explanation why these tests are valid.
+
+ # To 'C'
+ contig = py_buffer_to_contiguous(result, 'C', PyBUF_FULL_RO)
+ self.assertEqual(len(contig), nmemb * itemsize)
+ initlst = [struct.unpack_from(fmt, contig, n*itemsize)
+ for n in range(nmemb)]
+ if len(initlst[0]) == 1:
+ initlst = [v[0] for v in initlst]
+
+ y = ndarray(initlst, shape=shape, flags=ro, format=fmt)
+ self.assertEqual(memoryview(y), memoryview(result))
+
+ # To 'F'
+ contig = py_buffer_to_contiguous(result, 'F', PyBUF_FULL_RO)
+ self.assertEqual(len(contig), nmemb * itemsize)
+ initlst = [struct.unpack_from(fmt, contig, n*itemsize)
+ for n in range(nmemb)]
+ if len(initlst[0]) == 1:
+ initlst = [v[0] for v in initlst]
+
+ y = ndarray(initlst, shape=shape, flags=ro|ND_FORTRAN,
+ format=fmt)
+ self.assertEqual(memoryview(y), memoryview(result))
+
+ # To 'A'
+ contig = py_buffer_to_contiguous(result, 'A', PyBUF_FULL_RO)
+ self.assertEqual(len(contig), nmemb * itemsize)
+ initlst = [struct.unpack_from(fmt, contig, n*itemsize)
+ for n in range(nmemb)]
+ if len(initlst[0]) == 1:
+ initlst = [v[0] for v in initlst]
+
+ f = ND_FORTRAN if is_contiguous(result, 'F') else 0
+ y = ndarray(initlst, shape=shape, flags=f|ro, format=fmt)
+ self.assertEqual(memoryview(y), memoryview(result))
+
+ if is_memoryview_format(fmt):
+ try:
+ m = memoryview(result)
+ except BufferError: # re-exporter does not provide full information
+ return
+ ex = result.obj if isinstance(result, memoryview) else result
+ self.assertIs(m.obj, ex)
+ self.assertEqual(m.nbytes, expected_len)
+ self.assertEqual(m.itemsize, itemsize)
+ self.assertEqual(m.format, fmt)
+ self.assertEqual(m.readonly, readonly)
+ self.assertEqual(m.ndim, ndim)
+ self.assertEqual(m.shape, tuple(shape))
+ if not (sliced and suboffsets):
+ self.assertEqual(m.strides, tuple(strides))
+ self.assertEqual(m.suboffsets, tuple(suboffsets))
+
+ n = 1 if ndim == 0 else len(lst)
+ self.assertEqual(len(m), n)
+
+ rep = result.tolist() if fmt else result.tobytes()
+ self.assertEqual(rep, lst)
+ self.assertEqual(m, result)
+
+ def verify_getbuf(self, orig_ex, ex, req, sliced=False):
+ def simple_fmt(ex):
+ return ex.format == '' or ex.format == 'B'
+ def match(req, flag):
+ return ((req&flag) == flag)
+
+ if (# writable request to read-only exporter
+ (ex.readonly and match(req, PyBUF_WRITABLE)) or
+ # cannot match explicit contiguity request
+ (match(req, PyBUF_C_CONTIGUOUS) and not ex.c_contiguous) or
+ (match(req, PyBUF_F_CONTIGUOUS) and not ex.f_contiguous) or
+ (match(req, PyBUF_ANY_CONTIGUOUS) and not ex.contiguous) or
+ # buffer needs suboffsets
+ (not match(req, PyBUF_INDIRECT) and ex.suboffsets) or
+ # buffer without strides must be C-contiguous
+ (not match(req, PyBUF_STRIDES) and not ex.c_contiguous) or
+ # PyBUF_SIMPLE|PyBUF_FORMAT and PyBUF_WRITABLE|PyBUF_FORMAT
+ (not match(req, PyBUF_ND) and match(req, PyBUF_FORMAT))):
+
+ self.assertRaises(BufferError, ndarray, ex, getbuf=req)
+ return
+
+ if isinstance(ex, ndarray) or is_memoryview_format(ex.format):
+ lst = ex.tolist()
+ else:
+ nd = ndarray(ex, getbuf=PyBUF_FULL_RO)
+ lst = nd.tolist()
+
+ # The consumer may have requested default values or a NULL format.
+ ro = 0 if match(req, PyBUF_WRITABLE) else ex.readonly
+ fmt = ex.format
+ itemsize = ex.itemsize
+ ndim = ex.ndim
+ if not match(req, PyBUF_FORMAT):
+ # itemsize refers to the original itemsize before the cast.
+ # The equality product(shape) * itemsize = len still holds.
+ # The equality calcsize(format) = itemsize does _not_ hold.
+ fmt = ''
+ lst = orig_ex.tobytes() # Issue 12834
+ if not match(req, PyBUF_ND):
+ ndim = 1
+ shape = orig_ex.shape if match(req, PyBUF_ND) else ()
+ strides = orig_ex.strides if match(req, PyBUF_STRIDES) else ()
+
+ nd = ndarray(ex, getbuf=req)
+ self.verify(nd, obj=ex,
+ itemsize=itemsize, fmt=fmt, readonly=ro,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst, sliced=sliced)
+
+ def test_ndarray_getbuf(self):
+ requests = (
+ # distinct flags
+ PyBUF_INDIRECT, PyBUF_STRIDES, PyBUF_ND, PyBUF_SIMPLE,
+ PyBUF_C_CONTIGUOUS, PyBUF_F_CONTIGUOUS, PyBUF_ANY_CONTIGUOUS,
+ # compound requests
+ PyBUF_FULL, PyBUF_FULL_RO,
+ PyBUF_RECORDS, PyBUF_RECORDS_RO,
+ PyBUF_STRIDED, PyBUF_STRIDED_RO,
+ PyBUF_CONTIG, PyBUF_CONTIG_RO,
+ )
+ # items and format
+ items_fmt = (
+ ([True if x % 2 else False for x in range(12)], '?'),
+ ([1,2,3,4,5,6,7,8,9,10,11,12], 'b'),
+ ([1,2,3,4,5,6,7,8,9,10,11,12], 'B'),
+ ([(2**31-x) if x % 2 else (-2**31+x) for x in range(12)], 'l')
+ )
+ # shape, strides, offset
+ structure = (
+ ([], [], 0),
+ ([12], [], 0),
+ ([12], [-1], 11),
+ ([6], [2], 0),
+ ([6], [-2], 11),
+ ([3, 4], [], 0),
+ ([3, 4], [-4, -1], 11),
+ ([2, 2], [4, 1], 4),
+ ([2, 2], [-4, -1], 8)
+ )
+ # ndarray creation flags
+ ndflags = (
+ 0, ND_WRITABLE, ND_FORTRAN, ND_FORTRAN|ND_WRITABLE,
+ ND_PIL, ND_PIL|ND_WRITABLE
+ )
+ # flags that can actually be used as flags
+ real_flags = (0, PyBUF_WRITABLE, PyBUF_FORMAT,
+ PyBUF_WRITABLE|PyBUF_FORMAT)
+
+ for items, fmt in items_fmt:
+ itemsize = struct.calcsize(fmt)
+ for shape, strides, offset in structure:
+ strides = [v * itemsize for v in strides]
+ offset *= itemsize
+ for flags in ndflags:
+
+ if strides and (flags&ND_FORTRAN):
+ continue
+ if not shape and (flags&ND_PIL):
+ continue
+
+ _items = items if shape else items[0]
+ ex1 = ndarray(_items, format=fmt, flags=flags,
+ shape=shape, strides=strides, offset=offset)
+ ex2 = ex1[::-2] if shape else None
+
+ m1 = memoryview(ex1)
+ if ex2:
+ m2 = memoryview(ex2)
+ if ex1.ndim == 0 or (ex1.ndim == 1 and shape and strides):
+ self.assertEqual(m1, ex1)
+ if ex2 and ex2.ndim == 1 and shape and strides:
+ self.assertEqual(m2, ex2)
+
+ for req in requests:
+ for bits in real_flags:
+ self.verify_getbuf(ex1, ex1, req|bits)
+ self.verify_getbuf(ex1, m1, req|bits)
+ if ex2:
+ self.verify_getbuf(ex2, ex2, req|bits,
+ sliced=True)
+ self.verify_getbuf(ex2, m2, req|bits,
+ sliced=True)
+
+ items = [1,2,3,4,5,6,7,8,9,10,11,12]
+
+ # ND_GETBUF_FAIL
+ ex = ndarray(items, shape=[12], flags=ND_GETBUF_FAIL)
+ self.assertRaises(BufferError, ndarray, ex)
+
+ # Request complex structure from a simple exporter. In this
+ # particular case the test object is not PEP-3118 compliant.
+ base = ndarray([9], [1])
+ ex = ndarray(base, getbuf=PyBUF_SIMPLE)
+ self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_WRITABLE)
+ self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_ND)
+ self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_STRIDES)
+ self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_C_CONTIGUOUS)
+ self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_F_CONTIGUOUS)
+ self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_ANY_CONTIGUOUS)
+ nd = ndarray(ex, getbuf=PyBUF_SIMPLE)
+
+ def test_ndarray_exceptions(self):
+ nd = ndarray([9], [1])
+ ndm = ndarray([9], [1], flags=ND_VAREXPORT)
+
+ # Initialization of a new ndarray or mutation of an existing array.
+ for c in (ndarray, nd.push, ndm.push):
+ # Invalid types.
+ self.assertRaises(TypeError, c, {1,2,3})
+ self.assertRaises(TypeError, c, [1,2,'3'])
+ self.assertRaises(TypeError, c, [1,2,(3,4)])
+ self.assertRaises(TypeError, c, [1,2,3], shape={3})
+ self.assertRaises(TypeError, c, [1,2,3], shape=[3], strides={1})
+ self.assertRaises(TypeError, c, [1,2,3], shape=[3], offset=[])
+ self.assertRaises(TypeError, c, [1], shape=[1], format={})
+ self.assertRaises(TypeError, c, [1], shape=[1], flags={})
+ self.assertRaises(TypeError, c, [1], shape=[1], getbuf={})
+
+ # ND_FORTRAN flag is only valid without strides.
+ self.assertRaises(TypeError, c, [1], shape=[1], strides=[1],
+ flags=ND_FORTRAN)
+
+ # ND_PIL flag is only valid with ndim > 0.
+ self.assertRaises(TypeError, c, [1], shape=[], flags=ND_PIL)
+
+ # Invalid items.
+ self.assertRaises(ValueError, c, [], shape=[1])
+ self.assertRaises(ValueError, c, ['XXX'], shape=[1], format="L")
+ # Invalid combination of items and format.
+ self.assertRaises(struct.error, c, [1000], shape=[1], format="B")
+ self.assertRaises(ValueError, c, [1,(2,3)], shape=[2], format="B")
+ self.assertRaises(ValueError, c, [1,2,3], shape=[3], format="QL")
+
+ # Invalid ndim.
+ n = ND_MAX_NDIM+1
+ self.assertRaises(ValueError, c, [1]*n, shape=[1]*n)
+
+ # Invalid shape.
+ self.assertRaises(ValueError, c, [1], shape=[-1])
+ self.assertRaises(ValueError, c, [1,2,3], shape=['3'])
+ self.assertRaises(OverflowError, c, [1], shape=[2**128])
+ # prod(shape) * itemsize != len(items)
+ self.assertRaises(ValueError, c, [1,2,3,4,5], shape=[2,2], offset=3)
+
+ # Invalid strides.
+ self.assertRaises(ValueError, c, [1,2,3], shape=[3], strides=['1'])
+ self.assertRaises(OverflowError, c, [1], shape=[1],
+ strides=[2**128])
+
+ # Invalid combination of strides and shape.
+ self.assertRaises(ValueError, c, [1,2], shape=[2,1], strides=[1])
+ # Invalid combination of strides and format.
+ self.assertRaises(ValueError, c, [1,2,3,4], shape=[2], strides=[3],
+ format="L")
+
+ # Invalid offset.
+ self.assertRaises(ValueError, c, [1,2,3], shape=[3], offset=4)
+ self.assertRaises(ValueError, c, [1,2,3], shape=[1], offset=3,
+ format="L")
+
+ # Invalid format.
+ self.assertRaises(ValueError, c, [1,2,3], shape=[3], format="")
+ self.assertRaises(struct.error, c, [(1,2,3)], shape=[1],
+ format="@#$")
+
+ # Striding out of the memory bounds.
+ items = [1,2,3,4,5,6,7,8,9,10]
+ self.assertRaises(ValueError, c, items, shape=[2,3],
+ strides=[-3, -2], offset=5)
+
+ # Constructing consumer: format argument invalid.
+ self.assertRaises(TypeError, c, bytearray(), format="Q")
+
+ # Constructing original base object: getbuf argument invalid.
+ self.assertRaises(TypeError, c, [1], shape=[1], getbuf=PyBUF_FULL)
+
+ # Shape argument is mandatory for original base objects.
+ self.assertRaises(TypeError, c, [1])
+
+
+ # PyBUF_WRITABLE request to read-only provider.
+ self.assertRaises(BufferError, ndarray, b'123', getbuf=PyBUF_WRITABLE)
+
+ # ND_VAREXPORT can only be specified during construction.
+ nd = ndarray([9], [1], flags=ND_VAREXPORT)
+ self.assertRaises(ValueError, nd.push, [1], [1], flags=ND_VAREXPORT)
+
+ # Invalid operation for consumers: push/pop
+ nd = ndarray(b'123')
+ self.assertRaises(BufferError, nd.push, [1], [1])
+ self.assertRaises(BufferError, nd.pop)
+
+ # ND_VAREXPORT not set: push/pop fail with exported buffers
+ nd = ndarray([9], [1])
+ nd.push([1], [1])
+ m = memoryview(nd)
+ self.assertRaises(BufferError, nd.push, [1], [1])
+ self.assertRaises(BufferError, nd.pop)
+ m.release()
+ nd.pop()
+
+ # Single remaining buffer: pop fails
+ self.assertRaises(BufferError, nd.pop)
+ del nd
+
+ # get_pointer()
+ self.assertRaises(TypeError, get_pointer, {}, [1,2,3])
+ self.assertRaises(TypeError, get_pointer, b'123', {})
+
+ nd = ndarray(list(range(100)), shape=[1]*100)
+ self.assertRaises(ValueError, get_pointer, nd, [5])
+
+ nd = ndarray(list(range(12)), shape=[3,4])
+ self.assertRaises(ValueError, get_pointer, nd, [2,3,4])
+ self.assertRaises(ValueError, get_pointer, nd, [3,3])
+ self.assertRaises(ValueError, get_pointer, nd, [-3,3])
+ self.assertRaises(OverflowError, get_pointer, nd, [1<<64,3])
+
+ # tolist() needs format
+ ex = ndarray([1,2,3], shape=[3], format='L')
+ nd = ndarray(ex, getbuf=PyBUF_SIMPLE)
+ self.assertRaises(ValueError, nd.tolist)
+
+ # memoryview_from_buffer()
+ ex1 = ndarray([1,2,3], shape=[3], format='L')
+ ex2 = ndarray(ex1)
+ nd = ndarray(ex2)
+ self.assertRaises(TypeError, nd.memoryview_from_buffer)
+
+ nd = ndarray([(1,)*200], shape=[1], format='L'*200)
+ self.assertRaises(TypeError, nd.memoryview_from_buffer)
+
+ n = ND_MAX_NDIM
+ nd = ndarray(list(range(n)), shape=[1]*n)
+ self.assertRaises(ValueError, nd.memoryview_from_buffer)
+
+ # get_contiguous()
+ nd = ndarray([1], shape=[1])
+ self.assertRaises(TypeError, get_contiguous, 1, 2, 3, 4, 5)
+ self.assertRaises(TypeError, get_contiguous, nd, "xyz", 'C')
+ self.assertRaises(OverflowError, get_contiguous, nd, 2**64, 'C')
+ self.assertRaises(TypeError, get_contiguous, nd, PyBUF_READ, 961)
+ self.assertRaises(UnicodeEncodeError, get_contiguous, nd, PyBUF_READ,
+ '\u2007')
+ self.assertRaises(ValueError, get_contiguous, nd, PyBUF_READ, 'Z')
+ self.assertRaises(ValueError, get_contiguous, nd, 255, 'A')
+
+ # cmp_contig()
+ nd = ndarray([1], shape=[1])
+ self.assertRaises(TypeError, cmp_contig, 1, 2, 3, 4, 5)
+ self.assertRaises(TypeError, cmp_contig, {}, nd)
+ self.assertRaises(TypeError, cmp_contig, nd, {})
+
+ # is_contiguous()
+ nd = ndarray([1], shape=[1])
+ self.assertRaises(TypeError, is_contiguous, 1, 2, 3, 4, 5)
+ self.assertRaises(TypeError, is_contiguous, {}, 'A')
+ self.assertRaises(TypeError, is_contiguous, nd, 201)
+
+ def test_ndarray_linked_list(self):
+ for perm in permutations(range(5)):
+ m = [0]*5
+ nd = ndarray([1,2,3], shape=[3], flags=ND_VAREXPORT)
+ m[0] = memoryview(nd)
+
+ for i in range(1, 5):
+ nd.push([1,2,3], shape=[3])
+ m[i] = memoryview(nd)
+
+ for i in range(5):
+ m[perm[i]].release()
+
+ self.assertRaises(BufferError, nd.pop)
+ del nd
+
+ def test_ndarray_format_scalar(self):
+ # ndim = 0: scalar
+ for fmt, scalar, _ in iter_format(0):
+ itemsize = struct.calcsize(fmt)
+ nd = ndarray(scalar, shape=(), format=fmt)
+ self.verify(nd, obj=None,
+ itemsize=itemsize, fmt=fmt, readonly=1,
+ ndim=0, shape=(), strides=(),
+ lst=scalar)
+
+ def test_ndarray_format_shape(self):
+ # ndim = 1, shape = [n]
+ nitems = randrange(1, 10)
+ for fmt, items, _ in iter_format(nitems):
+ itemsize = struct.calcsize(fmt)
+ for flags in (0, ND_PIL):
+ nd = ndarray(items, shape=[nitems], format=fmt, flags=flags)
+ self.verify(nd, obj=None,
+ itemsize=itemsize, fmt=fmt, readonly=1,
+ ndim=1, shape=(nitems,), strides=(itemsize,),
+ lst=items)
+
+ def test_ndarray_format_strides(self):
+ # ndim = 1, strides
+ nitems = randrange(1, 30)
+ for fmt, items, _ in iter_format(nitems):
+ itemsize = struct.calcsize(fmt)
+ for step in range(-5, 5):
+ if step == 0:
+ continue
+
+ shape = [len(items[::step])]
+ strides = [step*itemsize]
+ offset = itemsize*(nitems-1) if step < 0 else 0
+
+ for flags in (0, ND_PIL):
+ nd = ndarray(items, shape=shape, strides=strides,
+ format=fmt, offset=offset, flags=flags)
+ self.verify(nd, obj=None,
+ itemsize=itemsize, fmt=fmt, readonly=1,
+ ndim=1, shape=shape, strides=strides,
+ lst=items[::step])
+
+ def test_ndarray_fortran(self):
+ items = [1,2,3,4,5,6,7,8,9,10,11,12]
+ ex = ndarray(items, shape=(3, 4), strides=(1, 3))
+ nd = ndarray(ex, getbuf=PyBUF_F_CONTIGUOUS|PyBUF_FORMAT)
+ self.assertEqual(nd.tolist(), farray(items, (3, 4)))
+
+ def test_ndarray_multidim(self):
+ for ndim in range(5):
+ shape_t = [randrange(2, 10) for _ in range(ndim)]
+ nitems = prod(shape_t)
+ for shape in permutations(shape_t):
+
+ fmt, items, _ = randitems(nitems)
+ itemsize = struct.calcsize(fmt)
+
+ for flags in (0, ND_PIL):
+ if ndim == 0 and flags == ND_PIL:
+ continue
+
+ # C array
+ nd = ndarray(items, shape=shape, format=fmt, flags=flags)
+
+ strides = strides_from_shape(ndim, shape, itemsize, 'C')
+ lst = carray(items, shape)
+ self.verify(nd, obj=None,
+ itemsize=itemsize, fmt=fmt, readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ if is_memoryview_format(fmt):
+ # memoryview: reconstruct strides
+ ex = ndarray(items, shape=shape, format=fmt)
+ nd = ndarray(ex, getbuf=PyBUF_CONTIG_RO|PyBUF_FORMAT)
+ self.assertTrue(nd.strides == ())
+ mv = nd.memoryview_from_buffer()
+ self.verify(mv, obj=None,
+ itemsize=itemsize, fmt=fmt, readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ # Fortran array
+ nd = ndarray(items, shape=shape, format=fmt,
+ flags=flags|ND_FORTRAN)
+
+ strides = strides_from_shape(ndim, shape, itemsize, 'F')
+ lst = farray(items, shape)
+ self.verify(nd, obj=None,
+ itemsize=itemsize, fmt=fmt, readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ def test_ndarray_index_invalid(self):
+ # not writable
+ nd = ndarray([1], shape=[1])
+ self.assertRaises(TypeError, nd.__setitem__, 1, 8)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertRaises(TypeError, mv.__setitem__, 1, 8)
+
+ # cannot be deleted
+ nd = ndarray([1], shape=[1], flags=ND_WRITABLE)
+ self.assertRaises(TypeError, nd.__delitem__, 1)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertRaises(TypeError, mv.__delitem__, 1)
+
+ # overflow
+ nd = ndarray([1], shape=[1], flags=ND_WRITABLE)
+ self.assertRaises(OverflowError, nd.__getitem__, 1<<64)
+ self.assertRaises(OverflowError, nd.__setitem__, 1<<64, 8)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertRaises(IndexError, mv.__getitem__, 1<<64)
+ self.assertRaises(IndexError, mv.__setitem__, 1<<64, 8)
+
+ # format
+ items = [1,2,3,4,5,6,7,8]
+ nd = ndarray(items, shape=[len(items)], format="B", flags=ND_WRITABLE)
+ self.assertRaises(struct.error, nd.__setitem__, 2, 300)
+ self.assertRaises(ValueError, nd.__setitem__, 1, (100, 200))
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertRaises(ValueError, mv.__setitem__, 2, 300)
+ self.assertRaises(TypeError, mv.__setitem__, 1, (100, 200))
+
+ items = [(1,2), (3,4), (5,6)]
+ nd = ndarray(items, shape=[len(items)], format="LQ", flags=ND_WRITABLE)
+ self.assertRaises(ValueError, nd.__setitem__, 2, 300)
+ self.assertRaises(struct.error, nd.__setitem__, 1, (b'\x001', 200))
+
+ def test_ndarray_index_scalar(self):
+ # scalar
+ nd = ndarray(1, shape=(), flags=ND_WRITABLE)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+
+ x = nd[()]; self.assertEqual(x, 1)
+ x = nd[...]; self.assertEqual(x.tolist(), nd.tolist())
+
+ x = mv[()]; self.assertEqual(x, 1)
+ x = mv[...]; self.assertEqual(x.tolist(), nd.tolist())
+
+ self.assertRaises(TypeError, nd.__getitem__, 0)
+ self.assertRaises(TypeError, mv.__getitem__, 0)
+ self.assertRaises(TypeError, nd.__setitem__, 0, 8)
+ self.assertRaises(TypeError, mv.__setitem__, 0, 8)
+
+ self.assertEqual(nd.tolist(), 1)
+ self.assertEqual(mv.tolist(), 1)
+
+ nd[()] = 9; self.assertEqual(nd.tolist(), 9)
+ mv[()] = 9; self.assertEqual(mv.tolist(), 9)
+
+ nd[...] = 5; self.assertEqual(nd.tolist(), 5)
+ mv[...] = 5; self.assertEqual(mv.tolist(), 5)
+
+ def test_ndarray_index_null_strides(self):
+ ex = ndarray(list(range(2*4)), shape=[2, 4], flags=ND_WRITABLE)
+ nd = ndarray(ex, getbuf=PyBUF_CONTIG)
+
+ # Sub-views are only possible for full exporters.
+ self.assertRaises(BufferError, nd.__getitem__, 1)
+ # Same for slices.
+ self.assertRaises(BufferError, nd.__getitem__, slice(3,5,1))
+
+ def test_ndarray_index_getitem_single(self):
+ # getitem
+ for fmt, items, _ in iter_format(5):
+ nd = ndarray(items, shape=[5], format=fmt)
+ for i in range(-5, 5):
+ self.assertEqual(nd[i], items[i])
+
+ self.assertRaises(IndexError, nd.__getitem__, -6)
+ self.assertRaises(IndexError, nd.__getitem__, 5)
+
+ if is_memoryview_format(fmt):
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ for i in range(-5, 5):
+ self.assertEqual(mv[i], items[i])
+
+ self.assertRaises(IndexError, mv.__getitem__, -6)
+ self.assertRaises(IndexError, mv.__getitem__, 5)
+
+ # getitem with null strides
+ for fmt, items, _ in iter_format(5):
+ ex = ndarray(items, shape=[5], flags=ND_WRITABLE, format=fmt)
+ nd = ndarray(ex, getbuf=PyBUF_CONTIG|PyBUF_FORMAT)
+
+ for i in range(-5, 5):
+ self.assertEqual(nd[i], items[i])
+
+ if is_memoryview_format(fmt):
+ mv = nd.memoryview_from_buffer()
+ self.assertIs(mv.__eq__(nd), NotImplemented)
+ for i in range(-5, 5):
+ self.assertEqual(mv[i], items[i])
+
+ # getitem with null format
+ items = [1,2,3,4,5]
+ ex = ndarray(items, shape=[5])
+ nd = ndarray(ex, getbuf=PyBUF_CONTIG_RO)
+ for i in range(-5, 5):
+ self.assertEqual(nd[i], items[i])
+
+ # getitem with null shape/strides/format
+ items = [1,2,3,4,5]
+ ex = ndarray(items, shape=[5])
+ nd = ndarray(ex, getbuf=PyBUF_SIMPLE)
+
+ for i in range(-5, 5):
+ self.assertEqual(nd[i], items[i])
+
+ def test_ndarray_index_setitem_single(self):
+ # assign single value
+ for fmt, items, single_item in iter_format(5):
+ nd = ndarray(items, shape=[5], format=fmt, flags=ND_WRITABLE)
+ for i in range(5):
+ items[i] = single_item
+ nd[i] = single_item
+ self.assertEqual(nd.tolist(), items)
+
+ self.assertRaises(IndexError, nd.__setitem__, -6, single_item)
+ self.assertRaises(IndexError, nd.__setitem__, 5, single_item)
+
+ if not is_memoryview_format(fmt):
+ continue
+
+ nd = ndarray(items, shape=[5], format=fmt, flags=ND_WRITABLE)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ for i in range(5):
+ items[i] = single_item
+ mv[i] = single_item
+ self.assertEqual(mv.tolist(), items)
+
+ self.assertRaises(IndexError, mv.__setitem__, -6, single_item)
+ self.assertRaises(IndexError, mv.__setitem__, 5, single_item)
+
+
+ # assign single value: lobject = robject
+ for fmt, items, single_item in iter_format(5):
+ nd = ndarray(items, shape=[5], format=fmt, flags=ND_WRITABLE)
+ for i in range(-5, 4):
+ items[i] = items[i+1]
+ nd[i] = nd[i+1]
+ self.assertEqual(nd.tolist(), items)
+
+ if not is_memoryview_format(fmt):
+ continue
+
+ nd = ndarray(items, shape=[5], format=fmt, flags=ND_WRITABLE)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ for i in range(-5, 4):
+ items[i] = items[i+1]
+ mv[i] = mv[i+1]
+ self.assertEqual(mv.tolist(), items)
+
+ def test_ndarray_index_getitem_multidim(self):
+ shape_t = (2, 3, 5)
+ nitems = prod(shape_t)
+ for shape in permutations(shape_t):
+
+ fmt, items, _ = randitems(nitems)
+
+ for flags in (0, ND_PIL):
+ # C array
+ nd = ndarray(items, shape=shape, format=fmt, flags=flags)
+ lst = carray(items, shape)
+
+ for i in range(-shape[0], shape[0]):
+ self.assertEqual(lst[i], nd[i].tolist())
+ for j in range(-shape[1], shape[1]):
+ self.assertEqual(lst[i][j], nd[i][j].tolist())
+ for k in range(-shape[2], shape[2]):
+ self.assertEqual(lst[i][j][k], nd[i][j][k])
+
+ # Fortran array
+ nd = ndarray(items, shape=shape, format=fmt,
+ flags=flags|ND_FORTRAN)
+ lst = farray(items, shape)
+
+ for i in range(-shape[0], shape[0]):
+ self.assertEqual(lst[i], nd[i].tolist())
+ for j in range(-shape[1], shape[1]):
+ self.assertEqual(lst[i][j], nd[i][j].tolist())
+ for k in range(shape[2], shape[2]):
+ self.assertEqual(lst[i][j][k], nd[i][j][k])
+
+ def test_ndarray_sequence(self):
+ nd = ndarray(1, shape=())
+ self.assertRaises(TypeError, eval, "1 in nd", locals())
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertRaises(TypeError, eval, "1 in mv", locals())
+
+ for fmt, items, _ in iter_format(5):
+ nd = ndarray(items, shape=[5], format=fmt)
+ for i, v in enumerate(nd):
+ self.assertEqual(v, items[i])
+ self.assertTrue(v in nd)
+
+ if is_memoryview_format(fmt):
+ mv = memoryview(nd)
+ for i, v in enumerate(mv):
+ self.assertEqual(v, items[i])
+ self.assertTrue(v in mv)
+
+ def test_ndarray_slice_invalid(self):
+ items = [1,2,3,4,5,6,7,8]
+
+ # rvalue is not an exporter
+ xl = ndarray(items, shape=[8], flags=ND_WRITABLE)
+ ml = memoryview(xl)
+ self.assertRaises(TypeError, xl.__setitem__, slice(0,8,1), items)
+ self.assertRaises(TypeError, ml.__setitem__, slice(0,8,1), items)
+
+ # rvalue is not a full exporter
+ xl = ndarray(items, shape=[8], flags=ND_WRITABLE)
+ ex = ndarray(items, shape=[8], flags=ND_WRITABLE)
+ xr = ndarray(ex, getbuf=PyBUF_ND)
+ self.assertRaises(BufferError, xl.__setitem__, slice(0,8,1), xr)
+
+ # zero step
+ nd = ndarray(items, shape=[8], format="L", flags=ND_WRITABLE)
+ mv = memoryview(nd)
+ self.assertRaises(ValueError, nd.__getitem__, slice(0,1,0))
+ self.assertRaises(ValueError, mv.__getitem__, slice(0,1,0))
+
+ nd = ndarray(items, shape=[2,4], format="L", flags=ND_WRITABLE)
+ mv = memoryview(nd)
+
+ self.assertRaises(ValueError, nd.__getitem__,
+ (slice(0,1,1), slice(0,1,0)))
+ self.assertRaises(ValueError, nd.__getitem__,
+ (slice(0,1,0), slice(0,1,1)))
+ self.assertRaises(TypeError, nd.__getitem__, "@%$")
+ self.assertRaises(TypeError, nd.__getitem__, ("@%$", slice(0,1,1)))
+ self.assertRaises(TypeError, nd.__getitem__, (slice(0,1,1), {}))
+
+ # memoryview: not implemented
+ self.assertRaises(NotImplementedError, mv.__getitem__,
+ (slice(0,1,1), slice(0,1,0)))
+ self.assertRaises(TypeError, mv.__getitem__, "@%$")
+
+ # differing format
+ xl = ndarray(items, shape=[8], format="B", flags=ND_WRITABLE)
+ xr = ndarray(items, shape=[8], format="b")
+ ml = memoryview(xl)
+ mr = memoryview(xr)
+ self.assertRaises(ValueError, xl.__setitem__, slice(0,1,1), xr[7:8])
+ self.assertEqual(xl.tolist(), items)
+ self.assertRaises(ValueError, ml.__setitem__, slice(0,1,1), mr[7:8])
+ self.assertEqual(ml.tolist(), items)
+
+ # differing itemsize
+ xl = ndarray(items, shape=[8], format="B", flags=ND_WRITABLE)
+ yr = ndarray(items, shape=[8], format="L")
+ ml = memoryview(xl)
+ mr = memoryview(xr)
+ self.assertRaises(ValueError, xl.__setitem__, slice(0,1,1), xr[7:8])
+ self.assertEqual(xl.tolist(), items)
+ self.assertRaises(ValueError, ml.__setitem__, slice(0,1,1), mr[7:8])
+ self.assertEqual(ml.tolist(), items)
+
+ # differing ndim
+ xl = ndarray(items, shape=[2, 4], format="b", flags=ND_WRITABLE)
+ xr = ndarray(items, shape=[8], format="b")
+ ml = memoryview(xl)
+ mr = memoryview(xr)
+ self.assertRaises(ValueError, xl.__setitem__, slice(0,1,1), xr[7:8])
+ self.assertEqual(xl.tolist(), [[1,2,3,4], [5,6,7,8]])
+ self.assertRaises(NotImplementedError, ml.__setitem__, slice(0,1,1),
+ mr[7:8])
+
+ # differing shape
+ xl = ndarray(items, shape=[8], format="b", flags=ND_WRITABLE)
+ xr = ndarray(items, shape=[8], format="b")
+ ml = memoryview(xl)
+ mr = memoryview(xr)
+ self.assertRaises(ValueError, xl.__setitem__, slice(0,2,1), xr[7:8])
+ self.assertEqual(xl.tolist(), items)
+ self.assertRaises(ValueError, ml.__setitem__, slice(0,2,1), mr[7:8])
+ self.assertEqual(ml.tolist(), items)
+
+ # _testbuffer.c module functions
+ self.assertRaises(TypeError, slice_indices, slice(0,1,2), {})
+ self.assertRaises(TypeError, slice_indices, "###########", 1)
+ self.assertRaises(ValueError, slice_indices, slice(0,1,0), 4)
+
+ x = ndarray(items, shape=[8], format="b", flags=ND_PIL)
+ self.assertRaises(TypeError, x.add_suboffsets)
+
+ ex = ndarray(items, shape=[8], format="B")
+ x = ndarray(ex, getbuf=PyBUF_SIMPLE)
+ self.assertRaises(TypeError, x.add_suboffsets)
+
+ def test_ndarray_slice_zero_shape(self):
+ items = [1,2,3,4,5,6,7,8,9,10,11,12]
+
+ x = ndarray(items, shape=[12], format="L", flags=ND_WRITABLE)
+ y = ndarray(items, shape=[12], format="L")
+ x[4:4] = y[9:9]
+ self.assertEqual(x.tolist(), items)
+
+ ml = memoryview(x)
+ mr = memoryview(y)
+ self.assertEqual(ml, x)
+ self.assertEqual(ml, y)
+ ml[4:4] = mr[9:9]
+ self.assertEqual(ml.tolist(), items)
+
+ x = ndarray(items, shape=[3, 4], format="L", flags=ND_WRITABLE)
+ y = ndarray(items, shape=[4, 3], format="L")
+ x[1:2, 2:2] = y[1:2, 3:3]
+ self.assertEqual(x.tolist(), carray(items, [3, 4]))
+
+ def test_ndarray_slice_multidim(self):
+ shape_t = (2, 3, 5)
+ ndim = len(shape_t)
+ nitems = prod(shape_t)
+ for shape in permutations(shape_t):
+
+ fmt, items, _ = randitems(nitems)
+ itemsize = struct.calcsize(fmt)
+
+ for flags in (0, ND_PIL):
+ nd = ndarray(items, shape=shape, format=fmt, flags=flags)
+ lst = carray(items, shape)
+
+ for slices in rslices_ndim(ndim, shape):
+
+ listerr = None
+ try:
+ sliced = multislice(lst, slices)
+ except Exception as e:
+ listerr = e.__class__
+
+ nderr = None
+ try:
+ ndsliced = nd[slices]
+ except Exception as e:
+ nderr = e.__class__
+
+ if nderr or listerr:
+ self.assertIs(nderr, listerr)
+ else:
+ self.assertEqual(ndsliced.tolist(), sliced)
+
+ def test_ndarray_slice_redundant_suboffsets(self):
+ shape_t = (2, 3, 5, 2)
+ ndim = len(shape_t)
+ nitems = prod(shape_t)
+ for shape in permutations(shape_t):
+
+ fmt, items, _ = randitems(nitems)
+ itemsize = struct.calcsize(fmt)
+
+ nd = ndarray(items, shape=shape, format=fmt)
+ nd.add_suboffsets()
+ ex = ndarray(items, shape=shape, format=fmt)
+ ex.add_suboffsets()
+ mv = memoryview(ex)
+ lst = carray(items, shape)
+
+ for slices in rslices_ndim(ndim, shape):
+
+ listerr = None
+ try:
+ sliced = multislice(lst, slices)
+ except Exception as e:
+ listerr = e.__class__
+
+ nderr = None
+ try:
+ ndsliced = nd[slices]
+ except Exception as e:
+ nderr = e.__class__
+
+ if nderr or listerr:
+ self.assertIs(nderr, listerr)
+ else:
+ self.assertEqual(ndsliced.tolist(), sliced)
+
+ def test_ndarray_slice_assign_single(self):
+ for fmt, items, _ in iter_format(5):
+ for lslice in genslices(5):
+ for rslice in genslices(5):
+ for flags in (0, ND_PIL):
+
+ f = flags|ND_WRITABLE
+ nd = ndarray(items, shape=[5], format=fmt, flags=f)
+ ex = ndarray(items, shape=[5], format=fmt, flags=f)
+ mv = memoryview(ex)
+
+ lsterr = None
+ diff_structure = None
+ lst = items[:]
+ try:
+ lval = lst[lslice]
+ rval = lst[rslice]
+ lst[lslice] = lst[rslice]
+ diff_structure = len(lval) != len(rval)
+ except Exception as e:
+ lsterr = e.__class__
+
+ nderr = None
+ try:
+ nd[lslice] = nd[rslice]
+ except Exception as e:
+ nderr = e.__class__
+
+ if diff_structure: # ndarray cannot change shape
+ self.assertIs(nderr, ValueError)
+ else:
+ self.assertEqual(nd.tolist(), lst)
+ self.assertIs(nderr, lsterr)
+
+ if not is_memoryview_format(fmt):
+ continue
+
+ mverr = None
+ try:
+ mv[lslice] = mv[rslice]
+ except Exception as e:
+ mverr = e.__class__
+
+ if diff_structure: # memoryview cannot change shape
+ self.assertIs(mverr, ValueError)
+ else:
+ self.assertEqual(mv.tolist(), lst)
+ self.assertEqual(mv, nd)
+ self.assertIs(mverr, lsterr)
+ self.verify(mv, obj=ex,
+ itemsize=nd.itemsize, fmt=fmt, readonly=0,
+ ndim=nd.ndim, shape=nd.shape, strides=nd.strides,
+ lst=nd.tolist())
+
+ def test_ndarray_slice_assign_multidim(self):
+ shape_t = (2, 3, 5)
+ ndim = len(shape_t)
+ nitems = prod(shape_t)
+ for shape in permutations(shape_t):
+
+ fmt, items, _ = randitems(nitems)
+
+ for flags in (0, ND_PIL):
+ for _ in range(ITERATIONS):
+ lslices, rslices = randslice_from_shape(ndim, shape)
+
+ nd = ndarray(items, shape=shape, format=fmt,
+ flags=flags|ND_WRITABLE)
+ lst = carray(items, shape)
+
+ listerr = None
+ try:
+ result = multislice_assign(lst, lst, lslices, rslices)
+ except Exception as e:
+ listerr = e.__class__
+
+ nderr = None
+ try:
+ nd[lslices] = nd[rslices]
+ except Exception as e:
+ nderr = e.__class__
+
+ if nderr or listerr:
+ self.assertIs(nderr, listerr)
+ else:
+ self.assertEqual(nd.tolist(), result)
+
+ def test_ndarray_random(self):
+ # construction of valid arrays
+ for _ in range(ITERATIONS):
+ for fmt in fmtdict['@']:
+ itemsize = struct.calcsize(fmt)
+
+ t = rand_structure(itemsize, True, maxdim=MAXDIM,
+ maxshape=MAXSHAPE)
+ self.assertTrue(verify_structure(*t))
+ items = randitems_from_structure(fmt, t)
+
+ x = ndarray_from_structure(items, fmt, t)
+ xlist = x.tolist()
+
+ mv = memoryview(x)
+ if is_memoryview_format(fmt):
+ mvlist = mv.tolist()
+ self.assertEqual(mvlist, xlist)
+
+ if t[2] > 0:
+ # ndim > 0: test against suboffsets representation.
+ y = ndarray_from_structure(items, fmt, t, flags=ND_PIL)
+ ylist = y.tolist()
+ self.assertEqual(xlist, ylist)
+
+ mv = memoryview(y)
+ if is_memoryview_format(fmt):
+ self.assertEqual(mv, y)
+ mvlist = mv.tolist()
+ self.assertEqual(mvlist, ylist)
+
+ if numpy_array:
+ shape = t[3]
+ if 0 in shape:
+ continue # http://projects.scipy.org/numpy/ticket/1910
+ z = numpy_array_from_structure(items, fmt, t)
+ self.verify(x, obj=None,
+ itemsize=z.itemsize, fmt=fmt, readonly=0,
+ ndim=z.ndim, shape=z.shape, strides=z.strides,
+ lst=z.tolist())
+
+ def test_ndarray_random_invalid(self):
+ # exceptions during construction of invalid arrays
+ for _ in range(ITERATIONS):
+ for fmt in fmtdict['@']:
+ itemsize = struct.calcsize(fmt)
+
+ t = rand_structure(itemsize, False, maxdim=MAXDIM,
+ maxshape=MAXSHAPE)
+ self.assertFalse(verify_structure(*t))
+ items = randitems_from_structure(fmt, t)
+
+ nderr = False
+ try:
+ x = ndarray_from_structure(items, fmt, t)
+ except Exception as e:
+ nderr = e.__class__
+ self.assertTrue(nderr)
+
+ if numpy_array:
+ numpy_err = False
+ try:
+ y = numpy_array_from_structure(items, fmt, t)
+ except Exception as e:
+ numpy_err = e.__class__
+
+ if 0: # http://projects.scipy.org/numpy/ticket/1910
+ self.assertTrue(numpy_err)
+
+ def test_ndarray_random_slice_assign(self):
+ # valid slice assignments
+ for _ in range(ITERATIONS):
+ for fmt in fmtdict['@']:
+ itemsize = struct.calcsize(fmt)
+
+ lshape, rshape, lslices, rslices = \
+ rand_aligned_slices(maxdim=MAXDIM, maxshape=MAXSHAPE)
+ tl = rand_structure(itemsize, True, shape=lshape)
+ tr = rand_structure(itemsize, True, shape=rshape)
+ self.assertTrue(verify_structure(*tl))
+ self.assertTrue(verify_structure(*tr))
+ litems = randitems_from_structure(fmt, tl)
+ ritems = randitems_from_structure(fmt, tr)
+
+ xl = ndarray_from_structure(litems, fmt, tl)
+ xr = ndarray_from_structure(ritems, fmt, tr)
+ xl[lslices] = xr[rslices]
+ xllist = xl.tolist()
+ xrlist = xr.tolist()
+
+ ml = memoryview(xl)
+ mr = memoryview(xr)
+ self.assertEqual(ml.tolist(), xllist)
+ self.assertEqual(mr.tolist(), xrlist)
+
+ if tl[2] > 0 and tr[2] > 0:
+ # ndim > 0: test against suboffsets representation.
+ yl = ndarray_from_structure(litems, fmt, tl, flags=ND_PIL)
+ yr = ndarray_from_structure(ritems, fmt, tr, flags=ND_PIL)
+ yl[lslices] = yr[rslices]
+ yllist = yl.tolist()
+ yrlist = yr.tolist()
+ self.assertEqual(xllist, yllist)
+ self.assertEqual(xrlist, yrlist)
+
+ ml = memoryview(yl)
+ mr = memoryview(yr)
+ self.assertEqual(ml.tolist(), yllist)
+ self.assertEqual(mr.tolist(), yrlist)
+
+ if numpy_array:
+ if 0 in lshape or 0 in rshape:
+ continue # http://projects.scipy.org/numpy/ticket/1910
+
+ zl = numpy_array_from_structure(litems, fmt, tl)
+ zr = numpy_array_from_structure(ritems, fmt, tr)
+ zl[lslices] = zr[rslices]
+
+ if not is_overlapping(tl) and not is_overlapping(tr):
+ # Slice assignment of overlapping structures
+ # is undefined in NumPy.
+ self.verify(xl, obj=None,
+ itemsize=zl.itemsize, fmt=fmt, readonly=0,
+ ndim=zl.ndim, shape=zl.shape,
+ strides=zl.strides, lst=zl.tolist())
+
+ self.verify(xr, obj=None,
+ itemsize=zr.itemsize, fmt=fmt, readonly=0,
+ ndim=zr.ndim, shape=zr.shape,
+ strides=zr.strides, lst=zr.tolist())
+
+ def test_ndarray_re_export(self):
+ items = [1,2,3,4,5,6,7,8,9,10,11,12]
+
+ nd = ndarray(items, shape=[3,4], flags=ND_PIL)
+ ex = ndarray(nd)
+
+ self.assertTrue(ex.flags & ND_PIL)
+ self.assertIs(ex.obj, nd)
+ self.assertEqual(ex.suboffsets, (0, -1))
+ self.assertFalse(ex.c_contiguous)
+ self.assertFalse(ex.f_contiguous)
+ self.assertFalse(ex.contiguous)
+
+ def test_ndarray_zero_shape(self):
+ # zeros in shape
+ for flags in (0, ND_PIL):
+ nd = ndarray([1,2,3], shape=[0], flags=flags)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertEqual(nd.tolist(), [])
+ self.assertEqual(mv.tolist(), [])
+
+ nd = ndarray([1,2,3], shape=[0,3,3], flags=flags)
+ self.assertEqual(nd.tolist(), [])
+
+ nd = ndarray([1,2,3], shape=[3,0,3], flags=flags)
+ self.assertEqual(nd.tolist(), [[], [], []])
+
+ nd = ndarray([1,2,3], shape=[3,3,0], flags=flags)
+ self.assertEqual(nd.tolist(),
+ [[[], [], []], [[], [], []], [[], [], []]])
+
+ def test_ndarray_zero_strides(self):
+ # zero strides
+ for flags in (0, ND_PIL):
+ nd = ndarray([1], shape=[5], strides=[0], flags=flags)
+ mv = memoryview(nd)
+ self.assertEqual(mv, nd)
+ self.assertEqual(nd.tolist(), [1, 1, 1, 1, 1])
+ self.assertEqual(mv.tolist(), [1, 1, 1, 1, 1])
+
+ def test_ndarray_offset(self):
+ nd = ndarray(list(range(20)), shape=[3], offset=7)
+ self.assertEqual(nd.offset, 7)
+ self.assertEqual(nd.tolist(), [7,8,9])
+
+ def test_ndarray_memoryview_from_buffer(self):
+ for flags in (0, ND_PIL):
+ nd = ndarray(list(range(3)), shape=[3], flags=flags)
+ m = nd.memoryview_from_buffer()
+ self.assertEqual(m, nd)
+
+ def test_ndarray_get_pointer(self):
+ for flags in (0, ND_PIL):
+ nd = ndarray(list(range(3)), shape=[3], flags=flags)
+ for i in range(3):
+ self.assertEqual(nd[i], get_pointer(nd, [i]))
+
+ def test_ndarray_tolist_null_strides(self):
+ ex = ndarray(list(range(20)), shape=[2,2,5])
+
+ nd = ndarray(ex, getbuf=PyBUF_ND|PyBUF_FORMAT)
+ self.assertEqual(nd.tolist(), ex.tolist())
+
+ m = memoryview(ex)
+ self.assertEqual(m.tolist(), ex.tolist())
+
+ def test_ndarray_cmp_contig(self):
+
+ self.assertFalse(cmp_contig(b"123", b"456"))
+
+ x = ndarray(list(range(12)), shape=[3,4])
+ y = ndarray(list(range(12)), shape=[4,3])
+ self.assertFalse(cmp_contig(x, y))
+
+ x = ndarray([1], shape=[1], format="B")
+ self.assertTrue(cmp_contig(x, b'\x01'))
+ self.assertTrue(cmp_contig(b'\x01', x))
+
+ def test_ndarray_hash(self):
+
+ a = array.array('L', [1,2,3])
+ nd = ndarray(a)
+ self.assertRaises(ValueError, hash, nd)
+
+ # one-dimensional
+ b = bytes(list(range(12)))
+
+ nd = ndarray(list(range(12)), shape=[12])
+ self.assertEqual(hash(nd), hash(b))
+
+ # C-contiguous
+ nd = ndarray(list(range(12)), shape=[3,4])
+ self.assertEqual(hash(nd), hash(b))
+
+ nd = ndarray(list(range(12)), shape=[3,2,2])
+ self.assertEqual(hash(nd), hash(b))
+
+ # Fortran contiguous
+ b = bytes(transpose(list(range(12)), shape=[4,3]))
+ nd = ndarray(list(range(12)), shape=[3,4], flags=ND_FORTRAN)
+ self.assertEqual(hash(nd), hash(b))
+
+ b = bytes(transpose(list(range(12)), shape=[2,3,2]))
+ nd = ndarray(list(range(12)), shape=[2,3,2], flags=ND_FORTRAN)
+ self.assertEqual(hash(nd), hash(b))
+
+ # suboffsets
+ b = bytes(list(range(12)))
+ nd = ndarray(list(range(12)), shape=[2,2,3], flags=ND_PIL)
+ self.assertEqual(hash(nd), hash(b))
+
+ # non-byte formats
+ nd = ndarray(list(range(12)), shape=[2,2,3], format='L')
+ self.assertEqual(hash(nd), hash(nd.tobytes()))
+
+ def test_py_buffer_to_contiguous(self):
+
+ # The requests are used in _testbuffer.c:py_buffer_to_contiguous
+ # to generate buffers without full information for testing.
+ requests = (
+ # distinct flags
+ PyBUF_INDIRECT, PyBUF_STRIDES, PyBUF_ND, PyBUF_SIMPLE,
+ # compound requests
+ PyBUF_FULL, PyBUF_FULL_RO,
+ PyBUF_RECORDS, PyBUF_RECORDS_RO,
+ PyBUF_STRIDED, PyBUF_STRIDED_RO,
+ PyBUF_CONTIG, PyBUF_CONTIG_RO,
+ )
+
+ # no buffer interface
+ self.assertRaises(TypeError, py_buffer_to_contiguous, {}, 'F',
+ PyBUF_FULL_RO)
+
+ # scalar, read-only request
+ nd = ndarray(9, shape=(), format="L", flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ for request in requests:
+ b = py_buffer_to_contiguous(nd, order, request)
+ self.assertEqual(b, nd.tobytes())
+
+ # zeros in shape
+ nd = ndarray([1], shape=[0], format="L", flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ for request in requests:
+ b = py_buffer_to_contiguous(nd, order, request)
+ self.assertEqual(b, b'')
+
+ nd = ndarray(list(range(8)), shape=[2, 0, 7], format="L",
+ flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ for request in requests:
+ b = py_buffer_to_contiguous(nd, order, request)
+ self.assertEqual(b, b'')
+
+ ### One-dimensional arrays are trivial, since Fortran and C order
+ ### are the same.
+
+ # one-dimensional
+ for f in [0, ND_FORTRAN]:
+ nd = ndarray([1], shape=[1], format="h", flags=f|ND_WRITABLE)
+ ndbytes = nd.tobytes()
+ for order in ['C', 'F', 'A']:
+ for request in requests:
+ b = py_buffer_to_contiguous(nd, order, request)
+ self.assertEqual(b, ndbytes)
+
+ nd = ndarray([1, 2, 3], shape=[3], format="b", flags=f|ND_WRITABLE)
+ ndbytes = nd.tobytes()
+ for order in ['C', 'F', 'A']:
+ for request in requests:
+ b = py_buffer_to_contiguous(nd, order, request)
+ self.assertEqual(b, ndbytes)
+
+ # one-dimensional, non-contiguous input
+ nd = ndarray([1, 2, 3], shape=[2], strides=[2], flags=ND_WRITABLE)
+ ndbytes = nd.tobytes()
+ for order in ['C', 'F', 'A']:
+ for request in [PyBUF_STRIDES, PyBUF_FULL]:
+ b = py_buffer_to_contiguous(nd, order, request)
+ self.assertEqual(b, ndbytes)
+
+ nd = nd[::-1]
+ ndbytes = nd.tobytes()
+ for order in ['C', 'F', 'A']:
+ for request in requests:
+ try:
+ b = py_buffer_to_contiguous(nd, order, request)
+ except BufferError:
+ continue
+ self.assertEqual(b, ndbytes)
+
+ ###
+ ### Multi-dimensional arrays:
+ ###
+ ### The goal here is to preserve the logical representation of the
+ ### input array but change the physical representation if necessary.
+ ###
+ ### _testbuffer example:
+ ### ====================
+ ###
+ ### C input array:
+ ### --------------
+ ### >>> nd = ndarray(list(range(12)), shape=[3, 4])
+ ### >>> nd.tolist()
+ ### [[0, 1, 2, 3],
+ ### [4, 5, 6, 7],
+ ### [8, 9, 10, 11]]
+ ###
+ ### Fortran output:
+ ### ---------------
+ ### >>> py_buffer_to_contiguous(nd, 'F', PyBUF_FULL_RO)
+ ### >>> b'\x00\x04\x08\x01\x05\t\x02\x06\n\x03\x07\x0b'
+ ###
+ ### The return value corresponds to this input list for
+ ### _testbuffer's ndarray:
+ ### >>> nd = ndarray([0,4,8,1,5,9,2,6,10,3,7,11], shape=[3,4],
+ ### flags=ND_FORTRAN)
+ ### >>> nd.tolist()
+ ### [[0, 1, 2, 3],
+ ### [4, 5, 6, 7],
+ ### [8, 9, 10, 11]]
+ ###
+ ### The logical array is the same, but the values in memory are now
+ ### in Fortran order.
+ ###
+ ### NumPy example:
+ ### ==============
+ ### _testbuffer's ndarray takes lists to initialize the memory.
+ ### Here's the same sequence in NumPy:
+ ###
+ ### C input:
+ ### --------
+ ### >>> nd = ndarray(buffer=bytearray(list(range(12))),
+ ### shape=[3, 4], dtype='B')
+ ### >>> nd
+ ### array([[ 0, 1, 2, 3],
+ ### [ 4, 5, 6, 7],
+ ### [ 8, 9, 10, 11]], dtype=uint8)
+ ###
+ ### Fortran output:
+ ### ---------------
+ ### >>> fortran_buf = nd.tostring(order='F')
+ ### >>> fortran_buf
+ ### b'\x00\x04\x08\x01\x05\t\x02\x06\n\x03\x07\x0b'
+ ###
+ ### >>> nd = ndarray(buffer=fortran_buf, shape=[3, 4],
+ ### dtype='B', order='F')
+ ###
+ ### >>> nd
+ ### array([[ 0, 1, 2, 3],
+ ### [ 4, 5, 6, 7],
+ ### [ 8, 9, 10, 11]], dtype=uint8)
+ ###
+
+ # multi-dimensional, contiguous input
+ lst = list(range(12))
+ for f in [0, ND_FORTRAN]:
+ nd = ndarray(lst, shape=[3, 4], flags=f|ND_WRITABLE)
+ if numpy_array:
+ na = numpy_array(buffer=bytearray(lst),
+ shape=[3, 4], dtype='B',
+ order='C' if f == 0 else 'F')
+
+ # 'C' request
+ if f == ND_FORTRAN: # 'F' to 'C'
+ x = ndarray(transpose(lst, [4, 3]), shape=[3, 4],
+ flags=ND_WRITABLE)
+ expected = x.tobytes()
+ else:
+ expected = nd.tobytes()
+ for request in requests:
+ try:
+ b = py_buffer_to_contiguous(nd, 'C', request)
+ except BufferError:
+ continue
+
+ self.assertEqual(b, expected)
+
+ # Check that output can be used as the basis for constructing
+ # a C array that is logically identical to the input array.
+ y = ndarray([v for v in b], shape=[3, 4], flags=ND_WRITABLE)
+ self.assertEqual(memoryview(y), memoryview(nd))
+
+ if numpy_array:
+ self.assertEqual(b, na.tostring(order='C'))
+
+ # 'F' request
+ if f == 0: # 'C' to 'F'
+ x = ndarray(transpose(lst, [3, 4]), shape=[4, 3],
+ flags=ND_WRITABLE)
+ else:
+ x = ndarray(lst, shape=[3, 4], flags=ND_WRITABLE)
+ expected = x.tobytes()
+ for request in [PyBUF_FULL, PyBUF_FULL_RO, PyBUF_INDIRECT,
+ PyBUF_STRIDES, PyBUF_ND]:
+ try:
+ b = py_buffer_to_contiguous(nd, 'F', request)
+ except BufferError:
+ continue
+ self.assertEqual(b, expected)
+
+ # Check that output can be used as the basis for constructing
+ # a Fortran array that is logically identical to the input array.
+ y = ndarray([v for v in b], shape=[3, 4], flags=ND_FORTRAN|ND_WRITABLE)
+ self.assertEqual(memoryview(y), memoryview(nd))
+
+ if numpy_array:
+ self.assertEqual(b, na.tostring(order='F'))
+
+ # 'A' request
+ if f == ND_FORTRAN:
+ x = ndarray(lst, shape=[3, 4], flags=ND_WRITABLE)
+ expected = x.tobytes()
+ else:
+ expected = nd.tobytes()
+ for request in [PyBUF_FULL, PyBUF_FULL_RO, PyBUF_INDIRECT,
+ PyBUF_STRIDES, PyBUF_ND]:
+ try:
+ b = py_buffer_to_contiguous(nd, 'A', request)
+ except BufferError:
+ continue
+
+ self.assertEqual(b, expected)
+
+ # Check that output can be used as the basis for constructing
+ # an array with order=f that is logically identical to the input
+ # array.
+ y = ndarray([v for v in b], shape=[3, 4], flags=f|ND_WRITABLE)
+ self.assertEqual(memoryview(y), memoryview(nd))
+
+ if numpy_array:
+ self.assertEqual(b, na.tostring(order='A'))
+
+ # multi-dimensional, non-contiguous input
+ nd = ndarray(list(range(12)), shape=[3, 4], flags=ND_WRITABLE|ND_PIL)
+
+ # 'C'
+ b = py_buffer_to_contiguous(nd, 'C', PyBUF_FULL_RO)
+ self.assertEqual(b, nd.tobytes())
+ y = ndarray([v for v in b], shape=[3, 4], flags=ND_WRITABLE)
+ self.assertEqual(memoryview(y), memoryview(nd))
+
+ # 'F'
+ b = py_buffer_to_contiguous(nd, 'F', PyBUF_FULL_RO)
+ x = ndarray(transpose(lst, [3, 4]), shape=[4, 3], flags=ND_WRITABLE)
+ self.assertEqual(b, x.tobytes())
+ y = ndarray([v for v in b], shape=[3, 4], flags=ND_FORTRAN|ND_WRITABLE)
+ self.assertEqual(memoryview(y), memoryview(nd))
+
+ # 'A'
+ b = py_buffer_to_contiguous(nd, 'A', PyBUF_FULL_RO)
+ self.assertEqual(b, nd.tobytes())
+ y = ndarray([v for v in b], shape=[3, 4], flags=ND_WRITABLE)
+ self.assertEqual(memoryview(y), memoryview(nd))
+
+ def test_memoryview_construction(self):
+
+ items_shape = [(9, []), ([1,2,3], [3]), (list(range(2*3*5)), [2,3,5])]
+
+ # NumPy style, C-contiguous:
+ for items, shape in items_shape:
+
+ # From PEP-3118 compliant exporter:
+ ex = ndarray(items, shape=shape)
+ m = memoryview(ex)
+ self.assertTrue(m.c_contiguous)
+ self.assertTrue(m.contiguous)
+
+ ndim = len(shape)
+ strides = strides_from_shape(ndim, shape, 1, 'C')
+ lst = carray(items, shape)
+
+ self.verify(m, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ # From memoryview:
+ m2 = memoryview(m)
+ self.verify(m2, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ # PyMemoryView_FromBuffer(): no strides
+ nd = ndarray(ex, getbuf=PyBUF_CONTIG_RO|PyBUF_FORMAT)
+ self.assertEqual(nd.strides, ())
+ m = nd.memoryview_from_buffer()
+ self.verify(m, obj=None,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ # PyMemoryView_FromBuffer(): no format, shape, strides
+ nd = ndarray(ex, getbuf=PyBUF_SIMPLE)
+ self.assertEqual(nd.format, '')
+ self.assertEqual(nd.shape, ())
+ self.assertEqual(nd.strides, ())
+ m = nd.memoryview_from_buffer()
+
+ lst = [items] if ndim == 0 else items
+ self.verify(m, obj=None,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=1, shape=[ex.nbytes], strides=(1,),
+ lst=lst)
+
+ # NumPy style, Fortran contiguous:
+ for items, shape in items_shape:
+
+ # From PEP-3118 compliant exporter:
+ ex = ndarray(items, shape=shape, flags=ND_FORTRAN)
+ m = memoryview(ex)
+ self.assertTrue(m.f_contiguous)
+ self.assertTrue(m.contiguous)
+
+ ndim = len(shape)
+ strides = strides_from_shape(ndim, shape, 1, 'F')
+ lst = farray(items, shape)
+
+ self.verify(m, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ # From memoryview:
+ m2 = memoryview(m)
+ self.verify(m2, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst)
+
+ # PIL style:
+ for items, shape in items_shape[1:]:
+
+ # From PEP-3118 compliant exporter:
+ ex = ndarray(items, shape=shape, flags=ND_PIL)
+ m = memoryview(ex)
+
+ ndim = len(shape)
+ lst = carray(items, shape)
+
+ self.verify(m, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=ex.strides,
+ lst=lst)
+
+ # From memoryview:
+ m2 = memoryview(m)
+ self.verify(m2, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=ndim, shape=shape, strides=ex.strides,
+ lst=lst)
+
+ # Invalid number of arguments:
+ self.assertRaises(TypeError, memoryview, b'9', 'x')
+ # Not a buffer provider:
+ self.assertRaises(TypeError, memoryview, {})
+ # Non-compliant buffer provider:
+ ex = ndarray([1,2,3], shape=[3])
+ nd = ndarray(ex, getbuf=PyBUF_SIMPLE)
+ self.assertRaises(BufferError, memoryview, nd)
+ nd = ndarray(ex, getbuf=PyBUF_CONTIG_RO|PyBUF_FORMAT)
+ self.assertRaises(BufferError, memoryview, nd)
+
+ # ndim > 64
+ nd = ndarray([1]*128, shape=[1]*128, format='L')
+ self.assertRaises(ValueError, memoryview, nd)
+ self.assertRaises(ValueError, nd.memoryview_from_buffer)
+ self.assertRaises(ValueError, get_contiguous, nd, PyBUF_READ, 'C')
+ self.assertRaises(ValueError, get_contiguous, nd, PyBUF_READ, 'F')
+ self.assertRaises(ValueError, get_contiguous, nd[::-1], PyBUF_READ, 'C')
+
+ def test_memoryview_cast_zero_shape(self):
+ # Casts are undefined if shape contains zeros. These arrays are
+ # regarded as C-contiguous by Numpy and PyBuffer_GetContiguous(),
+ # so they are not caught by the test for C-contiguity in memory_cast().
+ items = [1,2,3]
+ for shape in ([0,3,3], [3,0,3], [0,3,3]):
+ ex = ndarray(items, shape=shape)
+ self.assertTrue(ex.c_contiguous)
+ msrc = memoryview(ex)
+ self.assertRaises(TypeError, msrc.cast, 'c')
+
+ def test_memoryview_struct_module(self):
+
+ class INT(object):
+ def __init__(self, val):
+ self.val = val
+ def __int__(self):
+ return self.val
+
+ class IDX(object):
+ def __init__(self, val):
+ self.val = val
+ def __index__(self):
+ return self.val
+
+ def f(): return 7
+
+ values = [INT(9), IDX(9),
+ 2.2+3j, Decimal("-21.1"), 12.2, Fraction(5, 2),
+ [1,2,3], {4,5,6}, {7:8}, (), (9,),
+ True, False, None, NotImplemented,
+ b'a', b'abc', bytearray(b'a'), bytearray(b'abc'),
+ 'a', 'abc', r'a', r'abc',
+ f, lambda x: x]
+
+ for fmt, items, item in iter_format(10, 'memoryview'):
+ ex = ndarray(items, shape=[10], format=fmt, flags=ND_WRITABLE)
+ nd = ndarray(items, shape=[10], format=fmt, flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ struct.pack_into(fmt, nd, 0, item)
+ m[0] = item
+ self.assertEqual(m[0], nd[0])
+
+ itemsize = struct.calcsize(fmt)
+ if 'P' in fmt:
+ continue
+
+ for v in values:
+ struct_err = None
+ try:
+ struct.pack_into(fmt, nd, itemsize, v)
+ except struct.error:
+ struct_err = struct.error
+
+ mv_err = None
+ try:
+ m[1] = v
+ except (TypeError, ValueError) as e:
+ mv_err = e.__class__
+
+ if struct_err or mv_err:
+ self.assertIsNot(struct_err, None)
+ self.assertIsNot(mv_err, None)
+ else:
+ self.assertEqual(m[1], nd[1])
+
+ def test_memoryview_cast_zero_strides(self):
+ # Casts are undefined if strides contains zeros. These arrays are
+ # (sometimes!) regarded as C-contiguous by Numpy, but not by
+ # PyBuffer_GetContiguous().
+ ex = ndarray([1,2,3], shape=[3], strides=[0])
+ self.assertFalse(ex.c_contiguous)
+ msrc = memoryview(ex)
+ self.assertRaises(TypeError, msrc.cast, 'c')
+
+ def test_memoryview_cast_invalid(self):
+ # invalid format
+ for sfmt in NON_BYTE_FORMAT:
+ sformat = '@' + sfmt if randrange(2) else sfmt
+ ssize = struct.calcsize(sformat)
+ for dfmt in NON_BYTE_FORMAT:
+ dformat = '@' + dfmt if randrange(2) else dfmt
+ dsize = struct.calcsize(dformat)
+ ex = ndarray(list(range(32)), shape=[32//ssize], format=sformat)
+ msrc = memoryview(ex)
+ self.assertRaises(TypeError, msrc.cast, dfmt, [32//dsize])
+
+ for sfmt, sitems, _ in iter_format(1):
+ ex = ndarray(sitems, shape=[1], format=sfmt)
+ msrc = memoryview(ex)
+ for dfmt, _, _ in iter_format(1):
+ if (not is_memoryview_format(sfmt) or
+ not is_memoryview_format(dfmt)):
+ self.assertRaises(ValueError, msrc.cast, dfmt,
+ [32//dsize])
+ else:
+ if not is_byte_format(sfmt) and not is_byte_format(dfmt):
+ self.assertRaises(TypeError, msrc.cast, dfmt,
+ [32//dsize])
+
+ # invalid shape
+ size_h = struct.calcsize('h')
+ size_d = struct.calcsize('d')
+ ex = ndarray(list(range(2*2*size_d)), shape=[2,2,size_d], format='h')
+ msrc = memoryview(ex)
+ self.assertRaises(TypeError, msrc.cast, shape=[2,2,size_h], format='d')
+
+ ex = ndarray(list(range(120)), shape=[1,2,3,4,5])
+ m = memoryview(ex)
+
+ # incorrect number of args
+ self.assertRaises(TypeError, m.cast)
+ self.assertRaises(TypeError, m.cast, 1, 2, 3)
+
+ # incorrect dest format type
+ self.assertRaises(TypeError, m.cast, {})
+
+ # incorrect dest format
+ self.assertRaises(ValueError, m.cast, "X")
+ self.assertRaises(ValueError, m.cast, "@X")
+ self.assertRaises(ValueError, m.cast, "@XY")
+
+ # dest format not implemented
+ self.assertRaises(ValueError, m.cast, "=B")
+ self.assertRaises(ValueError, m.cast, "!L")
+ self.assertRaises(ValueError, m.cast, "<P")
+ self.assertRaises(ValueError, m.cast, ">l")
+ self.assertRaises(ValueError, m.cast, "BI")
+ self.assertRaises(ValueError, m.cast, "xBI")
+
+ # src format not implemented
+ ex = ndarray([(1,2), (3,4)], shape=[2], format="II")
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.__getitem__, 0)
+ self.assertRaises(NotImplementedError, m.__setitem__, 0, 8)
+ self.assertRaises(NotImplementedError, m.tolist)
+
+ # incorrect shape type
+ ex = ndarray(list(range(120)), shape=[1,2,3,4,5])
+ m = memoryview(ex)
+ self.assertRaises(TypeError, m.cast, "B", shape={})
+
+ # incorrect shape elements
+ ex = ndarray(list(range(120)), shape=[2*3*4*5])
+ m = memoryview(ex)
+ self.assertRaises(OverflowError, m.cast, "B", shape=[2**64])
+ self.assertRaises(ValueError, m.cast, "B", shape=[-1])
+ self.assertRaises(ValueError, m.cast, "B", shape=[2,3,4,5,6,7,-1])
+ self.assertRaises(ValueError, m.cast, "B", shape=[2,3,4,5,6,7,0])
+ self.assertRaises(TypeError, m.cast, "B", shape=[2,3,4,5,6,7,'x'])
+
+ # N-D -> N-D cast
+ ex = ndarray(list([9 for _ in range(3*5*7*11)]), shape=[3,5,7,11])
+ m = memoryview(ex)
+ self.assertRaises(TypeError, m.cast, "I", shape=[2,3,4,5])
+
+ # cast with ndim > 64
+ nd = ndarray(list(range(128)), shape=[128], format='I')
+ m = memoryview(nd)
+ self.assertRaises(ValueError, m.cast, 'I', [1]*128)
+
+ # view->len not a multiple of itemsize
+ ex = ndarray(list([9 for _ in range(3*5*7*11)]), shape=[3*5*7*11])
+ m = memoryview(ex)
+ self.assertRaises(TypeError, m.cast, "I", shape=[2,3,4,5])
+
+ # product(shape) * itemsize != buffer size
+ ex = ndarray(list([9 for _ in range(3*5*7*11)]), shape=[3*5*7*11])
+ m = memoryview(ex)
+ self.assertRaises(TypeError, m.cast, "B", shape=[2,3,4,5])
+
+ # product(shape) * itemsize overflow
+ nd = ndarray(list(range(128)), shape=[128], format='I')
+ m1 = memoryview(nd)
+ nd = ndarray(list(range(128)), shape=[128], format='B')
+ m2 = memoryview(nd)
+ if sys.maxsize == 2**63-1:
+ self.assertRaises(TypeError, m1.cast, 'B',
+ [7, 7, 73, 127, 337, 92737, 649657])
+ self.assertRaises(ValueError, m1.cast, 'B',
+ [2**20, 2**20, 2**10, 2**10, 2**3])
+ self.assertRaises(ValueError, m2.cast, 'I',
+ [2**20, 2**20, 2**10, 2**10, 2**1])
+ else:
+ self.assertRaises(TypeError, m1.cast, 'B',
+ [1, 2147483647])
+ self.assertRaises(ValueError, m1.cast, 'B',
+ [2**10, 2**10, 2**5, 2**5, 2**1])
+ self.assertRaises(ValueError, m2.cast, 'I',
+ [2**10, 2**10, 2**5, 2**3, 2**1])
+
+ def test_memoryview_cast(self):
+ bytespec = (
+ ('B', lambda ex: list(ex.tobytes())),
+ ('b', lambda ex: [x-256 if x > 127 else x for x in list(ex.tobytes())]),
+ ('c', lambda ex: [bytes(chr(x), 'latin-1') for x in list(ex.tobytes())]),
+ )
+
+ def iter_roundtrip(ex, m, items, fmt):
+ srcsize = struct.calcsize(fmt)
+ for bytefmt, to_bytelist in bytespec:
+
+ m2 = m.cast(bytefmt)
+ lst = to_bytelist(ex)
+ self.verify(m2, obj=ex,
+ itemsize=1, fmt=bytefmt, readonly=0,
+ ndim=1, shape=[31*srcsize], strides=(1,),
+ lst=lst, cast=True)
+
+ m3 = m2.cast(fmt)
+ self.assertEqual(m3, ex)
+ lst = ex.tolist()
+ self.verify(m3, obj=ex,
+ itemsize=srcsize, fmt=fmt, readonly=0,
+ ndim=1, shape=[31], strides=(srcsize,),
+ lst=lst, cast=True)
+
+ # cast from ndim = 0 to ndim = 1
+ srcsize = struct.calcsize('I')
+ ex = ndarray(9, shape=[], format='I')
+ destitems, destshape = cast_items(ex, 'B', 1)
+ m = memoryview(ex)
+ m2 = m.cast('B')
+ self.verify(m2, obj=ex,
+ itemsize=1, fmt='B', readonly=1,
+ ndim=1, shape=destshape, strides=(1,),
+ lst=destitems, cast=True)
+
+ # cast from ndim = 1 to ndim = 0
+ destsize = struct.calcsize('I')
+ ex = ndarray([9]*destsize, shape=[destsize], format='B')
+ destitems, destshape = cast_items(ex, 'I', destsize, shape=[])
+ m = memoryview(ex)
+ m2 = m.cast('I', shape=[])
+ self.verify(m2, obj=ex,
+ itemsize=destsize, fmt='I', readonly=1,
+ ndim=0, shape=(), strides=(),
+ lst=destitems, cast=True)
+
+ # array.array: roundtrip to/from bytes
+ for fmt, items, _ in iter_format(31, 'array'):
+ ex = array.array(fmt, items)
+ m = memoryview(ex)
+ iter_roundtrip(ex, m, items, fmt)
+
+ # ndarray: roundtrip to/from bytes
+ for fmt, items, _ in iter_format(31, 'memoryview'):
+ ex = ndarray(items, shape=[31], format=fmt, flags=ND_WRITABLE)
+ m = memoryview(ex)
+ iter_roundtrip(ex, m, items, fmt)
+
+ def test_memoryview_cast_1D_ND(self):
+ # Cast between C-contiguous buffers. At least one buffer must
+ # be 1D, at least one format must be 'c', 'b' or 'B'.
+ for _tshape in gencastshapes():
+ for char in fmtdict['@']:
+ tfmt = ('', '@')[randrange(2)] + char
+ tsize = struct.calcsize(tfmt)
+ n = prod(_tshape) * tsize
+ obj = 'memoryview' if is_byte_format(tfmt) else 'bytefmt'
+ for fmt, items, _ in iter_format(n, obj):
+ size = struct.calcsize(fmt)
+ shape = [n] if n > 0 else []
+ tshape = _tshape + [size]
+
+ ex = ndarray(items, shape=shape, format=fmt)
+ m = memoryview(ex)
+
+ titems, tshape = cast_items(ex, tfmt, tsize, shape=tshape)
+
+ if titems is None:
+ self.assertRaises(TypeError, m.cast, tfmt, tshape)
+ continue
+ if titems == 'nan':
+ continue # NaNs in lists are a recipe for trouble.
+
+ # 1D -> ND
+ nd = ndarray(titems, shape=tshape, format=tfmt)
+
+ m2 = m.cast(tfmt, shape=tshape)
+ ndim = len(tshape)
+ strides = nd.strides
+ lst = nd.tolist()
+ self.verify(m2, obj=ex,
+ itemsize=tsize, fmt=tfmt, readonly=1,
+ ndim=ndim, shape=tshape, strides=strides,
+ lst=lst, cast=True)
+
+ # ND -> 1D
+ m3 = m2.cast(fmt)
+ m4 = m2.cast(fmt, shape=shape)
+ ndim = len(shape)
+ strides = ex.strides
+ lst = ex.tolist()
+
+ self.verify(m3, obj=ex,
+ itemsize=size, fmt=fmt, readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst, cast=True)
+
+ self.verify(m4, obj=ex,
+ itemsize=size, fmt=fmt, readonly=1,
+ ndim=ndim, shape=shape, strides=strides,
+ lst=lst, cast=True)
+
+ def test_memoryview_tolist(self):
+
+ # Most tolist() tests are in self.verify() etc.
+
+ a = array.array('h', list(range(-6, 6)))
+ m = memoryview(a)
+ self.assertEqual(m, a)
+ self.assertEqual(m.tolist(), a.tolist())
+
+ a = a[2::3]
+ m = m[2::3]
+ self.assertEqual(m, a)
+ self.assertEqual(m.tolist(), a.tolist())
+
+ ex = ndarray(list(range(2*3*5*7*11)), shape=[11,2,7,3,5], format='L')
+ m = memoryview(ex)
+ self.assertEqual(m.tolist(), ex.tolist())
+
+ ex = ndarray([(2, 5), (7, 11)], shape=[2], format='lh')
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.tolist)
+
+ ex = ndarray([b'12345'], shape=[1], format="s")
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.tolist)
+
+ ex = ndarray([b"a",b"b",b"c",b"d",b"e",b"f"], shape=[2,3], format='s')
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.tolist)
+
+ def test_memoryview_repr(self):
+ m = memoryview(bytearray(9))
+ r = m.__repr__()
+ self.assertTrue(r.startswith("<memory"))
+
+ m.release()
+ r = m.__repr__()
+ self.assertTrue(r.startswith("<released"))
+
+ def test_memoryview_sequence(self):
+
+ for fmt in ('d', 'f'):
+ inf = float(3e400)
+ ex = array.array(fmt, [1.0, inf, 3.0])
+ m = memoryview(ex)
+ self.assertIn(1.0, m)
+ self.assertIn(5e700, m)
+ self.assertIn(3.0, m)
+
+ ex = ndarray(9.0, [], format='f')
+ m = memoryview(ex)
+ self.assertRaises(TypeError, eval, "9.0 in m", locals())
+
+ def test_memoryview_index(self):
+
+ # ndim = 0
+ ex = ndarray(12.5, shape=[], format='d')
+ m = memoryview(ex)
+ self.assertEqual(m[()], 12.5)
+ self.assertEqual(m[...], m)
+ self.assertEqual(m[...], ex)
+ self.assertRaises(TypeError, m.__getitem__, 0)
+
+ ex = ndarray((1,2,3), shape=[], format='iii')
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.__getitem__, ())
+
+ # range
+ ex = ndarray(list(range(7)), shape=[7], flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ self.assertRaises(IndexError, m.__getitem__, 2**64)
+ self.assertRaises(TypeError, m.__getitem__, 2.0)
+ self.assertRaises(TypeError, m.__getitem__, 0.0)
+
+ # out of bounds
+ self.assertRaises(IndexError, m.__getitem__, -8)
+ self.assertRaises(IndexError, m.__getitem__, 8)
+
+ # Not implemented: multidimensional sub-views
+ ex = ndarray(list(range(12)), shape=[3,4], flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ self.assertRaises(NotImplementedError, m.__getitem__, 0)
+ self.assertRaises(NotImplementedError, m.__setitem__, 0, 9)
+ self.assertRaises(NotImplementedError, m.__getitem__, 0)
+
+ def test_memoryview_assign(self):
+
+ # ndim = 0
+ ex = ndarray(12.5, shape=[], format='f', flags=ND_WRITABLE)
+ m = memoryview(ex)
+ m[()] = 22.5
+ self.assertEqual(m[()], 22.5)
+ m[...] = 23.5
+ self.assertEqual(m[()], 23.5)
+ self.assertRaises(TypeError, m.__setitem__, 0, 24.7)
+
+ # read-only
+ ex = ndarray(list(range(7)), shape=[7])
+ m = memoryview(ex)
+ self.assertRaises(TypeError, m.__setitem__, 2, 10)
+
+ # range
+ ex = ndarray(list(range(7)), shape=[7], flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ self.assertRaises(IndexError, m.__setitem__, 2**64, 9)
+ self.assertRaises(TypeError, m.__setitem__, 2.0, 10)
+ self.assertRaises(TypeError, m.__setitem__, 0.0, 11)
+
+ # out of bounds
+ self.assertRaises(IndexError, m.__setitem__, -8, 20)
+ self.assertRaises(IndexError, m.__setitem__, 8, 25)
+
+ # pack_single() success:
+ for fmt in fmtdict['@']:
+ if fmt == 'c' or fmt == '?':
+ continue
+ ex = ndarray([1,2,3], shape=[3], format=fmt, flags=ND_WRITABLE)
+ m = memoryview(ex)
+ i = randrange(-3, 3)
+ m[i] = 8
+ self.assertEqual(m[i], 8)
+ self.assertEqual(m[i], ex[i])
+
+ ex = ndarray([b'1', b'2', b'3'], shape=[3], format='c',
+ flags=ND_WRITABLE)
+ m = memoryview(ex)
+ m[2] = b'9'
+ self.assertEqual(m[2], b'9')
+
+ ex = ndarray([True, False, True], shape=[3], format='?',
+ flags=ND_WRITABLE)
+ m = memoryview(ex)
+ m[1] = True
+ self.assertEqual(m[1], True)
+
+ # pack_single() exceptions:
+ nd = ndarray([b'x'], shape=[1], format='c', flags=ND_WRITABLE)
+ m = memoryview(nd)
+ self.assertRaises(TypeError, m.__setitem__, 0, 100)
+
+ ex = ndarray(list(range(120)), shape=[1,2,3,4,5], flags=ND_WRITABLE)
+ m1 = memoryview(ex)
+
+ for fmt, _range in fmtdict['@'].items():
+ if (fmt == '?'): # PyObject_IsTrue() accepts anything
+ continue
+ if fmt == 'c': # special case tested above
+ continue
+ m2 = m1.cast(fmt)
+ lo, hi = _range
+ if fmt == 'd' or fmt == 'f':
+ lo, hi = -2**1024, 2**1024
+ if fmt != 'P': # PyLong_AsVoidPtr() accepts negative numbers
+ self.assertRaises(ValueError, m2.__setitem__, 0, lo-1)
+ self.assertRaises(TypeError, m2.__setitem__, 0, "xyz")
+ self.assertRaises(ValueError, m2.__setitem__, 0, hi)
+
+ # invalid item
+ m2 = m1.cast('c')
+ self.assertRaises(ValueError, m2.__setitem__, 0, b'\xff\xff')
+
+ # format not implemented
+ ex = ndarray(list(range(1)), shape=[1], format="xL", flags=ND_WRITABLE)
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.__setitem__, 0, 1)
+
+ ex = ndarray([b'12345'], shape=[1], format="s", flags=ND_WRITABLE)
+ m = memoryview(ex)
+ self.assertRaises(NotImplementedError, m.__setitem__, 0, 1)
+
+ # Not implemented: multidimensional sub-views
+ ex = ndarray(list(range(12)), shape=[3,4], flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ self.assertRaises(NotImplementedError, m.__setitem__, 0, [2, 3])
+
+ def test_memoryview_slice(self):
+
+ ex = ndarray(list(range(12)), shape=[12], flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ # zero step
+ self.assertRaises(ValueError, m.__getitem__, slice(0,2,0))
+ self.assertRaises(ValueError, m.__setitem__, slice(0,2,0),
+ bytearray([1,2]))
+
+ # invalid slice key
+ self.assertRaises(TypeError, m.__getitem__, ())
+
+ # multidimensional slices
+ ex = ndarray(list(range(12)), shape=[12], flags=ND_WRITABLE)
+ m = memoryview(ex)
+
+ self.assertRaises(NotImplementedError, m.__getitem__,
+ (slice(0,2,1), slice(0,2,1)))
+ self.assertRaises(NotImplementedError, m.__setitem__,
+ (slice(0,2,1), slice(0,2,1)), bytearray([1,2]))
+
+ # invalid slice tuple
+ self.assertRaises(TypeError, m.__getitem__, (slice(0,2,1), {}))
+ self.assertRaises(TypeError, m.__setitem__, (slice(0,2,1), {}),
+ bytearray([1,2]))
+
+ # rvalue is not an exporter
+ self.assertRaises(TypeError, m.__setitem__, slice(0,1,1), [1])
+
+ # non-contiguous slice assignment
+ for flags in (0, ND_PIL):
+ ex1 = ndarray(list(range(12)), shape=[12], strides=[-1], offset=11,
+ flags=ND_WRITABLE|flags)
+ ex2 = ndarray(list(range(24)), shape=[12], strides=[2], flags=flags)
+ m1 = memoryview(ex1)
+ m2 = memoryview(ex2)
+
+ ex1[2:5] = ex1[2:5]
+ m1[2:5] = m2[2:5]
+
+ self.assertEqual(m1, ex1)
+ self.assertEqual(m2, ex2)
+
+ ex1[1:3][::-1] = ex2[0:2][::1]
+ m1[1:3][::-1] = m2[0:2][::1]
+
+ self.assertEqual(m1, ex1)
+ self.assertEqual(m2, ex2)
+
+ ex1[4:1:-2][::-1] = ex1[1:4:2][::1]
+ m1[4:1:-2][::-1] = m1[1:4:2][::1]
+
+ self.assertEqual(m1, ex1)
+ self.assertEqual(m2, ex2)
+
+ def test_memoryview_array(self):
+
+ def cmptest(testcase, a, b, m, singleitem):
+ for i, _ in enumerate(a):
+ ai = a[i]
+ mi = m[i]
+ testcase.assertEqual(ai, mi)
+ a[i] = singleitem
+ if singleitem != ai:
+ testcase.assertNotEqual(a, m)
+ testcase.assertNotEqual(a, b)
+ else:
+ testcase.assertEqual(a, m)
+ testcase.assertEqual(a, b)
+ m[i] = singleitem
+ testcase.assertEqual(a, m)
+ testcase.assertEqual(b, m)
+ a[i] = ai
+ m[i] = mi
+
+ for n in range(1, 5):
+ for fmt, items, singleitem in iter_format(n, 'array'):
+ for lslice in genslices(n):
+ for rslice in genslices(n):
+
+ a = array.array(fmt, items)
+ b = array.array(fmt, items)
+ m = memoryview(b)
+
+ self.assertEqual(m, a)
+ self.assertEqual(m.tolist(), a.tolist())
+ self.assertEqual(m.tobytes(), a.tobytes())
+ self.assertEqual(len(m), len(a))
+
+ cmptest(self, a, b, m, singleitem)
+
+ array_err = None
+ have_resize = None
+ try:
+ al = a[lslice]
+ ar = a[rslice]
+ a[lslice] = a[rslice]
+ have_resize = len(al) != len(ar)
+ except Exception as e:
+ array_err = e.__class__
+
+ m_err = None
+ try:
+ m[lslice] = m[rslice]
+ except Exception as e:
+ m_err = e.__class__
+
+ if have_resize: # memoryview cannot change shape
+ self.assertIs(m_err, ValueError)
+ elif m_err or array_err:
+ self.assertIs(m_err, array_err)
+ else:
+ self.assertEqual(m, a)
+ self.assertEqual(m.tolist(), a.tolist())
+ self.assertEqual(m.tobytes(), a.tobytes())
+ cmptest(self, a, b, m, singleitem)
+
+ def test_memoryview_compare_special_cases(self):
+
+ a = array.array('L', [1, 2, 3])
+ b = array.array('L', [1, 2, 7])
+
+ # Ordering comparisons raise:
+ v = memoryview(a)
+ w = memoryview(b)
+ for attr in ('__lt__', '__le__', '__gt__', '__ge__'):
+ self.assertIs(getattr(v, attr)(w), NotImplemented)
+ self.assertIs(getattr(a, attr)(v), NotImplemented)
+
+ # Released views compare equal to themselves:
+ v = memoryview(a)
+ v.release()
+ self.assertEqual(v, v)
+ self.assertNotEqual(v, a)
+ self.assertNotEqual(a, v)
+
+ v = memoryview(a)
+ w = memoryview(a)
+ w.release()
+ self.assertNotEqual(v, w)
+ self.assertNotEqual(w, v)
+
+ # Operand does not implement the buffer protocol:
+ v = memoryview(a)
+ self.assertNotEqual(v, [1, 2, 3])
+
+ # NaNs
+ nd = ndarray([(0, 0)], shape=[1], format='l x d x', flags=ND_WRITABLE)
+ nd[0] = (-1, float('nan'))
+ self.assertNotEqual(memoryview(nd), nd)
+
+ # Depends on issue #15625: the struct module does not understand 'u'.
+ a = array.array('u', 'xyz')
+ v = memoryview(a)
+ self.assertNotEqual(a, v)
+ self.assertNotEqual(v, a)
+
+ # Some ctypes format strings are unknown to the struct module.
+ if ctypes:
+ # format: "T{>l:x:>l:y:}"
+ class BEPoint(ctypes.BigEndianStructure):
+ _fields_ = [("x", ctypes.c_long), ("y", ctypes.c_long)]
+ point = BEPoint(100, 200)
+ a = memoryview(point)
+ b = memoryview(point)
+ self.assertNotEqual(a, b)
+ self.assertNotEqual(a, point)
+ self.assertNotEqual(point, a)
+ self.assertRaises(NotImplementedError, a.tolist)
+
+ def test_memoryview_compare_ndim_zero(self):
+
+ nd1 = ndarray(1729, shape=[], format='@L')
+ nd2 = ndarray(1729, shape=[], format='L', flags=ND_WRITABLE)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+ self.assertEqual(v, w)
+ self.assertEqual(w, v)
+ self.assertEqual(v, nd2)
+ self.assertEqual(nd2, v)
+ self.assertEqual(w, nd1)
+ self.assertEqual(nd1, w)
+
+ self.assertFalse(v.__ne__(w))
+ self.assertFalse(w.__ne__(v))
+
+ w[()] = 1728
+ self.assertNotEqual(v, w)
+ self.assertNotEqual(w, v)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(nd2, v)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(nd1, w)
+
+ self.assertFalse(v.__eq__(w))
+ self.assertFalse(w.__eq__(v))
+
+ nd = ndarray(list(range(12)), shape=[12], flags=ND_WRITABLE|ND_PIL)
+ ex = ndarray(list(range(12)), shape=[12], flags=ND_WRITABLE|ND_PIL)
+ m = memoryview(ex)
+
+ self.assertEqual(m, nd)
+ m[9] = 100
+ self.assertNotEqual(m, nd)
+
+ # struct module: equal
+ nd1 = ndarray((1729, 1.2, b'12345'), shape=[], format='Lf5s')
+ nd2 = ndarray((1729, 1.2, b'12345'), shape=[], format='hf5s',
+ flags=ND_WRITABLE)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+ self.assertEqual(v, w)
+ self.assertEqual(w, v)
+ self.assertEqual(v, nd2)
+ self.assertEqual(nd2, v)
+ self.assertEqual(w, nd1)
+ self.assertEqual(nd1, w)
+
+ # struct module: not equal
+ nd1 = ndarray((1729, 1.2, b'12345'), shape=[], format='Lf5s')
+ nd2 = ndarray((-1729, 1.2, b'12345'), shape=[], format='hf5s',
+ flags=ND_WRITABLE)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+ self.assertNotEqual(v, w)
+ self.assertNotEqual(w, v)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(nd2, v)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(nd1, w)
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+
+ def test_memoryview_compare_ndim_one(self):
+
+ # contiguous
+ nd1 = ndarray([-529, 576, -625, 676, -729], shape=[5], format='@h')
+ nd2 = ndarray([-529, 576, -625, 676, 729], shape=[5], format='@h')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # contiguous, struct module
+ nd1 = ndarray([-529, 576, -625, 676, -729], shape=[5], format='<i')
+ nd2 = ndarray([-529, 576, -625, 676, 729], shape=[5], format='>h')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # non-contiguous
+ nd1 = ndarray([-529, -625, -729], shape=[3], format='@h')
+ nd2 = ndarray([-529, 576, -625, 676, -729], shape=[5], format='@h')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd2[::2])
+ self.assertEqual(w[::2], nd1)
+ self.assertEqual(v, w[::2])
+ self.assertEqual(v[::-1], w[::-2])
+
+ # non-contiguous, struct module
+ nd1 = ndarray([-529, -625, -729], shape=[3], format='!h')
+ nd2 = ndarray([-529, 576, -625, 676, -729], shape=[5], format='<l')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd2[::2])
+ self.assertEqual(w[::2], nd1)
+ self.assertEqual(v, w[::2])
+ self.assertEqual(v[::-1], w[::-2])
+
+ # non-contiguous, suboffsets
+ nd1 = ndarray([-529, -625, -729], shape=[3], format='@h')
+ nd2 = ndarray([-529, 576, -625, 676, -729], shape=[5], format='@h',
+ flags=ND_PIL)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd2[::2])
+ self.assertEqual(w[::2], nd1)
+ self.assertEqual(v, w[::2])
+ self.assertEqual(v[::-1], w[::-2])
+
+ # non-contiguous, suboffsets, struct module
+ nd1 = ndarray([-529, -625, -729], shape=[3], format='h 0c')
+ nd2 = ndarray([-529, 576, -625, 676, -729], shape=[5], format='> h',
+ flags=ND_PIL)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd2[::2])
+ self.assertEqual(w[::2], nd1)
+ self.assertEqual(v, w[::2])
+ self.assertEqual(v[::-1], w[::-2])
+
+ def test_memoryview_compare_zero_shape(self):
+
+ # zeros in shape
+ nd1 = ndarray([900, 961], shape=[0], format='@h')
+ nd2 = ndarray([-900, -961], shape=[0], format='@h')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ # zeros in shape, struct module
+ nd1 = ndarray([900, 961], shape=[0], format='= h0c')
+ nd2 = ndarray([-900, -961], shape=[0], format='@ i')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ def test_memoryview_compare_zero_strides(self):
+
+ # zero strides
+ nd1 = ndarray([900, 900, 900, 900], shape=[4], format='@L')
+ nd2 = ndarray([900], shape=[4], strides=[0], format='L')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ # zero strides, struct module
+ nd1 = ndarray([(900, 900)]*4, shape=[4], format='@ Li')
+ nd2 = ndarray([(900, 900)], shape=[4], strides=[0], format='!L h')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ def test_memoryview_compare_random_formats(self):
+
+ # random single character native formats
+ n = 10
+ for char in fmtdict['@m']:
+ fmt, items, singleitem = randitems(n, 'memoryview', '@', char)
+ for flags in (0, ND_PIL):
+ nd = ndarray(items, shape=[n], format=fmt, flags=flags)
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+
+ nd = nd[::-3]
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+
+ # random formats
+ n = 10
+ for _ in range(100):
+ fmt, items, singleitem = randitems(n)
+ for flags in (0, ND_PIL):
+ nd = ndarray(items, shape=[n], format=fmt, flags=flags)
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+
+ nd = nd[::-3]
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+
+ def test_memoryview_compare_multidim_c(self):
+
+ # C-contiguous, different values
+ nd1 = ndarray(list(range(-15, 15)), shape=[3, 2, 5], format='@h')
+ nd2 = ndarray(list(range(0, 30)), shape=[3, 2, 5], format='@h')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # C-contiguous, different values, struct module
+ nd1 = ndarray([(0, 1, 2)]*30, shape=[3, 2, 5], format='=f q xxL')
+ nd2 = ndarray([(-1.2, 1, 2)]*30, shape=[3, 2, 5], format='< f 2Q')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # C-contiguous, different shape
+ nd1 = ndarray(list(range(30)), shape=[2, 3, 5], format='L')
+ nd2 = ndarray(list(range(30)), shape=[3, 2, 5], format='L')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # C-contiguous, different shape, struct module
+ nd1 = ndarray([(0, 1, 2)]*21, shape=[3, 7], format='! b B xL')
+ nd2 = ndarray([(0, 1, 2)]*21, shape=[7, 3], format='= Qx l xxL')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # C-contiguous, different format, struct module
+ nd1 = ndarray(list(range(30)), shape=[2, 3, 5], format='L')
+ nd2 = ndarray(list(range(30)), shape=[2, 3, 5], format='l')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ def test_memoryview_compare_multidim_fortran(self):
+
+ # Fortran-contiguous, different values
+ nd1 = ndarray(list(range(-15, 15)), shape=[5, 2, 3], format='@h',
+ flags=ND_FORTRAN)
+ nd2 = ndarray(list(range(0, 30)), shape=[5, 2, 3], format='@h',
+ flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # Fortran-contiguous, different values, struct module
+ nd1 = ndarray([(2**64-1, -1)]*6, shape=[2, 3], format='=Qq',
+ flags=ND_FORTRAN)
+ nd2 = ndarray([(-1, 2**64-1)]*6, shape=[2, 3], format='=qQ',
+ flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # Fortran-contiguous, different shape
+ nd1 = ndarray(list(range(-15, 15)), shape=[2, 3, 5], format='l',
+ flags=ND_FORTRAN)
+ nd2 = ndarray(list(range(-15, 15)), shape=[3, 2, 5], format='l',
+ flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # Fortran-contiguous, different shape, struct module
+ nd1 = ndarray(list(range(-15, 15)), shape=[2, 3, 5], format='0ll',
+ flags=ND_FORTRAN)
+ nd2 = ndarray(list(range(-15, 15)), shape=[3, 2, 5], format='l',
+ flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # Fortran-contiguous, different format, struct module
+ nd1 = ndarray(list(range(30)), shape=[5, 2, 3], format='@h',
+ flags=ND_FORTRAN)
+ nd2 = ndarray(list(range(30)), shape=[5, 2, 3], format='@b',
+ flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ def test_memoryview_compare_multidim_mixed(self):
+
+ # mixed C/Fortran contiguous
+ lst1 = list(range(-15, 15))
+ lst2 = transpose(lst1, [3, 2, 5])
+ nd1 = ndarray(lst1, shape=[3, 2, 5], format='@l')
+ nd2 = ndarray(lst2, shape=[3, 2, 5], format='l', flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, w)
+
+ # mixed C/Fortran contiguous, struct module
+ lst1 = [(-3.3, -22, b'x')]*30
+ lst1[5] = (-2.2, -22, b'x')
+ lst2 = transpose(lst1, [3, 2, 5])
+ nd1 = ndarray(lst1, shape=[3, 2, 5], format='d b c')
+ nd2 = ndarray(lst2, shape=[3, 2, 5], format='d h c', flags=ND_FORTRAN)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, w)
+
+ # different values, non-contiguous
+ ex1 = ndarray(list(range(40)), shape=[5, 8], format='@I')
+ nd1 = ex1[3:1:-1, ::-2]
+ ex2 = ndarray(list(range(40)), shape=[5, 8], format='I')
+ nd2 = ex2[1:3:1, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # same values, non-contiguous, struct module
+ ex1 = ndarray([(2**31-1, -2**31)]*22, shape=[11, 2], format='=ii')
+ nd1 = ex1[3:1:-1, ::-2]
+ ex2 = ndarray([(2**31-1, -2**31)]*22, shape=[11, 2], format='>ii')
+ nd2 = ex2[1:3:1, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ # different shape
+ ex1 = ndarray(list(range(30)), shape=[2, 3, 5], format='b')
+ nd1 = ex1[1:3:, ::-2]
+ nd2 = ndarray(list(range(30)), shape=[3, 2, 5], format='b')
+ nd2 = ex2[1:3:, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # different shape, struct module
+ ex1 = ndarray(list(range(30)), shape=[2, 3, 5], format='B')
+ nd1 = ex1[1:3:, ::-2]
+ nd2 = ndarray(list(range(30)), shape=[3, 2, 5], format='b')
+ nd2 = ex2[1:3:, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # different format, struct module
+ ex1 = ndarray([(2, b'123')]*30, shape=[5, 3, 2], format='b3s')
+ nd1 = ex1[1:3:, ::-2]
+ nd2 = ndarray([(2, b'123')]*30, shape=[5, 3, 2], format='i3s')
+ nd2 = ex2[1:3:, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ def test_memoryview_compare_multidim_zero_shape(self):
+
+ # zeros in shape
+ nd1 = ndarray(list(range(30)), shape=[0, 3, 2], format='i')
+ nd2 = ndarray(list(range(30)), shape=[5, 0, 2], format='@i')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # zeros in shape, struct module
+ nd1 = ndarray(list(range(30)), shape=[0, 3, 2], format='i')
+ nd2 = ndarray(list(range(30)), shape=[5, 0, 2], format='@i')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ def test_memoryview_compare_multidim_zero_strides(self):
+
+ # zero strides
+ nd1 = ndarray([900]*80, shape=[4, 5, 4], format='@L')
+ nd2 = ndarray([900], shape=[4, 5, 4], strides=[0, 0, 0], format='L')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+ self.assertEqual(v.tolist(), w.tolist())
+
+ # zero strides, struct module
+ nd1 = ndarray([(1, 2)]*10, shape=[2, 5], format='=lQ')
+ nd2 = ndarray([(1, 2)], shape=[2, 5], strides=[0, 0], format='<lQ')
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ def test_memoryview_compare_multidim_suboffsets(self):
+
+ # suboffsets
+ ex1 = ndarray(list(range(40)), shape=[5, 8], format='@I')
+ nd1 = ex1[3:1:-1, ::-2]
+ ex2 = ndarray(list(range(40)), shape=[5, 8], format='I', flags=ND_PIL)
+ nd2 = ex2[1:3:1, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # suboffsets, struct module
+ ex1 = ndarray([(2**64-1, -1)]*40, shape=[5, 8], format='=Qq',
+ flags=ND_WRITABLE)
+ ex1[2][7] = (1, -2)
+ nd1 = ex1[3:1:-1, ::-2]
+
+ ex2 = ndarray([(2**64-1, -1)]*40, shape=[5, 8], format='>Qq',
+ flags=ND_PIL|ND_WRITABLE)
+ ex2[2][7] = (1, -2)
+ nd2 = ex2[1:3:1, ::-2]
+
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ # suboffsets, different shape
+ ex1 = ndarray(list(range(30)), shape=[2, 3, 5], format='b',
+ flags=ND_PIL)
+ nd1 = ex1[1:3:, ::-2]
+ nd2 = ndarray(list(range(30)), shape=[3, 2, 5], format='b')
+ nd2 = ex2[1:3:, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # suboffsets, different shape, struct module
+ ex1 = ndarray([(2**8-1, -1)]*40, shape=[2, 3, 5], format='Bb',
+ flags=ND_PIL|ND_WRITABLE)
+ nd1 = ex1[1:2:, ::-2]
+
+ ex2 = ndarray([(2**8-1, -1)]*40, shape=[3, 2, 5], format='Bb')
+ nd2 = ex2[1:2:, ::-2]
+
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # suboffsets, different format
+ ex1 = ndarray(list(range(30)), shape=[5, 3, 2], format='i', flags=ND_PIL)
+ nd1 = ex1[1:3:, ::-2]
+ ex2 = ndarray(list(range(30)), shape=[5, 3, 2], format='@I', flags=ND_PIL)
+ nd2 = ex2[1:3:, ::-2]
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, nd2)
+ self.assertEqual(w, nd1)
+ self.assertEqual(v, w)
+
+ # suboffsets, different format, struct module
+ ex1 = ndarray([(b'hello', b'', 1)]*27, shape=[3, 3, 3], format='5s0sP',
+ flags=ND_PIL|ND_WRITABLE)
+ ex1[1][2][2] = (b'sushi', b'', 1)
+ nd1 = ex1[1:3:, ::-2]
+
+ ex2 = ndarray([(b'hello', b'', 1)]*27, shape=[3, 3, 3], format='5s0sP',
+ flags=ND_PIL|ND_WRITABLE)
+ ex1[1][2][2] = (b'sushi', b'', 1)
+ nd2 = ex2[1:3:, ::-2]
+
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertNotEqual(v, nd2)
+ self.assertNotEqual(w, nd1)
+ self.assertNotEqual(v, w)
+
+ # initialize mixed C/Fortran + suboffsets
+ lst1 = list(range(-15, 15))
+ lst2 = transpose(lst1, [3, 2, 5])
+ nd1 = ndarray(lst1, shape=[3, 2, 5], format='@l', flags=ND_PIL)
+ nd2 = ndarray(lst2, shape=[3, 2, 5], format='l', flags=ND_FORTRAN|ND_PIL)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, w)
+
+ # initialize mixed C/Fortran + suboffsets, struct module
+ lst1 = [(b'sashimi', b'sliced', 20.05)]*30
+ lst1[11] = (b'ramen', b'spicy', 9.45)
+ lst2 = transpose(lst1, [3, 2, 5])
+
+ nd1 = ndarray(lst1, shape=[3, 2, 5], format='< 10p 9p d', flags=ND_PIL)
+ nd2 = ndarray(lst2, shape=[3, 2, 5], format='> 10p 9p d',
+ flags=ND_FORTRAN|ND_PIL)
+ v = memoryview(nd1)
+ w = memoryview(nd2)
+
+ self.assertEqual(v, nd1)
+ self.assertEqual(w, nd2)
+ self.assertEqual(v, w)
+
+ def test_memoryview_compare_not_equal(self):
+
+ # items not equal
+ for byteorder in ['=', '<', '>', '!']:
+ x = ndarray([2**63]*120, shape=[3,5,2,2,2], format=byteorder+'Q')
+ y = ndarray([2**63]*120, shape=[3,5,2,2,2], format=byteorder+'Q',
+ flags=ND_WRITABLE|ND_FORTRAN)
+ y[2][3][1][1][1] = 1
+ a = memoryview(x)
+ b = memoryview(y)
+ self.assertEqual(a, x)
+ self.assertEqual(b, y)
+ self.assertNotEqual(a, b)
+ self.assertNotEqual(a, y)
+ self.assertNotEqual(b, x)
+
+ x = ndarray([(2**63, 2**31, 2**15)]*120, shape=[3,5,2,2,2],
+ format=byteorder+'QLH')
+ y = ndarray([(2**63, 2**31, 2**15)]*120, shape=[3,5,2,2,2],
+ format=byteorder+'QLH', flags=ND_WRITABLE|ND_FORTRAN)
+ y[2][3][1][1][1] = (1, 1, 1)
+ a = memoryview(x)
+ b = memoryview(y)
+ self.assertEqual(a, x)
+ self.assertEqual(b, y)
+ self.assertNotEqual(a, b)
+ self.assertNotEqual(a, y)
+ self.assertNotEqual(b, x)
+
+ def test_memoryview_check_released(self):
+
+ a = array.array('d', [1.1, 2.2, 3.3])
+
+ m = memoryview(a)
+ m.release()
+
+ # PyMemoryView_FromObject()
+ self.assertRaises(ValueError, memoryview, m)
+ # memoryview.cast()
+ self.assertRaises(ValueError, m.cast, 'c')
+ # getbuffer()
+ self.assertRaises(ValueError, ndarray, m)
+ # memoryview.tolist()
+ self.assertRaises(ValueError, m.tolist)
+ # memoryview.tobytes()
+ self.assertRaises(ValueError, m.tobytes)
+ # sequence
+ self.assertRaises(ValueError, eval, "1.0 in m", locals())
+ # subscript
+ self.assertRaises(ValueError, m.__getitem__, 0)
+ # assignment
+ self.assertRaises(ValueError, m.__setitem__, 0, 1)
+
+ for attr in ('obj', 'nbytes', 'readonly', 'itemsize', 'format', 'ndim',
+ 'shape', 'strides', 'suboffsets', 'c_contiguous',
+ 'f_contiguous', 'contiguous'):
+ self.assertRaises(ValueError, m.__getattribute__, attr)
+
+ # richcompare
+ b = array.array('d', [1.1, 2.2, 3.3])
+ m1 = memoryview(a)
+ m2 = memoryview(b)
+
+ self.assertEqual(m1, m2)
+ m1.release()
+ self.assertNotEqual(m1, m2)
+ self.assertNotEqual(m1, a)
+ self.assertEqual(m1, m1)
+
+ def test_memoryview_tobytes(self):
+ # Many implicit tests are already in self.verify().
+
+ t = (-529, 576, -625, 676, -729)
+
+ nd = ndarray(t, shape=[5], format='@h')
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tobytes(), nd.tobytes())
+
+ nd = ndarray([t], shape=[1], format='>hQiLl')
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tobytes(), nd.tobytes())
+
+ nd = ndarray([t for _ in range(12)], shape=[2,2,3], format='=hQiLl')
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tobytes(), nd.tobytes())
+
+ nd = ndarray([t for _ in range(120)], shape=[5,2,2,3,2],
+ format='<hQiLl')
+ m = memoryview(nd)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tobytes(), nd.tobytes())
+
+ # Unknown formats are handled: tobytes() purely depends on itemsize.
+ if ctypes:
+ # format: "T{>l:x:>l:y:}"
+ class BEPoint(ctypes.BigEndianStructure):
+ _fields_ = [("x", ctypes.c_long), ("y", ctypes.c_long)]
+ point = BEPoint(100, 200)
+ a = memoryview(point)
+ self.assertEqual(a.tobytes(), bytes(point))
+
+ def test_memoryview_get_contiguous(self):
+ # Many implicit tests are already in self.verify().
+
+ # no buffer interface
+ self.assertRaises(TypeError, get_contiguous, {}, PyBUF_READ, 'F')
+
+ # writable request to read-only object
+ self.assertRaises(BufferError, get_contiguous, b'x', PyBUF_WRITE, 'C')
+
+ # writable request to non-contiguous object
+ nd = ndarray([1, 2, 3], shape=[2], strides=[2])
+ self.assertRaises(BufferError, get_contiguous, nd, PyBUF_WRITE, 'A')
+
+ # scalar, read-only request from read-only exporter
+ nd = ndarray(9, shape=(), format="L")
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m[()], 9)
+
+ # scalar, read-only request from writable exporter
+ nd = ndarray(9, shape=(), format="L", flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m[()], 9)
+
+ # scalar, writable request
+ for order in ['C', 'F', 'A']:
+ nd[()] = 9
+ m = get_contiguous(nd, PyBUF_WRITE, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m[()], 9)
+
+ m[()] = 10
+ self.assertEqual(m[()], 10)
+ self.assertEqual(nd[()], 10)
+
+ # zeros in shape
+ nd = ndarray([1], shape=[0], format="L", flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertRaises(IndexError, m.__getitem__, 0)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tolist(), [])
+
+ nd = ndarray(list(range(8)), shape=[2, 0, 7], format="L",
+ flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(ndarray(m).tolist(), [[], []])
+
+ # one-dimensional
+ nd = ndarray([1], shape=[1], format="h", flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_WRITE, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tolist(), nd.tolist())
+
+ nd = ndarray([1, 2, 3], shape=[3], format="b", flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_WRITE, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tolist(), nd.tolist())
+
+ # one-dimensional, non-contiguous
+ nd = ndarray([1, 2, 3], shape=[2], strides=[2], flags=ND_WRITABLE)
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tolist(), nd.tolist())
+ self.assertRaises(TypeError, m.__setitem__, 1, 20)
+ self.assertEqual(m[1], 3)
+ self.assertEqual(nd[1], 3)
+
+ nd = nd[::-1]
+ for order in ['C', 'F', 'A']:
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(m, nd)
+ self.assertEqual(m.tolist(), nd.tolist())
+ self.assertRaises(TypeError, m.__setitem__, 1, 20)
+ self.assertEqual(m[1], 1)
+ self.assertEqual(nd[1], 1)
+
+ # multi-dimensional, contiguous input
+ nd = ndarray(list(range(12)), shape=[3, 4], flags=ND_WRITABLE)
+ for order in ['C', 'A']:
+ m = get_contiguous(nd, PyBUF_WRITE, order)
+ self.assertEqual(ndarray(m).tolist(), nd.tolist())
+
+ self.assertRaises(BufferError, get_contiguous, nd, PyBUF_WRITE, 'F')
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(ndarray(m).tolist(), nd.tolist())
+
+ nd = ndarray(list(range(12)), shape=[3, 4],
+ flags=ND_WRITABLE|ND_FORTRAN)
+ for order in ['F', 'A']:
+ m = get_contiguous(nd, PyBUF_WRITE, order)
+ self.assertEqual(ndarray(m).tolist(), nd.tolist())
+
+ self.assertRaises(BufferError, get_contiguous, nd, PyBUF_WRITE, 'C')
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(ndarray(m).tolist(), nd.tolist())
+
+ # multi-dimensional, non-contiguous input
+ nd = ndarray(list(range(12)), shape=[3, 4], flags=ND_WRITABLE|ND_PIL)
+ for order in ['C', 'F', 'A']:
+ self.assertRaises(BufferError, get_contiguous, nd, PyBUF_WRITE,
+ order)
+ m = get_contiguous(nd, PyBUF_READ, order)
+ self.assertEqual(ndarray(m).tolist(), nd.tolist())
+
+ # flags
+ nd = ndarray([1,2,3,4,5], shape=[3], strides=[2])
+ m = get_contiguous(nd, PyBUF_READ, 'C')
+ self.assertTrue(m.c_contiguous)
+
+ def test_memoryview_serializing(self):
+
+ # C-contiguous
+ size = struct.calcsize('i')
+ a = array.array('i', [1,2,3,4,5])
+ m = memoryview(a)
+ buf = io.BytesIO(m)
+ b = bytearray(5*size)
+ buf.readinto(b)
+ self.assertEqual(m.tobytes(), b)
+
+ # C-contiguous, multi-dimensional
+ size = struct.calcsize('L')
+ nd = ndarray(list(range(12)), shape=[2,3,2], format="L")
+ m = memoryview(nd)
+ buf = io.BytesIO(m)
+ b = bytearray(2*3*2*size)
+ buf.readinto(b)
+ self.assertEqual(m.tobytes(), b)
+
+ # Fortran contiguous, multi-dimensional
+ #size = struct.calcsize('L')
+ #nd = ndarray(list(range(12)), shape=[2,3,2], format="L",
+ # flags=ND_FORTRAN)
+ #m = memoryview(nd)
+ #buf = io.BytesIO(m)
+ #b = bytearray(2*3*2*size)
+ #buf.readinto(b)
+ #self.assertEqual(m.tobytes(), b)
+
+ def test_memoryview_hash(self):
+
+ # bytes exporter
+ b = bytes(list(range(12)))
+ m = memoryview(b)
+ self.assertEqual(hash(b), hash(m))
+
+ # C-contiguous
+ mc = m.cast('c', shape=[3,4])
+ self.assertEqual(hash(mc), hash(b))
+
+ # non-contiguous
+ mx = m[::-2]
+ b = bytes(list(range(12))[::-2])
+ self.assertEqual(hash(mx), hash(b))
+
+ # Fortran contiguous
+ nd = ndarray(list(range(30)), shape=[3,2,5], flags=ND_FORTRAN)
+ m = memoryview(nd)
+ self.assertEqual(hash(m), hash(nd))
+
+ # multi-dimensional slice
+ nd = ndarray(list(range(30)), shape=[3,2,5])
+ x = nd[::2, ::, ::-1]
+ m = memoryview(x)
+ self.assertEqual(hash(m), hash(x))
+
+ # multi-dimensional slice with suboffsets
+ nd = ndarray(list(range(30)), shape=[2,5,3], flags=ND_PIL)
+ x = nd[::2, ::, ::-1]
+ m = memoryview(x)
+ self.assertEqual(hash(m), hash(x))
+
+ # equality-hash invariant
+ x = ndarray(list(range(12)), shape=[12], format='B')
+ a = memoryview(x)
+
+ y = ndarray(list(range(12)), shape=[12], format='b')
+ b = memoryview(y)
+
+ self.assertEqual(a, b)
+ self.assertEqual(hash(a), hash(b))
+
+ # non-byte formats
+ nd = ndarray(list(range(12)), shape=[2,2,3], format='L')
+ m = memoryview(nd)
+ self.assertRaises(ValueError, m.__hash__)
+
+ nd = ndarray(list(range(-6, 6)), shape=[2,2,3], format='h')
+ m = memoryview(nd)
+ self.assertRaises(ValueError, m.__hash__)
+
+ nd = ndarray(list(range(12)), shape=[2,2,3], format='= L')
+ m = memoryview(nd)
+ self.assertRaises(ValueError, m.__hash__)
+
+ nd = ndarray(list(range(-6, 6)), shape=[2,2,3], format='< h')
+ m = memoryview(nd)
+ self.assertRaises(ValueError, m.__hash__)
+
+ def test_memoryview_release(self):
+
+ # Create re-exporter from getbuffer(memoryview), then release the view.
+ a = bytearray([1,2,3])
+ m = memoryview(a)
+ nd = ndarray(m) # re-exporter
+ self.assertRaises(BufferError, m.release)
+ del nd
+ m.release()
+
+ a = bytearray([1,2,3])
+ m = memoryview(a)
+ nd1 = ndarray(m, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ nd2 = ndarray(nd1, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ self.assertIs(nd2.obj, m)
+ self.assertRaises(BufferError, m.release)
+ del nd1, nd2
+ m.release()
+
+ # chained views
+ a = bytearray([1,2,3])
+ m1 = memoryview(a)
+ m2 = memoryview(m1)
+ nd = ndarray(m2) # re-exporter
+ m1.release()
+ self.assertRaises(BufferError, m2.release)
+ del nd
+ m2.release()
+
+ a = bytearray([1,2,3])
+ m1 = memoryview(a)
+ m2 = memoryview(m1)
+ nd1 = ndarray(m2, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ nd2 = ndarray(nd1, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ self.assertIs(nd2.obj, m2)
+ m1.release()
+ self.assertRaises(BufferError, m2.release)
+ del nd1, nd2
+ m2.release()
+
+ # Allow changing layout while buffers are exported.
+ nd = ndarray([1,2,3], shape=[3], flags=ND_VAREXPORT)
+ m1 = memoryview(nd)
+
+ nd.push([4,5,6,7,8], shape=[5]) # mutate nd
+ m2 = memoryview(nd)
+
+ x = memoryview(m1)
+ self.assertEqual(x.tolist(), m1.tolist())
+
+ y = memoryview(m2)
+ self.assertEqual(y.tolist(), m2.tolist())
+ self.assertEqual(y.tolist(), nd.tolist())
+ m2.release()
+ y.release()
+
+ nd.pop() # pop the current view
+ self.assertEqual(x.tolist(), nd.tolist())
+
+ del nd
+ m1.release()
+ x.release()
+
+ # If multiple memoryviews share the same managed buffer, implicit
+ # release() in the context manager's __exit__() method should still
+ # work.
+ def catch22(b):
+ with memoryview(b) as m2:
+ pass
+
+ x = bytearray(b'123')
+ with memoryview(x) as m1:
+ catch22(m1)
+ self.assertEqual(m1[0], ord(b'1'))
+
+ x = ndarray(list(range(12)), shape=[2,2,3], format='l')
+ y = ndarray(x, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ z = ndarray(y, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ self.assertIs(z.obj, x)
+ with memoryview(z) as m:
+ catch22(m)
+ self.assertEqual(m[0:1].tolist(), [[[0, 1, 2], [3, 4, 5]]])
+
+ # Test garbage collection.
+ for flags in (0, ND_REDIRECT):
+ x = bytearray(b'123')
+ with memoryview(x) as m1:
+ del x
+ y = ndarray(m1, getbuf=PyBUF_FULL_RO, flags=flags)
+ with memoryview(y) as m2:
+ del y
+ z = ndarray(m2, getbuf=PyBUF_FULL_RO, flags=flags)
+ with memoryview(z) as m3:
+ del z
+ catch22(m3)
+ catch22(m2)
+ catch22(m1)
+ self.assertEqual(m1[0], ord(b'1'))
+ self.assertEqual(m2[1], ord(b'2'))
+ self.assertEqual(m3[2], ord(b'3'))
+ del m3
+ del m2
+ del m1
+
+ x = bytearray(b'123')
+ with memoryview(x) as m1:
+ del x
+ y = ndarray(m1, getbuf=PyBUF_FULL_RO, flags=flags)
+ with memoryview(y) as m2:
+ del y
+ z = ndarray(m2, getbuf=PyBUF_FULL_RO, flags=flags)
+ with memoryview(z) as m3:
+ del z
+ catch22(m1)
+ catch22(m2)
+ catch22(m3)
+ self.assertEqual(m1[0], ord(b'1'))
+ self.assertEqual(m2[1], ord(b'2'))
+ self.assertEqual(m3[2], ord(b'3'))
+ del m1, m2, m3
+
+ # memoryview.release() fails if the view has exported buffers.
+ x = bytearray(b'123')
+ with self.assertRaises(BufferError):
+ with memoryview(x) as m:
+ ex = ndarray(m)
+ m[0] == ord(b'1')
+
+ def test_memoryview_redirect(self):
+
+ nd = ndarray([1.0 * x for x in range(12)], shape=[12], format='d')
+ a = array.array('d', [1.0 * x for x in range(12)])
+
+ for x in (nd, a):
+ y = ndarray(x, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ z = ndarray(y, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ m = memoryview(z)
+
+ self.assertIs(y.obj, x)
+ self.assertIs(z.obj, x)
+ self.assertIs(m.obj, x)
+
+ self.assertEqual(m, x)
+ self.assertEqual(m, y)
+ self.assertEqual(m, z)
+
+ self.assertEqual(m[1:3], x[1:3])
+ self.assertEqual(m[1:3], y[1:3])
+ self.assertEqual(m[1:3], z[1:3])
+ del y, z
+ self.assertEqual(m[1:3], x[1:3])
+
+ def test_memoryview_from_static_exporter(self):
+
+ fmt = 'B'
+ lst = [0,1,2,3,4,5,6,7,8,9,10,11]
+
+ # exceptions
+ self.assertRaises(TypeError, staticarray, 1, 2, 3)
+
+ # view.obj==x
+ x = staticarray()
+ y = memoryview(x)
+ self.verify(y, obj=x,
+ itemsize=1, fmt=fmt, readonly=1,
+ ndim=1, shape=[12], strides=[1],
+ lst=lst)
+ for i in range(12):
+ self.assertEqual(y[i], i)
+ del x
+ del y
+
+ x = staticarray()
+ y = memoryview(x)
+ del y
+ del x
+
+ x = staticarray()
+ y = ndarray(x, getbuf=PyBUF_FULL_RO)
+ z = ndarray(y, getbuf=PyBUF_FULL_RO)
+ m = memoryview(z)
+ self.assertIs(y.obj, x)
+ self.assertIs(m.obj, z)
+ self.verify(m, obj=z,
+ itemsize=1, fmt=fmt, readonly=1,
+ ndim=1, shape=[12], strides=[1],
+ lst=lst)
+ del x, y, z, m
+
+ x = staticarray()
+ y = ndarray(x, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ z = ndarray(y, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ m = memoryview(z)
+ self.assertIs(y.obj, x)
+ self.assertIs(z.obj, x)
+ self.assertIs(m.obj, x)
+ self.verify(m, obj=x,
+ itemsize=1, fmt=fmt, readonly=1,
+ ndim=1, shape=[12], strides=[1],
+ lst=lst)
+ del x, y, z, m
+
+ # view.obj==NULL
+ x = staticarray(legacy_mode=True)
+ y = memoryview(x)
+ self.verify(y, obj=None,
+ itemsize=1, fmt=fmt, readonly=1,
+ ndim=1, shape=[12], strides=[1],
+ lst=lst)
+ for i in range(12):
+ self.assertEqual(y[i], i)
+ del x
+ del y
+
+ x = staticarray(legacy_mode=True)
+ y = memoryview(x)
+ del y
+ del x
+
+ x = staticarray(legacy_mode=True)
+ y = ndarray(x, getbuf=PyBUF_FULL_RO)
+ z = ndarray(y, getbuf=PyBUF_FULL_RO)
+ m = memoryview(z)
+ self.assertIs(y.obj, None)
+ self.assertIs(m.obj, z)
+ self.verify(m, obj=z,
+ itemsize=1, fmt=fmt, readonly=1,
+ ndim=1, shape=[12], strides=[1],
+ lst=lst)
+ del x, y, z, m
+
+ x = staticarray(legacy_mode=True)
+ y = ndarray(x, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ z = ndarray(y, getbuf=PyBUF_FULL_RO, flags=ND_REDIRECT)
+ m = memoryview(z)
+ # Clearly setting view.obj==NULL is inferior, since it
+ # messes up the redirection chain:
+ self.assertIs(y.obj, None)
+ self.assertIs(z.obj, y)
+ self.assertIs(m.obj, y)
+ self.verify(m, obj=y,
+ itemsize=1, fmt=fmt, readonly=1,
+ ndim=1, shape=[12], strides=[1],
+ lst=lst)
+ del x, y, z, m
+
+ def test_memoryview_getbuffer_undefined(self):
+
+ # getbufferproc does not adhere to the new documentation
+ nd = ndarray([1,2,3], [3], flags=ND_GETBUF_FAIL|ND_GETBUF_UNDEFINED)
+ self.assertRaises(BufferError, memoryview, nd)
+
+ def test_issue_7385(self):
+ x = ndarray([1,2,3], shape=[3], flags=ND_GETBUF_FAIL)
+ self.assertRaises(BufferError, memoryview, x)
+
+
+def test_main():
+ support.run_unittest(TestBufferProtocol)
+
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_bufio.py b/Lib/test/test_bufio.py
index 5ab6f5a..6338ad8 100644
--- a/Lib/test/test_bufio.py
+++ b/Lib/test/test_bufio.py
@@ -11,7 +11,7 @@ import _pyio as pyio # Python implementation.
lengths = list(range(1, 257)) + [512, 1000, 1024, 2048, 4096, 8192, 10000,
16384, 32768, 65536, 1000000]
-class BufferSizeTest(unittest.TestCase):
+class BufferSizeTest:
def try_one(self, s):
# Write s + "\n" + s to file, then open it and ensure that successive
# .readline()s deliver what we wrote.
@@ -62,15 +62,12 @@ class BufferSizeTest(unittest.TestCase):
self.drive_one(bytes(1000))
-class CBufferSizeTest(BufferSizeTest):
+class CBufferSizeTest(BufferSizeTest, unittest.TestCase):
open = io.open
-class PyBufferSizeTest(BufferSizeTest):
+class PyBufferSizeTest(BufferSizeTest, unittest.TestCase):
open = staticmethod(pyio.open)
-def test_main():
- support.run_unittest(CBufferSizeTest, PyBufferSizeTest)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index 42a1e35..bea88b4 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -1,19 +1,21 @@
# Python test set -- built-in functions
-import platform
-import unittest
-import sys
-import warnings
+import ast
+import builtins
import collections
import io
+import locale
import os
-import ast
-import types
-import builtins
+import pickle
+import platform
import random
+import sys
import traceback
-from test.support import fcmp, TESTFN, unlink, run_unittest, check_warnings
+import types
+import unittest
+import warnings
from operator import neg
+from test.support import TESTFN, unlink, run_unittest, check_warnings
try:
import pty, signal
except ImportError:
@@ -110,7 +112,30 @@ class TestFailingIter:
def __iter__(self):
raise RuntimeError
+def filter_char(arg):
+ return ord(arg) > ord("d")
+
+def map_char(arg):
+ return chr(ord(arg)+1)
+
class BuiltinTest(unittest.TestCase):
+ # Helper to check picklability
+ def check_iter_pickle(self, it, seq):
+ itorg = it
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(list(it), seq)
+
+ #test the iterator after dropping one from it
+ it = pickle.loads(d)
+ try:
+ next(it)
+ except StopIteration:
+ return
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(list(it), seq[1:])
def test_import(self):
__import__('sys')
@@ -255,8 +280,7 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(chr(0xff), '\xff')
self.assertRaises(ValueError, chr, 1<<24)
self.assertEqual(chr(sys.maxunicode),
- str(('\\U%08x' % (sys.maxunicode)).encode("ascii"),
- 'unicode-escape'))
+ str('\\U0010ffff'.encode("ascii"), 'unicode-escape'))
self.assertRaises(TypeError, chr)
self.assertEqual(chr(0x0000FFFF), "\U0000FFFF")
self.assertEqual(chr(0x00010000), "\U00010000")
@@ -378,7 +402,15 @@ class BuiltinTest(unittest.TestCase):
f = Foo()
self.assertTrue(dir(f) == ["ga", "kan", "roo"])
- # dir(obj__dir__not_list)
+ # dir(obj__dir__tuple)
+ class Foo(object):
+ def __dir__(self):
+ return ("b", "c", "a")
+ res = dir(Foo())
+ self.assertIsInstance(res, list)
+ self.assertTrue(res == ["a", "b", "c"])
+
+ # dir(obj__dir__not_sequence)
class Foo(object):
def __dir__(self):
return 7
@@ -391,6 +423,8 @@ class BuiltinTest(unittest.TestCase):
except:
self.assertEqual(len(dir(sys.exc_info()[2])), 4)
+ # test that object has a __dir__()
+ self.assertEqual(sorted([].__dir__()), dir([]))
def test_divmod(self):
self.assertEqual(divmod(12, 7), (1, 5))
@@ -400,10 +434,13 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(divmod(-sys.maxsize-1, -1), (sys.maxsize+1, 0))
- self.assertTrue(not fcmp(divmod(3.25, 1.0), (3.0, 0.25)))
- self.assertTrue(not fcmp(divmod(-3.25, 1.0), (-4.0, 0.75)))
- self.assertTrue(not fcmp(divmod(3.25, -1.0), (-4.0, -0.75)))
- self.assertTrue(not fcmp(divmod(-3.25, -1.0), (3.0, -0.25)))
+ for num, denom, exp_result in [ (3.25, 1.0, (3.0, 0.25)),
+ (-3.25, 1.0, (-4.0, 0.75)),
+ (3.25, -1.0, (-4.0, -0.75)),
+ (-3.25, -1.0, (3.0, -0.25))]:
+ result = divmod(num, denom)
+ self.assertAlmostEqual(result[0], exp_result[0])
+ self.assertAlmostEqual(result[1], exp_result[1])
self.assertRaises(TypeError, divmod)
@@ -518,6 +555,39 @@ class BuiltinTest(unittest.TestCase):
del l['__builtins__']
self.assertEqual((g, l), ({'a': 1}, {'b': 2}))
+ def test_exec_globals(self):
+ code = compile("print('Hello World!')", "", "exec")
+ # no builtin function
+ self.assertRaisesRegex(NameError, "name 'print' is not defined",
+ exec, code, {'__builtins__': {}})
+ # __builtins__ must be a mapping type
+ self.assertRaises(TypeError,
+ exec, code, {'__builtins__': 123})
+
+ # no __build_class__ function
+ code = compile("class A: pass", "", "exec")
+ self.assertRaisesRegex(NameError, "__build_class__ not found",
+ exec, code, {'__builtins__': {}})
+
+ class frozendict_error(Exception):
+ pass
+
+ class frozendict(dict):
+ def __setitem__(self, key, value):
+ raise frozendict_error("frozendict is readonly")
+
+ # read-only builtins
+ frozen_builtins = frozendict(__builtins__)
+ code = compile("__builtins__['superglobal']=2; print(superglobal)", "test", "exec")
+ self.assertRaises(frozendict_error,
+ exec, code, {'__builtins__': frozen_builtins})
+
+ # read-only globals
+ namespace = frozendict({})
+ code = compile("x=1", "test", "exec")
+ self.assertRaises(frozendict_error,
+ exec, code, namespace)
+
def test_exec_redirected(self):
savestdout = sys.stdout
sys.stdout = None # Whatever that cannot flush()
@@ -554,6 +624,11 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(list(filter(lambda x: x>=3, (1, 2, 3, 4))), [3, 4])
self.assertRaises(TypeError, list, filter(42, (1, 2)))
+ def test_filter_pickle(self):
+ f1 = filter(filter_char, "abcdeabcde")
+ f2 = filter(filter_char, "abcdeabcde")
+ self.check_iter_pickle(f1, list(f2))
+
def test_getattr(self):
self.assertTrue(getattr(sys, 'stdout') is sys.stdout)
self.assertRaises(TypeError, getattr, sys, 1)
@@ -747,6 +822,11 @@ class BuiltinTest(unittest.TestCase):
raise RuntimeError
self.assertRaises(RuntimeError, list, map(badfunc, range(5)))
+ def test_map_pickle(self):
+ m1 = map(map_char, "Is this the real life?")
+ m2 = map(map_char, "Is this the real life?")
+ self.check_iter_pickle(m1, list(m2))
+
def test_max(self):
self.assertEqual(max('123123'), '3')
self.assertEqual(max(1, 2, 3), 3)
@@ -880,7 +960,29 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(fp.read(1000), 'YYY'*100)
finally:
fp.close()
- unlink(TESTFN)
+ unlink(TESTFN)
+
+ def test_open_default_encoding(self):
+ old_environ = dict(os.environ)
+ try:
+ # try to get a user preferred encoding different than the current
+ # locale encoding to check that open() uses the current locale
+ # encoding and not the user preferred encoding
+ for key in ('LC_ALL', 'LANG', 'LC_CTYPE'):
+ if key in os.environ:
+ del os.environ[key]
+
+ self.write_testfile()
+ current_locale_encoding = locale.getpreferredencoding(False)
+ fp = open(TESTFN, 'w')
+ try:
+ self.assertEqual(fp.encoding, current_locale_encoding)
+ finally:
+ fp.close()
+ unlink(TESTFN)
+ finally:
+ os.environ.clear()
+ os.environ.update(old_environ)
def test_ord(self):
self.assertEqual(ord(' '), 32)
@@ -1070,6 +1172,8 @@ class BuiltinTest(unittest.TestCase):
# Check stdin/stdout error handler is used when invoking GNU readline
self.check_input_tty("prompté", b"quux\xe9", "ascii")
+ # test_int(): see test_int.py for tests of built-in function int().
+
def test_repr(self):
self.assertEqual(repr(''), '\'\'')
self.assertEqual(repr(0), '0')
@@ -1198,6 +1302,9 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, sum, 42)
self.assertRaises(TypeError, sum, ['a', 'b', 'c'])
self.assertRaises(TypeError, sum, ['a', 'b', 'c'], '')
+ self.assertRaises(TypeError, sum, [b'a', b'c'], b'')
+ values = [bytearray(b'a'), bytearray(b'b')]
+ self.assertRaises(TypeError, sum, values, bytearray(b''))
self.assertRaises(TypeError, sum, [[1], [2], [3]])
self.assertRaises(TypeError, sum, [{2:3}])
self.assertRaises(TypeError, sum, [{2:3}]*2, {2:3})
@@ -1286,6 +1393,13 @@ class BuiltinTest(unittest.TestCase):
return i
self.assertRaises(ValueError, list, zip(BadSeq(), BadSeq()))
+ def test_zip_pickle(self):
+ a = (1, 2, 3)
+ b = (4, 5, 6)
+ t = [(1, 4), (2, 5), (3, 6)]
+ z1 = zip(a, b)
+ self.check_iter_pickle(z1, t)
+
def test_format(self):
# Test the basic machinery of the format() builtin. Don't test
# the specifics of the various formatters
@@ -1359,14 +1473,14 @@ class BuiltinTest(unittest.TestCase):
# --------------------------------------------------------------------
# Issue #7994: object.__format__ with a non-empty format string is
- # pending deprecated
+ # deprecated
def test_deprecated_format_string(obj, fmt_str, should_raise_warning):
with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always", PendingDeprecationWarning)
+ warnings.simplefilter("always", DeprecationWarning)
format(obj, fmt_str)
if should_raise_warning:
self.assertEqual(len(w), 1)
- self.assertIsInstance(w[0].message, PendingDeprecationWarning)
+ self.assertIsInstance(w[0].message, DeprecationWarning)
self.assertIn('object.__format__ with a non-empty format '
'string', str(w[0].message))
else:
@@ -1410,6 +1524,13 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(ValueError, x.translate, b"1", 1)
self.assertRaises(TypeError, x.translate, b"1"*256, 1)
+ def test_construct_singletons(self):
+ for const in None, Ellipsis, NotImplemented:
+ tp = type(const)
+ self.assertIs(tp(), const)
+ self.assertRaises(TypeError, tp, 1, 2)
+ self.assertRaises(TypeError, tp, a=1, b=2)
+
class TestSorted(unittest.TestCase):
def test_basic(self):
diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
index ba40915..3520e83 100644
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -38,7 +38,7 @@ class Indexable:
return self.value
-class BaseBytesTest(unittest.TestCase):
+class BaseBytesTest:
def test_basics(self):
b = self.type2test()
@@ -188,24 +188,26 @@ class BaseBytesTest(unittest.TestCase):
def test_encoding(self):
sample = "Hello world\n\u1234\u5678\u9abc"
- for enc in ("utf8", "utf16"):
+ for enc in ("utf-8", "utf-16"):
b = self.type2test(sample, enc)
self.assertEqual(b, self.type2test(sample.encode(enc)))
- self.assertRaises(UnicodeEncodeError, self.type2test, sample, "latin1")
- b = self.type2test(sample, "latin1", "ignore")
+ self.assertRaises(UnicodeEncodeError, self.type2test, sample, "latin-1")
+ b = self.type2test(sample, "latin-1", "ignore")
self.assertEqual(b, self.type2test(sample[:-3], "utf-8"))
def test_decode(self):
sample = "Hello world\n\u1234\u5678\u9abc\def0\def0"
- for enc in ("utf8", "utf16"):
+ for enc in ("utf-8", "utf-16"):
b = self.type2test(sample, enc)
self.assertEqual(b.decode(enc), sample)
sample = "Hello world\n\x80\x81\xfe\xff"
- b = self.type2test(sample, "latin1")
- self.assertRaises(UnicodeDecodeError, b.decode, "utf8")
- self.assertEqual(b.decode("utf8", "ignore"), "Hello world\n")
- self.assertEqual(b.decode(errors="ignore", encoding="utf8"),
+ b = self.type2test(sample, "latin-1")
+ self.assertRaises(UnicodeDecodeError, b.decode, "utf-8")
+ self.assertEqual(b.decode("utf-8", "ignore"), "Hello world\n")
+ self.assertEqual(b.decode(errors="ignore", encoding="utf-8"),
"Hello world\n")
+ # Default encoding is utf-8
+ self.assertEqual(self.type2test(b'\xe2\x98\x83').decode(), '\u2603')
def test_from_int(self):
b = self.type2test(0)
@@ -291,10 +293,27 @@ class BaseBytesTest(unittest.TestCase):
def test_count(self):
b = self.type2test(b'mississippi')
+ i = 105
+ p = 112
+ w = 119
+
self.assertEqual(b.count(b'i'), 4)
self.assertEqual(b.count(b'ss'), 2)
self.assertEqual(b.count(b'w'), 0)
+ self.assertEqual(b.count(i), 4)
+ self.assertEqual(b.count(w), 0)
+
+ self.assertEqual(b.count(b'i', 6), 2)
+ self.assertEqual(b.count(b'p', 6), 2)
+ self.assertEqual(b.count(b'i', 1, 3), 1)
+ self.assertEqual(b.count(b'p', 7, 9), 1)
+
+ self.assertEqual(b.count(i, 6), 2)
+ self.assertEqual(b.count(p, 6), 2)
+ self.assertEqual(b.count(i, 1, 3), 1)
+ self.assertEqual(b.count(p, 7, 9), 1)
+
def test_startswith(self):
b = self.type2test(b'hello')
self.assertFalse(self.type2test().startswith(b"anything"))
@@ -325,35 +344,86 @@ class BaseBytesTest(unittest.TestCase):
def test_find(self):
b = self.type2test(b'mississippi')
+ i = 105
+ w = 119
+
self.assertEqual(b.find(b'ss'), 2)
+ self.assertEqual(b.find(b'w'), -1)
+ self.assertEqual(b.find(b'mississippian'), -1)
+
+ self.assertEqual(b.find(i), 1)
+ self.assertEqual(b.find(w), -1)
+
self.assertEqual(b.find(b'ss', 3), 5)
self.assertEqual(b.find(b'ss', 1, 7), 2)
self.assertEqual(b.find(b'ss', 1, 3), -1)
- self.assertEqual(b.find(b'w'), -1)
- self.assertEqual(b.find(b'mississippian'), -1)
+
+ self.assertEqual(b.find(i, 6), 7)
+ self.assertEqual(b.find(i, 1, 3), 1)
+ self.assertEqual(b.find(w, 1, 3), -1)
+
+ for index in (-1, 256, sys.maxsize + 1):
+ self.assertRaisesRegex(
+ ValueError, r'byte must be in range\(0, 256\)',
+ b.find, index)
def test_rfind(self):
b = self.type2test(b'mississippi')
+ i = 105
+ w = 119
+
self.assertEqual(b.rfind(b'ss'), 5)
- self.assertEqual(b.rfind(b'ss', 3), 5)
- self.assertEqual(b.rfind(b'ss', 0, 6), 2)
self.assertEqual(b.rfind(b'w'), -1)
self.assertEqual(b.rfind(b'mississippian'), -1)
+ self.assertEqual(b.rfind(i), 10)
+ self.assertEqual(b.rfind(w), -1)
+
+ self.assertEqual(b.rfind(b'ss', 3), 5)
+ self.assertEqual(b.rfind(b'ss', 0, 6), 2)
+
+ self.assertEqual(b.rfind(i, 1, 3), 1)
+ self.assertEqual(b.rfind(i, 3, 9), 7)
+ self.assertEqual(b.rfind(w, 1, 3), -1)
+
def test_index(self):
- b = self.type2test(b'world')
- self.assertEqual(b.index(b'w'), 0)
- self.assertEqual(b.index(b'orl'), 1)
- self.assertRaises(ValueError, b.index, b'worm')
- self.assertRaises(ValueError, b.index, b'ldo')
+ b = self.type2test(b'mississippi')
+ i = 105
+ w = 119
+
+ self.assertEqual(b.index(b'ss'), 2)
+ self.assertRaises(ValueError, b.index, b'w')
+ self.assertRaises(ValueError, b.index, b'mississippian')
+
+ self.assertEqual(b.index(i), 1)
+ self.assertRaises(ValueError, b.index, w)
+
+ self.assertEqual(b.index(b'ss', 3), 5)
+ self.assertEqual(b.index(b'ss', 1, 7), 2)
+ self.assertRaises(ValueError, b.index, b'ss', 1, 3)
+
+ self.assertEqual(b.index(i, 6), 7)
+ self.assertEqual(b.index(i, 1, 3), 1)
+ self.assertRaises(ValueError, b.index, w, 1, 3)
def test_rindex(self):
- # XXX could be more rigorous
- b = self.type2test(b'world')
- self.assertEqual(b.rindex(b'w'), 0)
- self.assertEqual(b.rindex(b'orl'), 1)
- self.assertRaises(ValueError, b.rindex, b'worm')
- self.assertRaises(ValueError, b.rindex, b'ldo')
+ b = self.type2test(b'mississippi')
+ i = 105
+ w = 119
+
+ self.assertEqual(b.rindex(b'ss'), 5)
+ self.assertRaises(ValueError, b.rindex, b'w')
+ self.assertRaises(ValueError, b.rindex, b'mississippian')
+
+ self.assertEqual(b.rindex(i), 10)
+ self.assertRaises(ValueError, b.rindex, w)
+
+ self.assertEqual(b.rindex(b'ss', 3), 5)
+ self.assertEqual(b.rindex(b'ss', 0, 6), 2)
+
+ self.assertEqual(b.rindex(i, 1, 3), 1)
+ self.assertEqual(b.rindex(i, 3, 9), 7)
+ self.assertRaises(ValueError, b.rindex, w, 1, 3)
def test_replace(self):
b = self.type2test(b'mississippi')
@@ -365,6 +435,14 @@ class BaseBytesTest(unittest.TestCase):
self.assertEqual(b.split(b'i'), [b'm', b'ss', b'ss', b'pp', b''])
self.assertEqual(b.split(b'ss'), [b'mi', b'i', b'ippi'])
self.assertEqual(b.split(b'w'), [b])
+ # with keyword args
+ b = self.type2test(b'a|b|c|d')
+ self.assertEqual(b.split(sep=b'|'), [b'a', b'b', b'c', b'd'])
+ self.assertEqual(b.split(b'|', maxsplit=1), [b'a', b'b|c|d'])
+ self.assertEqual(b.split(sep=b'|', maxsplit=1), [b'a', b'b|c|d'])
+ self.assertEqual(b.split(maxsplit=1, sep=b'|'), [b'a', b'b|c|d'])
+ b = self.type2test(b'a b c d')
+ self.assertEqual(b.split(maxsplit=1), [b'a', b'b c d'])
def test_split_whitespace(self):
for b in (b' arf barf ', b'arf\tbarf', b'arf\nbarf', b'arf\rbarf',
@@ -393,6 +471,14 @@ class BaseBytesTest(unittest.TestCase):
self.assertEqual(b.rsplit(b'i'), [b'm', b'ss', b'ss', b'pp', b''])
self.assertEqual(b.rsplit(b'ss'), [b'mi', b'i', b'ippi'])
self.assertEqual(b.rsplit(b'w'), [b])
+ # with keyword args
+ b = self.type2test(b'a|b|c|d')
+ self.assertEqual(b.rsplit(sep=b'|'), [b'a', b'b', b'c', b'd'])
+ self.assertEqual(b.rsplit(b'|', maxsplit=1), [b'a|b|c', b'd'])
+ self.assertEqual(b.rsplit(sep=b'|', maxsplit=1), [b'a|b|c', b'd'])
+ self.assertEqual(b.rsplit(maxsplit=1, sep=b'|'), [b'a|b|c', b'd'])
+ b = self.type2test(b'a b c d')
+ self.assertEqual(b.rsplit(maxsplit=1), [b'a b c', b'd'])
def test_rsplit_whitespace(self):
for b in (b' arf barf ', b'arf\tbarf', b'arf\nbarf', b'arf\rbarf',
@@ -432,6 +518,24 @@ class BaseBytesTest(unittest.TestCase):
q = pickle.loads(ps)
self.assertEqual(b, q)
+ def test_iterator_pickling(self):
+ for b in b"", b"a", b"abc", b"\xffab\x80", b"\0\0\377\0\0":
+ it = itorg = iter(self.type2test(b))
+ data = list(self.type2test(b))
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(list(it), data)
+
+ it = pickle.loads(d)
+ try:
+ next(it)
+ except StopIteration:
+ continue
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(list(it), data[1:])
+
def test_strip(self):
b = self.type2test(b'mississippi')
self.assertEqual(b.strip(b'i'), b'mississipp')
@@ -473,6 +577,27 @@ class BaseBytesTest(unittest.TestCase):
self.assertRaises(TypeError, self.type2test(b'abc').lstrip, 'b')
self.assertRaises(TypeError, self.type2test(b'abc').rstrip, 'b')
+ def test_center(self):
+ # Fill character can be either bytes or bytearray (issue 12380)
+ b = self.type2test(b'abc')
+ for fill_type in (bytes, bytearray):
+ self.assertEqual(b.center(7, fill_type(b'-')),
+ self.type2test(b'--abc--'))
+
+ def test_ljust(self):
+ # Fill character can be either bytes or bytearray (issue 12380)
+ b = self.type2test(b'abc')
+ for fill_type in (bytes, bytearray):
+ self.assertEqual(b.ljust(7, fill_type(b'-')),
+ self.type2test(b'abc----'))
+
+ def test_rjust(self):
+ # Fill character can be either bytes or bytearray (issue 12380)
+ b = self.type2test(b'abc')
+ for fill_type in (bytes, bytearray):
+ self.assertEqual(b.rjust(7, fill_type(b'-')),
+ self.type2test(b'----abc'))
+
def test_ord(self):
b = self.type2test(b'\0A\x7f\x80\xff')
self.assertEqual([ord(b[i:i+1]) for i in range(len(b))],
@@ -529,6 +654,14 @@ class BaseBytesTest(unittest.TestCase):
self.assertEqual(True, b.startswith(h, None, -2))
self.assertEqual(False, b.startswith(x, None, None))
+ def test_integer_arguments_out_of_byte_range(self):
+ b = self.type2test(b'hello')
+
+ for method in (b.count, b.find, b.index, b.rfind, b.rindex):
+ self.assertRaises(ValueError, method, -1)
+ self.assertRaises(ValueError, method, 256)
+ self.assertRaises(ValueError, method, 9999)
+
def test_find_etc_raise_correct_error_messages(self):
# issue 11828
b = self.type2test(b'hello')
@@ -549,7 +682,7 @@ class BaseBytesTest(unittest.TestCase):
x, None, None, None)
-class BytesTest(BaseBytesTest):
+class BytesTest(BaseBytesTest, unittest.TestCase):
type2test = bytes
def test_buffer_is_readonly(self):
@@ -568,6 +701,12 @@ class BytesTest(BaseBytesTest):
def __bytes__(self):
return None
self.assertRaises(TypeError, bytes, A())
+ class A:
+ def __bytes__(self):
+ return b'a'
+ def __index__(self):
+ return 42
+ self.assertEqual(bytes(A()), b'a')
# Test PyBytes_FromFormat()
def test_from_format(self):
@@ -591,7 +730,7 @@ class BytesTest(BaseBytesTest):
b's:cstr')
-class ByteArrayTest(BaseBytesTest):
+class ByteArrayTest(BaseBytesTest, unittest.TestCase):
type2test = bytearray
def test_nohash(self):
@@ -634,6 +773,39 @@ class ByteArrayTest(BaseBytesTest):
b.reverse()
self.assertFalse(b)
+ def test_clear(self):
+ b = bytearray(b'python')
+ b.clear()
+ self.assertEqual(b, b'')
+
+ b = bytearray(b'')
+ b.clear()
+ self.assertEqual(b, b'')
+
+ b = bytearray(b'')
+ b.append(ord('r'))
+ b.clear()
+ b.append(ord('p'))
+ self.assertEqual(b, b'p')
+
+ def test_copy(self):
+ b = bytearray(b'abc')
+ bb = b.copy()
+ self.assertEqual(bb, b'abc')
+
+ b = bytearray(b'')
+ bb = b.copy()
+ self.assertEqual(bb, b'')
+
+ # test that it's indeed a copy and not a reference
+ b = bytearray(b'abc')
+ bb = b.copy()
+ self.assertEqual(b, bb)
+ self.assertIsNot(b, bb)
+ bb.append(ord('d'))
+ self.assertEqual(bb, b'abcd')
+ self.assertEqual(b, b'abc')
+
def test_regexps(self):
def by(s):
return bytearray(map(ord, s))
@@ -1122,14 +1294,16 @@ class FixedStringTest(test.string_tests.BaseTest):
def test_lower(self):
pass
-class ByteArrayAsStringTest(FixedStringTest):
+class ByteArrayAsStringTest(FixedStringTest, unittest.TestCase):
type2test = bytearray
+ contains_bytes = True
-class BytesAsStringTest(FixedStringTest):
+class BytesAsStringTest(FixedStringTest, unittest.TestCase):
type2test = bytes
+ contains_bytes = True
-class SubclassTest(unittest.TestCase):
+class SubclassTest:
def test_basic(self):
self.assertTrue(issubclass(self.subclass2test, self.type2test))
@@ -1201,7 +1375,7 @@ class ByteArraySubclass(bytearray):
class BytesSubclass(bytes):
pass
-class ByteArraySubclassTest(SubclassTest):
+class ByteArraySubclassTest(SubclassTest, unittest.TestCase):
type2test = bytearray
subclass2test = ByteArraySubclass
@@ -1216,16 +1390,10 @@ class ByteArraySubclassTest(SubclassTest):
self.assertEqual(x, b"abcd")
-class BytesSubclassTest(SubclassTest):
+class BytesSubclassTest(SubclassTest, unittest.TestCase):
type2test = bytes
subclass2test = BytesSubclass
-def test_main():
- test.support.run_unittest(
- BytesTest, AssortedBytesTest, BytesAsStringTest,
- ByteArrayTest, ByteArrayAsStringTest, BytesSubclassTest,
- ByteArraySubclassTest, BytearrayPEP3137Test)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py
index fb104d7..df7e18c 100644
--- a/Lib/test/test_bz2.py
+++ b/Lib/test/test_bz2.py
@@ -1,10 +1,11 @@
#!/usr/bin/env python3
from test import support
-from test.support import TESTFN, _4G, bigmemtest, findfile
+from test.support import TESTFN, bigmemtest, _4G
import unittest
from io import BytesIO
import os
+import random
import subprocess
import sys
@@ -21,13 +22,39 @@ has_cmdline_bunzip2 = sys.platform not in ("win32", "os2emx")
class BaseTest(unittest.TestCase):
"Base for other testcases."
- TEXT = b'root:x:0:0:root:/root:/bin/bash\nbin:x:1:1:bin:/bin:\ndaemon:x:2:2:daemon:/sbin:\nadm:x:3:4:adm:/var/adm:\nlp:x:4:7:lp:/var/spool/lpd:\nsync:x:5:0:sync:/sbin:/bin/sync\nshutdown:x:6:0:shutdown:/sbin:/sbin/shutdown\nhalt:x:7:0:halt:/sbin:/sbin/halt\nmail:x:8:12:mail:/var/spool/mail:\nnews:x:9:13:news:/var/spool/news:\nuucp:x:10:14:uucp:/var/spool/uucp:\noperator:x:11:0:operator:/root:\ngames:x:12:100:games:/usr/games:\ngopher:x:13:30:gopher:/usr/lib/gopher-data:\nftp:x:14:50:FTP User:/var/ftp:/bin/bash\nnobody:x:65534:65534:Nobody:/home:\npostfix:x:100:101:postfix:/var/spool/postfix:\nniemeyer:x:500:500::/home/niemeyer:/bin/bash\npostgres:x:101:102:PostgreSQL Server:/var/lib/pgsql:/bin/bash\nmysql:x:102:103:MySQL server:/var/lib/mysql:/bin/bash\nwww:x:103:104::/var/www:/bin/false\n'
+ TEXT_LINES = [
+ b'root:x:0:0:root:/root:/bin/bash\n',
+ b'bin:x:1:1:bin:/bin:\n',
+ b'daemon:x:2:2:daemon:/sbin:\n',
+ b'adm:x:3:4:adm:/var/adm:\n',
+ b'lp:x:4:7:lp:/var/spool/lpd:\n',
+ b'sync:x:5:0:sync:/sbin:/bin/sync\n',
+ b'shutdown:x:6:0:shutdown:/sbin:/sbin/shutdown\n',
+ b'halt:x:7:0:halt:/sbin:/sbin/halt\n',
+ b'mail:x:8:12:mail:/var/spool/mail:\n',
+ b'news:x:9:13:news:/var/spool/news:\n',
+ b'uucp:x:10:14:uucp:/var/spool/uucp:\n',
+ b'operator:x:11:0:operator:/root:\n',
+ b'games:x:12:100:games:/usr/games:\n',
+ b'gopher:x:13:30:gopher:/usr/lib/gopher-data:\n',
+ b'ftp:x:14:50:FTP User:/var/ftp:/bin/bash\n',
+ b'nobody:x:65534:65534:Nobody:/home:\n',
+ b'postfix:x:100:101:postfix:/var/spool/postfix:\n',
+ b'niemeyer:x:500:500::/home/niemeyer:/bin/bash\n',
+ b'postgres:x:101:102:PostgreSQL Server:/var/lib/pgsql:/bin/bash\n',
+ b'mysql:x:102:103:MySQL server:/var/lib/mysql:/bin/bash\n',
+ b'www:x:103:104::/var/www:/bin/false\n',
+ ]
+ TEXT = b''.join(TEXT_LINES)
DATA = b'BZh91AY&SY.\xc8N\x18\x00\x01>_\x80\x00\x10@\x02\xff\xf0\x01\x07n\x00?\xe7\xff\xe00\x01\x99\xaa\x00\xc0\x03F\x86\x8c#&\x83F\x9a\x03\x06\xa6\xd0\xa6\x93M\x0fQ\xa7\xa8\x06\x804hh\x12$\x11\xa4i4\xf14S\xd2<Q\xb5\x0fH\xd3\xd4\xdd\xd5\x87\xbb\xf8\x94\r\x8f\xafI\x12\xe1\xc9\xf8/E\x00pu\x89\x12]\xc9\xbbDL\nQ\x0e\t1\x12\xdf\xa0\xc0\x97\xac2O9\x89\x13\x94\x0e\x1c7\x0ed\x95I\x0c\xaaJ\xa4\x18L\x10\x05#\x9c\xaf\xba\xbc/\x97\x8a#C\xc8\xe1\x8cW\xf9\xe2\xd0\xd6M\xa7\x8bXa<e\x84t\xcbL\xb3\xa7\xd9\xcd\xd1\xcb\x84.\xaf\xb3\xab\xab\xad`n}\xa0lh\tE,\x8eZ\x15\x17VH>\x88\xe5\xcd9gd6\x0b\n\xe9\x9b\xd5\x8a\x99\xf7\x08.K\x8ev\xfb\xf7xw\xbb\xdf\xa1\x92\xf1\xdd|/";\xa2\xba\x9f\xd5\xb1#A\xb6\xf6\xb3o\xc9\xc5y\\\xebO\xe7\x85\x9a\xbc\xb6f8\x952\xd5\xd7"%\x89>V,\xf7\xa6z\xe2\x9f\xa3\xdf\x11\x11"\xd6E)I\xa9\x13^\xca\xf3r\xd0\x03U\x922\xf26\xec\xb6\xed\x8b\xc3U\x13\x9d\xc5\x170\xa4\xfa^\x92\xacDF\x8a\x97\xd6\x19\xfe\xdd\xb8\xbd\x1a\x9a\x19\xa3\x80ankR\x8b\xe5\xd83]\xa9\xc6\x08\x82f\xf6\xb9"6l$\xb8j@\xc0\x8a\xb0l1..\xbak\x83ls\x15\xbc\xf4\xc1\x13\xbe\xf8E\xb8\x9d\r\xa8\x9dk\x84\xd3n\xfa\xacQ\x07\xb1%y\xaav\xb4\x08\xe0z\x1b\x16\xf5\x04\xe9\xcc\xb9\x08z\x1en7.G\xfc]\xc9\x14\xe1B@\xbb!8`'
- DATA_CRLF = b'BZh91AY&SY\xaez\xbbN\x00\x01H\xdf\x80\x00\x12@\x02\xff\xf0\x01\x07n\x00?\xe7\xff\xe0@\x01\xbc\xc6`\x86*\x8d=M\xa9\x9a\x86\xd0L@\x0fI\xa6!\xa1\x13\xc8\x88jdi\x8d@\x03@\x1a\x1a\x0c\x0c\x83 \x00\xc4h2\x19\x01\x82D\x84e\t\xe8\x99\x89\x19\x1ah\x00\r\x1a\x11\xaf\x9b\x0fG\xf5(\x1b\x1f?\t\x12\xcf\xb5\xfc\x95E\x00ps\x89\x12^\xa4\xdd\xa2&\x05(\x87\x04\x98\x89u\xe40%\xb6\x19\'\x8c\xc4\x89\xca\x07\x0e\x1b!\x91UIFU%C\x994!DI\xd2\xfa\xf0\xf1N8W\xde\x13A\xf5\x9cr%?\x9f3;I45A\xd1\x8bT\xb1<l\xba\xcb_\xc00xY\x17r\x17\x88\x08\x08@\xa0\ry@\x10\x04$)`\xf2\xce\x89z\xb0s\xec\x9b.iW\x9d\x81\xb5-+t\x9f\x1a\'\x97dB\xf5x\xb5\xbe.[.\xd7\x0e\x81\xe7\x08\x1cN`\x88\x10\xca\x87\xc3!"\x80\x92R\xa1/\xd1\xc0\xe6mf\xac\xbd\x99\xcca\xb3\x8780>\xa4\xc7\x8d\x1a\\"\xad\xa1\xabyBg\x15\xb9l\x88\x88\x91k"\x94\xa4\xd4\x89\xae*\xa6\x0b\x10\x0c\xd6\xd4m\xe86\xec\xb5j\x8a\x86j\';\xca.\x01I\xf2\xaaJ\xe8\x88\x8cU+t3\xfb\x0c\n\xa33\x13r2\r\x16\xe0\xb3(\xbf\x1d\x83r\xe7M\xf0D\x1365\xd8\x88\xd3\xa4\x92\xcb2\x06\x04\\\xc1\xb0\xea//\xbek&\xd8\xe6+t\xe5\xa1\x13\xada\x16\xder5"w]\xa2i\xb7[\x97R \xe2IT\xcd;Z\x04dk4\xad\x8a\t\xd3\x81z\x10\xf1:^`\xab\x1f\xc5\xdc\x91N\x14$+\x9e\xae\xd3\x80'
EMPTY_DATA = b'BZh9\x17rE8P\x90\x00\x00\x00\x00'
- with open(findfile("testbz2_bigmem.bz2"), "rb") as f:
- DATA_BIGMEM = f.read()
+ def setUp(self):
+ self.filename = TESTFN
+
+ def tearDown(self):
+ if os.path.isfile(self.filename):
+ os.unlink(self.filename)
if has_cmdline_bunzip2:
def decompress(self, data):
@@ -48,94 +75,152 @@ class BaseTest(unittest.TestCase):
def decompress(self, data):
return bz2.decompress(data)
-
class BZ2FileTest(BaseTest):
"Test BZ2File type miscellaneous methods."
- def setUp(self):
- self.filename = TESTFN
-
- def tearDown(self):
- if os.path.isfile(self.filename):
- os.unlink(self.filename)
-
- def createTempFile(self, crlf=0):
+ def createTempFile(self, streams=1):
with open(self.filename, "wb") as f:
- if crlf:
- data = self.DATA_CRLF
- else:
- data = self.DATA
- f.write(data)
+ f.write(self.DATA * streams)
+
+ def testBadArgs(self):
+ with self.assertRaises(TypeError):
+ BZ2File(123.456)
+ with self.assertRaises(ValueError):
+ BZ2File("/dev/null", "z")
+ with self.assertRaises(ValueError):
+ BZ2File("/dev/null", "rx")
+ with self.assertRaises(ValueError):
+ BZ2File("/dev/null", "rbt")
+ with self.assertRaises(ValueError):
+ BZ2File("/dev/null", compresslevel=0)
+ with self.assertRaises(ValueError):
+ BZ2File("/dev/null", compresslevel=10)
def testRead(self):
- # "Test BZ2File.read()"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
self.assertRaises(TypeError, bz2f.read, None)
self.assertEqual(bz2f.read(), self.TEXT)
+ def testReadMultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ self.assertRaises(TypeError, bz2f.read, None)
+ self.assertEqual(bz2f.read(), self.TEXT * 5)
+
+ def testReadMonkeyMultiStream(self):
+ # Test BZ2File.read() on a multi-stream archive where a stream
+ # boundary coincides with the end of the raw read buffer.
+ buffer_size = bz2._BUFFER_SIZE
+ bz2._BUFFER_SIZE = len(self.DATA)
+ try:
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ self.assertRaises(TypeError, bz2f.read, None)
+ self.assertEqual(bz2f.read(), self.TEXT * 5)
+ finally:
+ bz2._BUFFER_SIZE = buffer_size
+
def testRead0(self):
- # Test BBZ2File.read(0)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
self.assertRaises(TypeError, bz2f.read, None)
self.assertEqual(bz2f.read(0), b"")
def testReadChunk10(self):
- # "Test BZ2File.read() in chunks of 10 bytes"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
text = b''
- while 1:
+ while True:
str = bz2f.read(10)
if not str:
break
text += str
self.assertEqual(text, self.TEXT)
+ def testReadChunk10MultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ text = b''
+ while True:
+ str = bz2f.read(10)
+ if not str:
+ break
+ text += str
+ self.assertEqual(text, self.TEXT * 5)
+
def testRead100(self):
- # "Test BZ2File.read(100)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
self.assertEqual(bz2f.read(100), self.TEXT[:100])
+ def testPeek(self):
+ self.createTempFile()
+ with BZ2File(self.filename) as bz2f:
+ pdata = bz2f.peek()
+ self.assertNotEqual(len(pdata), 0)
+ self.assertTrue(self.TEXT.startswith(pdata))
+ self.assertEqual(bz2f.read(), self.TEXT)
+
+ def testReadInto(self):
+ self.createTempFile()
+ with BZ2File(self.filename) as bz2f:
+ n = 128
+ b = bytearray(n)
+ self.assertEqual(bz2f.readinto(b), n)
+ self.assertEqual(b, self.TEXT[:n])
+ n = len(self.TEXT) - n
+ b = bytearray(len(self.TEXT))
+ self.assertEqual(bz2f.readinto(b), n)
+ self.assertEqual(b[:n], self.TEXT[-n:])
+
def testReadLine(self):
- # "Test BZ2File.readline()"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
self.assertRaises(TypeError, bz2f.readline, None)
- sio = BytesIO(self.TEXT)
- for line in sio.readlines():
+ for line in self.TEXT_LINES:
+ self.assertEqual(bz2f.readline(), line)
+
+ def testReadLineMultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ self.assertRaises(TypeError, bz2f.readline, None)
+ for line in self.TEXT_LINES * 5:
self.assertEqual(bz2f.readline(), line)
def testReadLines(self):
- # "Test BZ2File.readlines()"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
self.assertRaises(TypeError, bz2f.readlines, None)
- sio = BytesIO(self.TEXT)
- self.assertEqual(bz2f.readlines(), sio.readlines())
+ self.assertEqual(bz2f.readlines(), self.TEXT_LINES)
+
+ def testReadLinesMultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ self.assertRaises(TypeError, bz2f.readlines, None)
+ self.assertEqual(bz2f.readlines(), self.TEXT_LINES * 5)
def testIterator(self):
- # "Test iter(BZ2File)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
- sio = BytesIO(self.TEXT)
- self.assertEqual(list(iter(bz2f)), sio.readlines())
+ self.assertEqual(list(iter(bz2f)), self.TEXT_LINES)
+
+ def testIteratorMultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ self.assertEqual(list(iter(bz2f)), self.TEXT_LINES * 5)
def testClosedIteratorDeadlock(self):
- # "Test that iteration on a closed bz2file releases the lock."
- # http://bugs.python.org/issue3309
+ # Issue #3309: Iteration on a closed BZ2File should release the lock.
self.createTempFile()
bz2f = BZ2File(self.filename)
bz2f.close()
self.assertRaises(ValueError, bz2f.__next__)
- # This call will deadlock of the above .__next__ call failed to
+ # This call will deadlock if the above .__next__ call failed to
# release the lock.
self.assertRaises(ValueError, bz2f.readlines)
def testWrite(self):
- # "Test BZ2File.write()"
with BZ2File(self.filename, "w") as bz2f:
self.assertRaises(TypeError, bz2f.write)
bz2f.write(self.TEXT)
@@ -143,10 +228,9 @@ class BZ2FileTest(BaseTest):
self.assertEqual(self.decompress(f.read()), self.TEXT)
def testWriteChunks10(self):
- # "Test BZ2File.write() with chunks of 10 bytes"
with BZ2File(self.filename, "w") as bz2f:
n = 0
- while 1:
+ while True:
str = self.TEXT[n*10:(n+1)*10]
if not str:
break
@@ -155,13 +239,19 @@ class BZ2FileTest(BaseTest):
with open(self.filename, 'rb') as f:
self.assertEqual(self.decompress(f.read()), self.TEXT)
+ def testWriteNonDefaultCompressLevel(self):
+ expected = bz2.compress(self.TEXT, compresslevel=5)
+ with BZ2File(self.filename, "w", compresslevel=5) as bz2f:
+ bz2f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ self.assertEqual(f.read(), expected)
+
def testWriteLines(self):
- # "Test BZ2File.writelines()"
with BZ2File(self.filename, "w") as bz2f:
self.assertRaises(TypeError, bz2f.writelines)
- sio = BytesIO(self.TEXT)
- bz2f.writelines(sio.readlines())
- # patch #1535500
+ bz2f.writelines(self.TEXT_LINES)
+ # Issue #1535500: Calling writelines() on a closed BZ2File
+ # should raise an exception.
self.assertRaises(ValueError, bz2f.writelines, ["a"])
with open(self.filename, 'rb') as f:
self.assertEqual(self.decompress(f.read()), self.TEXT)
@@ -174,39 +264,73 @@ class BZ2FileTest(BaseTest):
self.assertRaises(IOError, bz2f.write, b"a")
self.assertRaises(IOError, bz2f.writelines, [b"a"])
+ def testAppend(self):
+ with BZ2File(self.filename, "w") as bz2f:
+ self.assertRaises(TypeError, bz2f.write)
+ bz2f.write(self.TEXT)
+ with BZ2File(self.filename, "a") as bz2f:
+ self.assertRaises(TypeError, bz2f.write)
+ bz2f.write(self.TEXT)
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.decompress(f.read()), self.TEXT * 2)
+
def testSeekForward(self):
- # "Test BZ2File.seek(150, 0)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
self.assertRaises(TypeError, bz2f.seek)
bz2f.seek(150)
self.assertEqual(bz2f.read(), self.TEXT[150:])
+ def testSeekForwardAcrossStreams(self):
+ self.createTempFile(streams=2)
+ with BZ2File(self.filename) as bz2f:
+ self.assertRaises(TypeError, bz2f.seek)
+ bz2f.seek(len(self.TEXT) + 150)
+ self.assertEqual(bz2f.read(), self.TEXT[150:])
+
def testSeekBackwards(self):
- # "Test BZ2File.seek(-150, 1)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
bz2f.read(500)
bz2f.seek(-150, 1)
self.assertEqual(bz2f.read(), self.TEXT[500-150:])
+ def testSeekBackwardsAcrossStreams(self):
+ self.createTempFile(streams=2)
+ with BZ2File(self.filename) as bz2f:
+ readto = len(self.TEXT) + 100
+ while readto > 0:
+ readto -= len(bz2f.read(readto))
+ bz2f.seek(-150, 1)
+ self.assertEqual(bz2f.read(), self.TEXT[100-150:] + self.TEXT)
+
def testSeekBackwardsFromEnd(self):
- # "Test BZ2File.seek(-150, 2)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
bz2f.seek(-150, 2)
self.assertEqual(bz2f.read(), self.TEXT[len(self.TEXT)-150:])
+ def testSeekBackwardsFromEndAcrossStreams(self):
+ self.createTempFile(streams=2)
+ with BZ2File(self.filename) as bz2f:
+ bz2f.seek(-1000, 2)
+ self.assertEqual(bz2f.read(), (self.TEXT * 2)[-1000:])
+
def testSeekPostEnd(self):
- # "Test BZ2File.seek(150000)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
bz2f.seek(150000)
self.assertEqual(bz2f.tell(), len(self.TEXT))
self.assertEqual(bz2f.read(), b"")
+ def testSeekPostEndMultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ bz2f.seek(150000)
+ self.assertEqual(bz2f.tell(), len(self.TEXT) * 5)
+ self.assertEqual(bz2f.read(), b"")
+
def testSeekPostEndTwice(self):
- # "Test BZ2File.seek(150000) twice"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
bz2f.seek(150000)
@@ -214,27 +338,109 @@ class BZ2FileTest(BaseTest):
self.assertEqual(bz2f.tell(), len(self.TEXT))
self.assertEqual(bz2f.read(), b"")
+ def testSeekPostEndTwiceMultiStream(self):
+ self.createTempFile(streams=5)
+ with BZ2File(self.filename) as bz2f:
+ bz2f.seek(150000)
+ bz2f.seek(150000)
+ self.assertEqual(bz2f.tell(), len(self.TEXT) * 5)
+ self.assertEqual(bz2f.read(), b"")
+
def testSeekPreStart(self):
- # "Test BZ2File.seek(-150, 0)"
self.createTempFile()
with BZ2File(self.filename) as bz2f:
bz2f.seek(-150)
self.assertEqual(bz2f.tell(), 0)
self.assertEqual(bz2f.read(), self.TEXT)
+ def testSeekPreStartMultiStream(self):
+ self.createTempFile(streams=2)
+ with BZ2File(self.filename) as bz2f:
+ bz2f.seek(-150)
+ self.assertEqual(bz2f.tell(), 0)
+ self.assertEqual(bz2f.read(), self.TEXT * 2)
+
+ def testFileno(self):
+ self.createTempFile()
+ with open(self.filename, 'rb') as rawf:
+ bz2f = BZ2File(rawf)
+ try:
+ self.assertEqual(bz2f.fileno(), rawf.fileno())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.fileno)
+
+ def testSeekable(self):
+ bz2f = BZ2File(BytesIO(self.DATA))
+ try:
+ self.assertTrue(bz2f.seekable())
+ bz2f.read()
+ self.assertTrue(bz2f.seekable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.seekable)
+
+ bz2f = BZ2File(BytesIO(), mode="w")
+ try:
+ self.assertFalse(bz2f.seekable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.seekable)
+
+ src = BytesIO(self.DATA)
+ src.seekable = lambda: False
+ bz2f = BZ2File(src)
+ try:
+ self.assertFalse(bz2f.seekable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.seekable)
+
+ def testReadable(self):
+ bz2f = BZ2File(BytesIO(self.DATA))
+ try:
+ self.assertTrue(bz2f.readable())
+ bz2f.read()
+ self.assertTrue(bz2f.readable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.readable)
+
+ bz2f = BZ2File(BytesIO(), mode="w")
+ try:
+ self.assertFalse(bz2f.readable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.readable)
+
+ def testWritable(self):
+ bz2f = BZ2File(BytesIO(self.DATA))
+ try:
+ self.assertFalse(bz2f.writable())
+ bz2f.read()
+ self.assertFalse(bz2f.writable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.writable)
+
+ bz2f = BZ2File(BytesIO(), mode="w")
+ try:
+ self.assertTrue(bz2f.writable())
+ finally:
+ bz2f.close()
+ self.assertRaises(ValueError, bz2f.writable)
+
def testOpenDel(self):
- # "Test opening and deleting a file many times"
self.createTempFile()
for i in range(10000):
o = BZ2File(self.filename)
del o
def testOpenNonexistent(self):
- # "Test opening a nonexistent file"
self.assertRaises(IOError, BZ2File, "/non/existent")
- def testBug1191043(self):
- # readlines() for files containing no newline
+ def testReadlinesNoNewline(self):
+ # Issue #1191043: readlines() fails on a file containing no newline.
data = b'BZh91AY&SY\xd9b\x89]\x00\x00\x00\x03\x80\x04\x00\x02\x00\x0c\x00 \x00!\x9ah3M\x13<]\xc9\x14\xe1BCe\x8a%t'
with open(self.filename, "wb") as f:
f.write(data)
@@ -246,7 +452,6 @@ class BZ2FileTest(BaseTest):
self.assertEqual(xlines, [b'Test'])
def testContextProtocol(self):
- # BZ2File supports the context management protocol
f = None
with BZ2File(self.filename, "wb") as f:
f.write(b"xxx")
@@ -269,7 +474,7 @@ class BZ2FileTest(BaseTest):
@unittest.skipUnless(threading, 'Threading required for this test.')
def testThreading(self):
- # Using a BZ2File from several threads doesn't deadlock (issue #7205).
+ # Issue #7205: Using a BZ2File from several threads shouldn't deadlock.
data = b"1" * 2**20
nthreads = 10
with bz2.BZ2File(self.filename, 'wb') as f:
@@ -282,40 +487,112 @@ class BZ2FileTest(BaseTest):
for t in threads:
t.join()
- def testMixedIterationReads(self):
- # Issue #8397: mixed iteration and reads should be forbidden.
- with bz2.BZ2File(self.filename, 'wb') as f:
- # The internal buffer size is hard-wired to 8192 bytes, we must
- # write out more than that for the test to stop half through
- # the buffer.
- f.write(self.TEXT * 100)
- with bz2.BZ2File(self.filename, 'rb') as f:
- next(f)
- self.assertRaises(ValueError, f.read)
- self.assertRaises(ValueError, f.readline)
- self.assertRaises(ValueError, f.readlines)
+ def testWithoutThreading(self):
+ bz2 = support.import_fresh_module("bz2", blocked=("threading",))
+ with bz2.BZ2File(self.filename, "wb") as f:
+ f.write(b"abc")
+ with bz2.BZ2File(self.filename, "rb") as f:
+ self.assertEqual(f.read(), b"abc")
+
+ def testMixedIterationAndReads(self):
+ self.createTempFile()
+ linelen = len(self.TEXT_LINES[0])
+ halflen = linelen // 2
+ with bz2.BZ2File(self.filename) as bz2f:
+ bz2f.read(halflen)
+ self.assertEqual(next(bz2f), self.TEXT_LINES[0][halflen:])
+ self.assertEqual(bz2f.read(), self.TEXT[linelen:])
+ with bz2.BZ2File(self.filename) as bz2f:
+ bz2f.readline()
+ self.assertEqual(next(bz2f), self.TEXT_LINES[1])
+ self.assertEqual(bz2f.readline(), self.TEXT_LINES[2])
+ with bz2.BZ2File(self.filename) as bz2f:
+ bz2f.readlines()
+ with self.assertRaises(StopIteration):
+ next(bz2f)
+ self.assertEqual(bz2f.readlines(), [])
+
+ def testMultiStreamOrdering(self):
+ # Test the ordering of streams when reading a multi-stream archive.
+ data1 = b"foo" * 1000
+ data2 = b"bar" * 1000
+ with BZ2File(self.filename, "w") as bz2f:
+ bz2f.write(data1)
+ with BZ2File(self.filename, "a") as bz2f:
+ bz2f.write(data2)
+ with BZ2File(self.filename) as bz2f:
+ self.assertEqual(bz2f.read(), data1 + data2)
+
+ def testOpenBytesFilename(self):
+ str_filename = self.filename
+ try:
+ bytes_filename = str_filename.encode("ascii")
+ except UnicodeEncodeError:
+ self.skipTest("Temporary file name needs to be ASCII")
+ with BZ2File(bytes_filename, "wb") as f:
+ f.write(self.DATA)
+ with BZ2File(bytes_filename, "rb") as f:
+ self.assertEqual(f.read(), self.DATA)
+ # Sanity check that we are actually operating on the right file.
+ with BZ2File(str_filename, "rb") as f:
+ self.assertEqual(f.read(), self.DATA)
+
+
+ # Tests for a BZ2File wrapping another file object:
+
+ def testReadBytesIO(self):
+ with BytesIO(self.DATA) as bio:
+ with BZ2File(bio) as bz2f:
+ self.assertRaises(TypeError, bz2f.read, None)
+ self.assertEqual(bz2f.read(), self.TEXT)
+ self.assertFalse(bio.closed)
+
+ def testPeekBytesIO(self):
+ with BytesIO(self.DATA) as bio:
+ with BZ2File(bio) as bz2f:
+ pdata = bz2f.peek()
+ self.assertNotEqual(len(pdata), 0)
+ self.assertTrue(self.TEXT.startswith(pdata))
+ self.assertEqual(bz2f.read(), self.TEXT)
+
+ def testWriteBytesIO(self):
+ with BytesIO() as bio:
+ with BZ2File(bio, "w") as bz2f:
+ self.assertRaises(TypeError, bz2f.write)
+ bz2f.write(self.TEXT)
+ self.assertEqual(self.decompress(bio.getvalue()), self.TEXT)
+ self.assertFalse(bio.closed)
+
+ def testSeekForwardBytesIO(self):
+ with BytesIO(self.DATA) as bio:
+ with BZ2File(bio) as bz2f:
+ self.assertRaises(TypeError, bz2f.seek)
+ bz2f.seek(150)
+ self.assertEqual(bz2f.read(), self.TEXT[150:])
+
+ def testSeekBackwardsBytesIO(self):
+ with BytesIO(self.DATA) as bio:
+ with BZ2File(bio) as bz2f:
+ bz2f.read(500)
+ bz2f.seek(-150, 1)
+ self.assertEqual(bz2f.read(), self.TEXT[500-150:])
def test_read_truncated(self):
# Drop the eos_magic field (6 bytes) and CRC (4 bytes).
truncated = self.DATA[:-10]
- with open(self.filename, 'wb') as f:
- f.write(truncated)
- with BZ2File(self.filename) as f:
+ with BZ2File(BytesIO(truncated)) as f:
self.assertRaises(EOFError, f.read)
- with BZ2File(self.filename) as f:
+ with BZ2File(BytesIO(truncated)) as f:
self.assertEqual(f.read(len(self.TEXT)), self.TEXT)
self.assertRaises(EOFError, f.read, 1)
# Incomplete 4-byte file header, and block header of at least 146 bits.
for i in range(22):
- with open(self.filename, 'wb') as f:
- f.write(truncated[:i])
- with BZ2File(self.filename) as f:
+ with BZ2File(BytesIO(truncated[:i])) as f:
self.assertRaises(EOFError, f.read, 1)
class BZ2CompressorTest(BaseTest):
def testCompress(self):
- # "Test BZ2Compressor.compress()/flush()"
bz2c = BZ2Compressor()
self.assertRaises(TypeError, bz2c.compress)
data = bz2c.compress(self.TEXT)
@@ -329,11 +606,10 @@ class BZ2CompressorTest(BaseTest):
self.assertEqual(data, self.EMPTY_DATA)
def testCompressChunks10(self):
- # "Test BZ2Compressor.compress()/flush() with chunks of 10 bytes"
bz2c = BZ2Compressor()
n = 0
data = b''
- while 1:
+ while True:
str = self.TEXT[n*10:(n+1)*10]
if not str:
break
@@ -342,34 +618,38 @@ class BZ2CompressorTest(BaseTest):
data += bz2c.flush()
self.assertEqual(self.decompress(data), self.TEXT)
- @bigmemtest(size=_4G, memuse=1.25)
- def testBigmem(self, size):
- text = b"a" * size
- bz2c = bz2.BZ2Compressor()
- data = bz2c.compress(text) + bz2c.flush()
- del text
- text = self.decompress(data)
- self.assertEqual(len(text), size)
- self.assertEqual(text.strip(b"a"), b"")
-
+ @bigmemtest(size=_4G + 100, memuse=2)
+ def testCompress4G(self, size):
+ # "Test BZ2Compressor.compress()/flush() with >4GiB input"
+ bz2c = BZ2Compressor()
+ data = b"x" * size
+ try:
+ compressed = bz2c.compress(data)
+ compressed += bz2c.flush()
+ finally:
+ data = None # Release memory
+ data = bz2.decompress(compressed)
+ try:
+ self.assertEqual(len(data), size)
+ self.assertEqual(len(data.strip(b"x")), 0)
+ finally:
+ data = None
class BZ2DecompressorTest(BaseTest):
def test_Constructor(self):
self.assertRaises(TypeError, BZ2Decompressor, 42)
def testDecompress(self):
- # "Test BZ2Decompressor.decompress()"
bz2d = BZ2Decompressor()
self.assertRaises(TypeError, bz2d.decompress)
text = bz2d.decompress(self.DATA)
self.assertEqual(text, self.TEXT)
def testDecompressChunks10(self):
- # "Test BZ2Decompressor.decompress() with chunks of 10 bytes"
bz2d = BZ2Decompressor()
text = b''
n = 0
- while 1:
+ while True:
str = self.DATA[n*10:(n+1)*10]
if not str:
break
@@ -378,7 +658,6 @@ class BZ2DecompressorTest(BaseTest):
self.assertEqual(text, self.TEXT)
def testDecompressUnusedData(self):
- # "Test BZ2Decompressor.decompress() with unused data"
bz2d = BZ2Decompressor()
unused_data = b"this is unused data"
text = bz2d.decompress(self.DATA+unused_data)
@@ -386,25 +665,30 @@ class BZ2DecompressorTest(BaseTest):
self.assertEqual(bz2d.unused_data, unused_data)
def testEOFError(self):
- # "Calling BZ2Decompressor.decompress() after EOS must raise EOFError"
bz2d = BZ2Decompressor()
text = bz2d.decompress(self.DATA)
self.assertRaises(EOFError, bz2d.decompress, b"anything")
self.assertRaises(EOFError, bz2d.decompress, b"")
- @bigmemtest(size=_4G, memuse=1.25, dry_run=False)
- def testBigmem(self, unused_size):
- # Issue #14398: decompression fails when output data is >=2GB.
- text = bz2.BZ2Decompressor().decompress(self.DATA_BIGMEM)
- self.assertEqual(len(text), _4G)
- self.assertEqual(text.strip(b"\0"), b"")
-
-
-class FuncTest(BaseTest):
- "Test module functions"
-
+ @bigmemtest(size=_4G + 100, memuse=3)
+ def testDecompress4G(self, size):
+ # "Test BZ2Decompressor.decompress() with >4GiB input"
+ blocksize = 10 * 1024 * 1024
+ block = random.getrandbits(blocksize * 8).to_bytes(blocksize, 'little')
+ try:
+ data = block * (size // blocksize + 1)
+ compressed = bz2.compress(data)
+ bz2d = BZ2Decompressor()
+ decompressed = bz2d.decompress(compressed)
+ self.assertTrue(decompressed == data)
+ finally:
+ data = None
+ compressed = None
+ decompressed = None
+
+
+class CompressDecompressTest(BaseTest):
def testCompress(self):
- # "Test compress() function"
data = bz2.compress(self.TEXT)
self.assertEqual(self.decompress(data), self.TEXT)
@@ -413,12 +697,10 @@ class FuncTest(BaseTest):
self.assertEqual(text, self.EMPTY_DATA)
def testDecompress(self):
- # "Test decompress() function"
text = bz2.decompress(self.DATA)
self.assertEqual(text, self.TEXT)
def testDecompressEmpty(self):
- # "Test decompress() function with empty string"
text = bz2.decompress(b"")
self.assertEqual(text, b"")
@@ -427,35 +709,117 @@ class FuncTest(BaseTest):
self.assertEqual(text, b'')
def testDecompressIncomplete(self):
- # "Test decompress() function with incomplete data"
self.assertRaises(ValueError, bz2.decompress, self.DATA[:-10])
- @bigmemtest(size=_4G, memuse=1.25)
- def testCompressBigmem(self, size):
- text = b"a" * size
- data = bz2.compress(text)
- del text
- text = self.decompress(data)
- self.assertEqual(len(text), size)
- self.assertEqual(text.strip(b"a"), b"")
-
- @bigmemtest(size=_4G, memuse=1.25, dry_run=False)
- def testDecompressBigmem(self, unused_size):
- # Issue #14398: decompression fails when output data is >=2GB.
- text = bz2.decompress(self.DATA_BIGMEM)
- self.assertEqual(len(text), _4G)
- self.assertEqual(text.strip(b"\0"), b"")
+ def testDecompressMultiStream(self):
+ text = bz2.decompress(self.DATA * 5)
+ self.assertEqual(text, self.TEXT * 5)
+
+
+class OpenTest(BaseTest):
+ def test_binary_modes(self):
+ with bz2.open(self.filename, "wb") as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT)
+ with bz2.open(self.filename, "rb") as f:
+ self.assertEqual(f.read(), self.TEXT)
+ with bz2.open(self.filename, "ab") as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT * 2)
+
+ def test_implicit_binary_modes(self):
+ # Test implicit binary modes (no "b" or "t" in mode string).
+ with bz2.open(self.filename, "w") as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT)
+ with bz2.open(self.filename, "r") as f:
+ self.assertEqual(f.read(), self.TEXT)
+ with bz2.open(self.filename, "a") as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT * 2)
+
+ def test_text_modes(self):
+ text = self.TEXT.decode("ascii")
+ text_native_eol = text.replace("\n", os.linesep)
+ with bz2.open(self.filename, "wt") as f:
+ f.write(text)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read()).decode("ascii")
+ self.assertEqual(file_data, text_native_eol)
+ with bz2.open(self.filename, "rt") as f:
+ self.assertEqual(f.read(), text)
+ with bz2.open(self.filename, "at") as f:
+ f.write(text)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read()).decode("ascii")
+ self.assertEqual(file_data, text_native_eol * 2)
+
+ def test_fileobj(self):
+ with bz2.open(BytesIO(self.DATA), "r") as f:
+ self.assertEqual(f.read(), self.TEXT)
+ with bz2.open(BytesIO(self.DATA), "rb") as f:
+ self.assertEqual(f.read(), self.TEXT)
+ text = self.TEXT.decode("ascii")
+ with bz2.open(BytesIO(self.DATA), "rt") as f:
+ self.assertEqual(f.read(), text)
+
+ def test_bad_params(self):
+ # Test invalid parameter combinations.
+ with self.assertRaises(ValueError):
+ bz2.open(self.filename, "wbt")
+ with self.assertRaises(ValueError):
+ bz2.open(self.filename, "rb", encoding="utf-8")
+ with self.assertRaises(ValueError):
+ bz2.open(self.filename, "rb", errors="ignore")
+ with self.assertRaises(ValueError):
+ bz2.open(self.filename, "rb", newline="\n")
+
+ def test_encoding(self):
+ # Test non-default encoding.
+ text = self.TEXT.decode("ascii")
+ text_native_eol = text.replace("\n", os.linesep)
+ with bz2.open(self.filename, "wt", encoding="utf-16-le") as f:
+ f.write(text)
+ with open(self.filename, "rb") as f:
+ file_data = bz2.decompress(f.read()).decode("utf-16-le")
+ self.assertEqual(file_data, text_native_eol)
+ with bz2.open(self.filename, "rt", encoding="utf-16-le") as f:
+ self.assertEqual(f.read(), text)
+
+ def test_encoding_error_handler(self):
+ # Test with non-default encoding error handler.
+ with bz2.open(self.filename, "wb") as f:
+ f.write(b"foo\xffbar")
+ with bz2.open(self.filename, "rt", encoding="ascii", errors="ignore") \
+ as f:
+ self.assertEqual(f.read(), "foobar")
+
+ def test_newline(self):
+ # Test with explicit newline (universal newline mode disabled).
+ text = self.TEXT.decode("ascii")
+ with bz2.open(self.filename, "wt", newline="\n") as f:
+ f.write(text)
+ with bz2.open(self.filename, "rt", newline="\r") as f:
+ self.assertEqual(f.readlines(), [text])
+
def test_main():
support.run_unittest(
BZ2FileTest,
BZ2CompressorTest,
BZ2DecompressorTest,
- FuncTest
+ CompressDecompressTest,
+ OpenTest,
)
support.reap_children()
if __name__ == '__main__':
test_main()
-
-# vim:ts=4:sw=4
diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py
index 37b2818..8adf5bd 100644
--- a/Lib/test/test_calendar.py
+++ b/Lib/test/test_calendar.py
@@ -5,8 +5,19 @@ from test import support
from test.script_helper import assert_python_ok
import time
import locale
+import sys
import datetime
+result_2004_01_text = """
+ January 2004
+Mo Tu We Th Fr Sa Su
+ 1 2 3 4
+ 5 6 7 8 9 10 11
+12 13 14 15 16 17 18
+19 20 21 22 23 24 25
+26 27 28 29 30 31
+"""
+
result_2004_text = """
2004
@@ -46,11 +57,11 @@ Mo Tu We Th Fr Sa Su Mo Tu We Th Fr Sa Su Mo Tu We Th Fr Sa Su
"""
result_2004_html = """
-<?xml version="1.0" encoding="ascii"?>
+<?xml version="1.0" encoding="%(e)s"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
<html>
<head>
-<meta http-equiv="Content-Type" content="text/html; charset=ascii" />
+<meta http-equiv="Content-Type" content="text/html; charset=%(e)s" />
<link rel="stylesheet" type="text/css" href="calendar.css" />
<title>Calendar for 2004</title>
</head>
@@ -170,6 +181,135 @@ result_2004_html = """
</html>
"""
+result_2004_days = [
+ [[[0, 0, 0, 1, 2, 3, 4],
+ [5, 6, 7, 8, 9, 10, 11],
+ [12, 13, 14, 15, 16, 17, 18],
+ [19, 20, 21, 22, 23, 24, 25],
+ [26, 27, 28, 29, 30, 31, 0]],
+ [[0, 0, 0, 0, 0, 0, 1],
+ [2, 3, 4, 5, 6, 7, 8],
+ [9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22],
+ [23, 24, 25, 26, 27, 28, 29]],
+ [[1, 2, 3, 4, 5, 6, 7],
+ [8, 9, 10, 11, 12, 13, 14],
+ [15, 16, 17, 18, 19, 20, 21],
+ [22, 23, 24, 25, 26, 27, 28],
+ [29, 30, 31, 0, 0, 0, 0]]],
+ [[[0, 0, 0, 1, 2, 3, 4],
+ [5, 6, 7, 8, 9, 10, 11],
+ [12, 13, 14, 15, 16, 17, 18],
+ [19, 20, 21, 22, 23, 24, 25],
+ [26, 27, 28, 29, 30, 0, 0]],
+ [[0, 0, 0, 0, 0, 1, 2],
+ [3, 4, 5, 6, 7, 8, 9],
+ [10, 11, 12, 13, 14, 15, 16],
+ [17, 18, 19, 20, 21, 22, 23],
+ [24, 25, 26, 27, 28, 29, 30],
+ [31, 0, 0, 0, 0, 0, 0]],
+ [[0, 1, 2, 3, 4, 5, 6],
+ [7, 8, 9, 10, 11, 12, 13],
+ [14, 15, 16, 17, 18, 19, 20],
+ [21, 22, 23, 24, 25, 26, 27],
+ [28, 29, 30, 0, 0, 0, 0]]],
+ [[[0, 0, 0, 1, 2, 3, 4],
+ [5, 6, 7, 8, 9, 10, 11],
+ [12, 13, 14, 15, 16, 17, 18],
+ [19, 20, 21, 22, 23, 24, 25],
+ [26, 27, 28, 29, 30, 31, 0]],
+ [[0, 0, 0, 0, 0, 0, 1],
+ [2, 3, 4, 5, 6, 7, 8],
+ [9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22],
+ [23, 24, 25, 26, 27, 28, 29],
+ [30, 31, 0, 0, 0, 0, 0]],
+ [[0, 0, 1, 2, 3, 4, 5],
+ [6, 7, 8, 9, 10, 11, 12],
+ [13, 14, 15, 16, 17, 18, 19],
+ [20, 21, 22, 23, 24, 25, 26],
+ [27, 28, 29, 30, 0, 0, 0]]],
+ [[[0, 0, 0, 0, 1, 2, 3],
+ [4, 5, 6, 7, 8, 9, 10],
+ [11, 12, 13, 14, 15, 16, 17],
+ [18, 19, 20, 21, 22, 23, 24],
+ [25, 26, 27, 28, 29, 30, 31]],
+ [[1, 2, 3, 4, 5, 6, 7],
+ [8, 9, 10, 11, 12, 13, 14],
+ [15, 16, 17, 18, 19, 20, 21],
+ [22, 23, 24, 25, 26, 27, 28],
+ [29, 30, 0, 0, 0, 0, 0]],
+ [[0, 0, 1, 2, 3, 4, 5],
+ [6, 7, 8, 9, 10, 11, 12],
+ [13, 14, 15, 16, 17, 18, 19],
+ [20, 21, 22, 23, 24, 25, 26],
+ [27, 28, 29, 30, 31, 0, 0]]]
+]
+
+result_2004_dates = \
+ [[['12/29/03 12/30/03 12/31/03 01/01/04 01/02/04 01/03/04 01/04/04',
+ '01/05/04 01/06/04 01/07/04 01/08/04 01/09/04 01/10/04 01/11/04',
+ '01/12/04 01/13/04 01/14/04 01/15/04 01/16/04 01/17/04 01/18/04',
+ '01/19/04 01/20/04 01/21/04 01/22/04 01/23/04 01/24/04 01/25/04',
+ '01/26/04 01/27/04 01/28/04 01/29/04 01/30/04 01/31/04 02/01/04'],
+ ['01/26/04 01/27/04 01/28/04 01/29/04 01/30/04 01/31/04 02/01/04',
+ '02/02/04 02/03/04 02/04/04 02/05/04 02/06/04 02/07/04 02/08/04',
+ '02/09/04 02/10/04 02/11/04 02/12/04 02/13/04 02/14/04 02/15/04',
+ '02/16/04 02/17/04 02/18/04 02/19/04 02/20/04 02/21/04 02/22/04',
+ '02/23/04 02/24/04 02/25/04 02/26/04 02/27/04 02/28/04 02/29/04'],
+ ['03/01/04 03/02/04 03/03/04 03/04/04 03/05/04 03/06/04 03/07/04',
+ '03/08/04 03/09/04 03/10/04 03/11/04 03/12/04 03/13/04 03/14/04',
+ '03/15/04 03/16/04 03/17/04 03/18/04 03/19/04 03/20/04 03/21/04',
+ '03/22/04 03/23/04 03/24/04 03/25/04 03/26/04 03/27/04 03/28/04',
+ '03/29/04 03/30/04 03/31/04 04/01/04 04/02/04 04/03/04 04/04/04']],
+ [['03/29/04 03/30/04 03/31/04 04/01/04 04/02/04 04/03/04 04/04/04',
+ '04/05/04 04/06/04 04/07/04 04/08/04 04/09/04 04/10/04 04/11/04',
+ '04/12/04 04/13/04 04/14/04 04/15/04 04/16/04 04/17/04 04/18/04',
+ '04/19/04 04/20/04 04/21/04 04/22/04 04/23/04 04/24/04 04/25/04',
+ '04/26/04 04/27/04 04/28/04 04/29/04 04/30/04 05/01/04 05/02/04'],
+ ['04/26/04 04/27/04 04/28/04 04/29/04 04/30/04 05/01/04 05/02/04',
+ '05/03/04 05/04/04 05/05/04 05/06/04 05/07/04 05/08/04 05/09/04',
+ '05/10/04 05/11/04 05/12/04 05/13/04 05/14/04 05/15/04 05/16/04',
+ '05/17/04 05/18/04 05/19/04 05/20/04 05/21/04 05/22/04 05/23/04',
+ '05/24/04 05/25/04 05/26/04 05/27/04 05/28/04 05/29/04 05/30/04',
+ '05/31/04 06/01/04 06/02/04 06/03/04 06/04/04 06/05/04 06/06/04'],
+ ['05/31/04 06/01/04 06/02/04 06/03/04 06/04/04 06/05/04 06/06/04',
+ '06/07/04 06/08/04 06/09/04 06/10/04 06/11/04 06/12/04 06/13/04',
+ '06/14/04 06/15/04 06/16/04 06/17/04 06/18/04 06/19/04 06/20/04',
+ '06/21/04 06/22/04 06/23/04 06/24/04 06/25/04 06/26/04 06/27/04',
+ '06/28/04 06/29/04 06/30/04 07/01/04 07/02/04 07/03/04 07/04/04']],
+ [['06/28/04 06/29/04 06/30/04 07/01/04 07/02/04 07/03/04 07/04/04',
+ '07/05/04 07/06/04 07/07/04 07/08/04 07/09/04 07/10/04 07/11/04',
+ '07/12/04 07/13/04 07/14/04 07/15/04 07/16/04 07/17/04 07/18/04',
+ '07/19/04 07/20/04 07/21/04 07/22/04 07/23/04 07/24/04 07/25/04',
+ '07/26/04 07/27/04 07/28/04 07/29/04 07/30/04 07/31/04 08/01/04'],
+ ['07/26/04 07/27/04 07/28/04 07/29/04 07/30/04 07/31/04 08/01/04',
+ '08/02/04 08/03/04 08/04/04 08/05/04 08/06/04 08/07/04 08/08/04',
+ '08/09/04 08/10/04 08/11/04 08/12/04 08/13/04 08/14/04 08/15/04',
+ '08/16/04 08/17/04 08/18/04 08/19/04 08/20/04 08/21/04 08/22/04',
+ '08/23/04 08/24/04 08/25/04 08/26/04 08/27/04 08/28/04 08/29/04',
+ '08/30/04 08/31/04 09/01/04 09/02/04 09/03/04 09/04/04 09/05/04'],
+ ['08/30/04 08/31/04 09/01/04 09/02/04 09/03/04 09/04/04 09/05/04',
+ '09/06/04 09/07/04 09/08/04 09/09/04 09/10/04 09/11/04 09/12/04',
+ '09/13/04 09/14/04 09/15/04 09/16/04 09/17/04 09/18/04 09/19/04',
+ '09/20/04 09/21/04 09/22/04 09/23/04 09/24/04 09/25/04 09/26/04',
+ '09/27/04 09/28/04 09/29/04 09/30/04 10/01/04 10/02/04 10/03/04']],
+ [['09/27/04 09/28/04 09/29/04 09/30/04 10/01/04 10/02/04 10/03/04',
+ '10/04/04 10/05/04 10/06/04 10/07/04 10/08/04 10/09/04 10/10/04',
+ '10/11/04 10/12/04 10/13/04 10/14/04 10/15/04 10/16/04 10/17/04',
+ '10/18/04 10/19/04 10/20/04 10/21/04 10/22/04 10/23/04 10/24/04',
+ '10/25/04 10/26/04 10/27/04 10/28/04 10/29/04 10/30/04 10/31/04'],
+ ['11/01/04 11/02/04 11/03/04 11/04/04 11/05/04 11/06/04 11/07/04',
+ '11/08/04 11/09/04 11/10/04 11/11/04 11/12/04 11/13/04 11/14/04',
+ '11/15/04 11/16/04 11/17/04 11/18/04 11/19/04 11/20/04 11/21/04',
+ '11/22/04 11/23/04 11/24/04 11/25/04 11/26/04 11/27/04 11/28/04',
+ '11/29/04 11/30/04 12/01/04 12/02/04 12/03/04 12/04/04 12/05/04'],
+ ['11/29/04 11/30/04 12/01/04 12/02/04 12/03/04 12/04/04 12/05/04',
+ '12/06/04 12/07/04 12/08/04 12/09/04 12/10/04 12/11/04 12/12/04',
+ '12/13/04 12/14/04 12/15/04 12/16/04 12/17/04 12/18/04 12/19/04',
+ '12/20/04 12/21/04 12/22/04 12/23/04 12/24/04 12/25/04 12/26/04',
+ '12/27/04 12/28/04 12/29/04 12/30/04 12/31/04 01/01/05 01/02/05']]]
+
class OutputTestCase(unittest.TestCase):
def normalize_calendar(self, s):
@@ -178,12 +318,19 @@ class OutputTestCase(unittest.TestCase):
return not c.isspace() and not c.isdigit()
lines = []
- for line in s.splitlines(False):
+ for line in s.splitlines(keepends=False):
# Drop texts, as they are locale dependent
if line and not filter(neitherspacenordigit, line):
lines.append(line)
return lines
+ def check_htmlcalendar_encoding(self, req, res):
+ cal = calendar.HTMLCalendar()
+ self.assertEqual(
+ cal.formatyearpage(2004, encoding=req).strip(b' \t\n'),
+ (result_2004_html % {'e': res}).strip(' \t\n').encode(res)
+ )
+
def test_output(self):
self.assertEqual(
self.normalize_calendar(calendar.calendar(2004)),
@@ -196,12 +343,60 @@ class OutputTestCase(unittest.TestCase):
result_2004_text.strip()
)
- def test_output_htmlcalendar(self):
- encoding = 'ascii'
- cal = calendar.HTMLCalendar()
+ def test_output_htmlcalendar_encoding_ascii(self):
+ self.check_htmlcalendar_encoding('ascii', 'ascii')
+
+ def test_output_htmlcalendar_encoding_utf8(self):
+ self.check_htmlcalendar_encoding('utf-8', 'utf-8')
+
+ def test_output_htmlcalendar_encoding_default(self):
+ self.check_htmlcalendar_encoding(None, sys.getdefaultencoding())
+
+ def test_yeardatescalendar(self):
+ def shrink(cal):
+ return [[[' '.join('{:02d}/{:02d}/{}'.format(
+ d.month, d.day, str(d.year)[-2:]) for d in z)
+ for z in y] for y in x] for x in cal]
+ self.assertEqual(
+ shrink(calendar.Calendar().yeardatescalendar(2004)),
+ result_2004_dates
+ )
+
+ def test_yeardayscalendar(self):
self.assertEqual(
- cal.formatyearpage(2004, encoding=encoding).strip(b' \t\n'),
- result_2004_html.strip(' \t\n').encode(encoding)
+ calendar.Calendar().yeardayscalendar(2004),
+ result_2004_days
+ )
+
+ def test_formatweekheader_short(self):
+ self.assertEqual(
+ calendar.TextCalendar().formatweekheader(2),
+ 'Mo Tu We Th Fr Sa Su'
+ )
+
+ def test_formatweekheader_long(self):
+ self.assertEqual(
+ calendar.TextCalendar().formatweekheader(9),
+ ' Monday Tuesday Wednesday Thursday '
+ ' Friday Saturday Sunday '
+ )
+
+ def test_formatmonth(self):
+ self.assertEqual(
+ calendar.TextCalendar().formatmonth(2004, 1).strip(),
+ result_2004_01_text.strip()
+ )
+
+ def test_formatmonthname_with_year(self):
+ self.assertEqual(
+ calendar.HTMLCalendar().formatmonthname(2004, 1, withyear=True),
+ '<tr><th colspan="7" class="month">January 2004</th></tr>'
+ )
+
+ def test_formatmonthname_without_year(self):
+ self.assertEqual(
+ calendar.HTMLCalendar().formatmonthname(2004, 1, withyear=False),
+ '<tr><th colspan="7" class="month">January</th></tr>'
)
@@ -227,7 +422,11 @@ class CalendarTestCase(unittest.TestCase):
self.assertEqual(calendar.firstweekday(), calendar.MONDAY)
calendar.setfirstweekday(orig)
- def test_enumerateweekdays(self):
+ def test_illegal_weekday_reported(self):
+ with self.assertRaisesRegex(calendar.IllegalWeekdayError, '123'):
+ calendar.setfirstweekday(123)
+
+ def test_enumerate_weekdays(self):
self.assertRaises(IndexError, calendar.day_abbr.__getitem__, -10)
self.assertRaises(IndexError, calendar.day_name.__getitem__, 10)
self.assertEqual(len([d for d in calendar.day_abbr]), 7)
@@ -253,7 +452,7 @@ class CalendarTestCase(unittest.TestCase):
# verify it "acts like a sequence" in two forms of iteration
self.assertEqual(value[::-1], list(reversed(value)))
- def test_localecalendars(self):
+ def test_locale_calendars(self):
# ensure that Locale{Text,HTML}Calendar resets the locale properly
# (it is still not thread-safe though)
old_october = calendar.TextCalendar().formatmonthname(2010, 10, 10)
@@ -447,6 +646,10 @@ class MonthRangeTestCase(unittest.TestCase):
with self.assertRaises(calendar.IllegalMonthError):
calendar.monthrange(2004, 13)
+ def test_illegal_month_reported(self):
+ with self.assertRaisesRegex(calendar.IllegalMonthError, '65'):
+ calendar.monthrange(2004, 65)
+
class LeapdaysTestCase(unittest.TestCase):
def test_no_range(self):
# test when no range i.e. two identical years as args
diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py
index e7afa55..5002c77 100644
--- a/Lib/test/test_capi.py
+++ b/Lib/test/test_capi.py
@@ -52,13 +52,36 @@ class CAPITest(unittest.TestCase):
(out, err) = p.communicate()
self.assertEqual(out, b'')
# This used to cause an infinite loop.
- self.assertEqual(err.rstrip(),
+ self.assertTrue(err.rstrip().startswith(
b'Fatal Python error:'
- b' PyThreadState_Get: no current thread')
+ b' PyThreadState_Get: no current thread'))
def test_memoryview_from_NULL_pointer(self):
self.assertRaises(ValueError, _testcapi.make_memoryview_from_NULL_pointer)
+ def test_exc_info(self):
+ raised_exception = ValueError("5")
+ new_exc = TypeError("TEST")
+ try:
+ raise raised_exception
+ except ValueError as e:
+ tb = e.__traceback__
+ orig_sys_exc_info = sys.exc_info()
+ orig_exc_info = _testcapi.set_exc_info(new_exc.__class__, new_exc, None)
+ new_sys_exc_info = sys.exc_info()
+ new_exc_info = _testcapi.set_exc_info(*orig_exc_info)
+ reset_sys_exc_info = sys.exc_info()
+
+ self.assertEqual(orig_exc_info[1], e)
+
+ self.assertSequenceEqual(orig_exc_info, (raised_exception.__class__, raised_exception, tb))
+ self.assertSequenceEqual(orig_sys_exc_info, orig_exc_info)
+ self.assertSequenceEqual(reset_sys_exc_info, orig_exc_info)
+ self.assertSequenceEqual(new_exc_info, (new_exc.__class__, new_exc, None))
+ self.assertSequenceEqual(new_sys_exc_info, new_exc_info)
+ else:
+ self.assertTrue(False)
+
@unittest.skipUnless(_posixsubprocess, '_posixsubprocess required for this test.')
def test_seq_bytes_to_charp_array(self):
# Issue #15732: crash in _PySequence_BytesToCharpArray()
@@ -221,9 +244,91 @@ class EmbeddingTest(unittest.TestCase):
finally:
os.chdir(oldcwd)
+class SkipitemTest(unittest.TestCase):
+
+ def test_skipitem(self):
+ """
+ If this test failed, you probably added a new "format unit"
+ in Python/getargs.c, but neglected to update our poor friend
+ skipitem() in the same file. (If so, shame on you!)
+
+ With a few exceptions**, this function brute-force tests all
+ printable ASCII*** characters (32 to 126 inclusive) as format units,
+ checking to see that PyArg_ParseTupleAndKeywords() return consistent
+ errors both when the unit is attempted to be used and when it is
+ skipped. If the format unit doesn't exist, we'll get one of two
+ specific error messages (one for used, one for skipped); if it does
+ exist we *won't* get that error--we'll get either no error or some
+ other error. If we get the specific "does not exist" error for one
+ test and not for the other, there's a mismatch, and the test fails.
+
+ ** Some format units have special funny semantics and it would
+ be difficult to accomodate them here. Since these are all
+ well-established and properly skipped in skipitem() we can
+ get away with not testing them--this test is really intended
+ to catch *new* format units.
+
+ *** Python C source files must be ASCII. Therefore it's impossible
+ to have non-ASCII format units.
+
+ """
+ empty_tuple = ()
+ tuple_1 = (0,)
+ dict_b = {'b':1}
+ keywords = ["a", "b"]
+
+ for i in range(32, 127):
+ c = chr(i)
+
+ # skip parentheses, the error reporting is inconsistent about them
+ # skip 'e', it's always a two-character code
+ # skip '|' and '$', they don't represent arguments anyway
+ if c in '()e|$':
+ continue
+
+ # test the format unit when not skipped
+ format = c + "i"
+ try:
+ # (note: the format string must be bytes!)
+ _testcapi.parse_tuple_and_keywords(tuple_1, dict_b,
+ format.encode("ascii"), keywords)
+ when_not_skipped = False
+ except TypeError as e:
+ s = "argument 1 must be impossible<bad format char>, not int"
+ when_not_skipped = (str(e) == s)
+ except RuntimeError as e:
+ when_not_skipped = False
+
+ # test the format unit when skipped
+ optional_format = "|" + format
+ try:
+ _testcapi.parse_tuple_and_keywords(empty_tuple, dict_b,
+ optional_format.encode("ascii"), keywords)
+ when_skipped = False
+ except RuntimeError as e:
+ s = "impossible<bad format char>: '{}'".format(format)
+ when_skipped = (str(e) == s)
+
+ message = ("test_skipitem_parity: "
+ "detected mismatch between convertsimple and skipitem "
+ "for format unit '{}' ({}), not skipped {}, skipped {}".format(
+ c, i, when_skipped, when_not_skipped))
+ self.assertIs(when_skipped, when_not_skipped, message)
+
+ def test_parse_tuple_and_keywords(self):
+ # parse_tuple_and_keywords error handling tests
+ self.assertRaises(TypeError, _testcapi.parse_tuple_and_keywords,
+ (), {}, 42, [])
+ self.assertRaises(ValueError, _testcapi.parse_tuple_and_keywords,
+ (), {}, b'', 42)
+ self.assertRaises(ValueError, _testcapi.parse_tuple_and_keywords,
+ (), {}, b'', [''] * 42)
+ self.assertRaises(ValueError, _testcapi.parse_tuple_and_keywords,
+ (), {}, b'', [42])
def test_main():
- support.run_unittest(CAPITest, TestPendingCalls, Test6012, EmbeddingTest)
+ support.run_unittest(CAPITest, TestPendingCalls,
+ Test6012, EmbeddingTest, SkipitemTest)
for name in dir(_testcapi):
if name.startswith('test_'):
@@ -240,18 +345,17 @@ def test_main():
idents = []
def callback():
- idents.append(_thread.get_ident())
+ idents.append(threading.get_ident())
_testcapi._test_thread_state(callback)
a = b = callback
time.sleep(1)
# Check our main thread is in the list exactly 3 times.
- if idents.count(_thread.get_ident()) != 3:
+ if idents.count(threading.get_ident()) != 3:
raise support.TestFailed(
"Couldn't find main thread correctly in the list")
if threading:
- import _thread
import time
TestThreadState()
t = threading.Thread(target=TestThreadState)
diff --git a/Lib/test/test_cgi.py b/Lib/test/test_cgi.py
index 07e760b..cb28aa8 100644
--- a/Lib/test/test_cgi.py
+++ b/Lib/test/test_cgi.py
@@ -4,6 +4,7 @@ import os
import sys
import tempfile
import unittest
+import warnings
from collections import namedtuple
from io import StringIO, BytesIO
@@ -137,9 +138,13 @@ class CgiTests(unittest.TestCase):
self.assertTrue(fs)
def test_escape(self):
- self.assertEqual("test &amp; string", cgi.escape("test & string"))
- self.assertEqual("&lt;test string&gt;", cgi.escape("<test string>"))
- self.assertEqual("&quot;test string&quot;", cgi.escape('"test string"', True))
+ # cgi.escape() is deprecated.
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'cgi\.escape',
+ DeprecationWarning)
+ self.assertEqual("test &amp; string", cgi.escape("test & string"))
+ self.assertEqual("&lt;test string&gt;", cgi.escape("<test string>"))
+ self.assertEqual("&quot;test string&quot;", cgi.escape('"test string"', True))
def test_strict(self):
for orig, expect in parse_strict_test_cases:
@@ -169,8 +174,7 @@ class CgiTests(unittest.TestCase):
def test_log(self):
cgi.log("Testing")
- cgi.logfile = "fail/"
- cgi.initlog("%s", "Testing initlog")
+
cgi.logfp = StringIO()
cgi.initlog("%s", "Testing initlog 1")
cgi.log("%s", "Testing log 2")
@@ -179,13 +183,7 @@ class CgiTests(unittest.TestCase):
cgi.logfp = None
cgi.logfile = "/dev/null"
cgi.initlog("%s", "Testing log 3")
- def log_cleanup():
- """Restore the global state of the log vars."""
- cgi.logfile = ''
- cgi.logfp.close()
- cgi.logfp = None
- cgi.log = cgi.initlog
- self.addCleanup(log_cleanup)
+ self.addCleanup(cgi.closelog)
cgi.log("Testing log 4")
def test_fieldstorage_readline(self):
diff --git a/Lib/test/test_cgitb.py b/Lib/test/test_cgitb.py
new file mode 100644
index 0000000..2e072a9
--- /dev/null
+++ b/Lib/test/test_cgitb.py
@@ -0,0 +1,70 @@
+from test.support import run_unittest
+from test.script_helper import assert_python_failure, temp_dir
+import unittest
+import sys
+import cgitb
+
+class TestCgitb(unittest.TestCase):
+
+ def test_fonts(self):
+ text = "Hello Robbie!"
+ self.assertEqual(cgitb.small(text), "<small>{}</small>".format(text))
+ self.assertEqual(cgitb.strong(text), "<strong>{}</strong>".format(text))
+ self.assertEqual(cgitb.grey(text),
+ '<font color="#909090">{}</font>'.format(text))
+
+ def test_blanks(self):
+ self.assertEqual(cgitb.small(""), "")
+ self.assertEqual(cgitb.strong(""), "")
+ self.assertEqual(cgitb.grey(""), "")
+
+ def test_html(self):
+ try:
+ raise ValueError("Hello World")
+ except ValueError as err:
+ # If the html was templated we could do a bit more here.
+ # At least check that we get details on what we just raised.
+ html = cgitb.html(sys.exc_info())
+ self.assertIn("ValueError", html)
+ self.assertIn(str(err), html)
+
+ def test_text(self):
+ try:
+ raise ValueError("Hello World")
+ except ValueError as err:
+ text = cgitb.text(sys.exc_info())
+ self.assertIn("ValueError", text)
+ self.assertIn("Hello World", text)
+
+ def test_syshook_no_logdir_default_format(self):
+ with temp_dir() as tracedir:
+ rc, out, err = assert_python_failure(
+ '-c',
+ ('import cgitb; cgitb.enable(logdir=%s); '
+ 'raise ValueError("Hello World")') % repr(tracedir))
+ out = out.decode(sys.getfilesystemencoding())
+ self.assertIn("ValueError", out)
+ self.assertIn("Hello World", out)
+ # By default we emit HTML markup.
+ self.assertIn('<p>', out)
+ self.assertIn('</p>', out)
+
+ def test_syshook_no_logdir_text_format(self):
+ # Issue 12890: we were emitting the <p> tag in text mode.
+ with temp_dir() as tracedir:
+ rc, out, err = assert_python_failure(
+ '-c',
+ ('import cgitb; cgitb.enable(format="text", logdir=%s); '
+ 'raise ValueError("Hello World")') % repr(tracedir))
+ out = out.decode(sys.getfilesystemencoding())
+ self.assertIn("ValueError", out)
+ self.assertIn("Hello World", out)
+ self.assertNotIn('<p>', out)
+ self.assertNotIn('</p>', out)
+
+
+def test_main():
+ run_unittest(TestCgitb)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_cmd.py b/Lib/test/test_cmd.py
index 3a46355..6618535 100644
--- a/Lib/test/test_cmd.py
+++ b/Lib/test/test_cmd.py
@@ -228,7 +228,7 @@ def test_main(verbose=None):
def test_coverage(coverdir):
trace = support.import_module('trace')
- tracer=trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix,],
+ tracer=trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,],
trace=0, count=1)
tracer.run('reload(cmd);test_main()')
r=tracer.results()
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index 67375cd..a89d7e4 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -32,12 +32,6 @@ class CmdLineTest(unittest.TestCase):
self.verify_valid_flag('-O')
self.verify_valid_flag('-OO')
- def test_q(self):
- self.verify_valid_flag('-Qold')
- self.verify_valid_flag('-Qnew')
- self.verify_valid_flag('-Qwarn')
- self.verify_valid_flag('-Qwarnall')
-
def test_site_flag(self):
self.verify_valid_flag('-S')
@@ -148,7 +142,7 @@ class CmdLineTest(unittest.TestCase):
@unittest.skipUnless(sys.platform == 'darwin', 'test specific to Mac OS X')
def test_osx_utf8(self):
def check_output(text):
- decoded = text.decode('utf8', 'surrogateescape')
+ decoded = text.decode('utf-8', 'surrogateescape')
expected = ascii(decoded).encode('ascii') + b'\n'
env = os.environ.copy()
@@ -220,7 +214,7 @@ class CmdLineTest(unittest.TestCase):
self.assertIn(path2.encode('ascii'), out)
def test_displayhook_unencodable(self):
- for encoding in ('ascii', 'latin1', 'utf8'):
+ for encoding in ('ascii', 'latin-1', 'utf-8'):
env = os.environ.copy()
env['PYTHONIOENCODING'] = encoding
p = subprocess.Popen(
@@ -296,7 +290,7 @@ class CmdLineTest(unittest.TestCase):
rc, out, err = assert_python_ok('-c', code)
self.assertEqual(b'', out)
self.assertRegex(err.decode('ascii', 'ignore'),
- 'Exception IOError: .* ignored')
+ 'Exception OSError: .* ignored')
def test_closed_stdout(self):
# Issue #13444: if stdout has been explicitly closed, we should
@@ -350,14 +344,14 @@ class CmdLineTest(unittest.TestCase):
hashes = []
for i in range(2):
code = 'print(hash("spam"))'
- rc, out, err = assert_python_ok('-R', '-c', code)
+ rc, out, err = assert_python_ok('-c', code)
self.assertEqual(rc, 0)
hashes.append(out)
self.assertNotEqual(hashes[0], hashes[1])
# Verify that sys.flags contains hash_randomization
code = 'import sys; print("random is", sys.flags.hash_randomization)'
- rc, out, err = assert_python_ok('-R', '-c', code)
+ rc, out, err = assert_python_ok('-c', code)
self.assertEqual(rc, 0)
self.assertIn(b'random is 1', out)
diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py
index 70f7d1e..6051e18 100644
--- a/Lib/test/test_cmd_line_script.py
+++ b/Lib/test/test_cmd_line_script.py
@@ -1,15 +1,20 @@
# tests command line execution of scripts
+import importlib
+import importlib.machinery
+import zipimport
import unittest
import sys
import os
import os.path
import py_compile
+import textwrap
from test import support
from test.script_helper import (
make_pkg, make_script, make_zip_pkg, make_zip_script,
- assert_python_ok, assert_python_failure, temp_dir)
+ assert_python_ok, assert_python_failure, temp_dir,
+ spawn_python, kill_python)
verbose = support.verbose
@@ -32,6 +37,9 @@ f()
assertEqual(result, ['Top level assignment', 'Lower level reference'])
# Check population of magic variables
assertEqual(__name__, '__main__')
+from importlib.machinery import BuiltinImporter
+_loader = __loader__ if __loader__ is BuiltinImporter else type(__loader__)
+print('__loader__==%a' % _loader)
print('__file__==%a' % __file__)
assertEqual(__cached__, None)
print('__package__==%r' % __package__)
@@ -49,12 +57,16 @@ print('cwd==%a' % os.getcwd())
"""
def _make_test_script(script_dir, script_basename, source=test_source):
- return make_script(script_dir, script_basename, source)
+ to_return = make_script(script_dir, script_basename, source)
+ importlib.invalidate_caches()
+ return to_return
def _make_test_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename,
source=test_source, depth=1):
- return make_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename,
- source, depth)
+ to_return = make_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename,
+ source, depth)
+ importlib.invalidate_caches()
+ return to_return
# There's no easy way to pass the script directory in to get
# -m to work (avoiding that is the whole point of making
@@ -72,16 +84,20 @@ def _make_launch_script(script_dir, script_basename, module_name, path=None):
else:
path = repr(path)
source = launch_source % (path, module_name)
- return make_script(script_dir, script_basename, source)
+ to_return = make_script(script_dir, script_basename, source)
+ importlib.invalidate_caches()
+ return to_return
class CmdLineTest(unittest.TestCase):
def _check_output(self, script_name, exit_code, data,
expected_file, expected_argv0,
- expected_path0, expected_package):
+ expected_path0, expected_package,
+ expected_loader):
if verbose > 1:
print("Output from test script %r:" % script_name)
print(data)
self.assertEqual(exit_code, 0)
+ printed_loader = '__loader__==%a' % expected_loader
printed_file = '__file__==%a' % expected_file
printed_package = '__package__==%r' % expected_package
printed_argv0 = 'sys.argv[0]==%a' % expected_argv0
@@ -93,6 +109,7 @@ class CmdLineTest(unittest.TestCase):
print(printed_package)
print(printed_argv0)
print(printed_cwd)
+ self.assertIn(printed_loader.encode('utf-8'), data)
self.assertIn(printed_file.encode('utf-8'), data)
self.assertIn(printed_package.encode('utf-8'), data)
self.assertIn(printed_argv0.encode('utf-8'), data)
@@ -101,14 +118,15 @@ class CmdLineTest(unittest.TestCase):
def _check_script(self, script_name, expected_file,
expected_argv0, expected_path0,
- expected_package,
+ expected_package, expected_loader,
*cmd_line_switches):
if not __debug__:
cmd_line_switches += ('-' + 'O' * sys.flags.optimize,)
run_args = cmd_line_switches + (script_name,) + tuple(example_args)
rc, out, err = assert_python_ok(*run_args)
self._check_output(script_name, rc, out + err, expected_file,
- expected_argv0, expected_path0, expected_package)
+ expected_argv0, expected_path0,
+ expected_package, expected_loader)
def _check_import_error(self, script_name, expected_msg,
*cmd_line_switches):
@@ -120,11 +138,30 @@ class CmdLineTest(unittest.TestCase):
print('Expected output: %r' % expected_msg)
self.assertIn(expected_msg.encode('utf-8'), err)
+ def test_dash_c_loader(self):
+ rc, out, err = assert_python_ok("-c", "print(__loader__)")
+ expected = repr(importlib.machinery.BuiltinImporter).encode("utf-8")
+ self.assertIn(expected, out)
+
+ def test_stdin_loader(self):
+ # Unfortunately, there's no way to automatically test the fully
+ # interactive REPL, since that code path only gets executed when
+ # stdin is an interactive tty.
+ p = spawn_python()
+ try:
+ p.stdin.write(b"print(__loader__)\n")
+ p.stdin.flush()
+ finally:
+ out = kill_python(p)
+ expected = repr(importlib.machinery.BuiltinImporter).encode("utf-8")
+ self.assertIn(expected, out)
+
def test_basic_script(self):
with temp_dir() as script_dir:
script_name = _make_test_script(script_dir, 'script')
self._check_script(script_name, script_name, script_name,
- script_dir, None)
+ script_dir, None,
+ importlib.machinery.SourceFileLoader)
def test_script_compiled(self):
with temp_dir() as script_dir:
@@ -133,13 +170,15 @@ class CmdLineTest(unittest.TestCase):
os.remove(script_name)
pyc_file = support.make_legacy_pyc(script_name)
self._check_script(pyc_file, pyc_file,
- pyc_file, script_dir, None)
+ pyc_file, script_dir, None,
+ importlib.machinery.SourcelessFileLoader)
def test_directory(self):
with temp_dir() as script_dir:
script_name = _make_test_script(script_dir, '__main__')
self._check_script(script_dir, script_name, script_dir,
- script_dir, '')
+ script_dir, '',
+ importlib.machinery.SourceFileLoader)
def test_directory_compiled(self):
with temp_dir() as script_dir:
@@ -148,7 +187,8 @@ class CmdLineTest(unittest.TestCase):
os.remove(script_name)
pyc_file = support.make_legacy_pyc(script_name)
self._check_script(script_dir, pyc_file, script_dir,
- script_dir, '')
+ script_dir, '',
+ importlib.machinery.SourcelessFileLoader)
def test_directory_error(self):
with temp_dir() as script_dir:
@@ -159,14 +199,16 @@ class CmdLineTest(unittest.TestCase):
with temp_dir() as script_dir:
script_name = _make_test_script(script_dir, '__main__')
zip_name, run_name = make_zip_script(script_dir, 'test_zip', script_name)
- self._check_script(zip_name, run_name, zip_name, zip_name, '')
+ self._check_script(zip_name, run_name, zip_name, zip_name, '',
+ zipimport.zipimporter)
def test_zipfile_compiled(self):
with temp_dir() as script_dir:
script_name = _make_test_script(script_dir, '__main__')
compiled_name = py_compile.compile(script_name, doraise=True)
zip_name, run_name = make_zip_script(script_dir, 'test_zip', compiled_name)
- self._check_script(zip_name, run_name, zip_name, zip_name, '')
+ self._check_script(zip_name, run_name, zip_name, zip_name, '',
+ zipimport.zipimporter)
def test_zipfile_error(self):
with temp_dir() as script_dir:
@@ -181,19 +223,24 @@ class CmdLineTest(unittest.TestCase):
make_pkg(pkg_dir)
script_name = _make_test_script(pkg_dir, 'script')
launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.script')
- self._check_script(launch_name, script_name, script_name, script_dir, 'test_pkg')
+ self._check_script(launch_name, script_name, script_name,
+ script_dir, 'test_pkg',
+ importlib.machinery.SourceFileLoader)
def test_module_in_package_in_zipfile(self):
with temp_dir() as script_dir:
zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script')
launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.script', zip_name)
- self._check_script(launch_name, run_name, run_name, zip_name, 'test_pkg')
+ self._check_script(launch_name, run_name, run_name,
+ zip_name, 'test_pkg', zipimport.zipimporter)
def test_module_in_subpackage_in_zipfile(self):
with temp_dir() as script_dir:
zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script', depth=2)
launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.test_pkg.script', zip_name)
- self._check_script(launch_name, run_name, run_name, zip_name, 'test_pkg.test_pkg')
+ self._check_script(launch_name, run_name, run_name,
+ zip_name, 'test_pkg.test_pkg',
+ zipimport.zipimporter)
def test_package(self):
with temp_dir() as script_dir:
@@ -202,7 +249,8 @@ class CmdLineTest(unittest.TestCase):
script_name = _make_test_script(pkg_dir, '__main__')
launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg')
self._check_script(launch_name, script_name,
- script_name, script_dir, 'test_pkg')
+ script_name, script_dir, 'test_pkg',
+ importlib.machinery.SourceFileLoader)
def test_package_compiled(self):
with temp_dir() as script_dir:
@@ -214,7 +262,8 @@ class CmdLineTest(unittest.TestCase):
pyc_file = support.make_legacy_pyc(script_name)
launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg')
self._check_script(launch_name, pyc_file,
- pyc_file, script_dir, 'test_pkg')
+ pyc_file, script_dir, 'test_pkg',
+ importlib.machinery.SourcelessFileLoader)
def test_package_error(self):
with temp_dir() as script_dir:
@@ -251,7 +300,8 @@ class CmdLineTest(unittest.TestCase):
expected = "init_argv0==%r" % '-m'
self.assertIn(expected.encode('utf-8'), out)
self._check_output(script_name, rc, out,
- script_name, script_name, '', 'test_pkg')
+ script_name, script_name, '', 'test_pkg',
+ importlib.machinery.SourceFileLoader)
def test_issue8202_dash_c_file_ignored(self):
# Make sure a "-c" file in the current directory
@@ -277,7 +327,8 @@ class CmdLineTest(unittest.TestCase):
f.write("data")
rc, out, err = assert_python_ok('-m', 'other', *example_args)
self._check_output(script_name, rc, out,
- script_name, script_name, '', '')
+ script_name, script_name, '', '',
+ importlib.machinery.SourceFileLoader)
def test_dash_m_error_code_is_one(self):
# If a module is invoked with the -m command line flag
@@ -294,6 +345,24 @@ class CmdLineTest(unittest.TestCase):
print(out)
self.assertEqual(rc, 1)
+ def test_pep_409_verbiage(self):
+ # Make sure PEP 409 syntax properly suppresses
+ # the context of an exception
+ script = textwrap.dedent("""\
+ try:
+ raise ValueError
+ except:
+ raise NameError from None
+ """)
+ with temp_dir() as script_dir:
+ script_name = _make_test_script(script_dir, 'script', script)
+ exitcode, stdout, stderr = assert_python_failure(script_name)
+ text = stderr.decode('ascii').split('\n')
+ self.assertEqual(len(text), 4)
+ self.assertTrue(text[0].startswith('Traceback'))
+ self.assertTrue(text[1].startswith(' File '))
+ self.assertTrue(text[3].startswith('NameError'))
+
def test_non_ascii(self):
# Mac OS X denies the creation of a file with an invalid UTF-8 name.
# Windows allows to create a name with an arbitrary bytes name, but
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
index e1c7a78..3377a7b 100644
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -16,7 +16,7 @@ cellvars: ('x',)
freevars: ()
nlocals: 2
flags: 3
-consts: ('None', '<code object g>')
+consts: ('None', '<code object g>', "'f.<locals>.g'")
>>> dump(f(4).__code__)
name: g
diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py
new file mode 100644
index 0000000..adef170
--- /dev/null
+++ b/Lib/test/test_code_module.py
@@ -0,0 +1,72 @@
+"Test InteractiveConsole and InteractiveInterpreter from code module"
+import sys
+import unittest
+from contextlib import ExitStack
+from unittest import mock
+from test import support
+
+code = support.import_module('code')
+
+
+class TestInteractiveConsole(unittest.TestCase):
+
+ def setUp(self):
+ self.console = code.InteractiveConsole()
+ self.mock_sys()
+
+ def mock_sys(self):
+ "Mock system environment for InteractiveConsole"
+ # use exit stack to match patch context managers to addCleanup
+ stack = ExitStack()
+ self.addCleanup(stack.close)
+ self.infunc = stack.enter_context(mock.patch('code.input',
+ create=True))
+ self.stdout = stack.enter_context(mock.patch('code.sys.stdout'))
+ self.stderr = stack.enter_context(mock.patch('code.sys.stderr'))
+ prepatch = mock.patch('code.sys', wraps=code.sys, spec=code.sys)
+ self.sysmod = stack.enter_context(prepatch)
+ if sys.excepthook is sys.__excepthook__:
+ self.sysmod.excepthook = self.sysmod.__excepthook__
+
+ def test_ps1(self):
+ self.infunc.side_effect = EOFError('Finished')
+ self.console.interact()
+ self.assertEqual(self.sysmod.ps1, '>>> ')
+
+ def test_ps2(self):
+ self.infunc.side_effect = EOFError('Finished')
+ self.console.interact()
+ self.assertEqual(self.sysmod.ps2, '... ')
+
+ def test_console_stderr(self):
+ self.infunc.side_effect = ["'antioch'", "", EOFError('Finished')]
+ self.console.interact()
+ for call in list(self.stdout.method_calls):
+ if 'antioch' in ''.join(call[1]):
+ break
+ else:
+ raise AssertionError("no console stdout")
+
+ def test_syntax_error(self):
+ self.infunc.side_effect = ["undefined", EOFError('Finished')]
+ self.console.interact()
+ for call in self.stderr.method_calls:
+ if 'NameError:' in ''.join(call[1]):
+ break
+ else:
+ raise AssertionError("No syntax error from console")
+
+ def test_sysexcepthook(self):
+ self.infunc.side_effect = ["raise ValueError('')",
+ EOFError('Finished')]
+ hook = mock.Mock()
+ self.sysmod.excepthook = hook
+ self.console.interact()
+ self.assertTrue(hook.called)
+
+
+def test_main():
+ support.run_unittest(TestInteractiveConsole)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
index 3b78b4a..fd88505 100644
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -1,5 +1,18 @@
-import test.support, unittest
-import sys, codecs, html.entities, unicodedata
+import codecs
+import html.entities
+import sys
+import test.support
+import unicodedata
+import unittest
+import warnings
+
+try:
+ import ctypes
+except ImportError:
+ ctypes = None
+ SIZEOF_WCHAR_T = -1
+else:
+ SIZEOF_WCHAR_T = ctypes.sizeof(ctypes.c_wchar)
class PosReturn:
# this can be used for configurable callbacks
@@ -135,22 +148,14 @@ class CodecCallbackTest(unittest.TestCase):
def test_backslashescape(self):
# Does the same as the "unicode-escape" encoding, but with different
# base encodings.
- sin = "a\xac\u1234\u20ac\u8000"
- if sys.maxunicode > 0xffff:
- sin += chr(sys.maxunicode)
- sout = b"a\\xac\\u1234\\u20ac\\u8000"
- if sys.maxunicode > 0xffff:
- sout += bytes("\\U%08x" % sys.maxunicode, "ascii")
+ sin = "a\xac\u1234\u20ac\u8000\U0010ffff"
+ sout = b"a\\xac\\u1234\\u20ac\\u8000\\U0010ffff"
self.assertEqual(sin.encode("ascii", "backslashreplace"), sout)
- sout = b"a\xac\\u1234\\u20ac\\u8000"
- if sys.maxunicode > 0xffff:
- sout += bytes("\\U%08x" % sys.maxunicode, "ascii")
+ sout = b"a\xac\\u1234\\u20ac\\u8000\\U0010ffff"
self.assertEqual(sin.encode("latin-1", "backslashreplace"), sout)
- sout = b"a\xac\\u1234\xa4\\u8000"
- if sys.maxunicode > 0xffff:
- sout += bytes("\\U%08x" % sys.maxunicode, "ascii")
+ sout = b"a\xac\\u1234\xa4\\u8000\\U0010ffff"
self.assertEqual(sin.encode("iso-8859-15", "backslashreplace"), sout)
def test_decoding_callbacks(self):
@@ -200,33 +205,37 @@ class CodecCallbackTest(unittest.TestCase):
self.assertRaises(TypeError, codecs.charmap_encode, sin, "replace", charmap)
def test_decodeunicodeinternal(self):
- self.assertRaises(
- UnicodeDecodeError,
- b"\x00\x00\x00\x00\x00".decode,
- "unicode-internal",
- )
- if sys.maxunicode > 0xffff:
+ with test.support.check_warnings(('unicode_internal codec has been '
+ 'deprecated', DeprecationWarning)):
+ self.assertRaises(
+ UnicodeDecodeError,
+ b"\x00\x00\x00\x00\x00".decode,
+ "unicode-internal",
+ )
+ if SIZEOF_WCHAR_T == 4:
def handler_unicodeinternal(exc):
if not isinstance(exc, UnicodeDecodeError):
raise TypeError("don't know how to handle %r" % exc)
return ("\x01", 1)
- self.assertEqual(
- b"\x00\x00\x00\x00\x00".decode("unicode-internal", "ignore"),
- "\u0000"
- )
+ with test.support.check_warnings(('unicode_internal codec has been '
+ 'deprecated', DeprecationWarning)):
+ self.assertEqual(
+ b"\x00\x00\x00\x00\x00".decode("unicode-internal", "ignore"),
+ "\u0000"
+ )
- self.assertEqual(
- b"\x00\x00\x00\x00\x00".decode("unicode-internal", "replace"),
- "\u0000\ufffd"
- )
+ self.assertEqual(
+ b"\x00\x00\x00\x00\x00".decode("unicode-internal", "replace"),
+ "\u0000\ufffd"
+ )
- codecs.register_error("test.hui", handler_unicodeinternal)
+ codecs.register_error("test.hui", handler_unicodeinternal)
- self.assertEqual(
- b"\x00\x00\x00\x00\x00".decode("unicode-internal", "test.hui"),
- "\u0000\u0001\u0000"
- )
+ self.assertEqual(
+ b"\x00\x00\x00\x00\x00".decode("unicode-internal", "test.hui"),
+ "\u0000\u0001\u0000"
+ )
def test_callbacks(self):
def handler1(exc):
@@ -355,7 +364,7 @@ class CodecCallbackTest(unittest.TestCase):
["ascii", "\uffffx", 0, 1, "ouch"],
"'ascii' codec can't encode character '\\uffff' in position 0: ouch"
)
- if sys.maxunicode > 0xffff:
+ if SIZEOF_WCHAR_T == 4:
self.check_exceptionobjectargs(
UnicodeEncodeError,
["ascii", "\U00010000x", 0, 1, "ouch"],
@@ -390,7 +399,7 @@ class CodecCallbackTest(unittest.TestCase):
["g\uffffrk", 1, 2, "ouch"],
"can't translate character '\\uffff' in position 1: ouch"
)
- if sys.maxunicode > 0xffff:
+ if SIZEOF_WCHAR_T == 4:
self.check_exceptionobjectargs(
UnicodeTranslateError,
["g\U00010000rk", 1, 2, "ouch"],
@@ -577,31 +586,30 @@ class CodecCallbackTest(unittest.TestCase):
UnicodeEncodeError("ascii", "\uffff", 0, 1, "ouch")),
("\\uffff", 1)
)
- # 1 on UCS-4 builds, 2 on UCS-2
- len_wide = len("\U00010000")
- self.assertEqual(
- codecs.backslashreplace_errors(
- UnicodeEncodeError("ascii", "\U00010000",
- 0, len_wide, "ouch")),
- ("\\U00010000", len_wide)
- )
- self.assertEqual(
- codecs.backslashreplace_errors(
- UnicodeEncodeError("ascii", "\U0010ffff",
- 0, len_wide, "ouch")),
- ("\\U0010ffff", len_wide)
- )
- # Lone surrogates (regardless of unicode width)
- self.assertEqual(
- codecs.backslashreplace_errors(
- UnicodeEncodeError("ascii", "\ud800", 0, 1, "ouch")),
- ("\\ud800", 1)
- )
- self.assertEqual(
- codecs.backslashreplace_errors(
- UnicodeEncodeError("ascii", "\udfff", 0, 1, "ouch")),
- ("\\udfff", 1)
- )
+ if SIZEOF_WCHAR_T > 0:
+ self.assertEqual(
+ codecs.backslashreplace_errors(
+ UnicodeEncodeError("ascii", "\U00010000",
+ 0, 1, "ouch")),
+ ("\\U00010000", 1)
+ )
+ self.assertEqual(
+ codecs.backslashreplace_errors(
+ UnicodeEncodeError("ascii", "\U0010ffff",
+ 0, 1, "ouch")),
+ ("\\U0010ffff", 1)
+ )
+ # Lone surrogates (regardless of unicode width)
+ self.assertEqual(
+ codecs.backslashreplace_errors(
+ UnicodeEncodeError("ascii", "\ud800", 0, 1, "ouch")),
+ ("\\ud800", 1)
+ )
+ self.assertEqual(
+ codecs.backslashreplace_errors(
+ UnicodeEncodeError("ascii", "\udfff", 0, 1, "ouch")),
+ ("\\udfff", 1)
+ )
def test_badhandlerresults(self):
results = ( 42, "foo", (1,2,3), ("foo", 1, 3), ("foo", None), ("foo",), ("foo", 1, 3), ("foo", None), ("foo",) )
@@ -622,12 +630,14 @@ class CodecCallbackTest(unittest.TestCase):
("utf-7", b"+x-"),
("unicode-internal", b"\x00"),
):
- self.assertRaises(
- TypeError,
- bytes.decode,
- enc,
- "test.badhandler"
- )
+ with test.support.check_warnings():
+ # unicode-internal has been deprecated
+ self.assertRaises(
+ TypeError,
+ bytes.decode,
+ enc,
+ "test.badhandler"
+ )
def test_lookup(self):
self.assertEqual(codecs.strict_errors, codecs.lookup_error("strict"))
@@ -679,7 +689,7 @@ class CodecCallbackTest(unittest.TestCase):
# Python/codecs.c::PyCodec_XMLCharRefReplaceErrors()
# and inline implementations
v = (1, 5, 10, 50, 100, 500, 1000, 5000, 10000, 50000)
- if sys.maxunicode>=100000:
+ if SIZEOF_WCHAR_T == 4:
v += (100000, 500000, 1000000)
s = "".join([chr(x) for x in v])
codecs.register_error("test.xmlcharrefreplace", codecs.xmlcharrefreplace_errors)
@@ -744,7 +754,7 @@ class CodecCallbackTest(unittest.TestCase):
raise ValueError
self.assertRaises(UnicodeError, codecs.charmap_decode, b"\xff", "strict", {0xff: None})
self.assertRaises(ValueError, codecs.charmap_decode, b"\xff", "strict", D())
- self.assertRaises(TypeError, codecs.charmap_decode, b"\xff", "strict", {0xff: 0x110000})
+ self.assertRaises(TypeError, codecs.charmap_decode, b"\xff", "strict", {0xff: sys.maxunicode+1})
def test_encodehelper(self):
# enhance coverage of:
@@ -843,8 +853,12 @@ class CodecCallbackTest(unittest.TestCase):
else:
raise TypeError("don't know how to handle %r" % exc)
codecs.register_error("test.replacing", replacing)
- for (encoding, data) in baddata:
- self.assertRaises(TypeError, data.decode, encoding, "test.replacing")
+
+ with test.support.check_warnings():
+ # unicode-internal has been deprecated
+ for (encoding, data) in baddata:
+ with self.assertRaises(TypeError):
+ data.decode(encoding, "test.replacing")
def mutating(exc):
if isinstance(exc, UnicodeDecodeError):
@@ -855,8 +869,11 @@ class CodecCallbackTest(unittest.TestCase):
codecs.register_error("test.mutating", mutating)
# If the decoder doesn't pick up the modified input the following
# will lead to an endless loop
- for (encoding, data) in baddata:
- self.assertRaises(TypeError, data.decode, encoding, "test.replacing")
+ with test.support.check_warnings():
+ # unicode-internal has been deprecated
+ for (encoding, data) in baddata:
+ with self.assertRaises(TypeError):
+ data.decode(encoding, "test.replacing")
def test_main():
test.support.run_unittest(CodecCallbackTest)
diff --git a/Lib/test/test_codecencodings_cn.py b/Lib/test/test_codecencodings_cn.py
index dca9f10..b08c5fc 100644
--- a/Lib/test/test_codecencodings_cn.py
+++ b/Lib/test/test_codecencodings_cn.py
@@ -5,54 +5,57 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class Test_GB2312(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_GB2312(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'gb2312'
- tstring = test_multibytecodec_support.load_teststring('gb2312')
+ tstring = multibytecodec_support.load_teststring('gb2312')
codectests = (
# invalid bytes
(b"abc\x81\x81\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x81\x81\xc1\xc4", "replace", "abc\ufffd\u804a"),
- (b"abc\x81\x81\xc1\xc4\xc8", "replace", "abc\ufffd\u804a\ufffd"),
+ (b"abc\x81\x81\xc1\xc4", "replace", "abc\ufffd\ufffd\u804a"),
+ (b"abc\x81\x81\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\u804a\ufffd"),
(b"abc\x81\x81\xc1\xc4", "ignore", "abc\u804a"),
(b"\xc1\x64", "strict", None),
)
-class Test_GBK(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_GBK(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'gbk'
- tstring = test_multibytecodec_support.load_teststring('gbk')
+ tstring = multibytecodec_support.load_teststring('gbk')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\u804a"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\u804a\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\u804a"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\u804a\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\u804a"),
(b"\x83\x34\x83\x31", "strict", None),
("\u30fb", "strict", None),
)
-class Test_GB18030(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_GB18030(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'gb18030'
- tstring = test_multibytecodec_support.load_teststring('gb18030')
+ tstring = multibytecodec_support.load_teststring('gb18030')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\u804a"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\u804a\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\u804a"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\u804a\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\u804a"),
- (b"abc\x84\x39\x84\x39\xc1\xc4", "replace", "abc\ufffd\u804a"),
+ (b"abc\x84\x39\x84\x39\xc1\xc4", "replace", "abc\ufffd9\ufffd9\u804a"),
("\u30fb", "strict", b"\x819\xa79"),
+ (b"abc\x84\x32\x80\x80def", "replace", 'abc\ufffd2\ufffd\ufffddef'),
+ (b"abc\x81\x30\x81\x30def", "strict", 'abc\x80def'),
+ (b"abc\x86\x30\x81\x30def", "replace", 'abc\ufffd0\ufffd0def'),
)
has_iso10646 = True
-class Test_HZ(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_HZ(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'hz'
- tstring = test_multibytecodec_support.load_teststring('hz')
+ tstring = multibytecodec_support.load_teststring('hz')
codectests = (
# test '~\n' (3 lines)
(b'This sentence is in ASCII.\n'
@@ -74,9 +77,11 @@ class Test_HZ(test_multibytecodec_support.TestBase, unittest.TestCase):
'\u5df1\u6240\u4e0d\u6b32\uff0c\u52ff\u65bd\u65bc\u4eba\u3002'
'Bye.\n'),
# invalid bytes
- (b'ab~cd', 'replace', 'ab\uFFFDd'),
+ (b'ab~cd', 'replace', 'ab\uFFFDcd'),
(b'ab\xffcd', 'replace', 'ab\uFFFDcd'),
(b'ab~{\x81\x81\x41\x44~}cd', 'replace', 'ab\uFFFD\uFFFD\u804Acd'),
+ (b'ab~{\x41\x44~}cd', 'replace', 'ab\u804Acd'),
+ (b"ab~{\x79\x79\x41\x44~}cd", "replace", "ab\ufffd\ufffd\u804acd"),
)
def test_main():
diff --git a/Lib/test/test_codecencodings_hk.py b/Lib/test/test_codecencodings_hk.py
index ccdc0b4..31363f4 100644
--- a/Lib/test/test_codecencodings_hk.py
+++ b/Lib/test/test_codecencodings_hk.py
@@ -5,18 +5,18 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class Test_Big5HKSCS(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_Big5HKSCS(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'big5hkscs'
- tstring = test_multibytecodec_support.load_teststring('big5hkscs')
+ tstring = multibytecodec_support.load_teststring('big5hkscs')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\u8b10"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\u8b10\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\u8b10"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\u8b10\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\u8b10"),
)
diff --git a/Lib/test/test_codecencodings_iso2022.py b/Lib/test/test_codecencodings_iso2022.py
index 8c6e8a5..e4c1839 100644
--- a/Lib/test/test_codecencodings_iso2022.py
+++ b/Lib/test/test_codecencodings_iso2022.py
@@ -3,7 +3,7 @@
# Codec encoding tests for ISO 2022 encodings.
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
COMMON_CODEC_TESTS = (
@@ -13,23 +13,23 @@ COMMON_CODEC_TESTS = (
(b'ab\x1B$def', 'replace', 'ab\uFFFD'),
)
-class Test_ISO2022_JP(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_ISO2022_JP(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'iso2022_jp'
- tstring = test_multibytecodec_support.load_teststring('iso2022_jp')
+ tstring = multibytecodec_support.load_teststring('iso2022_jp')
codectests = COMMON_CODEC_TESTS + (
(b'ab\x1BNdef', 'replace', 'ab\x1BNdef'),
)
-class Test_ISO2022_JP2(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_ISO2022_JP2(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'iso2022_jp_2'
- tstring = test_multibytecodec_support.load_teststring('iso2022_jp')
+ tstring = multibytecodec_support.load_teststring('iso2022_jp')
codectests = COMMON_CODEC_TESTS + (
(b'ab\x1BNdef', 'replace', 'abdef'),
)
-class Test_ISO2022_KR(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_ISO2022_KR(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'iso2022_kr'
- tstring = test_multibytecodec_support.load_teststring('iso2022_kr')
+ tstring = multibytecodec_support.load_teststring('iso2022_kr')
codectests = COMMON_CODEC_TESTS + (
(b'ab\x1BNdef', 'replace', 'ab\x1BNdef'),
)
diff --git a/Lib/test/test_codecencodings_jp.py b/Lib/test/test_codecencodings_jp.py
index f56a373..30c9e19 100644
--- a/Lib/test/test_codecencodings_jp.py
+++ b/Lib/test/test_codecencodings_jp.py
@@ -5,60 +5,67 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class Test_CP932(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_CP932(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'cp932'
- tstring = test_multibytecodec_support.load_teststring('shift_jis')
+ tstring = multibytecodec_support.load_teststring('shift_jis')
codectests = (
# invalid bytes
(b"abc\x81\x00\x81\x00\x82\x84", "strict", None),
(b"abc\xf8", "strict", None),
- (b"abc\x81\x00\x82\x84", "replace", "abc\ufffd\uff44"),
- (b"abc\x81\x00\x82\x84\x88", "replace", "abc\ufffd\uff44\ufffd"),
- (b"abc\x81\x00\x82\x84", "ignore", "abc\uff44"),
+ (b"abc\x81\x00\x82\x84", "replace", "abc\ufffd\x00\uff44"),
+ (b"abc\x81\x00\x82\x84\x88", "replace", "abc\ufffd\x00\uff44\ufffd"),
+ (b"abc\x81\x00\x82\x84", "ignore", "abc\x00\uff44"),
+ (b"ab\xEBxy", "replace", "ab\uFFFDxy"),
+ (b"ab\xF0\x39xy", "replace", "ab\uFFFD9xy"),
+ (b"ab\xEA\xF0xy", "replace", 'ab\ufffd\ue038y'),
# sjis vs cp932
(b"\\\x7e", "replace", "\\\x7e"),
(b"\x81\x5f\x81\x61\x81\x7c", "replace", "\uff3c\u2225\uff0d"),
)
-class Test_EUC_JISX0213(test_multibytecodec_support.TestBase,
+euc_commontests = (
+ # invalid bytes
+ (b"abc\x80\x80\xc1\xc4", "strict", None),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\u7956"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\u7956\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "ignore", "abc\u7956"),
+ (b"abc\xc8", "strict", None),
+ (b"abc\x8f\x83\x83", "replace", "abc\ufffd\ufffd\ufffd"),
+ (b"\x82\xFCxy", "replace", "\ufffd\ufffdxy"),
+ (b"\xc1\x64", "strict", None),
+ (b"\xa1\xc0", "strict", "\uff3c"),
+ (b"\xa1\xc0\\", "strict", "\uff3c\\"),
+ (b"\x8eXY", "replace", "\ufffdXY"),
+)
+
+class Test_EUC_JIS_2004(multibytecodec_support.TestBase,
unittest.TestCase):
- encoding = 'euc_jisx0213'
- tstring = test_multibytecodec_support.load_teststring('euc_jisx0213')
- codectests = (
- # invalid bytes
- (b"abc\x80\x80\xc1\xc4", "strict", None),
- (b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\u7956"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\u7956\ufffd"),
- (b"abc\x80\x80\xc1\xc4", "ignore", "abc\u7956"),
- (b"abc\x8f\x83\x83", "replace", "abc\ufffd"),
- (b"\xc1\x64", "strict", None),
- (b"\xa1\xc0", "strict", "\uff3c"),
- )
+ encoding = 'euc_jis_2004'
+ tstring = multibytecodec_support.load_teststring('euc_jisx0213')
+ codectests = euc_commontests
xmlcharnametest = (
"\xab\u211c\xbb = \u2329\u1234\u232a",
b"\xa9\xa8&real;\xa9\xb2 = &lang;&#4660;&rang;"
)
-eucjp_commontests = (
- (b"abc\x80\x80\xc1\xc4", "strict", None),
- (b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\u7956"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\u7956\ufffd"),
- (b"abc\x80\x80\xc1\xc4", "ignore", "abc\u7956"),
- (b"abc\x8f\x83\x83", "replace", "abc\ufffd"),
- (b"\xc1\x64", "strict", None),
-)
+class Test_EUC_JISX0213(multibytecodec_support.TestBase,
+ unittest.TestCase):
+ encoding = 'euc_jisx0213'
+ tstring = multibytecodec_support.load_teststring('euc_jisx0213')
+ codectests = euc_commontests
+ xmlcharnametest = (
+ "\xab\u211c\xbb = \u2329\u1234\u232a",
+ b"\xa9\xa8&real;\xa9\xb2 = &lang;&#4660;&rang;"
+ )
-class Test_EUC_JP_COMPAT(test_multibytecodec_support.TestBase,
+class Test_EUC_JP_COMPAT(multibytecodec_support.TestBase,
unittest.TestCase):
encoding = 'euc_jp'
- tstring = test_multibytecodec_support.load_teststring('euc_jp')
- codectests = eucjp_commontests + (
- (b"\xa1\xc0\\", "strict", "\uff3c\\"),
+ tstring = multibytecodec_support.load_teststring('euc_jp')
+ codectests = euc_commontests + (
("\xa5", "strict", b"\x5c"),
("\u203e", "strict", b"\x7e"),
)
@@ -66,29 +73,48 @@ class Test_EUC_JP_COMPAT(test_multibytecodec_support.TestBase,
shiftjis_commonenctests = (
(b"abc\x80\x80\x82\x84", "strict", None),
(b"abc\xf8", "strict", None),
- (b"abc\x80\x80\x82\x84", "replace", "abc\ufffd\uff44"),
- (b"abc\x80\x80\x82\x84\x88", "replace", "abc\ufffd\uff44\ufffd"),
(b"abc\x80\x80\x82\x84def", "ignore", "abc\uff44def"),
)
-class Test_SJIS_COMPAT(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_SJIS_COMPAT(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'shift_jis'
- tstring = test_multibytecodec_support.load_teststring('shift_jis')
+ tstring = multibytecodec_support.load_teststring('shift_jis')
codectests = shiftjis_commonenctests + (
+ (b"abc\x80\x80\x82\x84", "replace", "abc\ufffd\ufffd\uff44"),
+ (b"abc\x80\x80\x82\x84\x88", "replace", "abc\ufffd\ufffd\uff44\ufffd"),
+
(b"\\\x7e", "strict", "\\\x7e"),
(b"\x81\x5f\x81\x61\x81\x7c", "strict", "\uff3c\u2016\u2212"),
+ (b"abc\x81\x39", "replace", "abc\ufffd9"),
+ (b"abc\xEA\xFC", "replace", "abc\ufffd\ufffd"),
+ (b"abc\xFF\x58", "replace", "abc\ufffdX"),
)
-class Test_SJISX0213(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_SJIS_2004(multibytecodec_support.TestBase, unittest.TestCase):
+ encoding = 'shift_jis_2004'
+ tstring = multibytecodec_support.load_teststring('shift_jis')
+ codectests = shiftjis_commonenctests + (
+ (b"\\\x7e", "strict", "\xa5\u203e"),
+ (b"\x81\x5f\x81\x61\x81\x7c", "strict", "\\\u2016\u2212"),
+ (b"abc\xEA\xFC", "strict", "abc\u64bf"),
+ (b"\x81\x39xy", "replace", "\ufffd9xy"),
+ (b"\xFF\x58xy", "replace", "\ufffdXxy"),
+ (b"\x80\x80\x82\x84xy", "replace", "\ufffd\ufffd\uff44xy"),
+ (b"\x80\x80\x82\x84\x88xy", "replace", "\ufffd\ufffd\uff44\u5864y"),
+ (b"\xFC\xFBxy", "replace", '\ufffd\u95b4y'),
+ )
+ xmlcharnametest = (
+ "\xab\u211c\xbb = \u2329\u1234\u232a",
+ b"\x85G&real;\x85Q = &lang;&#4660;&rang;"
+ )
+
+class Test_SJISX0213(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'shift_jisx0213'
- tstring = test_multibytecodec_support.load_teststring('shift_jisx0213')
- codectests = (
- # invalid bytes
- (b"abc\x80\x80\x82\x84", "strict", None),
- (b"abc\xf8", "strict", None),
- (b"abc\x80\x80\x82\x84", "replace", "abc\ufffd\uff44"),
- (b"abc\x80\x80\x82\x84\x88", "replace", "abc\ufffd\uff44\ufffd"),
- (b"abc\x80\x80\x82\x84def", "ignore", "abc\uff44def"),
+ tstring = multibytecodec_support.load_teststring('shift_jisx0213')
+ codectests = shiftjis_commonenctests + (
+ (b"abc\x80\x80\x82\x84", "replace", "abc\ufffd\ufffd\uff44"),
+ (b"abc\x80\x80\x82\x84\x88", "replace", "abc\ufffd\ufffd\uff44\ufffd"),
+
# sjis vs cp932
(b"\\\x7e", "replace", "\xa5\u203e"),
(b"\x81\x5f\x81\x61\x81\x7c", "replace", "\x5c\u2016\u2212"),
diff --git a/Lib/test/test_codecencodings_kr.py b/Lib/test/test_codecencodings_kr.py
index de4da7f..4dd6049 100644
--- a/Lib/test/test_codecencodings_kr.py
+++ b/Lib/test/test_codecencodings_kr.py
@@ -5,30 +5,30 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class Test_CP949(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_CP949(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'cp949'
- tstring = test_multibytecodec_support.load_teststring('cp949')
+ tstring = multibytecodec_support.load_teststring('cp949')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\uc894"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\uc894\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\uc894"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\uc894\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\uc894"),
)
-class Test_EUCKR(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_EUCKR(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'euc_kr'
- tstring = test_multibytecodec_support.load_teststring('euc_kr')
+ tstring = multibytecodec_support.load_teststring('euc_kr')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\uc894"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\uc894\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", 'abc\ufffd\ufffd\uc894'),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\uc894\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\uc894"),
# composed make-up sequence errors
@@ -40,26 +40,31 @@ class Test_EUCKR(test_multibytecodec_support.TestBase, unittest.TestCase):
(b"\xa4\xd4\xa4\xb6\xa4\xd0\xa4", "strict", None),
(b"\xa4\xd4\xa4\xb6\xa4\xd0\xa4\xd4", "strict", "\uc4d4"),
(b"\xa4\xd4\xa4\xb6\xa4\xd0\xa4\xd4x", "strict", "\uc4d4x"),
- (b"a\xa4\xd4\xa4\xb6\xa4", "replace", "a\ufffd"),
+ (b"a\xa4\xd4\xa4\xb6\xa4", "replace", 'a\ufffd'),
(b"\xa4\xd4\xa3\xb6\xa4\xd0\xa4\xd4", "strict", None),
(b"\xa4\xd4\xa4\xb6\xa3\xd0\xa4\xd4", "strict", None),
(b"\xa4\xd4\xa4\xb6\xa4\xd0\xa3\xd4", "strict", None),
- (b"\xa4\xd4\xa4\xff\xa4\xd0\xa4\xd4", "replace", "\ufffd"),
- (b"\xa4\xd4\xa4\xb6\xa4\xff\xa4\xd4", "replace", "\ufffd"),
- (b"\xa4\xd4\xa4\xb6\xa4\xd0\xa4\xff", "replace", "\ufffd"),
+ (b"\xa4\xd4\xa4\xff\xa4\xd0\xa4\xd4", "replace", '\ufffd\u6e21\ufffd\u3160\ufffd'),
+ (b"\xa4\xd4\xa4\xb6\xa4\xff\xa4\xd4", "replace", '\ufffd\u6e21\ub544\ufffd\ufffd'),
+ (b"\xa4\xd4\xa4\xb6\xa4\xd0\xa4\xff", "replace", '\ufffd\u6e21\ub544\u572d\ufffd'),
+ (b"\xa4\xd4\xff\xa4\xd4\xa4\xb6\xa4\xd0\xa4\xd4", "replace", '\ufffd\ufffd\ufffd\uc4d4'),
(b"\xc1\xc4", "strict", "\uc894"),
)
-class Test_JOHAB(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_JOHAB(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'johab'
- tstring = test_multibytecodec_support.load_teststring('johab')
+ tstring = multibytecodec_support.load_teststring('johab')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ucd27"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ucd27\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\ucd27"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\ucd27\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\ucd27"),
+ (b"\xD8abc", "replace", "\uFFFDabc"),
+ (b"\xD8\xFFabc", "replace", "\uFFFD\uFFFDabc"),
+ (b"\x84bxy", "replace", "\uFFFDbxy"),
+ (b"\x8CBxy", "replace", "\uFFFDBxy"),
)
def test_main():
diff --git a/Lib/test/test_codecencodings_tw.py b/Lib/test/test_codecencodings_tw.py
index 12d3c9f..96245b7 100644
--- a/Lib/test/test_codecencodings_tw.py
+++ b/Lib/test/test_codecencodings_tw.py
@@ -5,18 +5,18 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class Test_Big5(test_multibytecodec_support.TestBase, unittest.TestCase):
+class Test_Big5(multibytecodec_support.TestBase, unittest.TestCase):
encoding = 'big5'
- tstring = test_multibytecodec_support.load_teststring('big5')
+ tstring = multibytecodec_support.load_teststring('big5')
codectests = (
# invalid bytes
(b"abc\x80\x80\xc1\xc4", "strict", None),
(b"abc\xc8", "strict", None),
- (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\u8b10"),
- (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\u8b10\ufffd"),
+ (b"abc\x80\x80\xc1\xc4", "replace", "abc\ufffd\ufffd\u8b10"),
+ (b"abc\x80\x80\xc1\xc4\xc8", "replace", "abc\ufffd\ufffd\u8b10\ufffd"),
(b"abc\x80\x80\xc1\xc4", "ignore", "abc\u8b10"),
)
diff --git a/Lib/test/test_codecmaps_cn.py b/Lib/test/test_codecmaps_cn.py
index 063919d..1a761cf 100644
--- a/Lib/test/test_codecmaps_cn.py
+++ b/Lib/test/test_codecmaps_cn.py
@@ -5,21 +5,21 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class TestGB2312Map(test_multibytecodec_support.TestBase_Mapping,
+class TestGB2312Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'gb2312'
mapfileurl = 'http://people.freebsd.org/~perky/i18n/EUC-CN.TXT'
-class TestGBKMap(test_multibytecodec_support.TestBase_Mapping,
+class TestGBKMap(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'gbk'
mapfileurl = 'http://www.unicode.org/Public/MAPPINGS/VENDORS/' \
'MICSFT/WINDOWS/CP936.TXT'
-class TestGB18030Map(test_multibytecodec_support.TestBase_Mapping,
+class TestGB18030Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'gb18030'
mapfileurl = 'http://source.icu-project.org/repos/icu/data/' \
diff --git a/Lib/test/test_codecmaps_hk.py b/Lib/test/test_codecmaps_hk.py
index bbe1f2f..5f4e7c7 100644
--- a/Lib/test/test_codecmaps_hk.py
+++ b/Lib/test/test_codecmaps_hk.py
@@ -5,10 +5,10 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class TestBig5HKSCSMap(test_multibytecodec_support.TestBase_Mapping,
+class TestBig5HKSCSMap(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'big5hkscs'
mapfileurl = 'http://people.freebsd.org/~perky/i18n/BIG5HKSCS-2004.TXT'
diff --git a/Lib/test/test_codecmaps_jp.py b/Lib/test/test_codecmaps_jp.py
index 652bd81..1fdbf63 100644
--- a/Lib/test/test_codecmaps_jp.py
+++ b/Lib/test/test_codecmaps_jp.py
@@ -5,10 +5,10 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class TestCP932Map(test_multibytecodec_support.TestBase_Mapping,
+class TestCP932Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'cp932'
mapfileurl = 'http://www.unicode.org/Public/MAPPINGS/VENDORS/MICSFT/' \
@@ -24,14 +24,14 @@ class TestCP932Map(test_multibytecodec_support.TestBase_Mapping,
supmaps.append((bytes([i]), chr(i+0xfec0)))
-class TestEUCJPCOMPATMap(test_multibytecodec_support.TestBase_Mapping,
+class TestEUCJPCOMPATMap(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'euc_jp'
mapfilename = 'EUC-JP.TXT'
mapfileurl = 'http://people.freebsd.org/~perky/i18n/EUC-JP.TXT'
-class TestSJISCOMPATMap(test_multibytecodec_support.TestBase_Mapping,
+class TestSJISCOMPATMap(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'shift_jis'
mapfilename = 'SHIFTJIS.TXT'
@@ -46,14 +46,14 @@ class TestSJISCOMPATMap(test_multibytecodec_support.TestBase_Mapping,
(b'\x81_', '\\'),
]
-class TestEUCJISX0213Map(test_multibytecodec_support.TestBase_Mapping,
+class TestEUCJISX0213Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'euc_jisx0213'
mapfilename = 'EUC-JISX0213.TXT'
mapfileurl = 'http://people.freebsd.org/~perky/i18n/EUC-JISX0213.TXT'
-class TestSJISX0213Map(test_multibytecodec_support.TestBase_Mapping,
+class TestSJISX0213Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'shift_jisx0213'
mapfilename = 'SHIFT_JISX0213.TXT'
diff --git a/Lib/test/test_codecmaps_kr.py b/Lib/test/test_codecmaps_kr.py
index d909c8b..0356402 100644
--- a/Lib/test/test_codecmaps_kr.py
+++ b/Lib/test/test_codecmaps_kr.py
@@ -5,17 +5,17 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class TestCP949Map(test_multibytecodec_support.TestBase_Mapping,
+class TestCP949Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'cp949'
mapfileurl = 'http://www.unicode.org/Public/MAPPINGS/VENDORS/MICSFT' \
'/WINDOWS/CP949.TXT'
-class TestEUCKRMap(test_multibytecodec_support.TestBase_Mapping,
+class TestEUCKRMap(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'euc_kr'
mapfileurl = 'http://people.freebsd.org/~perky/i18n/EUC-KR.TXT'
@@ -25,7 +25,7 @@ class TestEUCKRMap(test_multibytecodec_support.TestBase_Mapping,
pass_dectest = [(b'\xa4\xd4', '\u3164')]
-class TestJOHABMap(test_multibytecodec_support.TestBase_Mapping,
+class TestJOHABMap(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'johab'
mapfileurl = 'http://www.unicode.org/Public/MAPPINGS/OBSOLETE/EASTASIA/' \
diff --git a/Lib/test/test_codecmaps_tw.py b/Lib/test/test_codecmaps_tw.py
index 6db5091..44467e3 100644
--- a/Lib/test/test_codecmaps_tw.py
+++ b/Lib/test/test_codecmaps_tw.py
@@ -5,16 +5,16 @@
#
from test import support
-from test import test_multibytecodec_support
+from test import multibytecodec_support
import unittest
-class TestBIG5Map(test_multibytecodec_support.TestBase_Mapping,
+class TestBIG5Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'big5'
mapfileurl = 'http://www.unicode.org/Public/MAPPINGS/OBSOLETE/' \
'EASTASIA/OTHER/BIG5.TXT'
-class TestCP950Map(test_multibytecodec_support.TestBase_Mapping,
+class TestCP950Map(multibytecodec_support.TestBase_Mapping,
unittest.TestCase):
encoding = 'cp950'
mapfileurl = 'http://www.unicode.org/Public/MAPPINGS/VENDORS/MICSFT/' \
@@ -23,6 +23,9 @@ class TestCP950Map(test_multibytecodec_support.TestBase_Mapping,
(b'\xa2\xcc', '\u5341'),
(b'\xa2\xce', '\u5345'),
]
+ codectests = (
+ (b"\xFFxy", "replace", "\ufffdxy"),
+ )
def test_main():
support.run_unittest(__name__)
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py
index 43886fc..c68088e 100644
--- a/Lib/test/test_codecs.py
+++ b/Lib/test/test_codecs.py
@@ -1,8 +1,25 @@
-from test import support
-import unittest
+import _testcapi
import codecs
+import io
import locale
-import sys, _testcapi, io
+import sys
+import unittest
+import warnings
+
+from test import support
+
+if sys.platform == 'win32':
+ VISTA_OR_LATER = (sys.getwindowsversion().major >= 6)
+else:
+ VISTA_OR_LATER = False
+
+try:
+ import ctypes
+except ImportError:
+ ctypes = None
+ SIZEOF_WCHAR_T = -1
+else:
+ SIZEOF_WCHAR_T = ctypes.sizeof(ctypes.c_wchar)
def coding_checker(self, coder):
def check(input, expect):
@@ -62,7 +79,7 @@ class MixInCheckStateHandling:
part2 = d.encode(u[i:], True)
self.assertEqual(s, part1+part2)
-class ReadTest(unittest.TestCase, MixInCheckStateHandling):
+class ReadTest(MixInCheckStateHandling):
def check_partial(self, input, partialresults):
# get a StreamReader for the encoding and feed the bytestring version
# of input to the reader byte by byte. Read everything available from
@@ -282,7 +299,7 @@ class ReadTest(unittest.TestCase, MixInCheckStateHandling):
self.assertEqual(reader.readline(), s5)
self.assertEqual(reader.readline(), "")
-class UTF32Test(ReadTest):
+class UTF32Test(ReadTest, unittest.TestCase):
encoding = "utf-32"
spamle = (b'\xff\xfe\x00\x00'
@@ -373,7 +390,7 @@ class UTF32Test(ReadTest):
self.assertEqual('\U00010000' * 1024,
codecs.utf_32_decode(encoded_be)[0])
-class UTF32LETest(ReadTest):
+class UTF32LETest(ReadTest, unittest.TestCase):
encoding = "utf-32-le"
def test_partial(self):
@@ -417,7 +434,7 @@ class UTF32LETest(ReadTest):
self.assertEqual('\U00010000' * 1024,
codecs.utf_32_le_decode(encoded)[0])
-class UTF32BETest(ReadTest):
+class UTF32BETest(ReadTest, unittest.TestCase):
encoding = "utf-32-be"
def test_partial(self):
@@ -462,7 +479,7 @@ class UTF32BETest(ReadTest):
codecs.utf_32_be_decode(encoded)[0])
-class UTF16Test(ReadTest):
+class UTF16Test(ReadTest, unittest.TestCase):
encoding = "utf-16"
spamle = b'\xff\xfes\x00p\x00a\x00m\x00s\x00p\x00a\x00m\x00'
@@ -542,7 +559,7 @@ class UTF16Test(ReadTest):
with codecs.open(support.TESTFN, 'U', encoding=self.encoding) as reader:
self.assertEqual(reader.read(), s1)
-class UTF16LETest(ReadTest):
+class UTF16LETest(ReadTest, unittest.TestCase):
encoding = "utf-16-le"
def test_partial(self):
@@ -585,7 +602,7 @@ class UTF16LETest(ReadTest):
self.assertEqual(b'\x00\xd8\x03\xde'.decode(self.encoding),
"\U00010203")
-class UTF16BETest(ReadTest):
+class UTF16BETest(ReadTest, unittest.TestCase):
encoding = "utf-16-be"
def test_partial(self):
@@ -628,7 +645,7 @@ class UTF16BETest(ReadTest):
self.assertEqual(b'\xd8\x00\xde\x03'.decode(self.encoding),
"\U00010203")
-class UTF8Test(ReadTest):
+class UTF8Test(ReadTest, unittest.TestCase):
encoding = "utf-8"
def test_partial(self):
@@ -677,13 +694,118 @@ class UTF8Test(ReadTest):
b"abc\xed\xa0\x80def")
self.assertEqual(b"abc\xed\xa0\x80def".decode("utf-8", "surrogatepass"),
"abc\ud800def")
+ self.assertEqual("\U00010fff\uD800".encode("utf-8", "surrogatepass"),
+ b"\xf0\x90\xbf\xbf\xed\xa0\x80")
+ self.assertEqual(b"\xf0\x90\xbf\xbf\xed\xa0\x80".decode("utf-8", "surrogatepass"),
+ "\U00010fff\uD800")
self.assertTrue(codecs.lookup_error("surrogatepass"))
with self.assertRaises(UnicodeDecodeError):
b"abc\xed\xa0".decode("utf-8", "surrogatepass")
with self.assertRaises(UnicodeDecodeError):
b"abc\xed\xa0z".decode("utf-8", "surrogatepass")
-class UTF7Test(ReadTest):
+@unittest.skipUnless(sys.platform == 'win32',
+ 'cp65001 is a Windows-only codec')
+class CP65001Test(ReadTest, unittest.TestCase):
+ encoding = "cp65001"
+
+ def test_encode(self):
+ tests = [
+ ('abc', 'strict', b'abc'),
+ ('\xe9\u20ac', 'strict', b'\xc3\xa9\xe2\x82\xac'),
+ ('\U0010ffff', 'strict', b'\xf4\x8f\xbf\xbf'),
+ ]
+ if VISTA_OR_LATER:
+ tests.extend((
+ ('\udc80', 'strict', None),
+ ('\udc80', 'ignore', b''),
+ ('\udc80', 'replace', b'?'),
+ ('\udc80', 'backslashreplace', b'\\udc80'),
+ ('\udc80', 'surrogatepass', b'\xed\xb2\x80'),
+ ))
+ else:
+ tests.append(('\udc80', 'strict', b'\xed\xb2\x80'))
+ for text, errors, expected in tests:
+ if expected is not None:
+ try:
+ encoded = text.encode('cp65001', errors)
+ except UnicodeEncodeError as err:
+ self.fail('Unable to encode %a to cp65001 with '
+ 'errors=%r: %s' % (text, errors, err))
+ self.assertEqual(encoded, expected,
+ '%a.encode("cp65001", %r)=%a != %a'
+ % (text, errors, encoded, expected))
+ else:
+ self.assertRaises(UnicodeEncodeError,
+ text.encode, "cp65001", errors)
+
+ def test_decode(self):
+ tests = [
+ (b'abc', 'strict', 'abc'),
+ (b'\xc3\xa9\xe2\x82\xac', 'strict', '\xe9\u20ac'),
+ (b'\xf4\x8f\xbf\xbf', 'strict', '\U0010ffff'),
+ (b'\xef\xbf\xbd', 'strict', '\ufffd'),
+ (b'[\xc3\xa9]', 'strict', '[\xe9]'),
+ # invalid bytes
+ (b'[\xff]', 'strict', None),
+ (b'[\xff]', 'ignore', '[]'),
+ (b'[\xff]', 'replace', '[\ufffd]'),
+ (b'[\xff]', 'surrogateescape', '[\udcff]'),
+ ]
+ if VISTA_OR_LATER:
+ tests.extend((
+ (b'[\xed\xb2\x80]', 'strict', None),
+ (b'[\xed\xb2\x80]', 'ignore', '[]'),
+ (b'[\xed\xb2\x80]', 'replace', '[\ufffd\ufffd\ufffd]'),
+ ))
+ else:
+ tests.extend((
+ (b'[\xed\xb2\x80]', 'strict', '[\udc80]'),
+ ))
+ for raw, errors, expected in tests:
+ if expected is not None:
+ try:
+ decoded = raw.decode('cp65001', errors)
+ except UnicodeDecodeError as err:
+ self.fail('Unable to decode %a from cp65001 with '
+ 'errors=%r: %s' % (raw, errors, err))
+ self.assertEqual(decoded, expected,
+ '%a.decode("cp65001", %r)=%a != %a'
+ % (raw, errors, decoded, expected))
+ else:
+ self.assertRaises(UnicodeDecodeError,
+ raw.decode, 'cp65001', errors)
+
+ @unittest.skipUnless(VISTA_OR_LATER, 'require Windows Vista or later')
+ def test_lone_surrogates(self):
+ self.assertRaises(UnicodeEncodeError, "\ud800".encode, "cp65001")
+ self.assertRaises(UnicodeDecodeError, b"\xed\xa0\x80".decode, "cp65001")
+ self.assertEqual("[\uDC80]".encode("cp65001", "backslashreplace"),
+ b'[\\udc80]')
+ self.assertEqual("[\uDC80]".encode("cp65001", "xmlcharrefreplace"),
+ b'[&#56448;]')
+ self.assertEqual("[\uDC80]".encode("cp65001", "surrogateescape"),
+ b'[\x80]')
+ self.assertEqual("[\uDC80]".encode("cp65001", "ignore"),
+ b'[]')
+ self.assertEqual("[\uDC80]".encode("cp65001", "replace"),
+ b'[?]')
+
+ @unittest.skipUnless(VISTA_OR_LATER, 'require Windows Vista or later')
+ def test_surrogatepass_handler(self):
+ self.assertEqual("abc\ud800def".encode("cp65001", "surrogatepass"),
+ b"abc\xed\xa0\x80def")
+ self.assertEqual(b"abc\xed\xa0\x80def".decode("cp65001", "surrogatepass"),
+ "abc\ud800def")
+ self.assertEqual("\U00010fff\uD800".encode("cp65001", "surrogatepass"),
+ b"\xf0\x90\xbf\xbf\xed\xa0\x80")
+ self.assertEqual(b"\xf0\x90\xbf\xbf\xed\xa0\x80".decode("cp65001", "surrogatepass"),
+ "\U00010fff\uD800")
+ self.assertTrue(codecs.lookup_error("surrogatepass"))
+
+
+
+class UTF7Test(ReadTest, unittest.TestCase):
encoding = "utf-7"
def test_partial(self):
@@ -722,7 +844,7 @@ class ReadBufferTest(unittest.TestCase):
self.assertRaises(TypeError, codecs.readbuffer_encode)
self.assertRaises(TypeError, codecs.readbuffer_encode, 42)
-class UTF8SigTest(ReadTest):
+class UTF8SigTest(ReadTest, unittest.TestCase):
encoding = "utf-8-sig"
def test_partial(self):
@@ -995,61 +1117,80 @@ class PunycodeTest(unittest.TestCase):
self.assertEqual(uni, puny.decode("punycode"))
class UnicodeInternalTest(unittest.TestCase):
+ @unittest.skipUnless(SIZEOF_WCHAR_T == 4, 'specific to 32-bit wchar_t')
def test_bug1251300(self):
# Decoding with unicode_internal used to not correctly handle "code
# points" above 0x10ffff on UCS-4 builds.
- if sys.maxunicode > 0xffff:
- ok = [
- (b"\x00\x10\xff\xff", "\U0010ffff"),
- (b"\x00\x00\x01\x01", "\U00000101"),
- (b"", ""),
- ]
- not_ok = [
- b"\x7f\xff\xff\xff",
- b"\x80\x00\x00\x00",
- b"\x81\x00\x00\x00",
- b"\x00",
- b"\x00\x00\x00\x00\x00",
- ]
- for internal, uni in ok:
- if sys.byteorder == "little":
- internal = bytes(reversed(internal))
+ ok = [
+ (b"\x00\x10\xff\xff", "\U0010ffff"),
+ (b"\x00\x00\x01\x01", "\U00000101"),
+ (b"", ""),
+ ]
+ not_ok = [
+ b"\x7f\xff\xff\xff",
+ b"\x80\x00\x00\x00",
+ b"\x81\x00\x00\x00",
+ b"\x00",
+ b"\x00\x00\x00\x00\x00",
+ ]
+ for internal, uni in ok:
+ if sys.byteorder == "little":
+ internal = bytes(reversed(internal))
+ with support.check_warnings():
self.assertEqual(uni, internal.decode("unicode_internal"))
- for internal in not_ok:
- if sys.byteorder == "little":
- internal = bytes(reversed(internal))
+ for internal in not_ok:
+ if sys.byteorder == "little":
+ internal = bytes(reversed(internal))
+ with support.check_warnings(('unicode_internal codec has been '
+ 'deprecated', DeprecationWarning)):
self.assertRaises(UnicodeDecodeError, internal.decode,
- "unicode_internal")
-
+ "unicode_internal")
+ if sys.byteorder == "little":
+ invalid = b"\x00\x00\x11\x00"
+ else:
+ invalid = b"\x00\x11\x00\x00"
+ with support.check_warnings():
+ self.assertRaises(UnicodeDecodeError,
+ invalid.decode, "unicode_internal")
+ with support.check_warnings():
+ self.assertEqual(invalid.decode("unicode_internal", "replace"),
+ '\ufffd')
+
+ @unittest.skipUnless(SIZEOF_WCHAR_T == 4, 'specific to 32-bit wchar_t')
def test_decode_error_attributes(self):
- if sys.maxunicode > 0xffff:
- try:
+ try:
+ with support.check_warnings(('unicode_internal codec has been '
+ 'deprecated', DeprecationWarning)):
b"\x00\x00\x00\x00\x00\x11\x11\x00".decode("unicode_internal")
- except UnicodeDecodeError as ex:
- self.assertEqual("unicode_internal", ex.encoding)
- self.assertEqual(b"\x00\x00\x00\x00\x00\x11\x11\x00", ex.object)
- self.assertEqual(4, ex.start)
- self.assertEqual(8, ex.end)
- else:
- self.fail()
+ except UnicodeDecodeError as ex:
+ self.assertEqual("unicode_internal", ex.encoding)
+ self.assertEqual(b"\x00\x00\x00\x00\x00\x11\x11\x00", ex.object)
+ self.assertEqual(4, ex.start)
+ self.assertEqual(8, ex.end)
+ else:
+ self.fail()
+ @unittest.skipUnless(SIZEOF_WCHAR_T == 4, 'specific to 32-bit wchar_t')
def test_decode_callback(self):
- if sys.maxunicode > 0xffff:
- codecs.register_error("UnicodeInternalTest", codecs.ignore_errors)
- decoder = codecs.getdecoder("unicode_internal")
+ codecs.register_error("UnicodeInternalTest", codecs.ignore_errors)
+ decoder = codecs.getdecoder("unicode_internal")
+ with support.check_warnings(('unicode_internal codec has been '
+ 'deprecated', DeprecationWarning)):
ab = "ab".encode("unicode_internal").decode()
ignored = decoder(bytes("%s\x22\x22\x22\x22%s" % (ab[:4], ab[4:]),
"ascii"),
"UnicodeInternalTest")
- self.assertEqual(("ab", 12), ignored)
+ self.assertEqual(("ab", 12), ignored)
def test_encode_length(self):
- # Issue 3739
- encoder = codecs.getencoder("unicode_internal")
- self.assertEqual(encoder("a")[1], 1)
- self.assertEqual(encoder("\xe9\u0142")[1], 2)
+ with support.check_warnings(('unicode_internal codec has been '
+ 'deprecated', DeprecationWarning)):
+ # Issue 3739
+ encoder = codecs.getencoder("unicode_internal")
+ self.assertEqual(encoder("a")[1], 1)
+ self.assertEqual(encoder("\xe9\u0142")[1], 2)
- self.assertEqual(codecs.escape_encode(br'\x00')[1], 4)
+ self.assertEqual(codecs.escape_encode(br'\x00')[1], 4)
# From http://www.gnu.org/software/libidn/draft-josefsson-idn-test-vectors.html
nameprep_tests = [
@@ -1373,7 +1514,7 @@ class EncodedFileTest(unittest.TestCase):
self.assertEqual(ef.read(), b'\\\xd5\n\x00\x00\xae')
f = io.BytesIO()
- ef = codecs.EncodedFile(f, 'utf-8', 'latin1')
+ ef = codecs.EncodedFile(f, 'utf-8', 'latin-1')
ef.write(b'\xc3\xbc')
self.assertEqual(f.getvalue(), b'\xfc')
@@ -1505,10 +1646,13 @@ class BasicUnicodeTest(unittest.TestCase, MixInCheckStateHandling):
elif encoding == "latin_1":
name = "latin_1"
self.assertEqual(encoding.replace("_", "-"), name.replace("_", "-"))
- (b, size) = codecs.getencoder(encoding)(s)
- self.assertEqual(size, len(s), "%r != %r (encoding=%r)" % (size, len(s), encoding))
- (chars, size) = codecs.getdecoder(encoding)(b)
- self.assertEqual(chars, s, "%r != %r (encoding=%r)" % (chars, s, encoding))
+
+ with support.check_warnings():
+ # unicode-internal has been deprecated
+ (b, size) = codecs.getencoder(encoding)(s)
+ self.assertEqual(size, len(s), "%r != %r (encoding=%r)" % (size, len(s), encoding))
+ (chars, size) = codecs.getdecoder(encoding)(b)
+ self.assertEqual(chars, s, "%r != %r (encoding=%r)" % (chars, s, encoding))
if encoding not in broken_unicode_with_streams:
# check stream reader/writer
@@ -1612,7 +1756,9 @@ class BasicUnicodeTest(unittest.TestCase, MixInCheckStateHandling):
def test_bad_encode_args(self):
for encoding in all_unicode_encodings:
encoder = codecs.getencoder(encoding)
- self.assertRaises(TypeError, encoder)
+ with support.check_warnings():
+ # unicode-internal has been deprecated
+ self.assertRaises(TypeError, encoder)
def test_encoding_map_type_initialized(self):
from encodings import cp1140
@@ -1635,6 +1781,11 @@ class CharmapTest(unittest.TestCase):
("abc", 3)
)
+ self.assertEqual(
+ codecs.charmap_decode(b"\x00\x01\x02", "strict", "\U0010FFFFbc"),
+ ("\U0010FFFFbc", 3)
+ )
+
self.assertRaises(UnicodeDecodeError,
codecs.charmap_decode, b"\x00\x01\x02", "strict", "ab"
)
@@ -1772,9 +1923,15 @@ class CharmapTest(unittest.TestCase):
("\U0010FFFFbc", 3)
)
+ self.assertEqual(
+ codecs.charmap_decode(b"\x00\x01\x02", "strict",
+ {0: sys.maxunicode, 1: b, 2: c}),
+ (chr(sys.maxunicode) + "bc", 3)
+ )
+
self.assertRaises(TypeError,
codecs.charmap_decode, b"\x00\x01\x02", "strict",
- {0: 0x110000, 1: b, 2: c}
+ {0: sys.maxunicode + 1, 1: b, 2: c}
)
self.assertRaises(UnicodeDecodeError,
@@ -1855,6 +2012,12 @@ class TypesTest(unittest.TestCase):
self.assertEqual(codecs.raw_unicode_escape_decode(r"\u1234"), ("\u1234", 6))
self.assertEqual(codecs.raw_unicode_escape_decode(br"\u1234"), ("\u1234", 6))
+ self.assertRaises(UnicodeDecodeError, codecs.unicode_escape_decode, br"\U00110000")
+ self.assertEqual(codecs.unicode_escape_decode(r"\U00110000", "replace"), ("\ufffd", 10))
+
+ self.assertRaises(UnicodeDecodeError, codecs.raw_unicode_escape_decode, br"\U00110000")
+ self.assertEqual(codecs.raw_unicode_escape_decode(r"\U00110000", "replace"), ("\ufffd", 10))
+
class UnicodeEscapeTest(unittest.TestCase):
def test_empty(self):
@@ -2014,7 +2177,7 @@ class SurrogateEscapeTest(unittest.TestCase):
def test_latin1(self):
# Issue6373
- self.assertEqual("\udce4\udceb\udcef\udcf6\udcfc".encode("latin1", "surrogateescape"),
+ self.assertEqual("\udce4\udceb\udcef\udcf6\udcfc".encode("latin-1", "surrogateescape"),
b"\xe4\xeb\xef\xf6\xfc")
@@ -2123,39 +2286,154 @@ class TransformCodecTest(unittest.TestCase):
self.assertEqual(sout, b"\x80")
-def test_main():
- support.run_unittest(
- UTF32Test,
- UTF32LETest,
- UTF32BETest,
- UTF16Test,
- UTF16LETest,
- UTF16BETest,
- UTF8Test,
- UTF8SigTest,
- EscapeDecodeTest,
- UTF7Test,
- UTF16ExTest,
- ReadBufferTest,
- RecodingTest,
- PunycodeTest,
- UnicodeInternalTest,
- NameprepTest,
- IDNACodecTest,
- CodecsModuleTest,
- StreamReaderTest,
- EncodedFileTest,
- BasicUnicodeTest,
- CharmapTest,
- WithStmtTest,
- TypesTest,
- UnicodeEscapeTest,
- RawUnicodeEscapeTest,
- SurrogateEscapeTest,
- BomTest,
- TransformCodecTest,
- )
+@unittest.skipUnless(sys.platform == 'win32',
+ 'code pages are specific to Windows')
+class CodePageTest(unittest.TestCase):
+ # CP_UTF8 is already tested by CP65001Test
+ CP_UTF8 = 65001
+
+ def test_invalid_code_page(self):
+ self.assertRaises(ValueError, codecs.code_page_encode, -1, 'a')
+ self.assertRaises(ValueError, codecs.code_page_decode, -1, b'a')
+ self.assertRaises(WindowsError, codecs.code_page_encode, 123, 'a')
+ self.assertRaises(WindowsError, codecs.code_page_decode, 123, b'a')
+
+ def test_code_page_name(self):
+ self.assertRaisesRegex(UnicodeEncodeError, 'cp932',
+ codecs.code_page_encode, 932, '\xff')
+ self.assertRaisesRegex(UnicodeDecodeError, 'cp932',
+ codecs.code_page_decode, 932, b'\x81\x00')
+ self.assertRaisesRegex(UnicodeDecodeError, 'CP_UTF8',
+ codecs.code_page_decode, self.CP_UTF8, b'\xff')
+
+ def check_decode(self, cp, tests):
+ for raw, errors, expected in tests:
+ if expected is not None:
+ try:
+ decoded = codecs.code_page_decode(cp, raw, errors)
+ except UnicodeDecodeError as err:
+ self.fail('Unable to decode %a from "cp%s" with '
+ 'errors=%r: %s' % (raw, cp, errors, err))
+ self.assertEqual(decoded[0], expected,
+ '%a.decode("cp%s", %r)=%a != %a'
+ % (raw, cp, errors, decoded[0], expected))
+ # assert 0 <= decoded[1] <= len(raw)
+ self.assertGreaterEqual(decoded[1], 0)
+ self.assertLessEqual(decoded[1], len(raw))
+ else:
+ self.assertRaises(UnicodeDecodeError,
+ codecs.code_page_decode, cp, raw, errors)
+
+ def check_encode(self, cp, tests):
+ for text, errors, expected in tests:
+ if expected is not None:
+ try:
+ encoded = codecs.code_page_encode(cp, text, errors)
+ except UnicodeEncodeError as err:
+ self.fail('Unable to encode %a to "cp%s" with '
+ 'errors=%r: %s' % (text, cp, errors, err))
+ self.assertEqual(encoded[0], expected,
+ '%a.encode("cp%s", %r)=%a != %a'
+ % (text, cp, errors, encoded[0], expected))
+ self.assertEqual(encoded[1], len(text))
+ else:
+ self.assertRaises(UnicodeEncodeError,
+ codecs.code_page_encode, cp, text, errors)
+
+ def test_cp932(self):
+ self.check_encode(932, (
+ ('abc', 'strict', b'abc'),
+ ('\uff44\u9a3e', 'strict', b'\x82\x84\xe9\x80'),
+ # test error handlers
+ ('\xff', 'strict', None),
+ ('[\xff]', 'ignore', b'[]'),
+ ('[\xff]', 'replace', b'[y]'),
+ ('[\u20ac]', 'replace', b'[?]'),
+ ('[\xff]', 'backslashreplace', b'[\\xff]'),
+ ('[\xff]', 'xmlcharrefreplace', b'[&#255;]'),
+ ))
+ self.check_decode(932, (
+ (b'abc', 'strict', 'abc'),
+ (b'\x82\x84\xe9\x80', 'strict', '\uff44\u9a3e'),
+ # invalid bytes
+ (b'[\xff]', 'strict', None),
+ (b'[\xff]', 'ignore', '[]'),
+ (b'[\xff]', 'replace', '[\ufffd]'),
+ (b'[\xff]', 'surrogateescape', '[\udcff]'),
+ (b'\x81\x00abc', 'strict', None),
+ (b'\x81\x00abc', 'ignore', '\x00abc'),
+ (b'\x81\x00abc', 'replace', '\ufffd\x00abc'),
+ ))
+
+ def test_cp1252(self):
+ self.check_encode(1252, (
+ ('abc', 'strict', b'abc'),
+ ('\xe9\u20ac', 'strict', b'\xe9\x80'),
+ ('\xff', 'strict', b'\xff'),
+ ('\u0141', 'strict', None),
+ ('\u0141', 'ignore', b''),
+ ('\u0141', 'replace', b'L'),
+ ))
+ self.check_decode(1252, (
+ (b'abc', 'strict', 'abc'),
+ (b'\xe9\x80', 'strict', '\xe9\u20ac'),
+ (b'\xff', 'strict', '\xff'),
+ ))
+
+ def test_cp_utf7(self):
+ cp = 65000
+ self.check_encode(cp, (
+ ('abc', 'strict', b'abc'),
+ ('\xe9\u20ac', 'strict', b'+AOkgrA-'),
+ ('\U0010ffff', 'strict', b'+2//f/w-'),
+ ('\udc80', 'strict', b'+3IA-'),
+ ('\ufffd', 'strict', b'+//0-'),
+ ))
+ self.check_decode(cp, (
+ (b'abc', 'strict', 'abc'),
+ (b'+AOkgrA-', 'strict', '\xe9\u20ac'),
+ (b'+2//f/w-', 'strict', '\U0010ffff'),
+ (b'+3IA-', 'strict', '\udc80'),
+ (b'+//0-', 'strict', '\ufffd'),
+ # invalid bytes
+ (b'[+/]', 'strict', '[]'),
+ (b'[\xff]', 'strict', '[\xff]'),
+ ))
+
+ def test_multibyte_encoding(self):
+ self.check_decode(932, (
+ (b'\x84\xe9\x80', 'ignore', '\u9a3e'),
+ (b'\x84\xe9\x80', 'replace', '\ufffd\u9a3e'),
+ ))
+ self.check_decode(self.CP_UTF8, (
+ (b'\xff\xf4\x8f\xbf\xbf', 'ignore', '\U0010ffff'),
+ (b'\xff\xf4\x8f\xbf\xbf', 'replace', '\ufffd\U0010ffff'),
+ ))
+ if VISTA_OR_LATER:
+ self.check_encode(self.CP_UTF8, (
+ ('[\U0010ffff\uDC80]', 'ignore', b'[\xf4\x8f\xbf\xbf]'),
+ ('[\U0010ffff\uDC80]', 'replace', b'[\xf4\x8f\xbf\xbf?]'),
+ ))
+
+ def test_incremental(self):
+ decoded = codecs.code_page_decode(932, b'\x82', 'strict', False)
+ self.assertEqual(decoded, ('', 0))
+
+ decoded = codecs.code_page_decode(932,
+ b'\xe9\x80\xe9', 'strict',
+ False)
+ self.assertEqual(decoded, ('\u9a3e', 2))
+
+ decoded = codecs.code_page_decode(932,
+ b'\xe9\x80\xe9\x80', 'strict',
+ False)
+ self.assertEqual(decoded, ('\u9a3e\u9a3e', 4))
+
+ decoded = codecs.code_page_decode(932,
+ b'abc', 'strict',
+ False)
+ self.assertEqual(decoded, ('abc', 3))
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_coding.py b/Lib/test/test_coding.py
index f9db0b4..dfd5431 100644
--- a/Lib/test/test_coding.py
+++ b/Lib/test/test_coding.py
@@ -1,7 +1,6 @@
-
import test.support, unittest
from test.support import TESTFN, unlink, unload
-import os, sys
+import importlib, os, sys
class CodingTest(unittest.TestCase):
def test_bad_coding(self):
@@ -40,6 +39,7 @@ class CodingTest(unittest.TestCase):
f.write("'A very long string %s'\n" % ("X" * 1000))
f.close()
+ importlib.invalidate_caches()
__import__(TESTFN)
finally:
f.close()
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index b2a5f05..8850e8b 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -1,6 +1,7 @@
"""Unit tests for collections.py."""
import unittest, doctest, operator
+from test.support import TESTFN, forget, unlink
import inspect
from test import support
from collections import namedtuple, Counter, OrderedDict, _count_elements
@@ -10,21 +11,20 @@ from random import randrange, shuffle
import keyword
import re
import sys
-from collections import _ChainMap
-from collections import Hashable, Iterable, Iterator
-from collections import Sized, Container, Callable
-from collections import Set, MutableSet
-from collections import Mapping, MutableMapping, KeysView, ItemsView, UserDict
-from collections import Sequence, MutableSequence
-from collections import ByteString
+from collections import UserDict
+from collections import ChainMap
+from collections.abc import Hashable, Iterable, Iterator
+from collections.abc import Sized, Container, Callable
+from collections.abc import Set, MutableSet
+from collections.abc import Mapping, MutableMapping, KeysView, ItemsView
+from collections.abc import Sequence, MutableSequence
+from collections.abc import ByteString
################################################################################
-### _ChainMap (helper class for configparser)
+### ChainMap (helper class for configparser and the string module)
################################################################################
-ChainMap = _ChainMap # rename to keep test code in sync with 3.3 version
-
class TestChainMap(unittest.TestCase):
def test_basics(self):
@@ -128,6 +128,7 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(Point.__module__, __name__)
self.assertEqual(Point.__getitem__, tuple.__getitem__)
self.assertEqual(Point._fields, ('x', 'y'))
+ self.assertIn('class Point(tuple)', Point._source)
self.assertRaises(ValueError, namedtuple, 'abc%', 'efg ghi') # type has non-alpha char
self.assertRaises(ValueError, namedtuple, 'class', 'efg ghi') # type has keyword
@@ -327,6 +328,17 @@ class TestNamedTuple(unittest.TestCase):
pass
self.assertEqual(repr(B(1)), 'B(x=1)')
+ def test_source(self):
+ # verify that _source can be run through exec()
+ tmp = namedtuple('NTColor', 'red green blue')
+ globals().pop('NTColor', None) # remove artifacts from other tests
+ exec(tmp._source, globals())
+ self.assertIn('NTColor', globals())
+ c = NTColor(10, 20, 30)
+ self.assertEqual((c.red, c.green, c.blue), (10, 20, 30))
+ self.assertEqual(NTColor._fields, ('red', 'green', 'blue'))
+ globals().pop('NTColor', None) # clean-up after this test
+
################################################################################
### Abstract Base Classes
@@ -762,6 +774,44 @@ class TestCollectionABCs(ABCTestCase):
self.validate_abstract_methods(MutableSequence, '__contains__', '__iter__',
'__len__', '__getitem__', '__setitem__', '__delitem__', 'insert')
+ def test_MutableSequence_mixins(self):
+ # Test the mixins of MutableSequence by creating a miminal concrete
+ # class inherited from it.
+ class MutableSequenceSubclass(MutableSequence):
+ def __init__(self):
+ self.lst = []
+
+ def __setitem__(self, index, value):
+ self.lst[index] = value
+
+ def __getitem__(self, index):
+ return self.lst[index]
+
+ def __len__(self):
+ return len(self.lst)
+
+ def __delitem__(self, index):
+ del self.lst[index]
+
+ def insert(self, index, value):
+ self.lst.insert(index, value)
+
+ mss = MutableSequenceSubclass()
+ mss.append(0)
+ mss.extend((1, 2, 3, 4))
+ self.assertEqual(len(mss), 5)
+ self.assertEqual(mss[3], 3)
+ mss.reverse()
+ self.assertEqual(mss[3], 1)
+ mss.pop()
+ self.assertEqual(len(mss), 4)
+ mss.remove(3)
+ self.assertEqual(len(mss), 3)
+ mss += (10, 20, 30)
+ self.assertEqual(len(mss), 6)
+ self.assertEqual(mss[-1], 30)
+ mss.clear()
+ self.assertEqual(len(mss), 0)
################################################################################
### Counter
@@ -915,6 +965,27 @@ class TestCounter(unittest.TestCase):
set_result = setop(set(p.elements()), set(q.elements()))
self.assertEqual(counter_result, dict.fromkeys(set_result, 1))
+ def test_inplace_operations(self):
+ elements = 'abcd'
+ for i in range(1000):
+ # test random pairs of multisets
+ p = Counter(dict((elem, randrange(-2,4)) for elem in elements))
+ p.update(e=1, f=-1, g=0)
+ q = Counter(dict((elem, randrange(-2,4)) for elem in elements))
+ q.update(h=1, i=-1, j=0)
+ for inplace_op, regular_op in [
+ (Counter.__iadd__, Counter.__add__),
+ (Counter.__isub__, Counter.__sub__),
+ (Counter.__ior__, Counter.__or__),
+ (Counter.__iand__, Counter.__and__),
+ ]:
+ c = p.copy()
+ c_id = id(c)
+ regular_result = regular_op(c, q)
+ inplace_result = inplace_op(c, q)
+ self.assertEqual(inplace_result, regular_result)
+ self.assertEqual(id(inplace_result), c_id)
+
def test_subtract(self):
c = Counter(a=-5, b=0, c=5, d=10, e=15,g=40)
c.subtract(a=1, b=2, c=-3, d=10, e=20, f=30, h=-50)
@@ -926,6 +997,11 @@ class TestCounter(unittest.TestCase):
c.subtract('aaaabbcce')
self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1))
+ def test_unary(self):
+ c = Counter(a=-5, b=0, c=5, d=10, e=15,g=40)
+ self.assertEqual(dict(+c), dict(c=5, d=10, e=15, g=40))
+ self.assertEqual(dict(-c), dict(a=5))
+
def test_repr_nonsortable(self):
c = Counter(a=2, b=None)
r = repr(c)
diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py
index 58ef297..1071f4a 100644
--- a/Lib/test/test_compile.py
+++ b/Lib/test/test_compile.py
@@ -1,10 +1,17 @@
import unittest
import sys
import _ast
+import types
from test import support
class TestSpecifics(unittest.TestCase):
+ def compile_single(self, source):
+ compile(source, "<single>", "single")
+
+ def assertInvalidSingle(self, source):
+ self.assertRaises(SyntaxError, self.compile_single, source)
+
def test_no_ending_newline(self):
compile("hi", "<test>", "exec")
compile("hi\r", "<test>", "exec")
@@ -433,6 +440,66 @@ if 1:
ast.body = [_ast.BoolOp()]
self.assertRaises(TypeError, compile, ast, '<ast>', 'exec')
+ @support.cpython_only
+ def test_same_filename_used(self):
+ s = """def f(): pass\ndef g(): pass"""
+ c = compile(s, "myfile", "exec")
+ for obj in c.co_consts:
+ if isinstance(obj, types.CodeType):
+ self.assertIs(obj.co_filename, c.co_filename)
+
+ def test_single_statement(self):
+ self.compile_single("1 + 2")
+ self.compile_single("\n1 + 2")
+ self.compile_single("1 + 2\n")
+ self.compile_single("1 + 2\n\n")
+ self.compile_single("1 + 2\t\t\n")
+ self.compile_single("1 + 2\t\t\n ")
+ self.compile_single("1 + 2 # one plus two")
+ self.compile_single("1; 2")
+ self.compile_single("import sys; sys")
+ self.compile_single("def f():\n pass")
+ self.compile_single("while False:\n pass")
+ self.compile_single("if x:\n f(x)")
+ self.compile_single("if x:\n f(x)\nelse:\n g(x)")
+ self.compile_single("class T:\n pass")
+
+ def test_bad_single_statement(self):
+ self.assertInvalidSingle('1\n2')
+ self.assertInvalidSingle('def f(): pass')
+ self.assertInvalidSingle('a = 13\nb = 187')
+ self.assertInvalidSingle('del x\ndel y')
+ self.assertInvalidSingle('f()\ng()')
+ self.assertInvalidSingle('f()\n# blah\nblah()')
+ self.assertInvalidSingle('f()\nxy # blah\nblah()')
+ self.assertInvalidSingle('x = 5 # comment\nx = 6\n')
+
+ @support.cpython_only
+ def test_compiler_recursion_limit(self):
+ # Expected limit is sys.getrecursionlimit() * the scaling factor
+ # in symtable.c (currently 3)
+ # We expect to fail *at* that limit, because we use up some of
+ # the stack depth limit in the test suite code
+ # So we check the expected limit and 75% of that
+ # XXX (ncoghlan): duplicating the scaling factor here is a little
+ # ugly. Perhaps it should be exposed somewhere...
+ fail_depth = sys.getrecursionlimit() * 3
+ success_depth = int(fail_depth * 0.75)
+
+ def check_limit(prefix, repeated):
+ expect_ok = prefix + repeated * success_depth
+ self.compile_single(expect_ok)
+ broken = prefix + repeated * fail_depth
+ details = "Compiling ({!r} + {!r} * {})".format(
+ prefix, repeated, fail_depth)
+ with self.assertRaises(RuntimeError, msg=details):
+ self.compile_single(broken)
+
+ check_limit("a", "()")
+ check_limit("a", ".b")
+ check_limit("a", "[0]")
+ check_limit("a", "*a")
+
def test_main():
support.run_unittest(TestSpecifics)
diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py
index 2afa938..6ae450d 100644
--- a/Lib/test/test_concurrent_futures.py
+++ b/Lib/test/test_concurrent_futures.py
@@ -19,7 +19,7 @@ import unittest
from concurrent import futures
from concurrent.futures._base import (
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
-import concurrent.futures.process
+from concurrent.futures.process import BrokenProcessPool
def create_future(state=PENDING, exception=None, result=None):
@@ -34,7 +34,7 @@ PENDING_FUTURE = create_future(state=PENDING)
RUNNING_FUTURE = create_future(state=RUNNING)
CANCELLED_FUTURE = create_future(state=CANCELLED)
CANCELLED_AND_NOTIFIED_FUTURE = create_future(state=CANCELLED_AND_NOTIFIED)
-EXCEPTION_FUTURE = create_future(state=FINISHED, exception=IOError())
+EXCEPTION_FUTURE = create_future(state=FINISHED, exception=OSError())
SUCCESSFUL_FUTURE = create_future(state=FINISHED, result=42)
@@ -160,7 +160,7 @@ class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest):
processes = self.executor._processes
self.executor.shutdown()
- for p in processes:
+ for p in processes.values():
p.join()
def test_context_manager_shutdown(self):
@@ -169,7 +169,7 @@ class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest):
self.assertEqual(list(e.map(abs, range(-5, 5))),
[5, 4, 3, 2, 1, 0, 1, 2, 3, 4])
- for p in processes:
+ for p in processes.values():
p.join()
def test_del_shutdown(self):
@@ -180,7 +180,7 @@ class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest):
del executor
queue_management_thread.join()
- for p in processes:
+ for p in processes.values():
p.join()
@@ -268,14 +268,14 @@ class WaitTests(unittest.TestCase):
def test_timeout(self):
future1 = self.executor.submit(mul, 6, 7)
- future2 = self.executor.submit(time.sleep, 3)
+ future2 = self.executor.submit(time.sleep, 6)
finished, pending = futures.wait(
[CANCELLED_AND_NOTIFIED_FUTURE,
EXCEPTION_FUTURE,
SUCCESSFUL_FUTURE,
future1, future2],
- timeout=1.5,
+ timeout=5,
return_when=futures.ALL_COMPLETED)
self.assertEqual(set([CANCELLED_AND_NOTIFIED_FUTURE,
@@ -379,8 +379,8 @@ class ExecutorTest(unittest.TestCase):
results = []
try:
for i in self.executor.map(time.sleep,
- [0, 0, 3],
- timeout=1.5):
+ [0, 0, 6],
+ timeout=5):
results.append(i)
except futures.TimeoutError:
pass
@@ -389,13 +389,38 @@ class ExecutorTest(unittest.TestCase):
self.assertEqual([None, None], results)
+ def test_shutdown_race_issue12456(self):
+ # Issue #12456: race condition at shutdown where trying to post a
+ # sentinel in the call queue blocks (the queue is full while processes
+ # have exited).
+ self.executor.map(str, [2] * (self.worker_count + 1))
+ self.executor.shutdown()
+
class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest):
- pass
+ def test_map_submits_without_iteration(self):
+ """Tests verifying issue 11777."""
+ finished = []
+ def record_finished(n):
+ finished.append(n)
+
+ self.executor.map(record_finished, range(10))
+ self.executor.shutdown(wait=True)
+ self.assertCountEqual(finished, range(10))
class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest):
- pass
+ def test_killed_child(self):
+ # When a child process is abruptly terminated, the whole pool gets
+ # "broken".
+ futures = [self.executor.submit(time.sleep, 3)]
+ # Get one of the processes, and terminate (kill) it
+ p = next(iter(self.executor._processes.values()))
+ p.terminate()
+ for fut in futures:
+ self.assertRaises(BrokenProcessPool, fut.result)
+ # Submitting other jobs fails as well.
+ self.assertRaises(BrokenProcessPool, self.executor.submit, pow, 2, 8)
class FutureTests(unittest.TestCase):
@@ -498,7 +523,7 @@ class FutureTests(unittest.TestCase):
'<Future at 0x[0-9a-f]+ state=cancelled>')
self.assertRegex(
repr(EXCEPTION_FUTURE),
- '<Future at 0x[0-9a-f]+ state=finished raised IOError>')
+ '<Future at 0x[0-9a-f]+ state=finished raised OSError>')
self.assertRegex(
repr(SUCCESSFUL_FUTURE),
'<Future at 0x[0-9a-f]+ state=finished returned int>')
@@ -509,7 +534,7 @@ class FutureTests(unittest.TestCase):
f2 = create_future(state=RUNNING)
f3 = create_future(state=CANCELLED)
f4 = create_future(state=CANCELLED_AND_NOTIFIED)
- f5 = create_future(state=FINISHED, exception=IOError())
+ f5 = create_future(state=FINISHED, exception=OSError())
f6 = create_future(state=FINISHED, result=5)
self.assertTrue(f1.cancel())
@@ -563,7 +588,7 @@ class FutureTests(unittest.TestCase):
CANCELLED_FUTURE.result, timeout=0)
self.assertRaises(futures.CancelledError,
CANCELLED_AND_NOTIFIED_FUTURE.result, timeout=0)
- self.assertRaises(IOError, EXCEPTION_FUTURE.result, timeout=0)
+ self.assertRaises(OSError, EXCEPTION_FUTURE.result, timeout=0)
self.assertEqual(SUCCESSFUL_FUTURE.result(timeout=0), 42)
def test_result_with_success(self):
@@ -602,7 +627,7 @@ class FutureTests(unittest.TestCase):
self.assertRaises(futures.CancelledError,
CANCELLED_AND_NOTIFIED_FUTURE.exception, timeout=0)
self.assertTrue(isinstance(EXCEPTION_FUTURE.exception(timeout=0),
- IOError))
+ OSError))
self.assertEqual(SUCCESSFUL_FUTURE.exception(timeout=0), None)
def test_exception_with_success(self):
@@ -611,14 +636,14 @@ class FutureTests(unittest.TestCase):
time.sleep(1)
with f1._condition:
f1._state = FINISHED
- f1._exception = IOError()
+ f1._exception = OSError()
f1._condition.notify_all()
f1 = create_future(state=PENDING)
t = threading.Thread(target=notification)
t.start()
- self.assertTrue(isinstance(f1.exception(timeout=5), IOError))
+ self.assertTrue(isinstance(f1.exception(timeout=5), OSError))
@test.support.reap_threads
def test_main():
@@ -631,7 +656,8 @@ def test_main():
ThreadPoolAsCompletedTests,
FutureTests,
ProcessPoolShutdownTest,
- ThreadPoolShutdownTest)
+ ThreadPoolShutdownTest,
+ )
finally:
test.support.reap_children()
diff --git a/Lib/test/test_cfgparser.py b/Lib/test/test_configparser.py
index cec9b44..fc56550 100644
--- a/Lib/test/test_cfgparser.py
+++ b/Lib/test/test_configparser.py
@@ -32,7 +32,7 @@ class SortedDict(collections.UserDict):
__iter__ = iterkeys
-class CfgParserTestCaseClass(unittest.TestCase):
+class CfgParserTestCaseClass:
allow_no_value = False
delimiters = ('=', ':')
comment_prefixes = (';', '#')
@@ -822,14 +822,20 @@ boolean {0[0]} NO
self.assertEqual(set(cf['section3'].keys()), {'named'})
self.assertNotIn('name3', cf['section3'])
self.assertEqual(cf.sections(), ['section1', 'section2', 'section3'])
+ cf[self.default_section] = {}
+ self.assertEqual(set(cf[self.default_section].keys()), set())
+ self.assertEqual(set(cf['section1'].keys()), {'name1'})
+ self.assertEqual(set(cf['section2'].keys()), {'name22'})
+ self.assertEqual(set(cf['section3'].keys()), set())
+ self.assertEqual(cf.sections(), ['section1', 'section2', 'section3'])
-class StrictTestCase(BasicTestCase):
+class StrictTestCase(BasicTestCase, unittest.TestCase):
config_class = configparser.RawConfigParser
strict = True
-class ConfigParserTestCase(BasicTestCase):
+class ConfigParserTestCase(BasicTestCase, unittest.TestCase):
config_class = configparser.ConfigParser
def test_interpolation(self):
@@ -918,7 +924,7 @@ class ConfigParserTestCase(BasicTestCase):
self.assertRaises(ValueError, cf.add_section, self.default_section)
-class ConfigParserTestCaseNoInterpolation(BasicTestCase):
+class ConfigParserTestCaseNoInterpolation(BasicTestCase, unittest.TestCase):
config_class = configparser.ConfigParser
interpolation = None
ini = textwrap.dedent("""
@@ -983,7 +989,7 @@ class ConfigParserTestCaseNonStandardDelimiters(ConfigParserTestCase):
class ConfigParserTestCaseNonStandardDefaultSection(ConfigParserTestCase):
default_section = 'general'
-class MultilineValuesTestCase(BasicTestCase):
+class MultilineValuesTestCase(BasicTestCase, unittest.TestCase):
config_class = configparser.ConfigParser
wonderful_spam = ("I'm having spam spam spam spam "
"spam spam spam beaked beans spam "
@@ -1011,7 +1017,7 @@ class MultilineValuesTestCase(BasicTestCase):
self.assertEqual(cf_from_file.get('section8', 'lovely_spam4'),
self.wonderful_spam.replace('\t\n', '\n'))
-class RawConfigParserTestCase(BasicTestCase):
+class RawConfigParserTestCase(BasicTestCase, unittest.TestCase):
config_class = configparser.RawConfigParser
def test_interpolation(self):
@@ -1058,7 +1064,7 @@ class RawConfigParserTestCaseNonStandardDelimiters(RawConfigParserTestCase):
comment_prefixes = ('//', '"')
inline_comment_prefixes = ('//', '"')
-class RawConfigParserTestSambaConf(CfgParserTestCaseClass):
+class RawConfigParserTestSambaConf(CfgParserTestCaseClass, unittest.TestCase):
config_class = configparser.RawConfigParser
comment_prefixes = ('#', ';', '----')
inline_comment_prefixes = ('//',)
@@ -1078,7 +1084,7 @@ class RawConfigParserTestSambaConf(CfgParserTestCaseClass):
self.assertEqual(cf.get("global", "hosts allow"), "127.")
self.assertEqual(cf.get("tmp", "echo command"), "cat %s; rm %s")
-class ConfigParserTestCaseExtendedInterpolation(BasicTestCase):
+class ConfigParserTestCaseExtendedInterpolation(BasicTestCase, unittest.TestCase):
config_class = configparser.ConfigParser
interpolation = configparser.ExtendedInterpolation()
default_section = 'common'
@@ -1252,7 +1258,7 @@ class ConfigParserTestCaseExtendedInterpolation(BasicTestCase):
class ConfigParserTestCaseNoValue(ConfigParserTestCase):
allow_no_value = True
-class ConfigParserTestCaseTrickyFile(CfgParserTestCaseClass):
+class ConfigParserTestCaseTrickyFile(CfgParserTestCaseClass, unittest.TestCase):
config_class = configparser.ConfigParser
delimiters = {'='}
comment_prefixes = {'#'}
@@ -1349,7 +1355,7 @@ class SortedTestCase(RawConfigParserTestCase):
"o4 = 1\n\n")
-class CompatibleTestCase(CfgParserTestCaseClass):
+class CompatibleTestCase(CfgParserTestCaseClass, unittest.TestCase):
config_class = configparser.RawConfigParser
comment_prefixes = '#;'
inline_comment_prefixes = ';'
@@ -1371,7 +1377,7 @@ class CompatibleTestCase(CfgParserTestCaseClass):
self.assertEqual(cf.get('Commented Bar', 'quirk'),
'this;is not a comment')
-class CopyTestCase(BasicTestCase):
+class CopyTestCase(BasicTestCase, unittest.TestCase):
config_class = configparser.ConfigParser
def fromstring(self, string, defaults=None):
@@ -1671,26 +1677,41 @@ class ExceptionPicklingTestCase(unittest.TestCase):
self.assertEqual(repr(e1), repr(e2))
-def test_main():
- support.run_unittest(
- ConfigParserTestCase,
- ConfigParserTestCaseNonStandardDelimiters,
- ConfigParserTestCaseNoValue,
- ConfigParserTestCaseExtendedInterpolation,
- ConfigParserTestCaseLegacyInterpolation,
- ConfigParserTestCaseNoInterpolation,
- ConfigParserTestCaseTrickyFile,
- MultilineValuesTestCase,
- RawConfigParserTestCase,
- RawConfigParserTestCaseNonStandardDelimiters,
- RawConfigParserTestSambaConf,
- SortedTestCase,
- Issue7005TestCase,
- StrictTestCase,
- CompatibleTestCase,
- CopyTestCase,
- ConfigParserTestCaseNonStandardDefaultSection,
- ReadFileTestCase,
- CoverageOneHundredTestCase,
- ExceptionPicklingTestCase,
- )
+class InlineCommentStrippingTestCase(unittest.TestCase):
+ """Tests for issue #14590: ConfigParser doesn't strip inline comment when
+ delimiter occurs earlier without preceding space.."""
+
+ def test_stripping(self):
+ cfg = configparser.ConfigParser(inline_comment_prefixes=(';', '#',
+ '//'))
+ cfg.read_string("""
+ [section]
+ k1 = v1;still v1
+ k2 = v2 ;a comment
+ k3 = v3 ; also a comment
+ k4 = v4;still v4 ;a comment
+ k5 = v5;still v5 ; also a comment
+ k6 = v6;still v6; and still v6 ;a comment
+ k7 = v7;still v7; and still v7 ; also a comment
+
+ [multiprefix]
+ k1 = v1;still v1 #a comment ; yeah, pretty much
+ k2 = v2 // this already is a comment ; continued
+ k3 = v3;#//still v3# and still v3 ; a comment
+ """)
+ s = cfg['section']
+ self.assertEqual(s['k1'], 'v1;still v1')
+ self.assertEqual(s['k2'], 'v2')
+ self.assertEqual(s['k3'], 'v3')
+ self.assertEqual(s['k4'], 'v4;still v4')
+ self.assertEqual(s['k5'], 'v5;still v5')
+ self.assertEqual(s['k6'], 'v6;still v6; and still v6')
+ self.assertEqual(s['k7'], 'v7;still v7; and still v7')
+ s = cfg['multiprefix']
+ self.assertEqual(s['k1'], 'v1;still v1')
+ self.assertEqual(s['k2'], 'v2')
+ self.assertEqual(s['k3'], 'v3;#//still v3# and still v3')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 6e38305..e52ed91 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -370,6 +370,231 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999])
+class TestExitStack(unittest.TestCase):
+
+ def test_no_resources(self):
+ with ExitStack():
+ pass
+
+ def test_callback(self):
+ expected = [
+ ((), {}),
+ ((1,), {}),
+ ((1,2), {}),
+ ((), dict(example=1)),
+ ((1,), dict(example=1)),
+ ((1,2), dict(example=1)),
+ ]
+ result = []
+ def _exit(*args, **kwds):
+ """Test metadata propagation"""
+ result.append((args, kwds))
+ with ExitStack() as stack:
+ for args, kwds in reversed(expected):
+ if args and kwds:
+ f = stack.callback(_exit, *args, **kwds)
+ elif args:
+ f = stack.callback(_exit, *args)
+ elif kwds:
+ f = stack.callback(_exit, **kwds)
+ else:
+ f = stack.callback(_exit)
+ self.assertIs(f, _exit)
+ for wrapper in stack._exit_callbacks:
+ self.assertIs(wrapper.__wrapped__, _exit)
+ self.assertNotEqual(wrapper.__name__, _exit.__name__)
+ self.assertIsNone(wrapper.__doc__, _exit.__doc__)
+ self.assertEqual(result, expected)
+
+ def test_push(self):
+ exc_raised = ZeroDivisionError
+ def _expect_exc(exc_type, exc, exc_tb):
+ self.assertIs(exc_type, exc_raised)
+ def _suppress_exc(*exc_details):
+ return True
+ def _expect_ok(exc_type, exc, exc_tb):
+ self.assertIsNone(exc_type)
+ self.assertIsNone(exc)
+ self.assertIsNone(exc_tb)
+ class ExitCM(object):
+ def __init__(self, check_exc):
+ self.check_exc = check_exc
+ def __enter__(self):
+ self.fail("Should not be called!")
+ def __exit__(self, *exc_details):
+ self.check_exc(*exc_details)
+ with ExitStack() as stack:
+ stack.push(_expect_ok)
+ self.assertIs(stack._exit_callbacks[-1], _expect_ok)
+ cm = ExitCM(_expect_ok)
+ stack.push(cm)
+ self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+ stack.push(_suppress_exc)
+ self.assertIs(stack._exit_callbacks[-1], _suppress_exc)
+ cm = ExitCM(_expect_exc)
+ stack.push(cm)
+ self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+ stack.push(_expect_exc)
+ self.assertIs(stack._exit_callbacks[-1], _expect_exc)
+ stack.push(_expect_exc)
+ self.assertIs(stack._exit_callbacks[-1], _expect_exc)
+ 1/0
+
+ def test_enter_context(self):
+ class TestCM(object):
+ def __enter__(self):
+ result.append(1)
+ def __exit__(self, *exc_details):
+ result.append(3)
+
+ result = []
+ cm = TestCM()
+ with ExitStack() as stack:
+ @stack.callback # Registered first => cleaned up last
+ def _exit():
+ result.append(4)
+ self.assertIsNotNone(_exit)
+ stack.enter_context(cm)
+ self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+ result.append(2)
+ self.assertEqual(result, [1, 2, 3, 4])
+
+ def test_close(self):
+ result = []
+ with ExitStack() as stack:
+ @stack.callback
+ def _exit():
+ result.append(1)
+ self.assertIsNotNone(_exit)
+ stack.close()
+ result.append(2)
+ self.assertEqual(result, [1, 2])
+
+ def test_pop_all(self):
+ result = []
+ with ExitStack() as stack:
+ @stack.callback
+ def _exit():
+ result.append(3)
+ self.assertIsNotNone(_exit)
+ new_stack = stack.pop_all()
+ result.append(1)
+ result.append(2)
+ new_stack.close()
+ self.assertEqual(result, [1, 2, 3])
+
+ def test_exit_raise(self):
+ with self.assertRaises(ZeroDivisionError):
+ with ExitStack() as stack:
+ stack.push(lambda *exc: False)
+ 1/0
+
+ def test_exit_suppress(self):
+ with ExitStack() as stack:
+ stack.push(lambda *exc: True)
+ 1/0
+
+ def test_exit_exception_chaining_reference(self):
+ # Sanity check to make sure that ExitStack chaining matches
+ # actual nested with statements
+ class RaiseExc:
+ def __init__(self, exc):
+ self.exc = exc
+ def __enter__(self):
+ return self
+ def __exit__(self, *exc_details):
+ raise self.exc
+
+ class RaiseExcWithContext:
+ def __init__(self, outer, inner):
+ self.outer = outer
+ self.inner = inner
+ def __enter__(self):
+ return self
+ def __exit__(self, *exc_details):
+ try:
+ raise self.inner
+ except:
+ raise self.outer
+
+ class SuppressExc:
+ def __enter__(self):
+ return self
+ def __exit__(self, *exc_details):
+ type(self).saved_details = exc_details
+ return True
+
+ try:
+ with RaiseExc(IndexError):
+ with RaiseExcWithContext(KeyError, AttributeError):
+ with SuppressExc():
+ with RaiseExc(ValueError):
+ 1 / 0
+ except IndexError as exc:
+ self.assertIsInstance(exc.__context__, KeyError)
+ self.assertIsInstance(exc.__context__.__context__, AttributeError)
+ # Inner exceptions were suppressed
+ self.assertIsNone(exc.__context__.__context__.__context__)
+ else:
+ self.fail("Expected IndexError, but no exception was raised")
+ # Check the inner exceptions
+ inner_exc = SuppressExc.saved_details[1]
+ self.assertIsInstance(inner_exc, ValueError)
+ self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
+
+ def test_exit_exception_chaining(self):
+ # Ensure exception chaining matches the reference behaviour
+ def raise_exc(exc):
+ raise exc
+
+ saved_details = None
+ def suppress_exc(*exc_details):
+ nonlocal saved_details
+ saved_details = exc_details
+ return True
+
+ try:
+ with ExitStack() as stack:
+ stack.callback(raise_exc, IndexError)
+ stack.callback(raise_exc, KeyError)
+ stack.callback(raise_exc, AttributeError)
+ stack.push(suppress_exc)
+ stack.callback(raise_exc, ValueError)
+ 1 / 0
+ except IndexError as exc:
+ self.assertIsInstance(exc.__context__, KeyError)
+ self.assertIsInstance(exc.__context__.__context__, AttributeError)
+ # Inner exceptions were suppressed
+ self.assertIsNone(exc.__context__.__context__.__context__)
+ else:
+ self.fail("Expected IndexError, but no exception was raised")
+ # Check the inner exceptions
+ inner_exc = saved_details[1]
+ self.assertIsInstance(inner_exc, ValueError)
+ self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
+
+ def test_exit_exception_chaining_suppress(self):
+ with ExitStack() as stack:
+ stack.push(lambda *exc: True)
+ stack.push(lambda *exc: 1/0)
+ stack.push(lambda *exc: {}[1])
+
+ def test_excessive_nesting(self):
+ # The original implementation would die with RecursionError here
+ with ExitStack() as stack:
+ for i in range(10000):
+ stack.callback(int)
+
+ def test_instance_bypass(self):
+ class Example(object): pass
+ cm = Example()
+ cm.__exit__ = object()
+ stack = ExitStack()
+ self.assertRaises(AttributeError, stack.enter_context, cm)
+ stack.push(cm)
+ self.assertIs(stack._exit_callbacks[-1], cm)
+
+
# This is needed to make the test actually run under regrtest.py!
def test_main():
support.run_unittest(__name__)
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index a84c109..c4baae4 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -17,7 +17,7 @@ class TestCopy(unittest.TestCase):
# Attempt full line coverage of copy.py from top to bottom
def test_exceptions(self):
- self.assertTrue(copy.Error is copy.error)
+ self.assertIs(copy.Error, copy.error)
self.assertTrue(issubclass(copy.Error, Exception))
# The copy() method
@@ -54,20 +54,26 @@ class TestCopy(unittest.TestCase):
def test_copy_reduce_ex(self):
class C(object):
def __reduce_ex__(self, proto):
+ c.append(1)
return ""
def __reduce__(self):
- raise support.TestFailed("shouldn't call this")
+ self.fail("shouldn't call this")
+ c = []
x = C()
y = copy.copy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
+ self.assertEqual(c, [1])
def test_copy_reduce(self):
class C(object):
def __reduce__(self):
+ c.append(1)
return ""
+ c = []
x = C()
y = copy.copy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
+ self.assertEqual(c, [1])
def test_copy_cant(self):
class C(object):
@@ -91,7 +97,7 @@ class TestCopy(unittest.TestCase):
"hello", "hello\u1234", f.__code__,
NewStyle, range(10), Classic, max]
for x in tests:
- self.assertTrue(copy.copy(x) is x, repr(x))
+ self.assertIs(copy.copy(x), x)
def test_copy_list(self):
x = [1, 2, 3]
@@ -185,9 +191,9 @@ class TestCopy(unittest.TestCase):
x = [x, x]
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y[0] is not x[0])
- self.assertTrue(y[0] is y[1])
+ self.assertIsNot(y, x)
+ self.assertIsNot(y[0], x[0])
+ self.assertIs(y[0], y[1])
def test_deepcopy_issubclass(self):
# XXX Note: there's no way to test the TypeError coming out of
@@ -227,20 +233,26 @@ class TestCopy(unittest.TestCase):
def test_deepcopy_reduce_ex(self):
class C(object):
def __reduce_ex__(self, proto):
+ c.append(1)
return ""
def __reduce__(self):
- raise support.TestFailed("shouldn't call this")
+ self.fail("shouldn't call this")
+ c = []
x = C()
y = copy.deepcopy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
+ self.assertEqual(c, [1])
def test_deepcopy_reduce(self):
class C(object):
def __reduce__(self):
+ c.append(1)
return ""
+ c = []
x = C()
y = copy.deepcopy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
+ self.assertEqual(c, [1])
def test_deepcopy_cant(self):
class C(object):
@@ -264,14 +276,14 @@ class TestCopy(unittest.TestCase):
"hello", "hello\u1234", f.__code__,
NewStyle, range(10), Classic, max]
for x in tests:
- self.assertTrue(copy.deepcopy(x) is x, repr(x))
+ self.assertIs(copy.deepcopy(x), x)
def test_deepcopy_list(self):
x = [[1, 2], 3]
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(x is not y)
- self.assertTrue(x[0] is not y[0])
+ self.assertIsNot(x, y)
+ self.assertIsNot(x[0], y[0])
def test_deepcopy_reflexive_list(self):
x = []
@@ -279,16 +291,26 @@ class TestCopy(unittest.TestCase):
y = copy.deepcopy(x)
for op in comparisons:
self.assertRaises(RuntimeError, op, y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y[0] is y)
+ self.assertIsNot(y, x)
+ self.assertIs(y[0], y)
self.assertEqual(len(y), 1)
+ def test_deepcopy_empty_tuple(self):
+ x = ()
+ y = copy.deepcopy(x)
+ self.assertIs(x, y)
+
def test_deepcopy_tuple(self):
x = ([1, 2], 3)
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(x is not y)
- self.assertTrue(x[0] is not y[0])
+ self.assertIsNot(x, y)
+ self.assertIsNot(x[0], y[0])
+
+ def test_deepcopy_tuple_of_immutables(self):
+ x = ((1, 2), 3)
+ y = copy.deepcopy(x)
+ self.assertIs(x, y)
def test_deepcopy_reflexive_tuple(self):
x = ([],)
@@ -296,16 +318,16 @@ class TestCopy(unittest.TestCase):
y = copy.deepcopy(x)
for op in comparisons:
self.assertRaises(RuntimeError, op, y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y[0] is not x[0])
- self.assertTrue(y[0][0] is y)
+ self.assertIsNot(y, x)
+ self.assertIsNot(y[0], x[0])
+ self.assertIs(y[0][0], y)
def test_deepcopy_dict(self):
x = {"foo": [1, 2], "bar": 3}
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(x is not y)
- self.assertTrue(x["foo"] is not y["foo"])
+ self.assertIsNot(x, y)
+ self.assertIsNot(x["foo"], y["foo"])
def test_deepcopy_reflexive_dict(self):
x = {}
@@ -315,15 +337,30 @@ class TestCopy(unittest.TestCase):
self.assertRaises(TypeError, op, y, x)
for op in equality_comparisons:
self.assertRaises(RuntimeError, op, y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y['foo'] is y)
+ self.assertIsNot(y, x)
+ self.assertIs(y['foo'], y)
self.assertEqual(len(y), 1)
def test_deepcopy_keepalive(self):
memo = {}
- x = 42
+ x = []
+ y = copy.deepcopy(x, memo)
+ self.assertIs(memo[id(memo)][0], x)
+
+ def test_deepcopy_dont_memo_immutable(self):
+ memo = {}
+ x = [1, 2, 3, 4]
y = copy.deepcopy(x, memo)
- self.assertTrue(memo[id(x)] is x)
+ self.assertEqual(y, x)
+ # There's the entry for the new list, and the keep alive.
+ self.assertEqual(len(memo), 2)
+
+ memo = {}
+ x = [(1, 2)]
+ y = copy.deepcopy(x, memo)
+ self.assertEqual(y, x)
+ # Tuples with immutable contents are immutable for deepcopy.
+ self.assertEqual(len(memo), 2)
def test_deepcopy_inst_vanilla(self):
class C:
@@ -334,7 +371,7 @@ class TestCopy(unittest.TestCase):
x = C([42])
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y.foo, x.foo)
def test_deepcopy_inst_deepcopy(self):
class C:
@@ -347,8 +384,8 @@ class TestCopy(unittest.TestCase):
x = C([42])
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y, x)
+ self.assertIsNot(y.foo, x.foo)
def test_deepcopy_inst_getinitargs(self):
class C:
@@ -361,8 +398,8 @@ class TestCopy(unittest.TestCase):
x = C([42])
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y, x)
+ self.assertIsNot(y.foo, x.foo)
def test_deepcopy_inst_getstate(self):
class C:
@@ -375,8 +412,8 @@ class TestCopy(unittest.TestCase):
x = C([42])
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y, x)
+ self.assertIsNot(y.foo, x.foo)
def test_deepcopy_inst_setstate(self):
class C:
@@ -389,8 +426,8 @@ class TestCopy(unittest.TestCase):
x = C([42])
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y, x)
+ self.assertIsNot(y.foo, x.foo)
def test_deepcopy_inst_getstate_setstate(self):
class C:
@@ -405,8 +442,8 @@ class TestCopy(unittest.TestCase):
x = C([42])
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y, x)
+ self.assertIsNot(y.foo, x.foo)
def test_deepcopy_reflexive_inst(self):
class C:
@@ -414,8 +451,8 @@ class TestCopy(unittest.TestCase):
x = C()
x.foo = x
y = copy.deepcopy(x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is y)
+ self.assertIsNot(y, x)
+ self.assertIs(y.foo, y)
# _reconstruct()
@@ -425,9 +462,9 @@ class TestCopy(unittest.TestCase):
return ""
x = C()
y = copy.copy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
y = copy.deepcopy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
def test_reconstruct_nostate(self):
class C(object):
@@ -436,9 +473,9 @@ class TestCopy(unittest.TestCase):
x = C()
x.foo = 42
y = copy.copy(x)
- self.assertTrue(y.__class__ is x.__class__)
+ self.assertIs(y.__class__, x.__class__)
y = copy.deepcopy(x)
- self.assertTrue(y.__class__ is x.__class__)
+ self.assertIs(y.__class__, x.__class__)
def test_reconstruct_state(self):
class C(object):
@@ -452,7 +489,7 @@ class TestCopy(unittest.TestCase):
self.assertEqual(y, x)
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y.foo, x.foo)
def test_reconstruct_state_setstate(self):
class C(object):
@@ -468,7 +505,7 @@ class TestCopy(unittest.TestCase):
self.assertEqual(y, x)
y = copy.deepcopy(x)
self.assertEqual(y, x)
- self.assertTrue(y.foo is not x.foo)
+ self.assertIsNot(y.foo, x.foo)
def test_reconstruct_reflexive(self):
class C(object):
@@ -476,8 +513,8 @@ class TestCopy(unittest.TestCase):
x = C()
x.foo = x
y = copy.deepcopy(x)
- self.assertTrue(y is not x)
- self.assertTrue(y.foo is y)
+ self.assertIsNot(y, x)
+ self.assertIs(y.foo, y)
# Additions for Python 2.3 and pickle protocol 2
@@ -491,12 +528,12 @@ class TestCopy(unittest.TestCase):
x = C([[1, 2], 3])
y = copy.copy(x)
self.assertEqual(x, y)
- self.assertTrue(x is not y)
- self.assertTrue(x[0] is y[0])
+ self.assertIsNot(x, y)
+ self.assertIs(x[0], y[0])
y = copy.deepcopy(x)
self.assertEqual(x, y)
- self.assertTrue(x is not y)
- self.assertTrue(x[0] is not y[0])
+ self.assertIsNot(x, y)
+ self.assertIsNot(x[0], y[0])
def test_reduce_5tuple(self):
class C(dict):
@@ -508,12 +545,12 @@ class TestCopy(unittest.TestCase):
x = C([("foo", [1, 2]), ("bar", 3)])
y = copy.copy(x)
self.assertEqual(x, y)
- self.assertTrue(x is not y)
- self.assertTrue(x["foo"] is y["foo"])
+ self.assertIsNot(x, y)
+ self.assertIs(x["foo"], y["foo"])
y = copy.deepcopy(x)
self.assertEqual(x, y)
- self.assertTrue(x is not y)
- self.assertTrue(x["foo"] is not y["foo"])
+ self.assertIsNot(x, y)
+ self.assertIsNot(x["foo"], y["foo"])
def test_copy_slots(self):
class C(object):
@@ -521,7 +558,7 @@ class TestCopy(unittest.TestCase):
x = C()
x.foo = [42]
y = copy.copy(x)
- self.assertTrue(x.foo is y.foo)
+ self.assertIs(x.foo, y.foo)
def test_deepcopy_slots(self):
class C(object):
@@ -530,7 +567,7 @@ class TestCopy(unittest.TestCase):
x.foo = [42]
y = copy.deepcopy(x)
self.assertEqual(x.foo, y.foo)
- self.assertTrue(x.foo is not y.foo)
+ self.assertIsNot(x.foo, y.foo)
def test_deepcopy_dict_subclass(self):
class C(dict):
@@ -547,7 +584,7 @@ class TestCopy(unittest.TestCase):
y = copy.deepcopy(x)
self.assertEqual(x, y)
self.assertEqual(x._keys, y._keys)
- self.assertTrue(x is not y)
+ self.assertIsNot(x, y)
x['bar'] = 1
self.assertNotEqual(x, y)
self.assertNotEqual(x._keys, y._keys)
@@ -560,8 +597,8 @@ class TestCopy(unittest.TestCase):
y = copy.copy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
- self.assertTrue(x[0] is y[0])
- self.assertTrue(x.foo is y.foo)
+ self.assertIs(x[0], y[0])
+ self.assertIs(x.foo, y.foo)
def test_deepcopy_list_subclass(self):
class C(list):
@@ -571,8 +608,8 @@ class TestCopy(unittest.TestCase):
y = copy.deepcopy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
- self.assertTrue(x[0] is not y[0])
- self.assertTrue(x.foo is not y.foo)
+ self.assertIsNot(x[0], y[0])
+ self.assertIsNot(x.foo, y.foo)
def test_copy_tuple_subclass(self):
class C(tuple):
@@ -589,8 +626,8 @@ class TestCopy(unittest.TestCase):
self.assertEqual(tuple(x), ([1, 2], 3))
y = copy.deepcopy(x)
self.assertEqual(tuple(y), ([1, 2], 3))
- self.assertTrue(x is not y)
- self.assertTrue(x[0] is not y[0])
+ self.assertIsNot(x, y)
+ self.assertIsNot(x[0], y[0])
def test_getstate_exc(self):
class EvilState(object):
@@ -618,10 +655,10 @@ class TestCopy(unittest.TestCase):
obj = C()
x = weakref.ref(obj)
y = _copy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
del obj
y = _copy(x)
- self.assertTrue(y is x)
+ self.assertIs(y, x)
def test_copy_weakref(self):
self._check_weakref(copy.copy)
@@ -637,7 +674,7 @@ class TestCopy(unittest.TestCase):
u[a] = b
u[c] = d
v = copy.copy(u)
- self.assertFalse(v is u)
+ self.assertIsNot(v, u)
self.assertEqual(v, u)
self.assertEqual(v[a], b)
self.assertEqual(v[c], d)
@@ -667,8 +704,8 @@ class TestCopy(unittest.TestCase):
v = copy.deepcopy(u)
self.assertNotEqual(v, u)
self.assertEqual(len(v), 2)
- self.assertFalse(v[a] is b)
- self.assertFalse(v[c] is d)
+ self.assertIsNot(v[a], b)
+ self.assertIsNot(v[c], d)
self.assertEqual(v[a].i, b.i)
self.assertEqual(v[c].i, d.i)
del c
@@ -687,12 +724,12 @@ class TestCopy(unittest.TestCase):
self.assertNotEqual(v, u)
self.assertEqual(len(v), 2)
(x, y), (z, t) = sorted(v.items(), key=lambda pair: pair[0].i)
- self.assertFalse(x is a)
+ self.assertIsNot(x, a)
self.assertEqual(x.i, a.i)
- self.assertTrue(y is b)
- self.assertFalse(z is c)
+ self.assertIs(y, b)
+ self.assertIsNot(z, c)
self.assertEqual(z.i, c.i)
- self.assertTrue(t is d)
+ self.assertIs(t, d)
del x, y, z, t
del d
self.assertEqual(len(v), 1)
@@ -705,7 +742,7 @@ class TestCopy(unittest.TestCase):
f.b = f.m
g = copy.deepcopy(f)
self.assertEqual(g.m, g.b)
- self.assertTrue(g.b.__self__ is g)
+ self.assertIs(g.b.__self__, g)
g.b()
diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py
index ae17c2b..5676668 100644
--- a/Lib/test/test_cprofile.py
+++ b/Lib/test/test_cprofile.py
@@ -18,16 +18,19 @@ class CProfileTest(ProfileTest):
def test_bad_counter_during_dealloc(self):
import _lsprof
# Must use a file as StringIO doesn't trigger the bug.
- with open(TESTFN, 'w') as file:
- sys.stderr = file
- try:
- obj = _lsprof.Profiler(lambda: int)
- obj.enable()
- obj = _lsprof.Profiler(1)
- obj.disable()
- finally:
- sys.stderr = sys.__stderr__
- unlink(TESTFN)
+ orig_stderr = sys.stderr
+ try:
+ with open(TESTFN, 'w') as file:
+ sys.stderr = file
+ try:
+ obj = _lsprof.Profiler(lambda: int)
+ obj.enable()
+ obj = _lsprof.Profiler(1)
+ obj.disable()
+ finally:
+ sys.stderr = orig_stderr
+ finally:
+ unlink(TESTFN)
def test_main():
diff --git a/Lib/test/test_crashers.py b/Lib/test/test_crashers.py
new file mode 100644
index 0000000..336ccbe
--- /dev/null
+++ b/Lib/test/test_crashers.py
@@ -0,0 +1,38 @@
+# Tests that the crashers in the Lib/test/crashers directory actually
+# do crash the interpreter as expected
+#
+# If a crasher is fixed, it should be moved elsewhere in the test suite to
+# ensure it continues to work correctly.
+
+import unittest
+import glob
+import os.path
+import test.support
+from test.script_helper import assert_python_failure
+
+CRASHER_DIR = os.path.join(os.path.dirname(__file__), "crashers")
+CRASHER_FILES = os.path.join(CRASHER_DIR, "*.py")
+
+infinite_loops = ["infinite_loop_re.py", "nasty_eq_vs_dict.py"]
+
+class CrasherTest(unittest.TestCase):
+
+ @unittest.skip("these tests are too fragile")
+ @test.support.cpython_only
+ def test_crashers_crash(self):
+ for fname in glob.glob(CRASHER_FILES):
+ if os.path.basename(fname) in infinite_loops:
+ continue
+ # Some "crashers" only trigger an exception rather than a
+ # segfault. Consider that an acceptable outcome.
+ if test.support.verbose:
+ print("Checking crasher:", fname)
+ assert_python_failure(fname)
+
+
+def test_main():
+ test.support.run_unittest(CrasherTest)
+ test.support.reap_children()
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_crypt.py b/Lib/test/test_crypt.py
index 2adb28d..cfb7341 100644
--- a/Lib/test/test_crypt.py
+++ b/Lib/test/test_crypt.py
@@ -1,7 +1,11 @@
from test import support
import unittest
-crypt = support.import_module('crypt')
+def setUpModule():
+ # this import will raise unittest.SkipTest if _crypt doesn't exist,
+ # so it has to be done in setUpModule for test discovery to work
+ global crypt
+ crypt = support.import_module('crypt')
class CryptTestCase(unittest.TestCase):
@@ -10,8 +14,24 @@ class CryptTestCase(unittest.TestCase):
if support.verbose:
print('Test encryption: ', c)
-def test_main():
- support.run_unittest(CryptTestCase)
+ def test_salt(self):
+ self.assertEqual(len(crypt._saltchars), 64)
+ for method in crypt.methods:
+ salt = crypt.mksalt(method)
+ self.assertEqual(len(salt),
+ method.salt_chars + (3 if method.ident else 0))
+
+ def test_saltedcrypt(self):
+ for method in crypt.methods:
+ pw = crypt.crypt('assword', method)
+ self.assertEqual(len(pw), method.total_size)
+ pw = crypt.crypt('assword', crypt.mksalt(method))
+ self.assertEqual(len(pw), method.total_size)
+
+ def test_methods(self):
+ # Gurantee that METHOD_CRYPT is the last method in crypt.methods.
+ self.assertTrue(len(crypt.methods) >= 1)
+ self.assertEqual(crypt.METHOD_CRYPT, crypt.methods[-1])
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 55796a2..96f8aa7 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -197,6 +197,17 @@ class Test_Csv(unittest.TestCase):
fileobj.seek(0)
self.assertEqual(fileobj.read(), "a,b\r\nc,d\r\n")
+ @support.cpython_only
+ def test_writerows_legacy_strings(self):
+ import _testcapi
+
+ c = _testcapi.unicode_legacy_string('a')
+ with TemporaryFile("w+", newline='') as fileobj:
+ writer = csv.writer(fileobj)
+ writer.writerows([[c]])
+ fileobj.seek(0)
+ self.assertEqual(fileobj.read(), "a\r\n")
+
def _read_test(self, input, expect, **kwargs):
reader = csv.reader(input, **kwargs)
result = list(reader)
diff --git a/Lib/test/test_curses.py b/Lib/test/test_curses.py
index 5812147..e959622 100644
--- a/Lib/test/test_curses.py
+++ b/Lib/test/test_curses.py
@@ -264,11 +264,53 @@ def test_issue6243(stdscr):
curses.ungetch(1025)
stdscr.getkey()
+def test_unget_wch(stdscr):
+ if not hasattr(curses, 'unget_wch'):
+ return
+ encoding = stdscr.encoding
+ for ch in ('a', '\xe9', '\u20ac', '\U0010FFFF'):
+ try:
+ ch.encode(encoding)
+ except UnicodeEncodeError:
+ continue
+ try:
+ curses.unget_wch(ch)
+ except Exception as err:
+ raise Exception("unget_wch(%a) failed with encoding %s: %s"
+ % (ch, stdscr.encoding, err))
+ read = stdscr.get_wch()
+ if read != ch:
+ raise AssertionError("%r != %r" % (read, ch))
+
+ code = ord(ch)
+ curses.unget_wch(code)
+ read = stdscr.get_wch()
+ if read != ch:
+ raise AssertionError("%r != %r" % (read, ch))
+
def test_issue10570():
b = curses.tparm(curses.tigetstr("cup"), 5, 3)
assert type(b) is bytes
curses.putp(b)
+def test_encoding(stdscr):
+ import codecs
+ encoding = stdscr.encoding
+ codecs.lookup(encoding)
+ try:
+ stdscr.encoding = 10
+ except TypeError:
+ pass
+ else:
+ raise AssertionError("TypeError not raised")
+ stdscr.encoding = encoding
+ try:
+ del stdscr.encoding
+ except TypeError:
+ pass
+ else:
+ raise AssertionError("TypeError not raised")
+
def main(stdscr):
curses.savetty()
try:
@@ -277,16 +319,18 @@ def main(stdscr):
test_userptr_without_set(stdscr)
test_resize_term(stdscr)
test_issue6243(stdscr)
+ test_unget_wch(stdscr)
test_issue10570()
+ test_encoding(stdscr)
finally:
curses.resetty()
def test_main():
- if not sys.stdout.isatty():
- raise unittest.SkipTest("sys.stdout is not a tty")
+ if not sys.__stdout__.isatty():
+ raise unittest.SkipTest("sys.__stdout__ is not a tty")
# testing setupterm() inside initscr/endwin
# causes terminal breakage
- curses.setupterm(fd=sys.stdout.fileno())
+ curses.setupterm(fd=sys.__stdout__.fileno())
try:
stdscr = curses.initscr()
main(stdscr)
diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py
index 26d4c14..02df7e3 100644
--- a/Lib/test/test_dbm.py
+++ b/Lib/test/test_dbm.py
@@ -71,8 +71,8 @@ class AnyDBMTestCase(unittest.TestCase):
f.close()
def test_anydbm_creation_n_file_exists_with_invalid_contents(self):
- with open(_fname, "w") as w:
- pass # create an empty file
+ # create an empty file
+ test.support.create_empty_file(_fname)
f = dbm.open(_fname, 'n')
self.addCleanup(f.close)
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 30d3971..69a2faf 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -16,7 +16,7 @@ test the pythonic behaviour according to PEP 327.
Cowlishaw's tests can be downloaded from:
- www2.hursley.ibm.com/decimal/dectest.zip
+ http://speleotrove.com/decimal/dectest.zip
This test module can be called from command line with one parameter (Arithmetic
or Behaviour) to test each part, or without parameter to test both parts. If
@@ -30,37 +30,81 @@ import operator
import warnings
import pickle, copy
import unittest
-from decimal import *
import numbers
+import locale
from test.support import (run_unittest, run_doctest, is_resource_enabled,
requires_IEEE_754)
-from test.support import check_warnings
+from test.support import (check_warnings, import_fresh_module, TestFailed,
+ run_with_locale, cpython_only)
import random
+import time
+import warnings
try:
import threading
except ImportError:
threading = None
-# Useful Test Constant
-Signals = tuple(getcontext().flags.keys())
+C = import_fresh_module('decimal', fresh=['_decimal'])
+P = import_fresh_module('decimal', blocked=['_decimal'])
+orig_sys_decimal = sys.modules['decimal']
+
+# fractions module must import the correct decimal module.
+cfractions = import_fresh_module('fractions', fresh=['fractions'])
+sys.modules['decimal'] = P
+pfractions = import_fresh_module('fractions', fresh=['fractions'])
+sys.modules['decimal'] = C
+fractions = {C:cfractions, P:pfractions}
+sys.modules['decimal'] = orig_sys_decimal
+
+
+# Useful Test Constant
+Signals = {
+ C: tuple(C.getcontext().flags.keys()) if C else None,
+ P: tuple(P.getcontext().flags.keys())
+}
# Signals ordered with respect to precedence: when an operation
# produces multiple signals, signals occurring later in the list
# should be handled before those occurring earlier in the list.
-OrderedSignals = (Clamped, Rounded, Inexact, Subnormal,
- Underflow, Overflow, DivisionByZero, InvalidOperation)
+OrderedSignals = {
+ C: [C.Clamped, C.Rounded, C.Inexact, C.Subnormal, C.Underflow,
+ C.Overflow, C.DivisionByZero, C.InvalidOperation,
+ C.FloatOperation] if C else None,
+ P: [P.Clamped, P.Rounded, P.Inexact, P.Subnormal, P.Underflow,
+ P.Overflow, P.DivisionByZero, P.InvalidOperation,
+ P.FloatOperation]
+}
+def assert_signals(cls, context, attr, expected):
+ d = getattr(context, attr)
+ cls.assertTrue(all(d[s] if s in expected else not d[s] for s in d))
+
+ROUND_UP = P.ROUND_UP
+ROUND_DOWN = P.ROUND_DOWN
+ROUND_CEILING = P.ROUND_CEILING
+ROUND_FLOOR = P.ROUND_FLOOR
+ROUND_HALF_UP = P.ROUND_HALF_UP
+ROUND_HALF_DOWN = P.ROUND_HALF_DOWN
+ROUND_HALF_EVEN = P.ROUND_HALF_EVEN
+ROUND_05UP = P.ROUND_05UP
+
+RoundingModes = [
+ ROUND_UP, ROUND_DOWN, ROUND_CEILING, ROUND_FLOOR,
+ ROUND_HALF_UP, ROUND_HALF_DOWN, ROUND_HALF_EVEN,
+ ROUND_05UP
+]
# Tests are built around these assumed context defaults.
# test_main() restores the original context.
-def init():
- global ORIGINAL_CONTEXT
- ORIGINAL_CONTEXT = getcontext().copy()
- DefaultTestContext = Context(
- prec = 9,
- rounding = ROUND_HALF_EVEN,
- traps = dict.fromkeys(Signals, 0)
- )
- setcontext(DefaultTestContext)
+ORIGINAL_CONTEXT = {
+ C: C.getcontext().copy() if C else None,
+ P: P.getcontext().copy()
+}
+def init(m):
+ if not m: return
+ DefaultTestContext = m.Context(
+ prec=9, rounding=ROUND_HALF_EVEN, traps=dict.fromkeys(Signals[m], 0)
+ )
+ m.setcontext(DefaultTestContext)
TESTDATADIR = 'decimaltestdata'
if __name__ == '__main__':
@@ -72,149 +116,175 @@ directory = testdir + os.sep + TESTDATADIR + os.sep
skip_expected = not os.path.isdir(directory)
-# list of individual .decTest test ids that correspond to tests that
-# we're skipping for one reason or another.
-skipped_test_ids = set([
- # Skip implementation-specific scaleb tests.
- 'scbx164',
- 'scbx165',
-
- # For some operations (currently exp, ln, log10, power), the decNumber
- # reference implementation imposes additional restrictions on the context
- # and operands. These restrictions are not part of the specification;
- # however, the effect of these restrictions does show up in some of the
- # testcases. We skip testcases that violate these restrictions, since
- # Decimal behaves differently from decNumber for these testcases so these
- # testcases would otherwise fail.
- 'expx901',
- 'expx902',
- 'expx903',
- 'expx905',
- 'lnx901',
- 'lnx902',
- 'lnx903',
- 'lnx905',
- 'logx901',
- 'logx902',
- 'logx903',
- 'logx905',
- 'powx1183',
- 'powx1184',
- 'powx4001',
- 'powx4002',
- 'powx4003',
- 'powx4005',
- 'powx4008',
- 'powx4010',
- 'powx4012',
- 'powx4014',
- ])
-
# Make sure it actually raises errors when not expected and caught in flags
# Slower, since it runs some things several times.
EXTENDEDERRORTEST = False
-#Map the test cases' error names to the actual errors
-ErrorNames = {'clamped' : Clamped,
- 'conversion_syntax' : InvalidOperation,
- 'division_by_zero' : DivisionByZero,
- 'division_impossible' : InvalidOperation,
- 'division_undefined' : InvalidOperation,
- 'inexact' : Inexact,
- 'invalid_context' : InvalidOperation,
- 'invalid_operation' : InvalidOperation,
- 'overflow' : Overflow,
- 'rounded' : Rounded,
- 'subnormal' : Subnormal,
- 'underflow' : Underflow}
-
-
-def Nonfunction(*args):
- """Doesn't do anything."""
- return None
-
-RoundingDict = {'ceiling' : ROUND_CEILING, #Maps test-case names to roundings.
- 'down' : ROUND_DOWN,
- 'floor' : ROUND_FLOOR,
- 'half_down' : ROUND_HALF_DOWN,
- 'half_even' : ROUND_HALF_EVEN,
- 'half_up' : ROUND_HALF_UP,
- 'up' : ROUND_UP,
- '05up' : ROUND_05UP}
-
-# Name adapter to be able to change the Decimal and Context
-# interface without changing the test files from Cowlishaw
-nameAdapter = {'and':'logical_and',
- 'apply':'_apply',
- 'class':'number_class',
- 'comparesig':'compare_signal',
- 'comparetotal':'compare_total',
- 'comparetotmag':'compare_total_mag',
- 'copy':'copy_decimal',
- 'copyabs':'copy_abs',
- 'copynegate':'copy_negate',
- 'copysign':'copy_sign',
- 'divideint':'divide_int',
- 'invert':'logical_invert',
- 'iscanonical':'is_canonical',
- 'isfinite':'is_finite',
- 'isinfinite':'is_infinite',
- 'isnan':'is_nan',
- 'isnormal':'is_normal',
- 'isqnan':'is_qnan',
- 'issigned':'is_signed',
- 'issnan':'is_snan',
- 'issubnormal':'is_subnormal',
- 'iszero':'is_zero',
- 'maxmag':'max_mag',
- 'minmag':'min_mag',
- 'nextminus':'next_minus',
- 'nextplus':'next_plus',
- 'nexttoward':'next_toward',
- 'or':'logical_or',
- 'reduce':'normalize',
- 'remaindernear':'remainder_near',
- 'samequantum':'same_quantum',
- 'squareroot':'sqrt',
- 'toeng':'to_eng_string',
- 'tointegral':'to_integral_value',
- 'tointegralx':'to_integral_exact',
- 'tosci':'to_sci_string',
- 'xor':'logical_xor',
- }
-
-# The following functions return True/False rather than a Decimal instance
-
-LOGICAL_FUNCTIONS = (
- 'is_canonical',
- 'is_finite',
- 'is_infinite',
- 'is_nan',
- 'is_normal',
- 'is_qnan',
- 'is_signed',
- 'is_snan',
- 'is_subnormal',
- 'is_zero',
- 'same_quantum',
- )
+# Test extra functionality in the C version (-DEXTRA_FUNCTIONALITY).
+EXTRA_FUNCTIONALITY = True if hasattr(C, 'DecClamped') else False
+requires_extra_functionality = unittest.skipUnless(
+ EXTRA_FUNCTIONALITY, "test requires build with -DEXTRA_FUNCTIONALITY")
+skip_if_extra_functionality = unittest.skipIf(
+ EXTRA_FUNCTIONALITY, "test requires regular build")
-class DecimalTest(unittest.TestCase):
- """Class which tests the Decimal class against the test cases.
- Changed for unittest.
- """
+class IBMTestCases(unittest.TestCase):
+ """Class which tests the Decimal class against the IBM test cases."""
+
def setUp(self):
- self.context = Context()
+ self.context = self.decimal.Context()
+ self.readcontext = self.decimal.Context()
self.ignore_list = ['#']
- # Basically, a # means return NaN InvalidOperation.
- # Different from a sNaN in trim
+ # List of individual .decTest test ids that correspond to tests that
+ # we're skipping for one reason or another.
+ self.skipped_test_ids = set([
+ # Skip implementation-specific scaleb tests.
+ 'scbx164',
+ 'scbx165',
+
+ # For some operations (currently exp, ln, log10, power), the decNumber
+ # reference implementation imposes additional restrictions on the context
+ # and operands. These restrictions are not part of the specification;
+ # however, the effect of these restrictions does show up in some of the
+ # testcases. We skip testcases that violate these restrictions, since
+ # Decimal behaves differently from decNumber for these testcases so these
+ # testcases would otherwise fail.
+ 'expx901',
+ 'expx902',
+ 'expx903',
+ 'expx905',
+ 'lnx901',
+ 'lnx902',
+ 'lnx903',
+ 'lnx905',
+ 'logx901',
+ 'logx902',
+ 'logx903',
+ 'logx905',
+ 'powx1183',
+ 'powx1184',
+ 'powx4001',
+ 'powx4002',
+ 'powx4003',
+ 'powx4005',
+ 'powx4008',
+ 'powx4010',
+ 'powx4012',
+ 'powx4014',
+ ])
+
+ if self.decimal == C:
+ # status has additional Subnormal, Underflow
+ self.skipped_test_ids.add('pwsx803')
+ self.skipped_test_ids.add('pwsx805')
+ # Correct rounding (skipped for decNumber, too)
+ self.skipped_test_ids.add('powx4302')
+ self.skipped_test_ids.add('powx4303')
+ self.skipped_test_ids.add('powx4342')
+ self.skipped_test_ids.add('powx4343')
+ # http://bugs.python.org/issue7049
+ self.skipped_test_ids.add('pwmx325')
+ self.skipped_test_ids.add('pwmx326')
+
+ # Map test directives to setter functions.
self.ChangeDict = {'precision' : self.change_precision,
- 'rounding' : self.change_rounding_method,
- 'maxexponent' : self.change_max_exponent,
- 'minexponent' : self.change_min_exponent,
- 'clamp' : self.change_clamp}
+ 'rounding' : self.change_rounding_method,
+ 'maxexponent' : self.change_max_exponent,
+ 'minexponent' : self.change_min_exponent,
+ 'clamp' : self.change_clamp}
+
+ # Name adapter to be able to change the Decimal and Context
+ # interface without changing the test files from Cowlishaw.
+ self.NameAdapter = {'and':'logical_and',
+ 'apply':'_apply',
+ 'class':'number_class',
+ 'comparesig':'compare_signal',
+ 'comparetotal':'compare_total',
+ 'comparetotmag':'compare_total_mag',
+ 'copy':'copy_decimal',
+ 'copyabs':'copy_abs',
+ 'copynegate':'copy_negate',
+ 'copysign':'copy_sign',
+ 'divideint':'divide_int',
+ 'invert':'logical_invert',
+ 'iscanonical':'is_canonical',
+ 'isfinite':'is_finite',
+ 'isinfinite':'is_infinite',
+ 'isnan':'is_nan',
+ 'isnormal':'is_normal',
+ 'isqnan':'is_qnan',
+ 'issigned':'is_signed',
+ 'issnan':'is_snan',
+ 'issubnormal':'is_subnormal',
+ 'iszero':'is_zero',
+ 'maxmag':'max_mag',
+ 'minmag':'min_mag',
+ 'nextminus':'next_minus',
+ 'nextplus':'next_plus',
+ 'nexttoward':'next_toward',
+ 'or':'logical_or',
+ 'reduce':'normalize',
+ 'remaindernear':'remainder_near',
+ 'samequantum':'same_quantum',
+ 'squareroot':'sqrt',
+ 'toeng':'to_eng_string',
+ 'tointegral':'to_integral_value',
+ 'tointegralx':'to_integral_exact',
+ 'tosci':'to_sci_string',
+ 'xor':'logical_xor'}
+
+ # Map test-case names to roundings.
+ self.RoundingDict = {'ceiling' : ROUND_CEILING,
+ 'down' : ROUND_DOWN,
+ 'floor' : ROUND_FLOOR,
+ 'half_down' : ROUND_HALF_DOWN,
+ 'half_even' : ROUND_HALF_EVEN,
+ 'half_up' : ROUND_HALF_UP,
+ 'up' : ROUND_UP,
+ '05up' : ROUND_05UP}
+
+ # Map the test cases' error names to the actual errors.
+ self.ErrorNames = {'clamped' : self.decimal.Clamped,
+ 'conversion_syntax' : self.decimal.InvalidOperation,
+ 'division_by_zero' : self.decimal.DivisionByZero,
+ 'division_impossible' : self.decimal.InvalidOperation,
+ 'division_undefined' : self.decimal.InvalidOperation,
+ 'inexact' : self.decimal.Inexact,
+ 'invalid_context' : self.decimal.InvalidOperation,
+ 'invalid_operation' : self.decimal.InvalidOperation,
+ 'overflow' : self.decimal.Overflow,
+ 'rounded' : self.decimal.Rounded,
+ 'subnormal' : self.decimal.Subnormal,
+ 'underflow' : self.decimal.Underflow}
+
+ # The following functions return True/False rather than a
+ # Decimal instance.
+ self.LogicalFunctions = ('is_canonical',
+ 'is_finite',
+ 'is_infinite',
+ 'is_nan',
+ 'is_normal',
+ 'is_qnan',
+ 'is_signed',
+ 'is_snan',
+ 'is_subnormal',
+ 'is_zero',
+ 'same_quantum')
+
+ def read_unlimited(self, v, context):
+ """Work around the limitations of the 32-bit _decimal version. The
+ guaranteed maximum values for prec, Emax etc. are 425000000,
+ but higher values usually work, except for rare corner cases.
+ In particular, all of the IBM tests pass with maximum values
+ of 1070000000."""
+ if self.decimal == C and self.decimal.MAX_EMAX == 425000000:
+ self.readcontext._unsafe_setprec(1070000000)
+ self.readcontext._unsafe_setemax(1070000000)
+ self.readcontext._unsafe_setemin(-1070000000)
+ return self.readcontext.create_decimal(v)
+ else:
+ return self.decimal.Decimal(v, context)
def eval_file(self, file):
global skip_expected
@@ -227,7 +297,7 @@ class DecimalTest(unittest.TestCase):
#print line
try:
t = self.eval_line(line)
- except DecimalException as exception:
+ except self.decimal.DecimalException as exception:
#Exception raised where there shouldn't have been one.
self.fail('Exception "'+exception.__class__.__name__ + '" raised on line '+line)
@@ -254,23 +324,23 @@ class DecimalTest(unittest.TestCase):
def eval_directive(self, s):
funct, value = (x.strip().lower() for x in s.split(':'))
if funct == 'rounding':
- value = RoundingDict[value]
+ value = self.RoundingDict[value]
else:
try:
value = int(value)
except ValueError:
pass
- funct = self.ChangeDict.get(funct, Nonfunction)
+ funct = self.ChangeDict.get(funct, (lambda *args: None))
funct(value)
def eval_equation(self, s):
- #global DEFAULT_PRECISION
- #print DEFAULT_PRECISION
if not TEST_ALL and random.random() < 0.90:
return
+ self.context.clear_flags()
+
try:
Sides = s.split('->')
L = Sides[0].strip().split()
@@ -283,26 +353,26 @@ class DecimalTest(unittest.TestCase):
ans = L[0]
exceptions = L[1:]
except (TypeError, AttributeError, IndexError):
- raise InvalidOperation
+ raise self.decimal.InvalidOperation
def FixQuotes(val):
val = val.replace("''", 'SingleQuote').replace('""', 'DoubleQuote')
val = val.replace("'", '').replace('"', '')
val = val.replace('SingleQuote', "'").replace('DoubleQuote', '"')
return val
- if id in skipped_test_ids:
+ if id in self.skipped_test_ids:
return
- fname = nameAdapter.get(funct, funct)
+ fname = self.NameAdapter.get(funct, funct)
if fname == 'rescale':
return
funct = getattr(self.context, fname)
vals = []
conglomerate = ''
quote = 0
- theirexceptions = [ErrorNames[x.lower()] for x in exceptions]
+ theirexceptions = [self.ErrorNames[x.lower()] for x in exceptions]
- for exception in Signals:
+ for exception in Signals[self.decimal]:
self.context.traps[exception] = 1 #Catch these bugs...
for exception in theirexceptions:
self.context.traps[exception] = 0
@@ -324,7 +394,7 @@ class DecimalTest(unittest.TestCase):
funct(self.context.create_decimal(v))
except error:
pass
- except Signals as e:
+ except Signals[self.decimal] as e:
self.fail("Raised %s in %s when %s disabled" % \
(e, s, error))
else:
@@ -332,7 +402,7 @@ class DecimalTest(unittest.TestCase):
self.context.traps[error] = 0
v = self.context.create_decimal(v)
else:
- v = Decimal(v, self.context)
+ v = self.read_unlimited(v, self.context)
vals.append(v)
ans = FixQuotes(ans)
@@ -344,7 +414,7 @@ class DecimalTest(unittest.TestCase):
funct(*vals)
except error:
pass
- except Signals as e:
+ except Signals[self.decimal] as e:
self.fail("Raised %s in %s when %s disabled" % \
(e, s, error))
else:
@@ -352,14 +422,14 @@ class DecimalTest(unittest.TestCase):
self.context.traps[error] = 0
# as above, but add traps cumulatively, to check precedence
- ordered_errors = [e for e in OrderedSignals if e in theirexceptions]
+ ordered_errors = [e for e in OrderedSignals[self.decimal] if e in theirexceptions]
for error in ordered_errors:
self.context.traps[error] = 1
try:
funct(*vals)
except error:
pass
- except Signals as e:
+ except Signals[self.decimal] as e:
self.fail("Raised %s in %s; expected %s" %
(type(e), s, error))
else:
@@ -373,54 +443,69 @@ class DecimalTest(unittest.TestCase):
print("--", self.context)
try:
result = str(funct(*vals))
- if fname in LOGICAL_FUNCTIONS:
+ if fname in self.LogicalFunctions:
result = str(int(eval(result))) # 'True', 'False' -> '1', '0'
- except Signals as error:
+ except Signals[self.decimal] as error:
self.fail("Raised %s in %s" % (error, s))
except: #Catch any error long enough to state the test case.
print("ERROR:", s)
raise
myexceptions = self.getexceptions()
- self.context.clear_flags()
myexceptions.sort(key=repr)
theirexceptions.sort(key=repr)
self.assertEqual(result, ans,
'Incorrect answer for ' + s + ' -- got ' + result)
+
self.assertEqual(myexceptions, theirexceptions,
'Incorrect flags set in ' + s + ' -- got ' + str(myexceptions))
return
def getexceptions(self):
- return [e for e in Signals if self.context.flags[e]]
+ return [e for e in Signals[self.decimal] if self.context.flags[e]]
def change_precision(self, prec):
- self.context.prec = prec
+ if self.decimal == C and self.decimal.MAX_PREC == 425000000:
+ self.context._unsafe_setprec(prec)
+ else:
+ self.context.prec = prec
def change_rounding_method(self, rounding):
self.context.rounding = rounding
def change_min_exponent(self, exp):
- self.context.Emin = exp
+ if self.decimal == C and self.decimal.MAX_PREC == 425000000:
+ self.context._unsafe_setemin(exp)
+ else:
+ self.context.Emin = exp
def change_max_exponent(self, exp):
- self.context.Emax = exp
+ if self.decimal == C and self.decimal.MAX_PREC == 425000000:
+ self.context._unsafe_setemax(exp)
+ else:
+ self.context.Emax = exp
def change_clamp(self, clamp):
self.context.clamp = clamp
-
+class CIBMTestCases(IBMTestCases):
+ decimal = C
+class PyIBMTestCases(IBMTestCases):
+ decimal = P
# The following classes test the behaviour of Decimal according to PEP 327
-class DecimalExplicitConstructionTest(unittest.TestCase):
+class ExplicitConstructionTest(unittest.TestCase):
'''Unit tests for Explicit Construction cases of Decimal.'''
def test_explicit_empty(self):
+ Decimal = self.decimal.Decimal
self.assertEqual(Decimal(), Decimal("0"))
def test_explicit_from_None(self):
+ Decimal = self.decimal.Decimal
self.assertRaises(TypeError, Decimal, None)
def test_explicit_from_int(self):
+ Decimal = self.decimal.Decimal
#positive
d = Decimal(45)
@@ -438,7 +523,18 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
d = Decimal(0)
self.assertEqual(str(d), '0')
+ # single word longs
+ for n in range(0, 32):
+ for sign in (-1, 1):
+ for x in range(-5, 5):
+ i = sign * (2**n + x)
+ d = Decimal(i)
+ self.assertEqual(str(d), str(i))
+
def test_explicit_from_string(self):
+ Decimal = self.decimal.Decimal
+ InvalidOperation = self.decimal.InvalidOperation
+ localcontext = self.decimal.localcontext
#empty
self.assertEqual(str(Decimal('')), 'NaN')
@@ -458,8 +554,44 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
#leading and trailing whitespace permitted
self.assertEqual(str(Decimal('1.3E4 \n')), '1.3E+4')
self.assertEqual(str(Decimal(' -7.89')), '-7.89')
+ self.assertEqual(str(Decimal(" 3.45679 ")), '3.45679')
+
+ # unicode whitespace
+ for lead in ["", ' ', '\u00a0', '\u205f']:
+ for trail in ["", ' ', '\u00a0', '\u205f']:
+ self.assertEqual(str(Decimal(lead + '9.311E+28' + trail)),
+ '9.311E+28')
+
+ with localcontext() as c:
+ c.traps[InvalidOperation] = True
+ # Invalid string
+ self.assertRaises(InvalidOperation, Decimal, "xyz")
+ # Two arguments max
+ self.assertRaises(TypeError, Decimal, "1234", "x", "y")
+
+ # space within the numeric part
+ self.assertRaises(InvalidOperation, Decimal, "1\u00a02\u00a03")
+ self.assertRaises(InvalidOperation, Decimal, "\u00a01\u00a02\u00a0")
+
+ # unicode whitespace
+ self.assertRaises(InvalidOperation, Decimal, "\u00a0")
+ self.assertRaises(InvalidOperation, Decimal, "\u00a0\u00a0")
+
+ # embedded NUL
+ self.assertRaises(InvalidOperation, Decimal, "12\u00003")
+
+ @cpython_only
+ def test_from_legacy_strings(self):
+ import _testcapi
+ Decimal = self.decimal.Decimal
+ context = self.decimal.Context()
+
+ s = _testcapi.unicode_legacy_string('9.999999')
+ self.assertEqual(str(Decimal(s)), '9.999999')
+ self.assertEqual(str(context.create_decimal(s)), '9.999999')
def test_explicit_from_tuples(self):
+ Decimal = self.decimal.Decimal
#zero
d = Decimal( (0, (0,), 0) )
@@ -477,6 +609,10 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
d = Decimal( (1, (4, 3, 4, 9, 1, 3, 5, 3, 4), -25) )
self.assertEqual(str(d), '-4.34913534E-17')
+ #inf
+ d = Decimal( (0, (), "F") )
+ self.assertEqual(str(d), 'Infinity')
+
#wrong number of items
self.assertRaises(ValueError, Decimal, (1, (4, 3, 4, 9, 1)) )
@@ -491,45 +627,63 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
self.assertRaises(ValueError, Decimal, (1, (4, 3, 4, 9, 1), '1') )
#bad coefficients
+ self.assertRaises(ValueError, Decimal, (1, "xyz", 2) )
self.assertRaises(ValueError, Decimal, (1, (4, 3, 4, None, 1), 2) )
self.assertRaises(ValueError, Decimal, (1, (4, -3, 4, 9, 1), 2) )
self.assertRaises(ValueError, Decimal, (1, (4, 10, 4, 9, 1), 2) )
self.assertRaises(ValueError, Decimal, (1, (4, 3, 4, 'a', 1), 2) )
+ def test_explicit_from_list(self):
+ Decimal = self.decimal.Decimal
+
+ d = Decimal([0, [0], 0])
+ self.assertEqual(str(d), '0')
+
+ d = Decimal([1, [4, 3, 4, 9, 1, 3, 5, 3, 4], -25])
+ self.assertEqual(str(d), '-4.34913534E-17')
+
+ d = Decimal([1, (4, 3, 4, 9, 1, 3, 5, 3, 4), -25])
+ self.assertEqual(str(d), '-4.34913534E-17')
+
+ d = Decimal((1, [4, 3, 4, 9, 1, 3, 5, 3, 4], -25))
+ self.assertEqual(str(d), '-4.34913534E-17')
+
def test_explicit_from_bool(self):
+ Decimal = self.decimal.Decimal
+
self.assertIs(bool(Decimal(0)), False)
self.assertIs(bool(Decimal(1)), True)
self.assertEqual(Decimal(False), Decimal(0))
self.assertEqual(Decimal(True), Decimal(1))
def test_explicit_from_Decimal(self):
+ Decimal = self.decimal.Decimal
#positive
d = Decimal(45)
e = Decimal(d)
self.assertEqual(str(e), '45')
- self.assertNotEqual(id(d), id(e))
#very large positive
d = Decimal(500000123)
e = Decimal(d)
self.assertEqual(str(e), '500000123')
- self.assertNotEqual(id(d), id(e))
#negative
d = Decimal(-45)
e = Decimal(d)
self.assertEqual(str(e), '-45')
- self.assertNotEqual(id(d), id(e))
#zero
d = Decimal(0)
e = Decimal(d)
self.assertEqual(str(e), '0')
- self.assertNotEqual(id(d), id(e))
@requires_IEEE_754
def test_explicit_from_float(self):
+
+ Decimal = self.decimal.Decimal
+
r = Decimal(0.1)
self.assertEqual(type(r), Decimal)
self.assertEqual(str(r),
@@ -550,8 +704,11 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
self.assertEqual(x, float(Decimal(x))) # roundtrip
def test_explicit_context_create_decimal(self):
+ Decimal = self.decimal.Decimal
+ InvalidOperation = self.decimal.InvalidOperation
+ Rounded = self.decimal.Rounded
- nc = copy.copy(getcontext())
+ nc = copy.copy(self.decimal.getcontext())
nc.prec = 3
# empty
@@ -592,7 +749,73 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
d = nc.create_decimal(prevdec)
self.assertEqual(str(d), '5.00E+8')
+ # more integers
+ nc.prec = 28
+ nc.traps[InvalidOperation] = True
+
+ for v in [-2**63-1, -2**63, -2**31-1, -2**31, 0,
+ 2**31-1, 2**31, 2**63-1, 2**63]:
+ d = nc.create_decimal(v)
+ self.assertTrue(isinstance(d, Decimal))
+ self.assertEqual(int(d), v)
+
+ nc.prec = 3
+ nc.traps[Rounded] = True
+ self.assertRaises(Rounded, nc.create_decimal, 1234)
+
+ # from string
+ nc.prec = 28
+ self.assertEqual(str(nc.create_decimal('0E-017')), '0E-17')
+ self.assertEqual(str(nc.create_decimal('45')), '45')
+ self.assertEqual(str(nc.create_decimal('-Inf')), '-Infinity')
+ self.assertEqual(str(nc.create_decimal('NaN123')), 'NaN123')
+
+ # invalid arguments
+ self.assertRaises(InvalidOperation, nc.create_decimal, "xyz")
+ self.assertRaises(ValueError, nc.create_decimal, (1, "xyz", -25))
+ self.assertRaises(TypeError, nc.create_decimal, "1234", "5678")
+
+ # too many NaN payload digits
+ nc.prec = 3
+ self.assertRaises(InvalidOperation, nc.create_decimal, 'NaN12345')
+ self.assertRaises(InvalidOperation, nc.create_decimal,
+ Decimal('NaN12345'))
+
+ nc.traps[InvalidOperation] = False
+ self.assertEqual(str(nc.create_decimal('NaN12345')), 'NaN')
+ self.assertTrue(nc.flags[InvalidOperation])
+
+ nc.flags[InvalidOperation] = False
+ self.assertEqual(str(nc.create_decimal(Decimal('NaN12345'))), 'NaN')
+ self.assertTrue(nc.flags[InvalidOperation])
+
+ def test_explicit_context_create_from_float(self):
+
+ Decimal = self.decimal.Decimal
+
+ nc = self.decimal.Context()
+ r = nc.create_decimal(0.1)
+ self.assertEqual(type(r), Decimal)
+ self.assertEqual(str(r), '0.1000000000000000055511151231')
+ self.assertTrue(nc.create_decimal(float('nan')).is_qnan())
+ self.assertTrue(nc.create_decimal(float('inf')).is_infinite())
+ self.assertTrue(nc.create_decimal(float('-inf')).is_infinite())
+ self.assertEqual(str(nc.create_decimal(float('nan'))),
+ str(nc.create_decimal('NaN')))
+ self.assertEqual(str(nc.create_decimal(float('inf'))),
+ str(nc.create_decimal('Infinity')))
+ self.assertEqual(str(nc.create_decimal(float('-inf'))),
+ str(nc.create_decimal('-Infinity')))
+ self.assertEqual(str(nc.create_decimal(float('-0.0'))),
+ str(nc.create_decimal('-0')))
+ nc.prec = 100
+ for i in range(200):
+ x = random.expovariate(0.01) * (random.random() * 2.0 - 1.0)
+ self.assertEqual(x, float(nc.create_decimal(x))) # roundtrip
+
def test_unicode_digits(self):
+ Decimal = self.decimal.Decimal
+
test_values = {
'\uff11': '1',
'\u0660.\u0660\u0663\u0667\u0662e-\u0663' : '0.0000372',
@@ -601,29 +824,41 @@ class DecimalExplicitConstructionTest(unittest.TestCase):
for input, expected in test_values.items():
self.assertEqual(str(Decimal(input)), expected)
+class CExplicitConstructionTest(ExplicitConstructionTest):
+ decimal = C
+class PyExplicitConstructionTest(ExplicitConstructionTest):
+ decimal = P
-class DecimalImplicitConstructionTest(unittest.TestCase):
+class ImplicitConstructionTest(unittest.TestCase):
'''Unit tests for Implicit Construction cases of Decimal.'''
def test_implicit_from_None(self):
- self.assertRaises(TypeError, eval, 'Decimal(5) + None', globals())
+ Decimal = self.decimal.Decimal
+ self.assertRaises(TypeError, eval, 'Decimal(5) + None', locals())
def test_implicit_from_int(self):
+ Decimal = self.decimal.Decimal
+
#normal
self.assertEqual(str(Decimal(5) + 45), '50')
#exceeding precision
self.assertEqual(Decimal(5) + 123456789000, Decimal(123456789000))
def test_implicit_from_string(self):
- self.assertRaises(TypeError, eval, 'Decimal(5) + "3"', globals())
+ Decimal = self.decimal.Decimal
+ self.assertRaises(TypeError, eval, 'Decimal(5) + "3"', locals())
def test_implicit_from_float(self):
- self.assertRaises(TypeError, eval, 'Decimal(5) + 2.2', globals())
+ Decimal = self.decimal.Decimal
+ self.assertRaises(TypeError, eval, 'Decimal(5) + 2.2', locals())
def test_implicit_from_Decimal(self):
+ Decimal = self.decimal.Decimal
self.assertEqual(Decimal(5) + Decimal(45), Decimal(50))
def test_rop(self):
+ Decimal = self.decimal.Decimal
+
# Allow other classes to be trained to interact with Decimals
class E:
def __divmod__(self, other):
@@ -671,10 +906,16 @@ class DecimalImplicitConstructionTest(unittest.TestCase):
self.assertEqual(eval('Decimal(10)' + sym + 'E()'),
'10' + rop + 'str')
+class CImplicitConstructionTest(ImplicitConstructionTest):
+ decimal = C
+class PyImplicitConstructionTest(ImplicitConstructionTest):
+ decimal = P
-class DecimalFormatTest(unittest.TestCase):
+class FormatTest(unittest.TestCase):
'''Unit tests for the format function.'''
def test_formatting(self):
+ Decimal = self.decimal.Decimal
+
# triples giving a format, a Decimal, and the expected result
test_values = [
('e', '0E-15', '0e-15'),
@@ -730,6 +971,7 @@ class DecimalFormatTest(unittest.TestCase):
('g', '0E-7', '0e-7'),
('g', '-0E2', '-0e+2'),
('.0g', '3.14159265', '3'), # 0 sig fig -> 1 sig fig
+ ('.0n', '3.14159265', '3'), # same for 'n'
('.1g', '3.14159265', '3'),
('.2g', '3.14159265', '3.1'),
('.5g', '3.14159265', '3.1416'),
@@ -814,56 +1056,60 @@ class DecimalFormatTest(unittest.TestCase):
# issue 6850
('a=-7.0', '0.12345', 'aaaa0.1'),
-
- # Issue 7094: Alternate formatting (specified by #)
- ('.0e', '1.0', '1e+0'),
- ('#.0e', '1.0', '1.e+0'),
- ('.0f', '1.0', '1'),
- ('#.0f', '1.0', '1.'),
- ('g', '1.1', '1.1'),
- ('#g', '1.1', '1.1'),
- ('.0g', '1', '1'),
- ('#.0g', '1', '1.'),
- ('.0%', '1.0', '100%'),
- ('#.0%', '1.0', '100.%'),
]
for fmt, d, result in test_values:
self.assertEqual(format(Decimal(d), fmt), result)
+ # bytes format argument
+ self.assertRaises(TypeError, Decimal(1).__format__, b'-020')
+
def test_n_format(self):
+ Decimal = self.decimal.Decimal
+
try:
from locale import CHAR_MAX
except ImportError:
return
+ def make_grouping(lst):
+ return ''.join([chr(x) for x in lst]) if self.decimal == C else lst
+
+ def get_fmt(x, override=None, fmt='n'):
+ if self.decimal == C:
+ return Decimal(x).__format__(fmt, override)
+ else:
+ return Decimal(x).__format__(fmt, _localeconv=override)
+
# Set up some localeconv-like dictionaries
en_US = {
'decimal_point' : '.',
- 'grouping' : [3, 3, 0],
- 'thousands_sep': ','
+ 'grouping' : make_grouping([3, 3, 0]),
+ 'thousands_sep' : ','
}
fr_FR = {
'decimal_point' : ',',
- 'grouping' : [CHAR_MAX],
+ 'grouping' : make_grouping([CHAR_MAX]),
'thousands_sep' : ''
}
ru_RU = {
'decimal_point' : ',',
- 'grouping' : [3, 3, 0],
+ 'grouping': make_grouping([3, 3, 0]),
'thousands_sep' : ' '
}
crazy = {
'decimal_point' : '&',
- 'grouping' : [1, 4, 2, CHAR_MAX],
+ 'grouping': make_grouping([1, 4, 2, CHAR_MAX]),
'thousands_sep' : '-'
}
-
- def get_fmt(x, locale, fmt='n'):
- return Decimal.__format__(Decimal(x), fmt, _localeconv=locale)
+ dotsep_wide = {
+ 'decimal_point' : b'\xc2\xbf'.decode('utf-8'),
+ 'grouping': make_grouping([3, 3, 0]),
+ 'thousands_sep' : b'\xc2\xb4'.decode('utf-8')
+ }
self.assertEqual(get_fmt(Decimal('12.7'), en_US), '12.7')
self.assertEqual(get_fmt(Decimal('12.7'), fr_FR), '12,7')
@@ -902,11 +1148,34 @@ class DecimalFormatTest(unittest.TestCase):
self.assertEqual(get_fmt(123456, crazy, '012n'), '00-01-2345-6')
self.assertEqual(get_fmt(123456, crazy, '013n'), '000-01-2345-6')
+ # wide char separator and decimal point
+ self.assertEqual(get_fmt(Decimal('-1.5'), dotsep_wide, '020n'),
+ '-0\u00b4000\u00b4000\u00b4000\u00b4001\u00bf5')
+
+ @run_with_locale('LC_ALL', 'ps_AF')
+ def test_wide_char_separator_decimal_point(self):
+ # locale with wide char separator and decimal point
+ import locale
+ Decimal = self.decimal.Decimal
-class DecimalArithmeticOperatorsTest(unittest.TestCase):
+ decimal_point = locale.localeconv()['decimal_point']
+ thousands_sep = locale.localeconv()['thousands_sep']
+ if decimal_point != '\u066b' or thousands_sep != '\u066c':
+ return
+
+ self.assertEqual(format(Decimal('100000000.123'), 'n'),
+ '100\u066c000\u066c000\u066b123')
+
+class CFormatTest(FormatTest):
+ decimal = C
+class PyFormatTest(FormatTest):
+ decimal = P
+
+class ArithmeticOperatorsTest(unittest.TestCase):
'''Unit tests for all arithmetic operators, binary and unary.'''
def test_addition(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('-11.1')
d2 = Decimal('22.2')
@@ -934,6 +1203,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('16.1'))
def test_subtraction(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('-11.1')
d2 = Decimal('22.2')
@@ -961,6 +1231,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('-38.3'))
def test_multiplication(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('-5')
d2 = Decimal('3')
@@ -988,6 +1259,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('-75'))
def test_division(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('-5')
d2 = Decimal('2')
@@ -1015,6 +1287,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('-0.625'))
def test_floor_division(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('5')
d2 = Decimal('2')
@@ -1042,6 +1315,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('1'))
def test_powering(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('5')
d2 = Decimal('2')
@@ -1069,6 +1343,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('390625'))
def test_module(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('5')
d2 = Decimal('2')
@@ -1096,6 +1371,7 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(d1, Decimal('1'))
def test_floor_div_module(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('5')
d2 = Decimal('2')
@@ -1122,6 +1398,8 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(type(q), type(d1))
def test_unary_operators(self):
+ Decimal = self.decimal.Decimal
+
self.assertEqual(+Decimal(45), Decimal(+45)) # +
self.assertEqual(-Decimal(45), Decimal(-45)) # -
self.assertEqual(abs(Decimal(45)), abs(Decimal(-45))) # abs
@@ -1134,6 +1412,9 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
# equality comparisons (==, !=) involving only quiet nans
# don't signal, but return False or True respectively.
+ Decimal = self.decimal.Decimal
+ InvalidOperation = self.decimal.InvalidOperation
+ localcontext = self.decimal.localcontext
n = Decimal('NaN')
s = Decimal('sNaN')
@@ -1179,53 +1460,124 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertRaises(InvalidOperation, op, x, y)
def test_copy_sign(self):
- d = Decimal(1).copy_sign(Decimal(-2))
+ Decimal = self.decimal.Decimal
+ d = Decimal(1).copy_sign(Decimal(-2))
self.assertEqual(Decimal(1).copy_sign(-2), d)
self.assertRaises(TypeError, Decimal(1).copy_sign, '-2')
+class CArithmeticOperatorsTest(ArithmeticOperatorsTest):
+ decimal = C
+class PyArithmeticOperatorsTest(ArithmeticOperatorsTest):
+ decimal = P
+
# The following are two functions used to test threading in the next class
def thfunc1(cls):
+ Decimal = cls.decimal.Decimal
+ InvalidOperation = cls.decimal.InvalidOperation
+ DivisionByZero = cls.decimal.DivisionByZero
+ Overflow = cls.decimal.Overflow
+ Underflow = cls.decimal.Underflow
+ Inexact = cls.decimal.Inexact
+ getcontext = cls.decimal.getcontext
+ localcontext = cls.decimal.localcontext
+
d1 = Decimal(1)
d3 = Decimal(3)
test1 = d1/d3
- cls.synchro.wait()
- test2 = d1/d3
+
cls.finish1.set()
+ cls.synchro.wait()
- cls.assertEqual(test1, Decimal('0.3333333333333333333333333333'))
- cls.assertEqual(test2, Decimal('0.3333333333333333333333333333'))
+ test2 = d1/d3
+ with localcontext() as c2:
+ cls.assertTrue(c2.flags[Inexact])
+ cls.assertRaises(DivisionByZero, c2.divide, d1, 0)
+ cls.assertTrue(c2.flags[DivisionByZero])
+ with localcontext() as c3:
+ cls.assertTrue(c3.flags[Inexact])
+ cls.assertTrue(c3.flags[DivisionByZero])
+ cls.assertRaises(InvalidOperation, c3.compare, d1, Decimal('sNaN'))
+ cls.assertTrue(c3.flags[InvalidOperation])
+ del c3
+ cls.assertFalse(c2.flags[InvalidOperation])
+ del c2
+
+ cls.assertEqual(test1, Decimal('0.333333333333333333333333'))
+ cls.assertEqual(test2, Decimal('0.333333333333333333333333'))
+
+ c1 = getcontext()
+ cls.assertTrue(c1.flags[Inexact])
+ for sig in Overflow, Underflow, DivisionByZero, InvalidOperation:
+ cls.assertFalse(c1.flags[sig])
return
def thfunc2(cls):
+ Decimal = cls.decimal.Decimal
+ InvalidOperation = cls.decimal.InvalidOperation
+ DivisionByZero = cls.decimal.DivisionByZero
+ Overflow = cls.decimal.Overflow
+ Underflow = cls.decimal.Underflow
+ Inexact = cls.decimal.Inexact
+ getcontext = cls.decimal.getcontext
+ localcontext = cls.decimal.localcontext
+
d1 = Decimal(1)
d3 = Decimal(3)
test1 = d1/d3
+
thiscontext = getcontext()
thiscontext.prec = 18
test2 = d1/d3
+
+ with localcontext() as c2:
+ cls.assertTrue(c2.flags[Inexact])
+ cls.assertRaises(Overflow, c2.multiply, Decimal('1e425000000'), 999)
+ cls.assertTrue(c2.flags[Overflow])
+ with localcontext(thiscontext) as c3:
+ cls.assertTrue(c3.flags[Inexact])
+ cls.assertFalse(c3.flags[Overflow])
+ c3.traps[Underflow] = True
+ cls.assertRaises(Underflow, c3.divide, Decimal('1e-425000000'), 999)
+ cls.assertTrue(c3.flags[Underflow])
+ del c3
+ cls.assertFalse(c2.flags[Underflow])
+ cls.assertFalse(c2.traps[Underflow])
+ del c2
+
cls.synchro.set()
cls.finish2.set()
- cls.assertEqual(test1, Decimal('0.3333333333333333333333333333'))
+ cls.assertEqual(test1, Decimal('0.333333333333333333333333'))
cls.assertEqual(test2, Decimal('0.333333333333333333'))
- return
-
-class DecimalUseOfContextTest(unittest.TestCase):
- '''Unit tests for Use of Context cases in Decimal.'''
+ cls.assertFalse(thiscontext.traps[Underflow])
+ cls.assertTrue(thiscontext.flags[Inexact])
+ for sig in Overflow, Underflow, DivisionByZero, InvalidOperation:
+ cls.assertFalse(thiscontext.flags[sig])
+ return
- try:
- import threading
- except ImportError:
- threading = None
+class ThreadingTest(unittest.TestCase):
+ '''Unit tests for thread local contexts in Decimal.'''
# Take care executing this test from IDLE, there's an issue in threading
# that hangs IDLE and I couldn't find it
def test_threading(self):
- #Test the "threading isolation" of a Context.
+ DefaultContext = self.decimal.DefaultContext
+
+ if self.decimal == C and not self.decimal.HAVE_THREADS:
+ self.skipTest("compiled without threading")
+ # Test the "threading isolation" of a Context. Also test changing
+ # the DefaultContext, which acts as a template for the thread-local
+ # contexts.
+ save_prec = DefaultContext.prec
+ save_emax = DefaultContext.Emax
+ save_emin = DefaultContext.Emin
+ DefaultContext.prec = 24
+ DefaultContext.Emax = 425000000
+ DefaultContext.Emin = -425000000
self.synchro = threading.Event()
self.finish1 = threading.Event()
@@ -1239,17 +1591,29 @@ class DecimalUseOfContextTest(unittest.TestCase):
self.finish1.wait()
self.finish2.wait()
- return
- if threading is None:
- del test_threading
+ for sig in Signals[self.decimal]:
+ self.assertFalse(DefaultContext.flags[sig])
+
+ DefaultContext.prec = save_prec
+ DefaultContext.Emax = save_emax
+ DefaultContext.Emin = save_emin
+ return
+@unittest.skipUnless(threading, 'threading required')
+class CThreadingTest(ThreadingTest):
+ decimal = C
+@unittest.skipUnless(threading, 'threading required')
+class PyThreadingTest(ThreadingTest):
+ decimal = P
-class DecimalUsabilityTest(unittest.TestCase):
+class UsabilityTest(unittest.TestCase):
'''Unit tests for Usability cases of Decimal.'''
def test_comparison_operators(self):
+ Decimal = self.decimal.Decimal
+
da = Decimal('23.42')
db = Decimal('23.42')
dc = Decimal('45')
@@ -1283,6 +1647,8 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertEqual(a, b)
def test_decimal_float_comparison(self):
+ Decimal = self.decimal.Decimal
+
da = Decimal('0.25')
db = Decimal('3.0')
self.assertLess(da, 3.0)
@@ -1299,7 +1665,71 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertEqual(3.0, db)
self.assertNotEqual(0.1, Decimal('0.1'))
+ def test_decimal_complex_comparison(self):
+ Decimal = self.decimal.Decimal
+
+ da = Decimal('0.25')
+ db = Decimal('3.0')
+ self.assertNotEqual(da, (1.5+0j))
+ self.assertNotEqual((1.5+0j), da)
+ self.assertEqual(da, (0.25+0j))
+ self.assertEqual((0.25+0j), da)
+ self.assertEqual((3.0+0j), db)
+ self.assertEqual(db, (3.0+0j))
+
+ self.assertNotEqual(db, (3.0+1j))
+ self.assertNotEqual((3.0+1j), db)
+
+ self.assertIs(db.__lt__(3.0+0j), NotImplemented)
+ self.assertIs(db.__le__(3.0+0j), NotImplemented)
+ self.assertIs(db.__gt__(3.0+0j), NotImplemented)
+ self.assertIs(db.__le__(3.0+0j), NotImplemented)
+
+ def test_decimal_fraction_comparison(self):
+ D = self.decimal.Decimal
+ F = fractions[self.decimal].Fraction
+ Context = self.decimal.Context
+ localcontext = self.decimal.localcontext
+ InvalidOperation = self.decimal.InvalidOperation
+
+
+ emax = C.MAX_EMAX if C else 999999999
+ emin = C.MIN_EMIN if C else -999999999
+ etiny = C.MIN_ETINY if C else -1999999997
+ c = Context(Emax=emax, Emin=emin)
+
+ with localcontext(c):
+ c.prec = emax
+ self.assertLess(D(0), F(1,9999999999999999999999999999999999999))
+ self.assertLess(F(-1,9999999999999999999999999999999999999), D(0))
+ self.assertLess(F(0,1), D("1e" + str(etiny)))
+ self.assertLess(D("-1e" + str(etiny)), F(0,1))
+ self.assertLess(F(0,9999999999999999999999999), D("1e" + str(etiny)))
+ self.assertLess(D("-1e" + str(etiny)), F(0,9999999999999999999999999))
+
+ self.assertEqual(D("0.1"), F(1,10))
+ self.assertEqual(F(1,10), D("0.1"))
+
+ c.prec = 300
+ self.assertNotEqual(D(1)/3, F(1,3))
+ self.assertNotEqual(F(1,3), D(1)/3)
+
+ self.assertLessEqual(F(120984237, 9999999999), D("9e" + str(emax)))
+ self.assertGreaterEqual(D("9e" + str(emax)), F(120984237, 9999999999))
+
+ self.assertGreater(D('inf'), F(99999999999,123))
+ self.assertGreater(D('inf'), F(-99999999999,123))
+ self.assertLess(D('-inf'), F(99999999999,123))
+ self.assertLess(D('-inf'), F(-99999999999,123))
+
+ self.assertRaises(InvalidOperation, D('nan').__gt__, F(-9,123))
+ self.assertIs(NotImplemented, F(-9,123).__lt__(D('nan')))
+ self.assertNotEqual(D('nan'), F(-9,123))
+ self.assertNotEqual(F(-9,123), D('nan'))
+
def test_copy_and_deepcopy_methods(self):
+ Decimal = self.decimal.Decimal
+
d = Decimal('43.24')
c = copy.copy(d)
self.assertEqual(id(c), id(d))
@@ -1307,6 +1737,10 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertEqual(id(dc), id(d))
def test_hash_method(self):
+
+ Decimal = self.decimal.Decimal
+ localcontext = self.decimal.localcontext
+
def hashit(d):
a = hash(d)
b = d.__hash__()
@@ -1367,24 +1801,27 @@ class DecimalUsabilityTest(unittest.TestCase):
d = Decimal(s)
self.assertEqual(hashit(f), hashit(d))
- # check that the value of the hash doesn't depend on the
- # current context (issue #1757)
- c = getcontext()
- old_precision = c.prec
- x = Decimal("123456789.1")
+ with localcontext() as c:
+ # check that the value of the hash doesn't depend on the
+ # current context (issue #1757)
+ x = Decimal("123456789.1")
- c.prec = 6
- h1 = hashit(x)
- c.prec = 10
- h2 = hashit(x)
- c.prec = 16
- h3 = hashit(x)
+ c.prec = 6
+ h1 = hashit(x)
+ c.prec = 10
+ h2 = hashit(x)
+ c.prec = 16
+ h3 = hashit(x)
- self.assertEqual(h1, h2)
- self.assertEqual(h1, h3)
- c.prec = old_precision
+ self.assertEqual(h1, h2)
+ self.assertEqual(h1, h3)
+
+ c.prec = 10000
+ x = 1100 ** 1248
+ self.assertEqual(hashit(Decimal(x)), hashit(x))
def test_min_and_max_methods(self):
+ Decimal = self.decimal.Decimal
d1 = Decimal('15.32')
d2 = Decimal('28.5')
@@ -1404,6 +1841,8 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertIs(max(d2,l1), d2)
def test_as_nonzero(self):
+ Decimal = self.decimal.Decimal
+
#as false
self.assertFalse(Decimal(0))
#as true
@@ -1411,6 +1850,7 @@ class DecimalUsabilityTest(unittest.TestCase):
def test_tostring_methods(self):
#Test str and repr methods.
+ Decimal = self.decimal.Decimal
d = Decimal('15.32')
self.assertEqual(str(d), '15.32') # str
@@ -1418,6 +1858,7 @@ class DecimalUsabilityTest(unittest.TestCase):
def test_tonum_methods(self):
#Test float and int methods.
+ Decimal = self.decimal.Decimal
d1 = Decimal('66')
d2 = Decimal('15.32')
@@ -1440,6 +1881,7 @@ class DecimalUsabilityTest(unittest.TestCase):
('-11.0', -11),
('0.0', 0),
('-0E3', 0),
+ ('89891211712379812736.1', 89891211712379812736),
]
for d, i in test_pairs:
self.assertEqual(math.floor(Decimal(d)), i)
@@ -1459,6 +1901,7 @@ class DecimalUsabilityTest(unittest.TestCase):
('-11.0', -11),
('0.0', 0),
('-0E3', 0),
+ ('89891211712379812736.1', 89891211712379812737),
]
for d, i in test_pairs:
self.assertEqual(math.ceil(Decimal(d)), i)
@@ -1519,16 +1962,21 @@ class DecimalUsabilityTest(unittest.TestCase):
def test_nan_to_float(self):
# Test conversions of decimal NANs to float.
# See http://bugs.python.org/issue15544
+ Decimal = self.decimal.Decimal
for s in ('nan', 'nan1234', '-nan', '-nan2468'):
f = float(Decimal(s))
self.assertTrue(math.isnan(f))
+ sign = math.copysign(1.0, f)
+ self.assertEqual(sign, -1.0 if s.startswith('-') else 1.0)
def test_snan_to_float(self):
+ Decimal = self.decimal.Decimal
for s in ('snan', '-snan', 'snan1357', '-snan1234'):
d = Decimal(s)
self.assertRaises(ValueError, float, d)
def test_eval_round_trip(self):
+ Decimal = self.decimal.Decimal
#with zero
d = Decimal( (0, (0,), 0) )
@@ -1547,6 +1995,7 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertEqual(d, eval(repr(d)))
def test_as_tuple(self):
+ Decimal = self.decimal.Decimal
#with zero
d = Decimal(0)
@@ -1560,7 +2009,8 @@ class DecimalUsabilityTest(unittest.TestCase):
d = Decimal("-4.34913534E-17")
self.assertEqual(d.as_tuple(), (1, (4, 3, 4, 9, 1, 3, 5, 3, 4), -25) )
- #inf
+ # The '0' coefficient is implementation specific to decimal.py.
+ # It has no meaning in the C-version and is ignored there.
d = Decimal("Infinity")
self.assertEqual(d.as_tuple(), (0, (0,), 'F') )
@@ -1580,90 +2030,21 @@ class DecimalUsabilityTest(unittest.TestCase):
d = Decimal( (1, (), 'n') )
self.assertEqual(d.as_tuple(), (1, (), 'n') )
- #coefficient in infinity should be ignored
+ # For infinities, decimal.py has always silently accepted any
+ # coefficient tuple.
+ d = Decimal( (0, (0,), 'F') )
+ self.assertEqual(d.as_tuple(), (0, (0,), 'F'))
d = Decimal( (0, (4, 5, 3, 4), 'F') )
self.assertEqual(d.as_tuple(), (0, (0,), 'F'))
d = Decimal( (1, (0, 2, 7, 1), 'F') )
self.assertEqual(d.as_tuple(), (1, (0,), 'F'))
- def test_immutability_operations(self):
- # Do operations and check that it didn't change change internal objects.
-
- d1 = Decimal('-25e55')
- b1 = Decimal('-25e55')
- d2 = Decimal('33e+33')
- b2 = Decimal('33e+33')
-
- def checkSameDec(operation, useOther=False):
- if useOther:
- eval("d1." + operation + "(d2)")
- self.assertEqual(d1._sign, b1._sign)
- self.assertEqual(d1._int, b1._int)
- self.assertEqual(d1._exp, b1._exp)
- self.assertEqual(d2._sign, b2._sign)
- self.assertEqual(d2._int, b2._int)
- self.assertEqual(d2._exp, b2._exp)
- else:
- eval("d1." + operation + "()")
- self.assertEqual(d1._sign, b1._sign)
- self.assertEqual(d1._int, b1._int)
- self.assertEqual(d1._exp, b1._exp)
- return
-
- Decimal(d1)
- self.assertEqual(d1._sign, b1._sign)
- self.assertEqual(d1._int, b1._int)
- self.assertEqual(d1._exp, b1._exp)
-
- checkSameDec("__abs__")
- checkSameDec("__add__", True)
- checkSameDec("__divmod__", True)
- checkSameDec("__eq__", True)
- checkSameDec("__ne__", True)
- checkSameDec("__le__", True)
- checkSameDec("__lt__", True)
- checkSameDec("__ge__", True)
- checkSameDec("__gt__", True)
- checkSameDec("__float__")
- checkSameDec("__floordiv__", True)
- checkSameDec("__hash__")
- checkSameDec("__int__")
- checkSameDec("__trunc__")
- checkSameDec("__mod__", True)
- checkSameDec("__mul__", True)
- checkSameDec("__neg__")
- checkSameDec("__bool__")
- checkSameDec("__pos__")
- checkSameDec("__pow__", True)
- checkSameDec("__radd__", True)
- checkSameDec("__rdivmod__", True)
- checkSameDec("__repr__")
- checkSameDec("__rfloordiv__", True)
- checkSameDec("__rmod__", True)
- checkSameDec("__rmul__", True)
- checkSameDec("__rpow__", True)
- checkSameDec("__rsub__", True)
- checkSameDec("__str__")
- checkSameDec("__sub__", True)
- checkSameDec("__truediv__", True)
- checkSameDec("adjusted")
- checkSameDec("as_tuple")
- checkSameDec("compare", True)
- checkSameDec("max", True)
- checkSameDec("min", True)
- checkSameDec("normalize")
- checkSameDec("quantize", True)
- checkSameDec("remainder_near", True)
- checkSameDec("same_quantum", True)
- checkSameDec("sqrt")
- checkSameDec("to_eng_string")
- checkSameDec("to_integral")
-
def test_subclassing(self):
# Different behaviours when subclassing Decimal
+ Decimal = self.decimal.Decimal
class MyDecimal(Decimal):
- pass
+ y = None
d1 = MyDecimal(1)
d2 = MyDecimal(2)
@@ -1673,15 +2054,291 @@ class DecimalUsabilityTest(unittest.TestCase):
d = d1.max(d2)
self.assertIs(type(d), Decimal)
+ d = copy.copy(d1)
+ self.assertIs(type(d), MyDecimal)
+ self.assertEqual(d, d1)
+
+ d = copy.deepcopy(d1)
+ self.assertIs(type(d), MyDecimal)
+ self.assertEqual(d, d1)
+
+ # Decimal(Decimal)
+ d = Decimal('1.0')
+ x = Decimal(d)
+ self.assertIs(type(x), Decimal)
+ self.assertEqual(x, d)
+
+ # MyDecimal(Decimal)
+ m = MyDecimal(d)
+ self.assertIs(type(m), MyDecimal)
+ self.assertEqual(m, d)
+ self.assertIs(m.y, None)
+
+ # Decimal(MyDecimal)
+ x = Decimal(m)
+ self.assertIs(type(x), Decimal)
+ self.assertEqual(x, d)
+
+ # MyDecimal(MyDecimal)
+ m.y = 9
+ x = MyDecimal(m)
+ self.assertIs(type(x), MyDecimal)
+ self.assertEqual(x, d)
+ self.assertIs(x.y, None)
+
def test_implicit_context(self):
+ Decimal = self.decimal.Decimal
+ getcontext = self.decimal.getcontext
+
# Check results when context given implicitly. (Issue 2478)
c = getcontext()
self.assertEqual(str(Decimal(0).sqrt()),
str(c.sqrt(Decimal(0))))
+ def test_none_args(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ localcontext = self.decimal.localcontext
+ InvalidOperation = self.decimal.InvalidOperation
+ DivisionByZero = self.decimal.DivisionByZero
+ Overflow = self.decimal.Overflow
+ Underflow = self.decimal.Underflow
+ Subnormal = self.decimal.Subnormal
+ Inexact = self.decimal.Inexact
+ Rounded = self.decimal.Rounded
+ Clamped = self.decimal.Clamped
+
+ with localcontext(Context()) as c:
+ c.prec = 7
+ c.Emax = 999
+ c.Emin = -999
+
+ x = Decimal("111")
+ y = Decimal("1e9999")
+ z = Decimal("1e-9999")
+
+ ##### Unary functions
+ c.clear_flags()
+ self.assertEqual(str(x.exp(context=None)), '1.609487E+48')
+ self.assertTrue(c.flags[Inexact])
+ self.assertTrue(c.flags[Rounded])
+ c.clear_flags()
+ self.assertRaises(Overflow, y.exp, context=None)
+ self.assertTrue(c.flags[Overflow])
+
+ self.assertIs(z.is_normal(context=None), False)
+ self.assertIs(z.is_subnormal(context=None), True)
+
+ c.clear_flags()
+ self.assertEqual(str(x.ln(context=None)), '4.709530')
+ self.assertTrue(c.flags[Inexact])
+ self.assertTrue(c.flags[Rounded])
+ c.clear_flags()
+ self.assertRaises(InvalidOperation, Decimal(-1).ln, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ self.assertEqual(str(x.log10(context=None)), '2.045323')
+ self.assertTrue(c.flags[Inexact])
+ self.assertTrue(c.flags[Rounded])
+ c.clear_flags()
+ self.assertRaises(InvalidOperation, Decimal(-1).log10, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ self.assertEqual(str(x.logb(context=None)), '2')
+ self.assertRaises(DivisionByZero, Decimal(0).logb, context=None)
+ self.assertTrue(c.flags[DivisionByZero])
+
+ c.clear_flags()
+ self.assertEqual(str(x.logical_invert(context=None)), '1111000')
+ self.assertRaises(InvalidOperation, y.logical_invert, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ self.assertEqual(str(y.next_minus(context=None)), '9.999999E+999')
+ self.assertRaises(InvalidOperation, Decimal('sNaN').next_minus, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ self.assertEqual(str(y.next_plus(context=None)), 'Infinity')
+ self.assertRaises(InvalidOperation, Decimal('sNaN').next_plus, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ self.assertEqual(str(z.normalize(context=None)), '0')
+ self.assertRaises(Overflow, y.normalize, context=None)
+ self.assertTrue(c.flags[Overflow])
+
+ self.assertEqual(str(z.number_class(context=None)), '+Subnormal')
+
+ c.clear_flags()
+ self.assertEqual(str(z.sqrt(context=None)), '0E-1005')
+ self.assertTrue(c.flags[Clamped])
+ self.assertTrue(c.flags[Inexact])
+ self.assertTrue(c.flags[Rounded])
+ self.assertTrue(c.flags[Subnormal])
+ self.assertTrue(c.flags[Underflow])
+ c.clear_flags()
+ self.assertRaises(Overflow, y.sqrt, context=None)
+ self.assertTrue(c.flags[Overflow])
+
+ c.capitals = 0
+ self.assertEqual(str(z.to_eng_string(context=None)), '1e-9999')
+ c.capitals = 1
+
+
+ ##### Binary functions
+ c.clear_flags()
+ ans = str(x.compare(Decimal('Nan891287828'), context=None))
+ self.assertEqual(ans, 'NaN1287828')
+ self.assertRaises(InvalidOperation, x.compare, Decimal('sNaN'), context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.compare_signal(8224, context=None))
+ self.assertEqual(ans, '-1')
+ self.assertRaises(InvalidOperation, x.compare_signal, Decimal('NaN'), context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.logical_and(101, context=None))
+ self.assertEqual(ans, '101')
+ self.assertRaises(InvalidOperation, x.logical_and, 123, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.logical_or(101, context=None))
+ self.assertEqual(ans, '111')
+ self.assertRaises(InvalidOperation, x.logical_or, 123, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.logical_xor(101, context=None))
+ self.assertEqual(ans, '10')
+ self.assertRaises(InvalidOperation, x.logical_xor, 123, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.max(101, context=None))
+ self.assertEqual(ans, '111')
+ self.assertRaises(InvalidOperation, x.max, Decimal('sNaN'), context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.max_mag(101, context=None))
+ self.assertEqual(ans, '111')
+ self.assertRaises(InvalidOperation, x.max_mag, Decimal('sNaN'), context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.min(101, context=None))
+ self.assertEqual(ans, '101')
+ self.assertRaises(InvalidOperation, x.min, Decimal('sNaN'), context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.min_mag(101, context=None))
+ self.assertEqual(ans, '101')
+ self.assertRaises(InvalidOperation, x.min_mag, Decimal('sNaN'), context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.remainder_near(101, context=None))
+ self.assertEqual(ans, '10')
+ self.assertRaises(InvalidOperation, y.remainder_near, 101, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.rotate(2, context=None))
+ self.assertEqual(ans, '11100')
+ self.assertRaises(InvalidOperation, x.rotate, 101, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.scaleb(7, context=None))
+ self.assertEqual(ans, '1.11E+9')
+ self.assertRaises(InvalidOperation, x.scaleb, 10000, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ ans = str(x.shift(2, context=None))
+ self.assertEqual(ans, '11100')
+ self.assertRaises(InvalidOperation, x.shift, 10000, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+
+ ##### Ternary functions
+ c.clear_flags()
+ ans = str(x.fma(2, 3, context=None))
+ self.assertEqual(ans, '225')
+ self.assertRaises(Overflow, x.fma, Decimal('1e9999'), 3, context=None)
+ self.assertTrue(c.flags[Overflow])
+
+
+ ##### Special cases
+ c.rounding = ROUND_HALF_EVEN
+ ans = str(Decimal('1.5').to_integral(rounding=None, context=None))
+ self.assertEqual(ans, '2')
+ c.rounding = ROUND_DOWN
+ ans = str(Decimal('1.5').to_integral(rounding=None, context=None))
+ self.assertEqual(ans, '1')
+ ans = str(Decimal('1.5').to_integral(rounding=ROUND_UP, context=None))
+ self.assertEqual(ans, '2')
+ c.clear_flags()
+ self.assertRaises(InvalidOperation, Decimal('sNaN').to_integral, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.rounding = ROUND_HALF_EVEN
+ ans = str(Decimal('1.5').to_integral_value(rounding=None, context=None))
+ self.assertEqual(ans, '2')
+ c.rounding = ROUND_DOWN
+ ans = str(Decimal('1.5').to_integral_value(rounding=None, context=None))
+ self.assertEqual(ans, '1')
+ ans = str(Decimal('1.5').to_integral_value(rounding=ROUND_UP, context=None))
+ self.assertEqual(ans, '2')
+ c.clear_flags()
+ self.assertRaises(InvalidOperation, Decimal('sNaN').to_integral_value, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.rounding = ROUND_HALF_EVEN
+ ans = str(Decimal('1.5').to_integral_exact(rounding=None, context=None))
+ self.assertEqual(ans, '2')
+ c.rounding = ROUND_DOWN
+ ans = str(Decimal('1.5').to_integral_exact(rounding=None, context=None))
+ self.assertEqual(ans, '1')
+ ans = str(Decimal('1.5').to_integral_exact(rounding=ROUND_UP, context=None))
+ self.assertEqual(ans, '2')
+ c.clear_flags()
+ self.assertRaises(InvalidOperation, Decimal('sNaN').to_integral_exact, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.rounding = ROUND_UP
+ ans = str(Decimal('1.50001').quantize(exp=Decimal('1e-3'), rounding=None, context=None))
+ self.assertEqual(ans, '1.501')
+ c.rounding = ROUND_DOWN
+ ans = str(Decimal('1.50001').quantize(exp=Decimal('1e-3'), rounding=None, context=None))
+ self.assertEqual(ans, '1.500')
+ ans = str(Decimal('1.50001').quantize(exp=Decimal('1e-3'), rounding=ROUND_UP, context=None))
+ self.assertEqual(ans, '1.501')
+ c.clear_flags()
+ self.assertRaises(InvalidOperation, y.quantize, Decimal('1e-10'), rounding=ROUND_UP, context=None)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ with localcontext(Context()) as context:
+ context.prec = 7
+ context.Emax = 999
+ context.Emin = -999
+ with localcontext(ctx=None) as c:
+ self.assertEqual(c.prec, 7)
+ self.assertEqual(c.Emax, 999)
+ self.assertEqual(c.Emin, -999)
+
def test_conversions_from_int(self):
# Check that methods taking a second Decimal argument will
# always accept an integer in place of a Decimal.
+ Decimal = self.decimal.Decimal
+
self.assertEqual(Decimal(4).compare(3),
Decimal(4).compare(Decimal(3)))
self.assertEqual(Decimal(4).compare_signal(3),
@@ -1726,22 +2383,57 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertEqual(Decimal(-12).fma(45, Decimal(67)),
Decimal(-12).fma(Decimal(45), Decimal(67)))
+class CUsabilityTest(UsabilityTest):
+ decimal = C
+class PyUsabilityTest(UsabilityTest):
+ decimal = P
-class DecimalPythonAPItests(unittest.TestCase):
+class PythonAPItests(unittest.TestCase):
def test_abc(self):
+ Decimal = self.decimal.Decimal
+
self.assertTrue(issubclass(Decimal, numbers.Number))
self.assertFalse(issubclass(Decimal, numbers.Real))
self.assertIsInstance(Decimal(0), numbers.Number)
self.assertNotIsInstance(Decimal(0), numbers.Real)
def test_pickle(self):
+ Decimal = self.decimal.Decimal
+
+ savedecimal = sys.modules['decimal']
+
+ # Round trip
+ sys.modules['decimal'] = self.decimal
d = Decimal('-3.141590000')
p = pickle.dumps(d)
e = pickle.loads(p)
self.assertEqual(d, e)
+ if C:
+ # Test interchangeability
+ x = C.Decimal('-3.123e81723')
+ y = P.Decimal('-3.123e81723')
+
+ sys.modules['decimal'] = C
+ sx = pickle.dumps(x)
+ sys.modules['decimal'] = P
+ r = pickle.loads(sx)
+ self.assertIsInstance(r, P.Decimal)
+ self.assertEqual(r, y)
+
+ sys.modules['decimal'] = P
+ sy = pickle.dumps(y)
+ sys.modules['decimal'] = C
+ r = pickle.loads(sy)
+ self.assertIsInstance(r, C.Decimal)
+ self.assertEqual(r, x)
+
+ sys.modules['decimal'] = savedecimal
+
def test_int(self):
+ Decimal = self.decimal.Decimal
+
for x in range(-250, 250):
s = '%0.2f' % (x / 100.0)
# should work the same as for floats
@@ -1757,6 +2449,8 @@ class DecimalPythonAPItests(unittest.TestCase):
self.assertRaises(OverflowError, int, Decimal('-inf'))
def test_trunc(self):
+ Decimal = self.decimal.Decimal
+
for x in range(-250, 250):
s = '%0.2f' % (x / 100.0)
# should work the same as for floats
@@ -1768,9 +2462,13 @@ class DecimalPythonAPItests(unittest.TestCase):
def test_from_float(self):
- class MyDecimal(Decimal):
+ Decimal = self.decimal.Decimal
+
+ class MyDecimal(Decimal):
pass
+ self.assertTrue(issubclass(MyDecimal, Decimal))
+
r = MyDecimal.from_float(0.1)
self.assertEqual(type(r), MyDecimal)
self.assertEqual(str(r),
@@ -1792,6 +2490,10 @@ class DecimalPythonAPItests(unittest.TestCase):
self.assertEqual(x, float(MyDecimal.from_float(x))) # roundtrip
def test_create_decimal_from_float(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ Inexact = self.decimal.Inexact
+
context = Context(prec=5, rounding=ROUND_DOWN)
self.assertEqual(
context.create_decimal_from_float(math.pi),
@@ -1815,27 +2517,320 @@ class DecimalPythonAPItests(unittest.TestCase):
self.assertEqual(repr(context.create_decimal_from_float(10)),
"Decimal('10')")
+ def test_quantize(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ InvalidOperation = self.decimal.InvalidOperation
+
+ c = Context(Emax=99999, Emin=-99999)
+ self.assertEqual(
+ Decimal('7.335').quantize(Decimal('.01')),
+ Decimal('7.34')
+ )
+ self.assertEqual(
+ Decimal('7.335').quantize(Decimal('.01'), rounding=ROUND_DOWN),
+ Decimal('7.33')
+ )
+ self.assertRaises(
+ InvalidOperation,
+ Decimal("10e99999").quantize, Decimal('1e100000'), context=c
+ )
+
+ c = Context()
+ d = Decimal("0.871831e800")
+ x = d.quantize(context=c, exp=Decimal("1e797"), rounding=ROUND_DOWN)
+ self.assertEqual(x, Decimal('8.71E+799'))
+
+ def test_complex(self):
+ Decimal = self.decimal.Decimal
+
+ x = Decimal("9.8182731e181273")
+ self.assertEqual(x.real, x)
+ self.assertEqual(x.imag, 0)
+ self.assertEqual(x.conjugate(), x)
+
+ x = Decimal("1")
+ self.assertEqual(complex(x), complex(float(1)))
+
+ self.assertRaises(AttributeError, setattr, x, 'real', 100)
+ self.assertRaises(AttributeError, setattr, x, 'imag', 100)
+ self.assertRaises(AttributeError, setattr, x, 'conjugate', 100)
+ self.assertRaises(AttributeError, setattr, x, '__complex__', 100)
+
+ def test_named_parameters(self):
+ D = self.decimal.Decimal
+ Context = self.decimal.Context
+ localcontext = self.decimal.localcontext
+ InvalidOperation = self.decimal.InvalidOperation
+ Overflow = self.decimal.Overflow
+
+ xc = Context()
+ xc.prec = 1
+ xc.Emax = 1
+ xc.Emin = -1
+
+ with localcontext() as c:
+ c.clear_flags()
+
+ self.assertEqual(D(9, xc), 9)
+ self.assertEqual(D(9, context=xc), 9)
+ self.assertEqual(D(context=xc, value=9), 9)
+ self.assertEqual(D(context=xc), 0)
+ xc.clear_flags()
+ self.assertRaises(InvalidOperation, D, "xyz", context=xc)
+ self.assertTrue(xc.flags[InvalidOperation])
+ self.assertFalse(c.flags[InvalidOperation])
+
+ xc.clear_flags()
+ self.assertEqual(D(2).exp(context=xc), 7)
+ self.assertRaises(Overflow, D(8).exp, context=xc)
+ self.assertTrue(xc.flags[Overflow])
+ self.assertFalse(c.flags[Overflow])
+
+ xc.clear_flags()
+ self.assertEqual(D(2).ln(context=xc), D('0.7'))
+ self.assertRaises(InvalidOperation, D(-1).ln, context=xc)
+ self.assertTrue(xc.flags[InvalidOperation])
+ self.assertFalse(c.flags[InvalidOperation])
+
+ self.assertEqual(D(0).log10(context=xc), D('-inf'))
+ self.assertEqual(D(-1).next_minus(context=xc), -2)
+ self.assertEqual(D(-1).next_plus(context=xc), D('-0.9'))
+ self.assertEqual(D("9.73").normalize(context=xc), D('1E+1'))
+ self.assertEqual(D("9999").to_integral(context=xc), 9999)
+ self.assertEqual(D("-2000").to_integral_exact(context=xc), -2000)
+ self.assertEqual(D("123").to_integral_value(context=xc), 123)
+ self.assertEqual(D("0.0625").sqrt(context=xc), D('0.2'))
+
+ self.assertEqual(D("0.0625").compare(context=xc, other=3), -1)
+ xc.clear_flags()
+ self.assertRaises(InvalidOperation,
+ D("0").compare_signal, D('nan'), context=xc)
+ self.assertTrue(xc.flags[InvalidOperation])
+ self.assertFalse(c.flags[InvalidOperation])
+ self.assertEqual(D("0.01").max(D('0.0101'), context=xc), D('0.0'))
+ self.assertEqual(D("0.01").max(D('0.0101'), context=xc), D('0.0'))
+ self.assertEqual(D("0.2").max_mag(D('-0.3'), context=xc),
+ D('-0.3'))
+ self.assertEqual(D("0.02").min(D('-0.03'), context=xc), D('-0.0'))
+ self.assertEqual(D("0.02").min_mag(D('-0.03'), context=xc),
+ D('0.0'))
+ self.assertEqual(D("0.2").next_toward(D('-1'), context=xc), D('0.1'))
+ xc.clear_flags()
+ self.assertRaises(InvalidOperation,
+ D("0.2").quantize, D('1e10'), context=xc)
+ self.assertTrue(xc.flags[InvalidOperation])
+ self.assertFalse(c.flags[InvalidOperation])
+ self.assertEqual(D("9.99").remainder_near(D('1.5'), context=xc),
+ D('-0.5'))
+
+ self.assertEqual(D("9.9").fma(third=D('0.9'), context=xc, other=7),
+ D('7E+1'))
+
+ self.assertRaises(TypeError, D(1).is_canonical, context=xc)
+ self.assertRaises(TypeError, D(1).is_finite, context=xc)
+ self.assertRaises(TypeError, D(1).is_infinite, context=xc)
+ self.assertRaises(TypeError, D(1).is_nan, context=xc)
+ self.assertRaises(TypeError, D(1).is_qnan, context=xc)
+ self.assertRaises(TypeError, D(1).is_snan, context=xc)
+ self.assertRaises(TypeError, D(1).is_signed, context=xc)
+ self.assertRaises(TypeError, D(1).is_zero, context=xc)
+
+ self.assertFalse(D("0.01").is_normal(context=xc))
+ self.assertTrue(D("0.01").is_subnormal(context=xc))
+
+ self.assertRaises(TypeError, D(1).adjusted, context=xc)
+ self.assertRaises(TypeError, D(1).conjugate, context=xc)
+ self.assertRaises(TypeError, D(1).radix, context=xc)
+
+ self.assertEqual(D(-111).logb(context=xc), 2)
+ self.assertEqual(D(0).logical_invert(context=xc), 1)
+ self.assertEqual(D('0.01').number_class(context=xc), '+Subnormal')
+ self.assertEqual(D('0.21').to_eng_string(context=xc), '0.21')
+
+ self.assertEqual(D('11').logical_and(D('10'), context=xc), 0)
+ self.assertEqual(D('11').logical_or(D('10'), context=xc), 1)
+ self.assertEqual(D('01').logical_xor(D('10'), context=xc), 1)
+ self.assertEqual(D('23').rotate(1, context=xc), 3)
+ self.assertEqual(D('23').rotate(1, context=xc), 3)
+ xc.clear_flags()
+ self.assertRaises(Overflow,
+ D('23').scaleb, 1, context=xc)
+ self.assertTrue(xc.flags[Overflow])
+ self.assertFalse(c.flags[Overflow])
+ self.assertEqual(D('23').shift(-1, context=xc), 0)
+
+ self.assertRaises(TypeError, D.from_float, 1.1, context=xc)
+ self.assertRaises(TypeError, D(0).as_tuple, context=xc)
+
+ self.assertEqual(D(1).canonical(), 1)
+ self.assertRaises(TypeError, D("-1").copy_abs, context=xc)
+ self.assertRaises(TypeError, D("-1").copy_negate, context=xc)
+ self.assertRaises(TypeError, D(1).canonical, context="x")
+ self.assertRaises(TypeError, D(1).canonical, xyz="x")
+
+ def test_exception_hierarchy(self):
+
+ decimal = self.decimal
+ DecimalException = decimal.DecimalException
+ InvalidOperation = decimal.InvalidOperation
+ FloatOperation = decimal.FloatOperation
+ DivisionByZero = decimal.DivisionByZero
+ Overflow = decimal.Overflow
+ Underflow = decimal.Underflow
+ Subnormal = decimal.Subnormal
+ Inexact = decimal.Inexact
+ Rounded = decimal.Rounded
+ Clamped = decimal.Clamped
+
+ self.assertTrue(issubclass(DecimalException, ArithmeticError))
+
+ self.assertTrue(issubclass(InvalidOperation, DecimalException))
+ self.assertTrue(issubclass(FloatOperation, DecimalException))
+ self.assertTrue(issubclass(FloatOperation, TypeError))
+ self.assertTrue(issubclass(DivisionByZero, DecimalException))
+ self.assertTrue(issubclass(DivisionByZero, ZeroDivisionError))
+ self.assertTrue(issubclass(Overflow, Rounded))
+ self.assertTrue(issubclass(Overflow, Inexact))
+ self.assertTrue(issubclass(Overflow, DecimalException))
+ self.assertTrue(issubclass(Underflow, Inexact))
+ self.assertTrue(issubclass(Underflow, Rounded))
+ self.assertTrue(issubclass(Underflow, Subnormal))
+ self.assertTrue(issubclass(Underflow, DecimalException))
+
+ self.assertTrue(issubclass(Subnormal, DecimalException))
+ self.assertTrue(issubclass(Inexact, DecimalException))
+ self.assertTrue(issubclass(Rounded, DecimalException))
+ self.assertTrue(issubclass(Clamped, DecimalException))
+
+ self.assertTrue(issubclass(decimal.ConversionSyntax, InvalidOperation))
+ self.assertTrue(issubclass(decimal.DivisionImpossible, InvalidOperation))
+ self.assertTrue(issubclass(decimal.DivisionUndefined, InvalidOperation))
+ self.assertTrue(issubclass(decimal.DivisionUndefined, ZeroDivisionError))
+ self.assertTrue(issubclass(decimal.InvalidContext, InvalidOperation))
+
+class CPythonAPItests(PythonAPItests):
+ decimal = C
+class PyPythonAPItests(PythonAPItests):
+ decimal = P
+
class ContextAPItests(unittest.TestCase):
+ def test_none_args(self):
+ Context = self.decimal.Context
+ InvalidOperation = self.decimal.InvalidOperation
+ DivisionByZero = self.decimal.DivisionByZero
+ Overflow = self.decimal.Overflow
+
+ c1 = Context()
+ c2 = Context(prec=None, rounding=None, Emax=None, Emin=None,
+ capitals=None, clamp=None, flags=None, traps=None)
+ for c in [c1, c2]:
+ self.assertEqual(c.prec, 28)
+ self.assertEqual(c.rounding, ROUND_HALF_EVEN)
+ self.assertEqual(c.Emax, 999999)
+ self.assertEqual(c.Emin, -999999)
+ self.assertEqual(c.capitals, 1)
+ self.assertEqual(c.clamp, 0)
+ assert_signals(self, c, 'flags', [])
+ assert_signals(self, c, 'traps', [InvalidOperation, DivisionByZero,
+ Overflow])
+
+ @cpython_only
+ def test_from_legacy_strings(self):
+ import _testcapi
+ c = self.decimal.Context()
+
+ for rnd in RoundingModes:
+ c.rounding = _testcapi.unicode_legacy_string(rnd)
+ self.assertEqual(c.rounding, rnd)
+
+ s = _testcapi.unicode_legacy_string('')
+ self.assertRaises(TypeError, setattr, c, 'rounding', s)
+
+ s = _testcapi.unicode_legacy_string('ROUND_\x00UP')
+ self.assertRaises(TypeError, setattr, c, 'rounding', s)
+
def test_pickle(self):
+
+ Context = self.decimal.Context
+
+ savedecimal = sys.modules['decimal']
+
+ # Round trip
+ sys.modules['decimal'] = self.decimal
c = Context()
e = pickle.loads(pickle.dumps(c))
- for k in vars(c):
- v1 = vars(c)[k]
- v2 = vars(e)[k]
- self.assertEqual(v1, v2)
+
+ self.assertEqual(c.prec, e.prec)
+ self.assertEqual(c.Emin, e.Emin)
+ self.assertEqual(c.Emax, e.Emax)
+ self.assertEqual(c.rounding, e.rounding)
+ self.assertEqual(c.capitals, e.capitals)
+ self.assertEqual(c.clamp, e.clamp)
+ self.assertEqual(c.flags, e.flags)
+ self.assertEqual(c.traps, e.traps)
+
+ # Test interchangeability
+ combinations = [(C, P), (P, C)] if C else [(P, P)]
+ for dumper, loader in combinations:
+ for ri, _ in enumerate(RoundingModes):
+ for fi, _ in enumerate(OrderedSignals[dumper]):
+ for ti, _ in enumerate(OrderedSignals[dumper]):
+
+ prec = random.randrange(1, 100)
+ emin = random.randrange(-100, 0)
+ emax = random.randrange(1, 100)
+ caps = random.randrange(2)
+ clamp = random.randrange(2)
+
+ # One module dumps
+ sys.modules['decimal'] = dumper
+ c = dumper.Context(
+ prec=prec, Emin=emin, Emax=emax,
+ rounding=RoundingModes[ri],
+ capitals=caps, clamp=clamp,
+ flags=OrderedSignals[dumper][:fi],
+ traps=OrderedSignals[dumper][:ti]
+ )
+ s = pickle.dumps(c)
+
+ # The other module loads
+ sys.modules['decimal'] = loader
+ d = pickle.loads(s)
+ self.assertIsInstance(d, loader.Context)
+
+ self.assertEqual(d.prec, prec)
+ self.assertEqual(d.Emin, emin)
+ self.assertEqual(d.Emax, emax)
+ self.assertEqual(d.rounding, RoundingModes[ri])
+ self.assertEqual(d.capitals, caps)
+ self.assertEqual(d.clamp, clamp)
+ assert_signals(self, d, 'flags', OrderedSignals[loader][:fi])
+ assert_signals(self, d, 'traps', OrderedSignals[loader][:ti])
+
+ sys.modules['decimal'] = savedecimal
def test_equality_with_other_types(self):
+ Decimal = self.decimal.Decimal
+
self.assertIn(Decimal(10), ['a', 1.0, Decimal(10), (1,2), {}])
self.assertNotIn(Decimal(10), ['a', 1.0, (1,2), {}])
def test_copy(self):
# All copies should be deep
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.copy()
self.assertNotEqual(id(c), id(d))
self.assertNotEqual(id(c.flags), id(d.flags))
self.assertNotEqual(id(c.traps), id(d.traps))
+ k1 = set(c.flags.keys())
+ k2 = set(d.flags.keys())
+ self.assertEqual(k1, k2)
+ self.assertEqual(c.flags, d.flags)
def test__clamp(self):
# In Python 3.2, the private attribute `_clamp` was made
@@ -1844,26 +2839,23 @@ class ContextAPItests(unittest.TestCase):
# only, the attribute should be gettable/settable via both
# `clamp` and `_clamp`; in Python 3.3, `_clamp` should be
# removed.
- c = Context(clamp = 0)
- self.assertEqual(c.clamp, 0)
-
- with check_warnings(("", DeprecationWarning)):
- c._clamp = 1
- self.assertEqual(c.clamp, 1)
- with check_warnings(("", DeprecationWarning)):
- self.assertEqual(c._clamp, 1)
- c.clamp = 0
- self.assertEqual(c.clamp, 0)
- with check_warnings(("", DeprecationWarning)):
- self.assertEqual(c._clamp, 0)
+ Context = self.decimal.Context
+ c = Context()
+ self.assertRaises(AttributeError, getattr, c, '_clamp')
def test_abs(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.abs(Decimal(-1))
self.assertEqual(c.abs(-1), d)
self.assertRaises(TypeError, c.abs, '-1')
def test_add(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.add(Decimal(1), Decimal(1))
self.assertEqual(c.add(1, 1), d)
@@ -1873,6 +2865,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.add, 1, '1')
def test_compare(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.compare(Decimal(1), Decimal(1))
self.assertEqual(c.compare(1, 1), d)
@@ -1882,6 +2877,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.compare, 1, '1')
def test_compare_signal(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.compare_signal(Decimal(1), Decimal(1))
self.assertEqual(c.compare_signal(1, 1), d)
@@ -1891,6 +2889,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.compare_signal, 1, '1')
def test_compare_total(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.compare_total(Decimal(1), Decimal(1))
self.assertEqual(c.compare_total(1, 1), d)
@@ -1900,6 +2901,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.compare_total, 1, '1')
def test_compare_total_mag(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.compare_total_mag(Decimal(1), Decimal(1))
self.assertEqual(c.compare_total_mag(1, 1), d)
@@ -1909,24 +2913,36 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.compare_total_mag, 1, '1')
def test_copy_abs(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.copy_abs(Decimal(-1))
self.assertEqual(c.copy_abs(-1), d)
self.assertRaises(TypeError, c.copy_abs, '-1')
def test_copy_decimal(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.copy_decimal(Decimal(-1))
self.assertEqual(c.copy_decimal(-1), d)
self.assertRaises(TypeError, c.copy_decimal, '-1')
def test_copy_negate(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.copy_negate(Decimal(-1))
self.assertEqual(c.copy_negate(-1), d)
self.assertRaises(TypeError, c.copy_negate, '-1')
def test_copy_sign(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.copy_sign(Decimal(1), Decimal(-2))
self.assertEqual(c.copy_sign(1, -2), d)
@@ -1936,6 +2952,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.copy_sign, 1, '-2')
def test_divide(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.divide(Decimal(1), Decimal(2))
self.assertEqual(c.divide(1, 2), d)
@@ -1945,6 +2964,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.divide, 1, '2')
def test_divide_int(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.divide_int(Decimal(1), Decimal(2))
self.assertEqual(c.divide_int(1, 2), d)
@@ -1954,6 +2976,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.divide_int, 1, '2')
def test_divmod(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.divmod(Decimal(1), Decimal(2))
self.assertEqual(c.divmod(1, 2), d)
@@ -1963,12 +2988,18 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.divmod, 1, '2')
def test_exp(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.exp(Decimal(10))
self.assertEqual(c.exp(10), d)
self.assertRaises(TypeError, c.exp, '10')
def test_fma(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.fma(Decimal(2), Decimal(3), Decimal(4))
self.assertEqual(c.fma(2, 3, 4), d)
@@ -1980,79 +3011,129 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.fma, 2, '3', 4)
self.assertRaises(TypeError, c.fma, 2, 3, '4')
+ # Issue 12079 for Context.fma ...
+ self.assertRaises(TypeError, c.fma,
+ Decimal('Infinity'), Decimal(0), "not a decimal")
+ self.assertRaises(TypeError, c.fma,
+ Decimal(1), Decimal('snan'), 1.222)
+ # ... and for Decimal.fma.
+ self.assertRaises(TypeError, Decimal('Infinity').fma,
+ Decimal(0), "not a decimal")
+ self.assertRaises(TypeError, Decimal(1).fma,
+ Decimal('snan'), 1.222)
+
def test_is_finite(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_finite(Decimal(10))
self.assertEqual(c.is_finite(10), d)
self.assertRaises(TypeError, c.is_finite, '10')
def test_is_infinite(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_infinite(Decimal(10))
self.assertEqual(c.is_infinite(10), d)
self.assertRaises(TypeError, c.is_infinite, '10')
def test_is_nan(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_nan(Decimal(10))
self.assertEqual(c.is_nan(10), d)
self.assertRaises(TypeError, c.is_nan, '10')
def test_is_normal(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_normal(Decimal(10))
self.assertEqual(c.is_normal(10), d)
self.assertRaises(TypeError, c.is_normal, '10')
def test_is_qnan(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_qnan(Decimal(10))
self.assertEqual(c.is_qnan(10), d)
self.assertRaises(TypeError, c.is_qnan, '10')
def test_is_signed(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_signed(Decimal(10))
self.assertEqual(c.is_signed(10), d)
self.assertRaises(TypeError, c.is_signed, '10')
def test_is_snan(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_snan(Decimal(10))
self.assertEqual(c.is_snan(10), d)
self.assertRaises(TypeError, c.is_snan, '10')
def test_is_subnormal(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_subnormal(Decimal(10))
self.assertEqual(c.is_subnormal(10), d)
self.assertRaises(TypeError, c.is_subnormal, '10')
def test_is_zero(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.is_zero(Decimal(10))
self.assertEqual(c.is_zero(10), d)
self.assertRaises(TypeError, c.is_zero, '10')
def test_ln(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.ln(Decimal(10))
self.assertEqual(c.ln(10), d)
self.assertRaises(TypeError, c.ln, '10')
def test_log10(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.log10(Decimal(10))
self.assertEqual(c.log10(10), d)
self.assertRaises(TypeError, c.log10, '10')
def test_logb(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.logb(Decimal(10))
self.assertEqual(c.logb(10), d)
self.assertRaises(TypeError, c.logb, '10')
def test_logical_and(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.logical_and(Decimal(1), Decimal(1))
self.assertEqual(c.logical_and(1, 1), d)
@@ -2062,12 +3143,18 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.logical_and, 1, '1')
def test_logical_invert(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.logical_invert(Decimal(1000))
self.assertEqual(c.logical_invert(1000), d)
self.assertRaises(TypeError, c.logical_invert, '1000')
def test_logical_or(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.logical_or(Decimal(1), Decimal(1))
self.assertEqual(c.logical_or(1, 1), d)
@@ -2077,6 +3164,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.logical_or, 1, '1')
def test_logical_xor(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.logical_xor(Decimal(1), Decimal(1))
self.assertEqual(c.logical_xor(1, 1), d)
@@ -2086,6 +3176,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.logical_xor, 1, '1')
def test_max(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.max(Decimal(1), Decimal(2))
self.assertEqual(c.max(1, 2), d)
@@ -2095,6 +3188,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.max, 1, '2')
def test_max_mag(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.max_mag(Decimal(1), Decimal(2))
self.assertEqual(c.max_mag(1, 2), d)
@@ -2104,6 +3200,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.max_mag, 1, '2')
def test_min(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.min(Decimal(1), Decimal(2))
self.assertEqual(c.min(1, 2), d)
@@ -2113,6 +3212,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.min, 1, '2')
def test_min_mag(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.min_mag(Decimal(1), Decimal(2))
self.assertEqual(c.min_mag(1, 2), d)
@@ -2122,12 +3224,18 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.min_mag, 1, '2')
def test_minus(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.minus(Decimal(10))
self.assertEqual(c.minus(10), d)
self.assertRaises(TypeError, c.minus, '10')
def test_multiply(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.multiply(Decimal(1), Decimal(2))
self.assertEqual(c.multiply(1, 2), d)
@@ -2137,18 +3245,27 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.multiply, 1, '2')
def test_next_minus(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.next_minus(Decimal(10))
self.assertEqual(c.next_minus(10), d)
self.assertRaises(TypeError, c.next_minus, '10')
def test_next_plus(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.next_plus(Decimal(10))
self.assertEqual(c.next_plus(10), d)
self.assertRaises(TypeError, c.next_plus, '10')
def test_next_toward(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.next_toward(Decimal(1), Decimal(2))
self.assertEqual(c.next_toward(1, 2), d)
@@ -2158,36 +3275,50 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.next_toward, 1, '2')
def test_normalize(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.normalize(Decimal(10))
self.assertEqual(c.normalize(10), d)
self.assertRaises(TypeError, c.normalize, '10')
def test_number_class(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
self.assertEqual(c.number_class(123), c.number_class(Decimal(123)))
self.assertEqual(c.number_class(0), c.number_class(Decimal(0)))
self.assertEqual(c.number_class(-45), c.number_class(Decimal(-45)))
- def test_power(self):
- c = Context()
- d = c.power(Decimal(1), Decimal(4), Decimal(2))
- self.assertEqual(c.power(1, 4, 2), d)
- self.assertEqual(c.power(Decimal(1), 4, 2), d)
- self.assertEqual(c.power(1, Decimal(4), 2), d)
- self.assertEqual(c.power(1, 4, Decimal(2)), d)
- self.assertEqual(c.power(Decimal(1), Decimal(4), 2), d)
- self.assertRaises(TypeError, c.power, '1', 4, 2)
- self.assertRaises(TypeError, c.power, 1, '4', 2)
- self.assertRaises(TypeError, c.power, 1, 4, '2')
-
def test_plus(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.plus(Decimal(10))
self.assertEqual(c.plus(10), d)
self.assertRaises(TypeError, c.plus, '10')
+ def test_power(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
+ c = Context()
+ d = c.power(Decimal(1), Decimal(4))
+ self.assertEqual(c.power(1, 4), d)
+ self.assertEqual(c.power(Decimal(1), 4), d)
+ self.assertEqual(c.power(1, Decimal(4)), d)
+ self.assertEqual(c.power(Decimal(1), Decimal(4)), d)
+ self.assertRaises(TypeError, c.power, '1', 4)
+ self.assertRaises(TypeError, c.power, 1, '4')
+ self.assertEqual(c.power(modulo=5, b=8, a=2), 1)
+
def test_quantize(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.quantize(Decimal(1), Decimal(2))
self.assertEqual(c.quantize(1, 2), d)
@@ -2197,6 +3328,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.quantize, 1, '2')
def test_remainder(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.remainder(Decimal(1), Decimal(2))
self.assertEqual(c.remainder(1, 2), d)
@@ -2206,6 +3340,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.remainder, 1, '2')
def test_remainder_near(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.remainder_near(Decimal(1), Decimal(2))
self.assertEqual(c.remainder_near(1, 2), d)
@@ -2215,6 +3352,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.remainder_near, 1, '2')
def test_rotate(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.rotate(Decimal(1), Decimal(2))
self.assertEqual(c.rotate(1, 2), d)
@@ -2224,12 +3364,18 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.rotate, 1, '2')
def test_sqrt(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.sqrt(Decimal(10))
self.assertEqual(c.sqrt(10), d)
self.assertRaises(TypeError, c.sqrt, '10')
def test_same_quantum(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.same_quantum(Decimal(1), Decimal(2))
self.assertEqual(c.same_quantum(1, 2), d)
@@ -2239,6 +3385,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.same_quantum, 1, '2')
def test_scaleb(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.scaleb(Decimal(1), Decimal(2))
self.assertEqual(c.scaleb(1, 2), d)
@@ -2248,6 +3397,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.scaleb, 1, '2')
def test_shift(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.shift(Decimal(1), Decimal(2))
self.assertEqual(c.shift(1, 2), d)
@@ -2257,6 +3409,9 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.shift, 1, '2')
def test_subtract(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.subtract(Decimal(1), Decimal(2))
self.assertEqual(c.subtract(1, 2), d)
@@ -2266,35 +3421,56 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.subtract, 1, '2')
def test_to_eng_string(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.to_eng_string(Decimal(10))
self.assertEqual(c.to_eng_string(10), d)
self.assertRaises(TypeError, c.to_eng_string, '10')
def test_to_sci_string(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.to_sci_string(Decimal(10))
self.assertEqual(c.to_sci_string(10), d)
self.assertRaises(TypeError, c.to_sci_string, '10')
def test_to_integral_exact(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.to_integral_exact(Decimal(10))
self.assertEqual(c.to_integral_exact(10), d)
self.assertRaises(TypeError, c.to_integral_exact, '10')
def test_to_integral_value(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+
c = Context()
d = c.to_integral_value(Decimal(10))
self.assertEqual(c.to_integral_value(10), d)
self.assertRaises(TypeError, c.to_integral_value, '10')
+ self.assertRaises(TypeError, c.to_integral_value, 10, 'x')
+
+class CContextAPItests(ContextAPItests):
+ decimal = C
+class PyContextAPItests(ContextAPItests):
+ decimal = P
-class WithStatementTest(unittest.TestCase):
+class ContextWithStatement(unittest.TestCase):
# Can't do these as docstrings until Python 2.6
# as doctest can't handle __future__ statements
def test_localcontext(self):
# Use a copy of the current context in the block
+ getcontext = self.decimal.getcontext
+ localcontext = self.decimal.localcontext
+
orig_ctx = getcontext()
with localcontext() as enter_ctx:
set_ctx = getcontext()
@@ -2305,6 +3481,11 @@ class WithStatementTest(unittest.TestCase):
def test_localcontextarg(self):
# Use a copy of the supplied context in the block
+ Context = self.decimal.Context
+ getcontext = self.decimal.getcontext
+ localcontext = self.decimal.localcontext
+
+ localcontext = self.decimal.localcontext
orig_ctx = getcontext()
new_ctx = Context(prec=42)
with localcontext(new_ctx) as enter_ctx:
@@ -2315,17 +3496,130 @@ class WithStatementTest(unittest.TestCase):
self.assertIsNot(new_ctx, set_ctx, 'did not copy the context')
self.assertIs(set_ctx, enter_ctx, '__enter__ returned wrong context')
+ def test_nested_with_statements(self):
+ # Use a copy of the supplied context in the block
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ getcontext = self.decimal.getcontext
+ localcontext = self.decimal.localcontext
+ Clamped = self.decimal.Clamped
+ Overflow = self.decimal.Overflow
+
+ orig_ctx = getcontext()
+ orig_ctx.clear_flags()
+ new_ctx = Context(Emax=384)
+ with localcontext() as c1:
+ self.assertEqual(c1.flags, orig_ctx.flags)
+ self.assertEqual(c1.traps, orig_ctx.traps)
+ c1.traps[Clamped] = True
+ c1.Emin = -383
+ self.assertNotEqual(orig_ctx.Emin, -383)
+ self.assertRaises(Clamped, c1.create_decimal, '0e-999')
+ self.assertTrue(c1.flags[Clamped])
+ with localcontext(new_ctx) as c2:
+ self.assertEqual(c2.flags, new_ctx.flags)
+ self.assertEqual(c2.traps, new_ctx.traps)
+ self.assertRaises(Overflow, c2.power, Decimal('3.4e200'), 2)
+ self.assertFalse(c2.flags[Clamped])
+ self.assertTrue(c2.flags[Overflow])
+ del c2
+ self.assertFalse(c1.flags[Overflow])
+ del c1
+ self.assertNotEqual(orig_ctx.Emin, -383)
+ self.assertFalse(orig_ctx.flags[Clamped])
+ self.assertFalse(orig_ctx.flags[Overflow])
+ self.assertFalse(new_ctx.flags[Clamped])
+ self.assertFalse(new_ctx.flags[Overflow])
+
+ def test_with_statements_gc1(self):
+ localcontext = self.decimal.localcontext
+
+ with localcontext() as c1:
+ del c1
+ with localcontext() as c2:
+ del c2
+ with localcontext() as c3:
+ del c3
+ with localcontext() as c4:
+ del c4
+
+ def test_with_statements_gc2(self):
+ localcontext = self.decimal.localcontext
+
+ with localcontext() as c1:
+ with localcontext(c1) as c2:
+ del c1
+ with localcontext(c2) as c3:
+ del c2
+ with localcontext(c3) as c4:
+ del c3
+ del c4
+
+ def test_with_statements_gc3(self):
+ Context = self.decimal.Context
+ localcontext = self.decimal.localcontext
+ getcontext = self.decimal.getcontext
+ setcontext = self.decimal.setcontext
+
+ with localcontext() as c1:
+ del c1
+ n1 = Context(prec=1)
+ setcontext(n1)
+ with localcontext(n1) as c2:
+ del n1
+ self.assertEqual(c2.prec, 1)
+ del c2
+ n2 = Context(prec=2)
+ setcontext(n2)
+ del n2
+ self.assertEqual(getcontext().prec, 2)
+ n3 = Context(prec=3)
+ setcontext(n3)
+ self.assertEqual(getcontext().prec, 3)
+ with localcontext(n3) as c3:
+ del n3
+ self.assertEqual(c3.prec, 3)
+ del c3
+ n4 = Context(prec=4)
+ setcontext(n4)
+ del n4
+ self.assertEqual(getcontext().prec, 4)
+ with localcontext() as c4:
+ self.assertEqual(c4.prec, 4)
+ del c4
+
+class CContextWithStatement(ContextWithStatement):
+ decimal = C
+class PyContextWithStatement(ContextWithStatement):
+ decimal = P
+
class ContextFlags(unittest.TestCase):
+
def test_flags_irrelevant(self):
# check that the result (numeric result + flags raised) of an
# arithmetic operation doesn't depend on the current flags
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ Inexact = self.decimal.Inexact
+ Rounded = self.decimal.Rounded
+ Underflow = self.decimal.Underflow
+ Clamped = self.decimal.Clamped
+ Subnormal = self.decimal.Subnormal
+
+ def raise_error(context, flag):
+ if self.decimal == C:
+ context.flags[flag] = True
+ if context.traps[flag]:
+ raise flag
+ else:
+ context._raise_error(flag)
- context = Context(prec=9, Emin = -999999999, Emax = 999999999,
- rounding=ROUND_HALF_EVEN, traps=[], flags=[])
+ context = Context(prec=9, Emin = -425000000, Emax = 425000000,
+ rounding=ROUND_HALF_EVEN, traps=[], flags=[])
# operations that raise various flags, in the form (function, arglist)
operations = [
- (context._apply, [Decimal("100E-1000000009")]),
+ (context._apply, [Decimal("100E-425000010")]),
(context.sqrt, [Decimal(2)]),
(context.add, [Decimal("1.23456789"), Decimal("9.87654321")]),
(context.multiply, [Decimal("1.23456789"), Decimal("9.87654321")]),
@@ -2346,7 +3640,7 @@ class ContextFlags(unittest.TestCase):
# set flags, before calling operation
context.clear_flags()
for flag in extra_flags:
- context._raise_error(flag)
+ raise_error(context, flag)
new_ans = fn(*args)
# flags that we expect to be set after the operation
@@ -2367,6 +3661,1742 @@ class ContextFlags(unittest.TestCase):
"operation raises different flags depending on flags set: " +
"expected %s, got %s" % (expected_flags, new_flags))
+ def test_flag_comparisons(self):
+ Context = self.decimal.Context
+ Inexact = self.decimal.Inexact
+ Rounded = self.decimal.Rounded
+
+ c = Context()
+
+ # Valid SignalDict
+ self.assertNotEqual(c.flags, c.traps)
+ self.assertNotEqual(c.traps, c.flags)
+
+ c.flags = c.traps
+ self.assertEqual(c.flags, c.traps)
+ self.assertEqual(c.traps, c.flags)
+
+ c.flags[Rounded] = True
+ c.traps = c.flags
+ self.assertEqual(c.flags, c.traps)
+ self.assertEqual(c.traps, c.flags)
+
+ d = {}
+ d.update(c.flags)
+ self.assertEqual(d, c.flags)
+ self.assertEqual(c.flags, d)
+
+ d[Inexact] = True
+ self.assertNotEqual(d, c.flags)
+ self.assertNotEqual(c.flags, d)
+
+ # Invalid SignalDict
+ d = {Inexact:False}
+ self.assertNotEqual(d, c.flags)
+ self.assertNotEqual(c.flags, d)
+
+ d = ["xyz"]
+ self.assertNotEqual(d, c.flags)
+ self.assertNotEqual(c.flags, d)
+
+ @requires_IEEE_754
+ def test_float_operation(self):
+ Decimal = self.decimal.Decimal
+ FloatOperation = self.decimal.FloatOperation
+ localcontext = self.decimal.localcontext
+
+ with localcontext() as c:
+ ##### trap is off by default
+ self.assertFalse(c.traps[FloatOperation])
+
+ # implicit conversion sets the flag
+ c.clear_flags()
+ self.assertEqual(Decimal(7.5), 7.5)
+ self.assertTrue(c.flags[FloatOperation])
+
+ c.clear_flags()
+ self.assertEqual(c.create_decimal(7.5), 7.5)
+ self.assertTrue(c.flags[FloatOperation])
+
+ # explicit conversion does not set the flag
+ c.clear_flags()
+ x = Decimal.from_float(7.5)
+ self.assertFalse(c.flags[FloatOperation])
+ # comparison sets the flag
+ self.assertEqual(x, 7.5)
+ self.assertTrue(c.flags[FloatOperation])
+
+ c.clear_flags()
+ x = c.create_decimal_from_float(7.5)
+ self.assertFalse(c.flags[FloatOperation])
+ self.assertEqual(x, 7.5)
+ self.assertTrue(c.flags[FloatOperation])
+
+ ##### set the trap
+ c.traps[FloatOperation] = True
+
+ # implicit conversion raises
+ c.clear_flags()
+ self.assertRaises(FloatOperation, Decimal, 7.5)
+ self.assertTrue(c.flags[FloatOperation])
+
+ c.clear_flags()
+ self.assertRaises(FloatOperation, c.create_decimal, 7.5)
+ self.assertTrue(c.flags[FloatOperation])
+
+ # explicit conversion is silent
+ c.clear_flags()
+ x = Decimal.from_float(7.5)
+ self.assertFalse(c.flags[FloatOperation])
+
+ c.clear_flags()
+ x = c.create_decimal_from_float(7.5)
+ self.assertFalse(c.flags[FloatOperation])
+
+ def test_float_comparison(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ FloatOperation = self.decimal.FloatOperation
+ localcontext = self.decimal.localcontext
+
+ def assert_attr(a, b, attr, context, signal=None):
+ context.clear_flags()
+ f = getattr(a, attr)
+ if signal == FloatOperation:
+ self.assertRaises(signal, f, b)
+ else:
+ self.assertIs(f(b), True)
+ self.assertTrue(context.flags[FloatOperation])
+
+ small_d = Decimal('0.25')
+ big_d = Decimal('3.0')
+ small_f = 0.25
+ big_f = 3.0
+
+ zero_d = Decimal('0.0')
+ neg_zero_d = Decimal('-0.0')
+ zero_f = 0.0
+ neg_zero_f = -0.0
+
+ inf_d = Decimal('Infinity')
+ neg_inf_d = Decimal('-Infinity')
+ inf_f = float('inf')
+ neg_inf_f = float('-inf')
+
+ def doit(c, signal=None):
+ # Order
+ for attr in '__lt__', '__le__':
+ assert_attr(small_d, big_f, attr, c, signal)
+
+ for attr in '__gt__', '__ge__':
+ assert_attr(big_d, small_f, attr, c, signal)
+
+ # Equality
+ assert_attr(small_d, small_f, '__eq__', c, None)
+
+ assert_attr(neg_zero_d, neg_zero_f, '__eq__', c, None)
+ assert_attr(neg_zero_d, zero_f, '__eq__', c, None)
+
+ assert_attr(zero_d, neg_zero_f, '__eq__', c, None)
+ assert_attr(zero_d, zero_f, '__eq__', c, None)
+
+ assert_attr(neg_inf_d, neg_inf_f, '__eq__', c, None)
+ assert_attr(inf_d, inf_f, '__eq__', c, None)
+
+ # Inequality
+ assert_attr(small_d, big_f, '__ne__', c, None)
+
+ assert_attr(Decimal('0.1'), 0.1, '__ne__', c, None)
+
+ assert_attr(neg_inf_d, inf_f, '__ne__', c, None)
+ assert_attr(inf_d, neg_inf_f, '__ne__', c, None)
+
+ assert_attr(Decimal('NaN'), float('nan'), '__ne__', c, None)
+
+ def test_containers(c, signal=None):
+ c.clear_flags()
+ s = set([100.0, Decimal('100.0')])
+ self.assertEqual(len(s), 1)
+ self.assertTrue(c.flags[FloatOperation])
+
+ c.clear_flags()
+ if signal:
+ self.assertRaises(signal, sorted, [1.0, Decimal('10.0')])
+ else:
+ s = sorted([10.0, Decimal('10.0')])
+ self.assertTrue(c.flags[FloatOperation])
+
+ c.clear_flags()
+ b = 10.0 in [Decimal('10.0'), 1.0]
+ self.assertTrue(c.flags[FloatOperation])
+
+ c.clear_flags()
+ b = 10.0 in {Decimal('10.0'):'a', 1.0:'b'}
+ self.assertTrue(c.flags[FloatOperation])
+
+ nc = Context()
+ with localcontext(nc) as c:
+ self.assertFalse(c.traps[FloatOperation])
+ doit(c, signal=None)
+ test_containers(c, signal=None)
+
+ c.traps[FloatOperation] = True
+ doit(c, signal=FloatOperation)
+ test_containers(c, signal=FloatOperation)
+
+ def test_float_operation_default(self):
+ Decimal = self.decimal.Decimal
+ Context = self.decimal.Context
+ Inexact = self.decimal.Inexact
+ FloatOperation= self.decimal.FloatOperation
+
+ context = Context()
+ self.assertFalse(context.flags[FloatOperation])
+ self.assertFalse(context.traps[FloatOperation])
+
+ context.clear_traps()
+ context.traps[Inexact] = True
+ context.traps[FloatOperation] = True
+ self.assertTrue(context.traps[FloatOperation])
+ self.assertTrue(context.traps[Inexact])
+
+class CContextFlags(ContextFlags):
+ decimal = C
+class PyContextFlags(ContextFlags):
+ decimal = P
+
+class SpecialContexts(unittest.TestCase):
+ """Test the context templates."""
+
+ def test_context_templates(self):
+ BasicContext = self.decimal.BasicContext
+ ExtendedContext = self.decimal.ExtendedContext
+ getcontext = self.decimal.getcontext
+ setcontext = self.decimal.setcontext
+ InvalidOperation = self.decimal.InvalidOperation
+ DivisionByZero = self.decimal.DivisionByZero
+ Overflow = self.decimal.Overflow
+ Underflow = self.decimal.Underflow
+ Clamped = self.decimal.Clamped
+
+ assert_signals(self, BasicContext, 'traps',
+ [InvalidOperation, DivisionByZero, Overflow, Underflow, Clamped]
+ )
+
+ savecontext = getcontext().copy()
+ basic_context_prec = BasicContext.prec
+ extended_context_prec = ExtendedContext.prec
+
+ ex = None
+ try:
+ BasicContext.prec = ExtendedContext.prec = 441
+ for template in BasicContext, ExtendedContext:
+ setcontext(template)
+ c = getcontext()
+ self.assertIsNot(c, template)
+ self.assertEqual(c.prec, 441)
+ except Exception as e:
+ ex = e.__class__
+ finally:
+ BasicContext.prec = basic_context_prec
+ ExtendedContext.prec = extended_context_prec
+ setcontext(savecontext)
+ if ex:
+ raise ex
+
+ def test_default_context(self):
+ DefaultContext = self.decimal.DefaultContext
+ BasicContext = self.decimal.BasicContext
+ ExtendedContext = self.decimal.ExtendedContext
+ getcontext = self.decimal.getcontext
+ setcontext = self.decimal.setcontext
+ InvalidOperation = self.decimal.InvalidOperation
+ DivisionByZero = self.decimal.DivisionByZero
+ Overflow = self.decimal.Overflow
+
+ self.assertEqual(BasicContext.prec, 9)
+ self.assertEqual(ExtendedContext.prec, 9)
+
+ assert_signals(self, DefaultContext, 'traps',
+ [InvalidOperation, DivisionByZero, Overflow]
+ )
+
+ savecontext = getcontext().copy()
+ default_context_prec = DefaultContext.prec
+
+ ex = None
+ try:
+ c = getcontext()
+ saveprec = c.prec
+
+ DefaultContext.prec = 961
+ c = getcontext()
+ self.assertEqual(c.prec, saveprec)
+
+ setcontext(DefaultContext)
+ c = getcontext()
+ self.assertIsNot(c, DefaultContext)
+ self.assertEqual(c.prec, 961)
+ except Exception as e:
+ ex = e.__class__
+ finally:
+ DefaultContext.prec = default_context_prec
+ setcontext(savecontext)
+ if ex:
+ raise ex
+
+class CSpecialContexts(SpecialContexts):
+ decimal = C
+class PySpecialContexts(SpecialContexts):
+ decimal = P
+
+class ContextInputValidation(unittest.TestCase):
+
+ def test_invalid_context(self):
+ Context = self.decimal.Context
+ DefaultContext = self.decimal.DefaultContext
+
+ c = DefaultContext.copy()
+
+ # prec, Emax
+ for attr in ['prec', 'Emax']:
+ setattr(c, attr, 999999)
+ self.assertEqual(getattr(c, attr), 999999)
+ self.assertRaises(ValueError, setattr, c, attr, -1)
+ self.assertRaises(TypeError, setattr, c, attr, 'xyz')
+
+ # Emin
+ setattr(c, 'Emin', -999999)
+ self.assertEqual(getattr(c, 'Emin'), -999999)
+ self.assertRaises(ValueError, setattr, c, 'Emin', 1)
+ self.assertRaises(TypeError, setattr, c, 'Emin', (1,2,3))
+
+ self.assertRaises(TypeError, setattr, c, 'rounding', -1)
+ self.assertRaises(TypeError, setattr, c, 'rounding', 9)
+ self.assertRaises(TypeError, setattr, c, 'rounding', 1.0)
+ self.assertRaises(TypeError, setattr, c, 'rounding', 'xyz')
+
+ # capitals, clamp
+ for attr in ['capitals', 'clamp']:
+ self.assertRaises(ValueError, setattr, c, attr, -1)
+ self.assertRaises(ValueError, setattr, c, attr, 2)
+ self.assertRaises(TypeError, setattr, c, attr, [1,2,3])
+
+ # Invalid attribute
+ self.assertRaises(AttributeError, setattr, c, 'emax', 100)
+
+ # Invalid signal dict
+ self.assertRaises(TypeError, setattr, c, 'flags', [])
+ self.assertRaises(KeyError, setattr, c, 'flags', {})
+ self.assertRaises(KeyError, setattr, c, 'traps',
+ {'InvalidOperation':0})
+
+ # Attributes cannot be deleted
+ for attr in ['prec', 'Emax', 'Emin', 'rounding', 'capitals', 'clamp',
+ 'flags', 'traps']:
+ self.assertRaises(AttributeError, c.__delattr__, attr)
+
+ # Invalid attributes
+ self.assertRaises(TypeError, getattr, c, 9)
+ self.assertRaises(TypeError, setattr, c, 9)
+
+ # Invalid values in constructor
+ self.assertRaises(TypeError, Context, rounding=999999)
+ self.assertRaises(TypeError, Context, rounding='xyz')
+ self.assertRaises(ValueError, Context, clamp=2)
+ self.assertRaises(ValueError, Context, capitals=-1)
+ self.assertRaises(KeyError, Context, flags=["P"])
+ self.assertRaises(KeyError, Context, traps=["Q"])
+
+ # Type error in conversion
+ self.assertRaises(TypeError, Context, flags=(0,1))
+ self.assertRaises(TypeError, Context, traps=(1,0))
+
+class CContextInputValidation(ContextInputValidation):
+ decimal = C
+class PyContextInputValidation(ContextInputValidation):
+ decimal = P
+
+class ContextSubclassing(unittest.TestCase):
+
+ def test_context_subclassing(self):
+ decimal = self.decimal
+ Decimal = decimal.Decimal
+ Context = decimal.Context
+ Clamped = decimal.Clamped
+ DivisionByZero = decimal.DivisionByZero
+ Inexact = decimal.Inexact
+ Overflow = decimal.Overflow
+ Rounded = decimal.Rounded
+ Subnormal = decimal.Subnormal
+ Underflow = decimal.Underflow
+ InvalidOperation = decimal.InvalidOperation
+
+ class MyContext(Context):
+ def __init__(self, prec=None, rounding=None, Emin=None, Emax=None,
+ capitals=None, clamp=None, flags=None,
+ traps=None):
+ Context.__init__(self)
+ if prec is not None:
+ self.prec = prec
+ if rounding is not None:
+ self.rounding = rounding
+ if Emin is not None:
+ self.Emin = Emin
+ if Emax is not None:
+ self.Emax = Emax
+ if capitals is not None:
+ self.capitals = capitals
+ if clamp is not None:
+ self.clamp = clamp
+ if flags is not None:
+ if isinstance(flags, list):
+ flags = {v:(v in flags) for v in OrderedSignals[decimal] + flags}
+ self.flags = flags
+ if traps is not None:
+ if isinstance(traps, list):
+ traps = {v:(v in traps) for v in OrderedSignals[decimal] + traps}
+ self.traps = traps
+
+ c = Context()
+ d = MyContext()
+ for attr in ('prec', 'rounding', 'Emin', 'Emax', 'capitals', 'clamp',
+ 'flags', 'traps'):
+ self.assertEqual(getattr(c, attr), getattr(d, attr))
+
+ # prec
+ self.assertRaises(ValueError, MyContext, **{'prec':-1})
+ c = MyContext(prec=1)
+ self.assertEqual(c.prec, 1)
+ self.assertRaises(InvalidOperation, c.quantize, Decimal('9e2'), 0)
+
+ # rounding
+ self.assertRaises(TypeError, MyContext, **{'rounding':'XYZ'})
+ c = MyContext(rounding=ROUND_DOWN, prec=1)
+ self.assertEqual(c.rounding, ROUND_DOWN)
+ self.assertEqual(c.plus(Decimal('9.9')), 9)
+
+ # Emin
+ self.assertRaises(ValueError, MyContext, **{'Emin':5})
+ c = MyContext(Emin=-1, prec=1)
+ self.assertEqual(c.Emin, -1)
+ x = c.add(Decimal('1e-99'), Decimal('2.234e-2000'))
+ self.assertEqual(x, Decimal('0.0'))
+ for signal in (Inexact, Underflow, Subnormal, Rounded, Clamped):
+ self.assertTrue(c.flags[signal])
+
+ # Emax
+ self.assertRaises(ValueError, MyContext, **{'Emax':-1})
+ c = MyContext(Emax=1, prec=1)
+ self.assertEqual(c.Emax, 1)
+ self.assertRaises(Overflow, c.add, Decimal('1e99'), Decimal('2.234e2000'))
+ if self.decimal == C:
+ for signal in (Inexact, Overflow, Rounded):
+ self.assertTrue(c.flags[signal])
+
+ # capitals
+ self.assertRaises(ValueError, MyContext, **{'capitals':-1})
+ c = MyContext(capitals=0)
+ self.assertEqual(c.capitals, 0)
+ x = c.create_decimal('1E222')
+ self.assertEqual(c.to_sci_string(x), '1e+222')
+
+ # clamp
+ self.assertRaises(ValueError, MyContext, **{'clamp':2})
+ c = MyContext(clamp=1, Emax=99)
+ self.assertEqual(c.clamp, 1)
+ x = c.plus(Decimal('1e99'))
+ self.assertEqual(str(x), '1.000000000000000000000000000E+99')
+
+ # flags
+ self.assertRaises(TypeError, MyContext, **{'flags':'XYZ'})
+ c = MyContext(flags=[Rounded, DivisionByZero])
+ for signal in (Rounded, DivisionByZero):
+ self.assertTrue(c.flags[signal])
+ c.clear_flags()
+ for signal in OrderedSignals[decimal]:
+ self.assertFalse(c.flags[signal])
+
+ # traps
+ self.assertRaises(TypeError, MyContext, **{'traps':'XYZ'})
+ c = MyContext(traps=[Rounded, DivisionByZero])
+ for signal in (Rounded, DivisionByZero):
+ self.assertTrue(c.traps[signal])
+ c.clear_traps()
+ for signal in OrderedSignals[decimal]:
+ self.assertFalse(c.traps[signal])
+
+class CContextSubclassing(ContextSubclassing):
+ decimal = C
+class PyContextSubclassing(ContextSubclassing):
+ decimal = P
+
+@skip_if_extra_functionality
+class CheckAttributes(unittest.TestCase):
+
+ def test_module_attributes(self):
+
+ # Architecture dependent context limits
+ self.assertEqual(C.MAX_PREC, P.MAX_PREC)
+ self.assertEqual(C.MAX_EMAX, P.MAX_EMAX)
+ self.assertEqual(C.MIN_EMIN, P.MIN_EMIN)
+ self.assertEqual(C.MIN_ETINY, P.MIN_ETINY)
+
+ self.assertTrue(C.HAVE_THREADS is True or C.HAVE_THREADS is False)
+ self.assertTrue(P.HAVE_THREADS is True or P.HAVE_THREADS is False)
+
+ self.assertEqual(C.__version__, P.__version__)
+
+ x = dir(C)
+ y = [s for s in dir(P) if '__' in s or not s.startswith('_')]
+ self.assertEqual(set(x) - set(y), set())
+
+ def test_context_attributes(self):
+
+ x = [s for s in dir(C.Context()) if '__' in s or not s.startswith('_')]
+ y = [s for s in dir(P.Context()) if '__' in s or not s.startswith('_')]
+ self.assertEqual(set(x) - set(y), set())
+
+ def test_decimal_attributes(self):
+
+ x = [s for s in dir(C.Decimal(9)) if '__' in s or not s.startswith('_')]
+ y = [s for s in dir(C.Decimal(9)) if '__' in s or not s.startswith('_')]
+ self.assertEqual(set(x) - set(y), set())
+
+class Coverage(unittest.TestCase):
+
+ def test_adjusted(self):
+ Decimal = self.decimal.Decimal
+
+ self.assertEqual(Decimal('1234e9999').adjusted(), 10002)
+ # XXX raise?
+ self.assertEqual(Decimal('nan').adjusted(), 0)
+ self.assertEqual(Decimal('inf').adjusted(), 0)
+
+ def test_canonical(self):
+ Decimal = self.decimal.Decimal
+ getcontext = self.decimal.getcontext
+
+ x = Decimal(9).canonical()
+ self.assertEqual(x, 9)
+
+ c = getcontext()
+ x = c.canonical(Decimal(9))
+ self.assertEqual(x, 9)
+
+ def test_context_repr(self):
+ c = self.decimal.DefaultContext.copy()
+
+ c.prec = 425000000
+ c.Emax = 425000000
+ c.Emin = -425000000
+ c.rounding = ROUND_HALF_DOWN
+ c.capitals = 0
+ c.clamp = 1
+ for sig in OrderedSignals[self.decimal]:
+ c.flags[sig] = False
+ c.traps[sig] = False
+
+ s = c.__repr__()
+ t = "Context(prec=425000000, rounding=ROUND_HALF_DOWN, " \
+ "Emin=-425000000, Emax=425000000, capitals=0, clamp=1, " \
+ "flags=[], traps=[])"
+ self.assertEqual(s, t)
+
+ def test_implicit_context(self):
+ Decimal = self.decimal.Decimal
+ localcontext = self.decimal.localcontext
+
+ with localcontext() as c:
+ c.prec = 1
+ c.Emax = 1
+ c.Emin = -1
+
+ # abs
+ self.assertEqual(abs(Decimal("-10")), 10)
+ # add
+ self.assertEqual(Decimal("7") + 1, 8)
+ # divide
+ self.assertEqual(Decimal("10") / 5, 2)
+ # divide_int
+ self.assertEqual(Decimal("10") // 7, 1)
+ # fma
+ self.assertEqual(Decimal("1.2").fma(Decimal("0.01"), 1), 1)
+ self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True)
+ # three arg power
+ self.assertEqual(pow(Decimal(10), 2, 7), 2)
+ # exp
+ self.assertEqual(Decimal("1.01").exp(), 3)
+ # is_normal
+ self.assertIs(Decimal("0.01").is_normal(), False)
+ # is_subnormal
+ self.assertIs(Decimal("0.01").is_subnormal(), True)
+ # ln
+ self.assertEqual(Decimal("20").ln(), 3)
+ # log10
+ self.assertEqual(Decimal("20").log10(), 1)
+ # logb
+ self.assertEqual(Decimal("580").logb(), 2)
+ # logical_invert
+ self.assertEqual(Decimal("10").logical_invert(), 1)
+ # minus
+ self.assertEqual(-Decimal("-10"), 10)
+ # multiply
+ self.assertEqual(Decimal("2") * 4, 8)
+ # next_minus
+ self.assertEqual(Decimal("10").next_minus(), 9)
+ # next_plus
+ self.assertEqual(Decimal("10").next_plus(), Decimal('2E+1'))
+ # normalize
+ self.assertEqual(Decimal("-10").normalize(), Decimal('-1E+1'))
+ # number_class
+ self.assertEqual(Decimal("10").number_class(), '+Normal')
+ # plus
+ self.assertEqual(+Decimal("-1"), -1)
+ # remainder
+ self.assertEqual(Decimal("10") % 7, 3)
+ # subtract
+ self.assertEqual(Decimal("10") - 7, 3)
+ # to_integral_exact
+ self.assertEqual(Decimal("1.12345").to_integral_exact(), 1)
+
+ # Boolean functions
+ self.assertTrue(Decimal("1").is_canonical())
+ self.assertTrue(Decimal("1").is_finite())
+ self.assertTrue(Decimal("1").is_finite())
+ self.assertTrue(Decimal("snan").is_snan())
+ self.assertTrue(Decimal("-1").is_signed())
+ self.assertTrue(Decimal("0").is_zero())
+ self.assertTrue(Decimal("0").is_zero())
+
+ # Copy
+ with localcontext() as c:
+ c.prec = 10000
+ x = 1228 ** 1523
+ y = -Decimal(x)
+
+ z = y.copy_abs()
+ self.assertEqual(z, x)
+
+ z = y.copy_negate()
+ self.assertEqual(z, x)
+
+ z = y.copy_sign(Decimal(1))
+ self.assertEqual(z, x)
+
+ def test_divmod(self):
+ Decimal = self.decimal.Decimal
+ localcontext = self.decimal.localcontext
+ InvalidOperation = self.decimal.InvalidOperation
+ DivisionByZero = self.decimal.DivisionByZero
+
+ with localcontext() as c:
+ q, r = divmod(Decimal("10912837129"), 1001)
+ self.assertEqual(q, Decimal('10901935'))
+ self.assertEqual(r, Decimal('194'))
+
+ q, r = divmod(Decimal("NaN"), 7)
+ self.assertTrue(q.is_nan() and r.is_nan())
+
+ c.traps[InvalidOperation] = False
+ q, r = divmod(Decimal("NaN"), 7)
+ self.assertTrue(q.is_nan() and r.is_nan())
+
+ c.traps[InvalidOperation] = False
+ c.clear_flags()
+ q, r = divmod(Decimal("inf"), Decimal("inf"))
+ self.assertTrue(q.is_nan() and r.is_nan())
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ q, r = divmod(Decimal("inf"), 101)
+ self.assertTrue(q.is_infinite() and r.is_nan())
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ q, r = divmod(Decimal(0), 0)
+ self.assertTrue(q.is_nan() and r.is_nan())
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.traps[DivisionByZero] = False
+ c.clear_flags()
+ q, r = divmod(Decimal(11), 0)
+ self.assertTrue(q.is_infinite() and r.is_nan())
+ self.assertTrue(c.flags[InvalidOperation] and
+ c.flags[DivisionByZero])
+
+ def test_power(self):
+ Decimal = self.decimal.Decimal
+ localcontext = self.decimal.localcontext
+ Overflow = self.decimal.Overflow
+ Rounded = self.decimal.Rounded
+
+ with localcontext() as c:
+ c.prec = 3
+ c.clear_flags()
+ self.assertEqual(Decimal("1.0") ** 100, Decimal('1.00'))
+ self.assertTrue(c.flags[Rounded])
+
+ c.prec = 1
+ c.Emax = 1
+ c.Emin = -1
+ c.clear_flags()
+ c.traps[Overflow] = False
+ self.assertEqual(Decimal(10000) ** Decimal("0.5"), Decimal('inf'))
+ self.assertTrue(c.flags[Overflow])
+
+ def test_quantize(self):
+ Decimal = self.decimal.Decimal
+ localcontext = self.decimal.localcontext
+ InvalidOperation = self.decimal.InvalidOperation
+
+ with localcontext() as c:
+ c.prec = 1
+ c.Emax = 1
+ c.Emin = -1
+ c.traps[InvalidOperation] = False
+ x = Decimal(99).quantize(Decimal("1e1"))
+ self.assertTrue(x.is_nan())
+
+ def test_radix(self):
+ Decimal = self.decimal.Decimal
+ getcontext = self.decimal.getcontext
+
+ c = getcontext()
+ self.assertEqual(Decimal("1").radix(), 10)
+ self.assertEqual(c.radix(), 10)
+
+ def test_rop(self):
+ Decimal = self.decimal.Decimal
+
+ for attr in ('__radd__', '__rsub__', '__rmul__', '__rtruediv__',
+ '__rdivmod__', '__rmod__', '__rfloordiv__', '__rpow__'):
+ self.assertIs(getattr(Decimal("1"), attr)("xyz"), NotImplemented)
+
+ def test_round(self):
+ # Python3 behavior: round() returns Decimal
+ Decimal = self.decimal.Decimal
+ getcontext = self.decimal.getcontext
+
+ c = getcontext()
+ c.prec = 28
+
+ self.assertEqual(str(Decimal("9.99").__round__()), "10")
+ self.assertEqual(str(Decimal("9.99e-5").__round__()), "0")
+ self.assertEqual(str(Decimal("1.23456789").__round__(5)), "1.23457")
+ self.assertEqual(str(Decimal("1.2345").__round__(10)), "1.2345000000")
+ self.assertEqual(str(Decimal("1.2345").__round__(-10)), "0E+10")
+
+ self.assertRaises(TypeError, Decimal("1.23").__round__, "5")
+ self.assertRaises(TypeError, Decimal("1.23").__round__, 5, 8)
+
+ def test_create_decimal(self):
+ c = self.decimal.Context()
+ self.assertRaises(ValueError, c.create_decimal, ["%"])
+
+ def test_int(self):
+ Decimal = self.decimal.Decimal
+ localcontext = self.decimal.localcontext
+
+ with localcontext() as c:
+ c.prec = 9999
+ x = Decimal(1221**1271) / 10**3923
+ self.assertEqual(int(x), 1)
+ self.assertEqual(x.to_integral(), 2)
+
+ def test_copy(self):
+ Context = self.decimal.Context
+
+ c = Context()
+ c.prec = 10000
+ x = -(1172 ** 1712)
+
+ y = c.copy_abs(x)
+ self.assertEqual(y, -x)
+
+ y = c.copy_negate(x)
+ self.assertEqual(y, -x)
+
+ y = c.copy_sign(x, 1)
+ self.assertEqual(y, -x)
+
+class CCoverage(Coverage):
+ decimal = C
+class PyCoverage(Coverage):
+ decimal = P
+
+class PyFunctionality(unittest.TestCase):
+ """Extra functionality in decimal.py"""
+
+ def test_py_quantize_watchexp(self):
+ # watchexp functionality
+ Decimal = P.Decimal
+ localcontext = P.localcontext
+
+ with localcontext() as c:
+ c.prec = 1
+ c.Emax = 1
+ c.Emin = -1
+ x = Decimal(99999).quantize(Decimal("1e3"), watchexp=False)
+ self.assertEqual(x, Decimal('1.00E+5'))
+
+ def test_py_alternate_formatting(self):
+ # triples giving a format, a Decimal, and the expected result
+ Decimal = P.Decimal
+ localcontext = P.localcontext
+
+ test_values = [
+ # Issue 7094: Alternate formatting (specified by #)
+ ('.0e', '1.0', '1e+0'),
+ ('#.0e', '1.0', '1.e+0'),
+ ('.0f', '1.0', '1'),
+ ('#.0f', '1.0', '1.'),
+ ('g', '1.1', '1.1'),
+ ('#g', '1.1', '1.1'),
+ ('.0g', '1', '1'),
+ ('#.0g', '1', '1.'),
+ ('.0%', '1.0', '100%'),
+ ('#.0%', '1.0', '100.%'),
+ ]
+ for fmt, d, result in test_values:
+ self.assertEqual(format(Decimal(d), fmt), result)
+
+class PyWhitebox(unittest.TestCase):
+ """White box testing for decimal.py"""
+
+ def test_py_exact_power(self):
+ # Rarely exercised lines in _power_exact.
+ Decimal = P.Decimal
+ localcontext = P.localcontext
+
+ with localcontext() as c:
+ c.prec = 8
+ x = Decimal(2**16) ** Decimal("-0.5")
+ self.assertEqual(x, Decimal('0.00390625'))
+
+ x = Decimal(2**16) ** Decimal("-0.6")
+ self.assertEqual(x, Decimal('0.0012885819'))
+
+ x = Decimal("256e7") ** Decimal("-0.5")
+
+ x = Decimal(152587890625) ** Decimal('-0.0625')
+ self.assertEqual(x, Decimal("0.2"))
+
+ x = Decimal("152587890625e7") ** Decimal('-0.0625')
+
+ x = Decimal(5**2659) ** Decimal('-0.0625')
+
+ c.prec = 1
+ x = Decimal("152587890625") ** Decimal('-0.5')
+ c.prec = 201
+ x = Decimal(2**578) ** Decimal("-0.5")
+
+ def test_py_immutability_operations(self):
+ # Do operations and check that it didn't change change internal objects.
+ Decimal = P.Decimal
+ DefaultContext = P.DefaultContext
+ setcontext = P.setcontext
+
+ c = DefaultContext.copy()
+ c.traps = dict((s, 0) for s in OrderedSignals[P])
+ setcontext(c)
+
+ d1 = Decimal('-25e55')
+ b1 = Decimal('-25e55')
+ d2 = Decimal('33e+33')
+ b2 = Decimal('33e+33')
+
+ def checkSameDec(operation, useOther=False):
+ if useOther:
+ eval("d1." + operation + "(d2)")
+ self.assertEqual(d1._sign, b1._sign)
+ self.assertEqual(d1._int, b1._int)
+ self.assertEqual(d1._exp, b1._exp)
+ self.assertEqual(d2._sign, b2._sign)
+ self.assertEqual(d2._int, b2._int)
+ self.assertEqual(d2._exp, b2._exp)
+ else:
+ eval("d1." + operation + "()")
+ self.assertEqual(d1._sign, b1._sign)
+ self.assertEqual(d1._int, b1._int)
+ self.assertEqual(d1._exp, b1._exp)
+ return
+
+ Decimal(d1)
+ self.assertEqual(d1._sign, b1._sign)
+ self.assertEqual(d1._int, b1._int)
+ self.assertEqual(d1._exp, b1._exp)
+
+ checkSameDec("__abs__")
+ checkSameDec("__add__", True)
+ checkSameDec("__divmod__", True)
+ checkSameDec("__eq__", True)
+ checkSameDec("__ne__", True)
+ checkSameDec("__le__", True)
+ checkSameDec("__lt__", True)
+ checkSameDec("__ge__", True)
+ checkSameDec("__gt__", True)
+ checkSameDec("__float__")
+ checkSameDec("__floordiv__", True)
+ checkSameDec("__hash__")
+ checkSameDec("__int__")
+ checkSameDec("__trunc__")
+ checkSameDec("__mod__", True)
+ checkSameDec("__mul__", True)
+ checkSameDec("__neg__")
+ checkSameDec("__bool__")
+ checkSameDec("__pos__")
+ checkSameDec("__pow__", True)
+ checkSameDec("__radd__", True)
+ checkSameDec("__rdivmod__", True)
+ checkSameDec("__repr__")
+ checkSameDec("__rfloordiv__", True)
+ checkSameDec("__rmod__", True)
+ checkSameDec("__rmul__", True)
+ checkSameDec("__rpow__", True)
+ checkSameDec("__rsub__", True)
+ checkSameDec("__str__")
+ checkSameDec("__sub__", True)
+ checkSameDec("__truediv__", True)
+ checkSameDec("adjusted")
+ checkSameDec("as_tuple")
+ checkSameDec("compare", True)
+ checkSameDec("max", True)
+ checkSameDec("min", True)
+ checkSameDec("normalize")
+ checkSameDec("quantize", True)
+ checkSameDec("remainder_near", True)
+ checkSameDec("same_quantum", True)
+ checkSameDec("sqrt")
+ checkSameDec("to_eng_string")
+ checkSameDec("to_integral")
+
+ def test_py_decimal_id(self):
+ Decimal = P.Decimal
+
+ d = Decimal(45)
+ e = Decimal(d)
+ self.assertEqual(str(e), '45')
+ self.assertNotEqual(id(d), id(e))
+
+ def test_py_rescale(self):
+ # Coverage
+ Decimal = P.Decimal
+ localcontext = P.localcontext
+
+ with localcontext() as c:
+ x = Decimal("NaN")._rescale(3, ROUND_UP)
+ self.assertTrue(x.is_nan())
+
+ def test_py__round(self):
+ # Coverage
+ Decimal = P.Decimal
+
+ self.assertRaises(ValueError, Decimal("3.1234")._round, 0, ROUND_UP)
+
+class CFunctionality(unittest.TestCase):
+ """Extra functionality in _decimal"""
+
+ @requires_extra_functionality
+ def test_c_ieee_context(self):
+ # issue 8786: Add support for IEEE 754 contexts to decimal module.
+ IEEEContext = C.IEEEContext
+ DECIMAL32 = C.DECIMAL32
+ DECIMAL64 = C.DECIMAL64
+ DECIMAL128 = C.DECIMAL128
+
+ def assert_rest(self, context):
+ self.assertEqual(context.clamp, 1)
+ assert_signals(self, context, 'traps', [])
+ assert_signals(self, context, 'flags', [])
+
+ c = IEEEContext(DECIMAL32)
+ self.assertEqual(c.prec, 7)
+ self.assertEqual(c.Emax, 96)
+ self.assertEqual(c.Emin, -95)
+ assert_rest(self, c)
+
+ c = IEEEContext(DECIMAL64)
+ self.assertEqual(c.prec, 16)
+ self.assertEqual(c.Emax, 384)
+ self.assertEqual(c.Emin, -383)
+ assert_rest(self, c)
+
+ c = IEEEContext(DECIMAL128)
+ self.assertEqual(c.prec, 34)
+ self.assertEqual(c.Emax, 6144)
+ self.assertEqual(c.Emin, -6143)
+ assert_rest(self, c)
+
+ # Invalid values
+ self.assertRaises(OverflowError, IEEEContext, 2**63)
+ self.assertRaises(ValueError, IEEEContext, -1)
+ self.assertRaises(ValueError, IEEEContext, 1024)
+
+ @requires_extra_functionality
+ def test_c_context(self):
+ Context = C.Context
+
+ c = Context(flags=C.DecClamped, traps=C.DecRounded)
+ self.assertEqual(c._flags, C.DecClamped)
+ self.assertEqual(c._traps, C.DecRounded)
+
+ @requires_extra_functionality
+ def test_constants(self):
+ # Condition flags
+ cond = (
+ C.DecClamped, C.DecConversionSyntax, C.DecDivisionByZero,
+ C.DecDivisionImpossible, C.DecDivisionUndefined,
+ C.DecFpuError, C.DecInexact, C.DecInvalidContext,
+ C.DecInvalidOperation, C.DecMallocError,
+ C.DecFloatOperation, C.DecOverflow, C.DecRounded,
+ C.DecSubnormal, C.DecUnderflow
+ )
+
+ # IEEEContext
+ self.assertEqual(C.DECIMAL32, 32)
+ self.assertEqual(C.DECIMAL64, 64)
+ self.assertEqual(C.DECIMAL128, 128)
+ self.assertEqual(C.IEEE_CONTEXT_MAX_BITS, 512)
+
+ # Conditions
+ for i, v in enumerate(cond):
+ self.assertEqual(v, 1<<i)
+
+ self.assertEqual(C.DecIEEEInvalidOperation,
+ C.DecConversionSyntax|
+ C.DecDivisionImpossible|
+ C.DecDivisionUndefined|
+ C.DecFpuError|
+ C.DecInvalidContext|
+ C.DecInvalidOperation|
+ C.DecMallocError)
+
+ self.assertEqual(C.DecErrors,
+ C.DecIEEEInvalidOperation|
+ C.DecDivisionByZero)
+
+ self.assertEqual(C.DecTraps,
+ C.DecErrors|C.DecOverflow|C.DecUnderflow)
+
+class CWhitebox(unittest.TestCase):
+ """Whitebox testing for _decimal"""
+
+ def test_bignum(self):
+ # Not exactly whitebox, but too slow with pydecimal.
+
+ Decimal = C.Decimal
+ localcontext = C.localcontext
+
+ b1 = 10**35
+ b2 = 10**36
+ with localcontext() as c:
+ c.prec = 1000000
+ for i in range(5):
+ a = random.randrange(b1, b2)
+ b = random.randrange(1000, 1200)
+ x = a ** b
+ y = Decimal(a) ** Decimal(b)
+ self.assertEqual(x, y)
+
+ def test_invalid_construction(self):
+ self.assertRaises(TypeError, C.Decimal, 9, "xyz")
+
+ def test_c_input_restriction(self):
+ # Too large for _decimal to be converted exactly
+ Decimal = C.Decimal
+ InvalidOperation = C.InvalidOperation
+ Context = C.Context
+ localcontext = C.localcontext
+
+ with localcontext(Context()):
+ self.assertRaises(InvalidOperation, Decimal,
+ "1e9999999999999999999")
+
+ def test_c_context_repr(self):
+ # This test is _decimal-only because flags are not printed
+ # in the same order.
+ DefaultContext = C.DefaultContext
+ FloatOperation = C.FloatOperation
+
+ c = DefaultContext.copy()
+
+ c.prec = 425000000
+ c.Emax = 425000000
+ c.Emin = -425000000
+ c.rounding = ROUND_HALF_DOWN
+ c.capitals = 0
+ c.clamp = 1
+ for sig in OrderedSignals[C]:
+ c.flags[sig] = True
+ c.traps[sig] = True
+ c.flags[FloatOperation] = True
+ c.traps[FloatOperation] = True
+
+ s = c.__repr__()
+ t = "Context(prec=425000000, rounding=ROUND_HALF_DOWN, " \
+ "Emin=-425000000, Emax=425000000, capitals=0, clamp=1, " \
+ "flags=[Clamped, InvalidOperation, DivisionByZero, Inexact, " \
+ "FloatOperation, Overflow, Rounded, Subnormal, Underflow], " \
+ "traps=[Clamped, InvalidOperation, DivisionByZero, Inexact, " \
+ "FloatOperation, Overflow, Rounded, Subnormal, Underflow])"
+ self.assertEqual(s, t)
+
+ def test_c_context_errors(self):
+ Context = C.Context
+ InvalidOperation = C.InvalidOperation
+ Overflow = C.Overflow
+ FloatOperation = C.FloatOperation
+ localcontext = C.localcontext
+ getcontext = C.getcontext
+ setcontext = C.setcontext
+ HAVE_CONFIG_64 = (C.MAX_PREC > 425000000)
+
+ c = Context()
+
+ # SignalDict: input validation
+ self.assertRaises(KeyError, c.flags.__setitem__, 801, 0)
+ self.assertRaises(KeyError, c.traps.__setitem__, 801, 0)
+ self.assertRaises(ValueError, c.flags.__delitem__, Overflow)
+ self.assertRaises(ValueError, c.traps.__delitem__, InvalidOperation)
+ self.assertRaises(TypeError, setattr, c, 'flags', ['x'])
+ self.assertRaises(TypeError, setattr, c,'traps', ['y'])
+ self.assertRaises(KeyError, setattr, c, 'flags', {0:1})
+ self.assertRaises(KeyError, setattr, c, 'traps', {0:1})
+
+ # Test assignment from a signal dict with the correct length but
+ # one invalid key.
+ d = c.flags.copy()
+ del d[FloatOperation]
+ d["XYZ"] = 91283719
+ self.assertRaises(KeyError, setattr, c, 'flags', d)
+ self.assertRaises(KeyError, setattr, c, 'traps', d)
+
+ # Input corner cases
+ int_max = 2**63-1 if HAVE_CONFIG_64 else 2**31-1
+ gt_max_emax = 10**18 if HAVE_CONFIG_64 else 10**9
+
+ # prec, Emax, Emin
+ for attr in ['prec', 'Emax']:
+ self.assertRaises(ValueError, setattr, c, attr, gt_max_emax)
+ self.assertRaises(ValueError, setattr, c, 'Emin', -gt_max_emax)
+
+ # prec, Emax, Emin in context constructor
+ self.assertRaises(ValueError, Context, prec=gt_max_emax)
+ self.assertRaises(ValueError, Context, Emax=gt_max_emax)
+ self.assertRaises(ValueError, Context, Emin=-gt_max_emax)
+
+ # Overflow in conversion
+ self.assertRaises(OverflowError, Context, prec=int_max+1)
+ self.assertRaises(OverflowError, Context, Emax=int_max+1)
+ self.assertRaises(OverflowError, Context, Emin=-int_max-2)
+ self.assertRaises(OverflowError, Context, clamp=int_max+1)
+ self.assertRaises(OverflowError, Context, capitals=int_max+1)
+
+ # OverflowError, general ValueError
+ for attr in ('prec', 'Emin', 'Emax', 'capitals', 'clamp'):
+ self.assertRaises(OverflowError, setattr, c, attr, int_max+1)
+ self.assertRaises(OverflowError, setattr, c, attr, -int_max-2)
+ if sys.platform != 'win32':
+ self.assertRaises(ValueError, setattr, c, attr, int_max)
+ self.assertRaises(ValueError, setattr, c, attr, -int_max-1)
+
+ # OverflowError: _unsafe_setprec, _unsafe_setemin, _unsafe_setemax
+ if C.MAX_PREC == 425000000:
+ self.assertRaises(OverflowError, getattr(c, '_unsafe_setprec'),
+ int_max+1)
+ self.assertRaises(OverflowError, getattr(c, '_unsafe_setemax'),
+ int_max+1)
+ self.assertRaises(OverflowError, getattr(c, '_unsafe_setemin'),
+ -int_max-2)
+
+ # ValueError: _unsafe_setprec, _unsafe_setemin, _unsafe_setemax
+ if C.MAX_PREC == 425000000:
+ self.assertRaises(ValueError, getattr(c, '_unsafe_setprec'), 0)
+ self.assertRaises(ValueError, getattr(c, '_unsafe_setprec'),
+ 1070000001)
+ self.assertRaises(ValueError, getattr(c, '_unsafe_setemax'), -1)
+ self.assertRaises(ValueError, getattr(c, '_unsafe_setemax'),
+ 1070000001)
+ self.assertRaises(ValueError, getattr(c, '_unsafe_setemin'),
+ -1070000001)
+ self.assertRaises(ValueError, getattr(c, '_unsafe_setemin'), 1)
+
+ # capitals, clamp
+ for attr in ['capitals', 'clamp']:
+ self.assertRaises(ValueError, setattr, c, attr, -1)
+ self.assertRaises(ValueError, setattr, c, attr, 2)
+ self.assertRaises(TypeError, setattr, c, attr, [1,2,3])
+ if HAVE_CONFIG_64:
+ self.assertRaises(ValueError, setattr, c, attr, 2**32)
+ self.assertRaises(ValueError, setattr, c, attr, 2**32+1)
+
+ # Invalid local context
+ self.assertRaises(TypeError, exec, 'with localcontext("xyz"): pass',
+ locals())
+ self.assertRaises(TypeError, exec,
+ 'with localcontext(context=getcontext()): pass',
+ locals())
+
+ # setcontext
+ saved_context = getcontext()
+ self.assertRaises(TypeError, setcontext, "xyz")
+ setcontext(saved_context)
+
+ def test_rounding_strings_interned(self):
+
+ self.assertIs(C.ROUND_UP, P.ROUND_UP)
+ self.assertIs(C.ROUND_DOWN, P.ROUND_DOWN)
+ self.assertIs(C.ROUND_CEILING, P.ROUND_CEILING)
+ self.assertIs(C.ROUND_FLOOR, P.ROUND_FLOOR)
+ self.assertIs(C.ROUND_HALF_UP, P.ROUND_HALF_UP)
+ self.assertIs(C.ROUND_HALF_DOWN, P.ROUND_HALF_DOWN)
+ self.assertIs(C.ROUND_HALF_EVEN, P.ROUND_HALF_EVEN)
+ self.assertIs(C.ROUND_05UP, P.ROUND_05UP)
+
+ @requires_extra_functionality
+ def test_c_context_errors_extra(self):
+ Context = C.Context
+ InvalidOperation = C.InvalidOperation
+ Overflow = C.Overflow
+ localcontext = C.localcontext
+ getcontext = C.getcontext
+ setcontext = C.setcontext
+ HAVE_CONFIG_64 = (C.MAX_PREC > 425000000)
+
+ c = Context()
+
+ # Input corner cases
+ int_max = 2**63-1 if HAVE_CONFIG_64 else 2**31-1
+
+ # OverflowError, general ValueError
+ self.assertRaises(OverflowError, setattr, c, '_allcr', int_max+1)
+ self.assertRaises(OverflowError, setattr, c, '_allcr', -int_max-2)
+ if sys.platform != 'win32':
+ self.assertRaises(ValueError, setattr, c, '_allcr', int_max)
+ self.assertRaises(ValueError, setattr, c, '_allcr', -int_max-1)
+
+ # OverflowError, general TypeError
+ for attr in ('_flags', '_traps'):
+ self.assertRaises(OverflowError, setattr, c, attr, int_max+1)
+ self.assertRaises(OverflowError, setattr, c, attr, -int_max-2)
+ if sys.platform != 'win32':
+ self.assertRaises(TypeError, setattr, c, attr, int_max)
+ self.assertRaises(TypeError, setattr, c, attr, -int_max-1)
+
+ # _allcr
+ self.assertRaises(ValueError, setattr, c, '_allcr', -1)
+ self.assertRaises(ValueError, setattr, c, '_allcr', 2)
+ self.assertRaises(TypeError, setattr, c, '_allcr', [1,2,3])
+ if HAVE_CONFIG_64:
+ self.assertRaises(ValueError, setattr, c, '_allcr', 2**32)
+ self.assertRaises(ValueError, setattr, c, '_allcr', 2**32+1)
+
+ # _flags, _traps
+ for attr in ['_flags', '_traps']:
+ self.assertRaises(TypeError, setattr, c, attr, 999999)
+ self.assertRaises(TypeError, setattr, c, attr, 'x')
+
+ def test_c_valid_context(self):
+ # These tests are for code coverage in _decimal.
+ DefaultContext = C.DefaultContext
+ Clamped = C.Clamped
+ Underflow = C.Underflow
+ Inexact = C.Inexact
+ Rounded = C.Rounded
+ Subnormal = C.Subnormal
+
+ c = DefaultContext.copy()
+
+ # Exercise all getters and setters
+ c.prec = 34
+ c.rounding = ROUND_HALF_UP
+ c.Emax = 3000
+ c.Emin = -3000
+ c.capitals = 1
+ c.clamp = 0
+
+ self.assertEqual(c.prec, 34)
+ self.assertEqual(c.rounding, ROUND_HALF_UP)
+ self.assertEqual(c.Emin, -3000)
+ self.assertEqual(c.Emax, 3000)
+ self.assertEqual(c.capitals, 1)
+ self.assertEqual(c.clamp, 0)
+
+ self.assertEqual(c.Etiny(), -3033)
+ self.assertEqual(c.Etop(), 2967)
+
+ # Exercise all unsafe setters
+ if C.MAX_PREC == 425000000:
+ c._unsafe_setprec(999999999)
+ c._unsafe_setemax(999999999)
+ c._unsafe_setemin(-999999999)
+ self.assertEqual(c.prec, 999999999)
+ self.assertEqual(c.Emax, 999999999)
+ self.assertEqual(c.Emin, -999999999)
+
+ @requires_extra_functionality
+ def test_c_valid_context_extra(self):
+ DefaultContext = C.DefaultContext
+
+ c = DefaultContext.copy()
+ self.assertEqual(c._allcr, 1)
+ c._allcr = 0
+ self.assertEqual(c._allcr, 0)
+
+ def test_c_round(self):
+ # Restricted input.
+ Decimal = C.Decimal
+ InvalidOperation = C.InvalidOperation
+ localcontext = C.localcontext
+ MAX_EMAX = C.MAX_EMAX
+ MIN_ETINY = C.MIN_ETINY
+ int_max = 2**63-1 if C.MAX_PREC > 425000000 else 2**31-1
+
+ with localcontext() as c:
+ c.traps[InvalidOperation] = True
+ self.assertRaises(InvalidOperation, Decimal("1.23").__round__,
+ -int_max-1)
+ self.assertRaises(InvalidOperation, Decimal("1.23").__round__,
+ int_max)
+ self.assertRaises(InvalidOperation, Decimal("1").__round__,
+ int(MAX_EMAX+1))
+ self.assertRaises(C.InvalidOperation, Decimal("1").__round__,
+ -int(MIN_ETINY-1))
+ self.assertRaises(OverflowError, Decimal("1.23").__round__,
+ -int_max-2)
+ self.assertRaises(OverflowError, Decimal("1.23").__round__,
+ int_max+1)
+
+ def test_c_format(self):
+ # Restricted input
+ Decimal = C.Decimal
+ HAVE_CONFIG_64 = (C.MAX_PREC > 425000000)
+
+ self.assertRaises(TypeError, Decimal(1).__format__, "=10.10", [], 9)
+ self.assertRaises(TypeError, Decimal(1).__format__, "=10.10", 9)
+ self.assertRaises(TypeError, Decimal(1).__format__, [])
+
+ self.assertRaises(ValueError, Decimal(1).__format__, "<>=10.10")
+ maxsize = 2**63-1 if HAVE_CONFIG_64 else 2**31-1
+ self.assertRaises(ValueError, Decimal("1.23456789").__format__,
+ "=%d.1" % maxsize)
+
+ def test_c_integral(self):
+ Decimal = C.Decimal
+ Inexact = C.Inexact
+ localcontext = C.localcontext
+
+ x = Decimal(10)
+ self.assertEqual(x.to_integral(), 10)
+ self.assertRaises(TypeError, x.to_integral, '10')
+ self.assertRaises(TypeError, x.to_integral, 10, 'x')
+ self.assertRaises(TypeError, x.to_integral, 10)
+
+ self.assertEqual(x.to_integral_value(), 10)
+ self.assertRaises(TypeError, x.to_integral_value, '10')
+ self.assertRaises(TypeError, x.to_integral_value, 10, 'x')
+ self.assertRaises(TypeError, x.to_integral_value, 10)
+
+ self.assertEqual(x.to_integral_exact(), 10)
+ self.assertRaises(TypeError, x.to_integral_exact, '10')
+ self.assertRaises(TypeError, x.to_integral_exact, 10, 'x')
+ self.assertRaises(TypeError, x.to_integral_exact, 10)
+
+ with localcontext() as c:
+ x = Decimal("99999999999999999999999999.9").to_integral_value(ROUND_UP)
+ self.assertEqual(x, Decimal('100000000000000000000000000'))
+
+ x = Decimal("99999999999999999999999999.9").to_integral_exact(ROUND_UP)
+ self.assertEqual(x, Decimal('100000000000000000000000000'))
+
+ c.traps[Inexact] = True
+ self.assertRaises(Inexact, Decimal("999.9").to_integral_exact, ROUND_UP)
+
+ def test_c_funcs(self):
+ # Invalid arguments
+ Decimal = C.Decimal
+ InvalidOperation = C.InvalidOperation
+ DivisionByZero = C.DivisionByZero
+ getcontext = C.getcontext
+ localcontext = C.localcontext
+
+ self.assertEqual(Decimal('9.99e10').to_eng_string(), '99.9E+9')
+
+ self.assertRaises(TypeError, pow, Decimal(1), 2, "3")
+ self.assertRaises(TypeError, Decimal(9).number_class, "x", "y")
+ self.assertRaises(TypeError, Decimal(9).same_quantum, 3, "x", "y")
+
+ self.assertRaises(
+ TypeError,
+ Decimal("1.23456789").quantize, Decimal('1e-100000'), []
+ )
+ self.assertRaises(
+ TypeError,
+ Decimal("1.23456789").quantize, Decimal('1e-100000'), getcontext()
+ )
+ self.assertRaises(
+ TypeError,
+ Decimal("1.23456789").quantize, Decimal('1e-100000'), 10
+ )
+ self.assertRaises(
+ TypeError,
+ Decimal("1.23456789").quantize, Decimal('1e-100000'), ROUND_UP, 1000
+ )
+
+ with localcontext() as c:
+ c.clear_traps()
+
+ # Invalid arguments
+ self.assertRaises(TypeError, c.copy_sign, Decimal(1), "x", "y")
+ self.assertRaises(TypeError, c.canonical, 200)
+ self.assertRaises(TypeError, c.is_canonical, 200)
+ self.assertRaises(TypeError, c.divmod, 9, 8, "x", "y")
+ self.assertRaises(TypeError, c.same_quantum, 9, 3, "x", "y")
+
+ self.assertEqual(str(c.canonical(Decimal(200))), '200')
+ self.assertEqual(c.radix(), 10)
+
+ c.traps[DivisionByZero] = True
+ self.assertRaises(DivisionByZero, Decimal(9).__divmod__, 0)
+ self.assertRaises(DivisionByZero, c.divmod, 9, 0)
+ self.assertTrue(c.flags[InvalidOperation])
+
+ c.clear_flags()
+ c.traps[InvalidOperation] = True
+ self.assertRaises(InvalidOperation, Decimal(9).__divmod__, 0)
+ self.assertRaises(InvalidOperation, c.divmod, 9, 0)
+ self.assertTrue(c.flags[DivisionByZero])
+
+ c.traps[InvalidOperation] = True
+ c.prec = 2
+ self.assertRaises(InvalidOperation, pow, Decimal(1000), 1, 501)
+
+ def test_va_args_exceptions(self):
+ Decimal = C.Decimal
+ Context = C.Context
+
+ x = Decimal("10001111111")
+
+ for attr in ['exp', 'is_normal', 'is_subnormal', 'ln', 'log10',
+ 'logb', 'logical_invert', 'next_minus', 'next_plus',
+ 'normalize', 'number_class', 'sqrt', 'to_eng_string']:
+ func = getattr(x, attr)
+ self.assertRaises(TypeError, func, context="x")
+ self.assertRaises(TypeError, func, "x", context=None)
+
+ for attr in ['compare', 'compare_signal', 'logical_and',
+ 'logical_or', 'max', 'max_mag', 'min', 'min_mag',
+ 'remainder_near', 'rotate', 'scaleb', 'shift']:
+ func = getattr(x, attr)
+ self.assertRaises(TypeError, func, context="x")
+ self.assertRaises(TypeError, func, "x", context=None)
+
+ self.assertRaises(TypeError, x.to_integral, rounding=None, context=[])
+ self.assertRaises(TypeError, x.to_integral, rounding={}, context=[])
+ self.assertRaises(TypeError, x.to_integral, [], [])
+
+ self.assertRaises(TypeError, x.to_integral_value, rounding=None, context=[])
+ self.assertRaises(TypeError, x.to_integral_value, rounding={}, context=[])
+ self.assertRaises(TypeError, x.to_integral_value, [], [])
+
+ self.assertRaises(TypeError, x.to_integral_exact, rounding=None, context=[])
+ self.assertRaises(TypeError, x.to_integral_exact, rounding={}, context=[])
+ self.assertRaises(TypeError, x.to_integral_exact, [], [])
+
+ self.assertRaises(TypeError, x.fma, 1, 2, context="x")
+ self.assertRaises(TypeError, x.fma, 1, 2, "x", context=None)
+
+ self.assertRaises(TypeError, x.quantize, 1, [], context=None)
+ self.assertRaises(TypeError, x.quantize, 1, [], rounding=None)
+ self.assertRaises(TypeError, x.quantize, 1, [], [])
+
+ c = Context()
+ self.assertRaises(TypeError, c.power, 1, 2, mod="x")
+ self.assertRaises(TypeError, c.power, 1, "x", mod=None)
+ self.assertRaises(TypeError, c.power, "x", 2, mod=None)
+
+ @requires_extra_functionality
+ def test_c_context_templates(self):
+ self.assertEqual(
+ C.BasicContext._traps,
+ C.DecIEEEInvalidOperation|C.DecDivisionByZero|C.DecOverflow|
+ C.DecUnderflow|C.DecClamped
+ )
+ self.assertEqual(
+ C.DefaultContext._traps,
+ C.DecIEEEInvalidOperation|C.DecDivisionByZero|C.DecOverflow
+ )
+
+ @requires_extra_functionality
+ def test_c_signal_dict(self):
+
+ # SignalDict coverage
+ Context = C.Context
+ DefaultContext = C.DefaultContext
+
+ InvalidOperation = C.InvalidOperation
+ DivisionByZero = C.DivisionByZero
+ Overflow = C.Overflow
+ Subnormal = C.Subnormal
+ Underflow = C.Underflow
+ Rounded = C.Rounded
+ Inexact = C.Inexact
+ Clamped = C.Clamped
+
+ DecClamped = C.DecClamped
+ DecInvalidOperation = C.DecInvalidOperation
+ DecIEEEInvalidOperation = C.DecIEEEInvalidOperation
+
+ def assertIsExclusivelySet(signal, signal_dict):
+ for sig in signal_dict:
+ if sig == signal:
+ self.assertTrue(signal_dict[sig])
+ else:
+ self.assertFalse(signal_dict[sig])
+
+ c = DefaultContext.copy()
+
+ # Signal dict methods
+ self.assertTrue(Overflow in c.traps)
+ c.clear_traps()
+ for k in c.traps.keys():
+ c.traps[k] = True
+ for v in c.traps.values():
+ self.assertTrue(v)
+ c.clear_traps()
+ for k, v in c.traps.items():
+ self.assertFalse(v)
+
+ self.assertFalse(c.flags.get(Overflow))
+ self.assertIs(c.flags.get("x"), None)
+ self.assertEqual(c.flags.get("x", "y"), "y")
+ self.assertRaises(TypeError, c.flags.get, "x", "y", "z")
+
+ self.assertEqual(len(c.flags), len(c.traps))
+ s = sys.getsizeof(c.flags)
+ s = sys.getsizeof(c.traps)
+ s = c.flags.__repr__()
+
+ # Set flags/traps.
+ c.clear_flags()
+ c._flags = DecClamped
+ self.assertTrue(c.flags[Clamped])
+
+ c.clear_traps()
+ c._traps = DecInvalidOperation
+ self.assertTrue(c.traps[InvalidOperation])
+
+ # Set flags/traps from dictionary.
+ c.clear_flags()
+ d = c.flags.copy()
+ d[DivisionByZero] = True
+ c.flags = d
+ assertIsExclusivelySet(DivisionByZero, c.flags)
+
+ c.clear_traps()
+ d = c.traps.copy()
+ d[Underflow] = True
+ c.traps = d
+ assertIsExclusivelySet(Underflow, c.traps)
+
+ # Random constructors
+ IntSignals = {
+ Clamped: C.DecClamped,
+ Rounded: C.DecRounded,
+ Inexact: C.DecInexact,
+ Subnormal: C.DecSubnormal,
+ Underflow: C.DecUnderflow,
+ Overflow: C.DecOverflow,
+ DivisionByZero: C.DecDivisionByZero,
+ InvalidOperation: C.DecIEEEInvalidOperation
+ }
+ IntCond = [
+ C.DecDivisionImpossible, C.DecDivisionUndefined, C.DecFpuError,
+ C.DecInvalidContext, C.DecInvalidOperation, C.DecMallocError,
+ C.DecConversionSyntax,
+ ]
+
+ lim = len(OrderedSignals[C])
+ for r in range(lim):
+ for t in range(lim):
+ for round in RoundingModes:
+ flags = random.sample(OrderedSignals[C], r)
+ traps = random.sample(OrderedSignals[C], t)
+ prec = random.randrange(1, 10000)
+ emin = random.randrange(-10000, 0)
+ emax = random.randrange(0, 10000)
+ clamp = random.randrange(0, 2)
+ caps = random.randrange(0, 2)
+ cr = random.randrange(0, 2)
+ c = Context(prec=prec, rounding=round, Emin=emin, Emax=emax,
+ capitals=caps, clamp=clamp, flags=list(flags),
+ traps=list(traps))
+
+ self.assertEqual(c.prec, prec)
+ self.assertEqual(c.rounding, round)
+ self.assertEqual(c.Emin, emin)
+ self.assertEqual(c.Emax, emax)
+ self.assertEqual(c.capitals, caps)
+ self.assertEqual(c.clamp, clamp)
+
+ f = 0
+ for x in flags:
+ f |= IntSignals[x]
+ self.assertEqual(c._flags, f)
+
+ f = 0
+ for x in traps:
+ f |= IntSignals[x]
+ self.assertEqual(c._traps, f)
+
+ for cond in IntCond:
+ c._flags = cond
+ self.assertTrue(c._flags&DecIEEEInvalidOperation)
+ assertIsExclusivelySet(InvalidOperation, c.flags)
+
+ for cond in IntCond:
+ c._traps = cond
+ self.assertTrue(c._traps&DecIEEEInvalidOperation)
+ assertIsExclusivelySet(InvalidOperation, c.traps)
+
+ def test_invalid_override(self):
+ Decimal = C.Decimal
+
+ try:
+ from locale import CHAR_MAX
+ except ImportError:
+ return
+
+ def make_grouping(lst):
+ return ''.join([chr(x) for x in lst])
+
+ def get_fmt(x, override=None, fmt='n'):
+ return Decimal(x).__format__(fmt, override)
+
+ invalid_grouping = {
+ 'decimal_point' : ',',
+ 'grouping' : make_grouping([255, 255, 0]),
+ 'thousands_sep' : ','
+ }
+ invalid_dot = {
+ 'decimal_point' : 'xxxxx',
+ 'grouping' : make_grouping([3, 3, 0]),
+ 'thousands_sep' : ','
+ }
+ invalid_sep = {
+ 'decimal_point' : '.',
+ 'grouping' : make_grouping([3, 3, 0]),
+ 'thousands_sep' : 'yyyyy'
+ }
+
+ if CHAR_MAX == 127: # negative grouping in override
+ self.assertRaises(ValueError, get_fmt, 12345,
+ invalid_grouping, 'g')
+
+ self.assertRaises(ValueError, get_fmt, 12345, invalid_dot, 'g')
+ self.assertRaises(ValueError, get_fmt, 12345, invalid_sep, 'g')
+
+ def test_exact_conversion(self):
+ Decimal = C.Decimal
+ localcontext = C.localcontext
+ InvalidOperation = C.InvalidOperation
+
+ with localcontext() as c:
+
+ c.traps[InvalidOperation] = True
+
+ # Clamped
+ x = "0e%d" % sys.maxsize
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ x = "0e%d" % (-sys.maxsize-1)
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ # Overflow
+ x = "1e%d" % sys.maxsize
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ # Underflow
+ x = "1e%d" % (-sys.maxsize-1)
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ def test_from_tuple(self):
+ Decimal = C.Decimal
+ localcontext = C.localcontext
+ InvalidOperation = C.InvalidOperation
+ Overflow = C.Overflow
+ Underflow = C.Underflow
+
+ with localcontext() as c:
+
+ c.traps[InvalidOperation] = True
+ c.traps[Overflow] = True
+ c.traps[Underflow] = True
+
+ # SSIZE_MAX
+ x = (1, (), sys.maxsize)
+ self.assertEqual(str(c.create_decimal(x)), '-0E+999999')
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ x = (1, (0, 1, 2), sys.maxsize)
+ self.assertRaises(Overflow, c.create_decimal, x)
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ # SSIZE_MIN
+ x = (1, (), -sys.maxsize-1)
+ self.assertEqual(str(c.create_decimal(x)), '-0E-1000026')
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ x = (1, (0, 1, 2), -sys.maxsize-1)
+ self.assertRaises(Underflow, c.create_decimal, x)
+ self.assertRaises(InvalidOperation, Decimal, x)
+
+ # OverflowError
+ x = (1, (), sys.maxsize+1)
+ self.assertRaises(OverflowError, c.create_decimal, x)
+ self.assertRaises(OverflowError, Decimal, x)
+
+ x = (1, (), -sys.maxsize-2)
+ self.assertRaises(OverflowError, c.create_decimal, x)
+ self.assertRaises(OverflowError, Decimal, x)
+
+ # Specials
+ x = (1, (), "N")
+ self.assertEqual(str(Decimal(x)), '-sNaN')
+ x = (1, (0,), "N")
+ self.assertEqual(str(Decimal(x)), '-sNaN')
+ x = (1, (0, 1), "N")
+ self.assertEqual(str(Decimal(x)), '-sNaN1')
+
+
+all_tests = [
+ CExplicitConstructionTest, PyExplicitConstructionTest,
+ CImplicitConstructionTest, PyImplicitConstructionTest,
+ CFormatTest, PyFormatTest,
+ CArithmeticOperatorsTest, PyArithmeticOperatorsTest,
+ CThreadingTest, PyThreadingTest,
+ CUsabilityTest, PyUsabilityTest,
+ CPythonAPItests, PyPythonAPItests,
+ CContextAPItests, PyContextAPItests,
+ CContextWithStatement, PyContextWithStatement,
+ CContextFlags, PyContextFlags,
+ CSpecialContexts, PySpecialContexts,
+ CContextInputValidation, PyContextInputValidation,
+ CContextSubclassing, PyContextSubclassing,
+ CCoverage, PyCoverage,
+ CFunctionality, PyFunctionality,
+ CWhitebox, PyWhitebox,
+ CIBMTestCases, PyIBMTestCases,
+]
+
+# Delete C tests if _decimal.so is not present.
+if not C:
+ all_tests = all_tests[1::2]
+else:
+ all_tests.insert(0, CheckAttributes)
+
+
def test_main(arith=False, verbose=None, todo_tests=None, debug=None):
""" Execute the tests.
@@ -2374,27 +5404,16 @@ def test_main(arith=False, verbose=None, todo_tests=None, debug=None):
is enabled in regrtest.py
"""
- init()
+ init(C)
+ init(P)
global TEST_ALL, DEBUG
TEST_ALL = arith or is_resource_enabled('decimal')
DEBUG = debug
if todo_tests is None:
- test_classes = [
- DecimalExplicitConstructionTest,
- DecimalImplicitConstructionTest,
- DecimalArithmeticOperatorsTest,
- DecimalFormatTest,
- DecimalUseOfContextTest,
- DecimalUsabilityTest,
- DecimalPythonAPItests,
- ContextAPItests,
- DecimalTest,
- WithStatementTest,
- ContextFlags
- ]
+ test_classes = all_tests
else:
- test_classes = [DecimalTest]
+ test_classes = [CIBMTestCases, PyIBMTestCases]
# Dynamically build custom test definition for each file in the test
# directory and add the definitions to the DecimalTest class. This
@@ -2406,17 +5425,32 @@ def test_main(arith=False, verbose=None, todo_tests=None, debug=None):
if todo_tests is not None and head not in todo_tests:
continue
tester = lambda self, f=filename: self.eval_file(directory + f)
- setattr(DecimalTest, 'test_' + head, tester)
+ setattr(CIBMTestCases, 'test_' + head, tester)
+ setattr(PyIBMTestCases, 'test_' + head, tester)
del filename, head, tail, tester
try:
run_unittest(*test_classes)
if todo_tests is None:
- import decimal as DecimalModule
- run_doctest(DecimalModule, verbose)
+ from doctest import IGNORE_EXCEPTION_DETAIL
+ savedecimal = sys.modules['decimal']
+ if C:
+ sys.modules['decimal'] = C
+ run_doctest(C, verbose, optionflags=IGNORE_EXCEPTION_DETAIL)
+ sys.modules['decimal'] = P
+ run_doctest(P, verbose)
+ sys.modules['decimal'] = savedecimal
finally:
- setcontext(ORIGINAL_CONTEXT)
+ if C: C.setcontext(ORIGINAL_CONTEXT[C])
+ P.setcontext(ORIGINAL_CONTEXT[P])
+ if not C:
+ warnings.warn('C tests skipped: no module named _decimal.',
+ UserWarning)
+ if not orig_sys_decimal is sys.modules['decimal']:
+ raise TestFailed("Internal error: unbalanced number of changes to "
+ "sys.modules['decimal'].")
+
if __name__ == '__main__':
import optparse
diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py
index f0afe1d..a8487d2 100644
--- a/Lib/test/test_deque.py
+++ b/Lib/test/test_deque.py
@@ -472,6 +472,19 @@ class TestBasic(unittest.TestCase):
## self.assertNotEqual(id(d), id(e))
## self.assertEqual(id(e), id(e[-1]))
+ def test_iterator_pickle(self):
+ data = deque(range(200))
+ it = itorg = iter(data)
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(list(it), list(data))
+
+ it = pickle.loads(d)
+ next(it)
+ d = pickle.dumps(it)
+ self.assertEqual(list(it), list(data)[1:])
+
def test_deepcopy(self):
mut = [10]
d = deque([mut])
@@ -524,7 +537,7 @@ class TestBasic(unittest.TestCase):
@support.cpython_only
def test_sizeof(self):
BLOCKLEN = 62
- basesize = support.calcobjsize('2P4PlP')
+ basesize = support.calcobjsize('2P4nlP')
blocksize = struct.calcsize('2P%dP' % BLOCKLEN)
self.assertEqual(object.__sizeof__(deque()), basesize)
check = self.check_sizeof
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index cdaf7d2..b5a10ed 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -654,19 +654,19 @@ class ClassPropertiesAndMethods(unittest.TestCase):
class A(metaclass=AMeta):
pass
self.assertEqual(['AMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
class B(metaclass=BMeta):
pass
# BMeta.__new__ calls AMeta.__new__ with super:
self.assertEqual(['BMeta', 'AMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
class C(A, B):
pass
# The most derived metaclass is BMeta:
self.assertEqual(['BMeta', 'AMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
# BMeta.__prepare__ should've been called:
self.assertIn('BMeta_was_here', C.__dict__)
@@ -674,20 +674,20 @@ class ClassPropertiesAndMethods(unittest.TestCase):
class C2(B, A):
pass
self.assertEqual(['BMeta', 'AMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertIn('BMeta_was_here', C2.__dict__)
# Check correct metaclass calculation when a metaclass is declared:
class D(C, metaclass=type):
pass
self.assertEqual(['BMeta', 'AMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertIn('BMeta_was_here', D.__dict__)
class E(C, metaclass=AMeta):
pass
self.assertEqual(['BMeta', 'AMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertIn('BMeta_was_here', E.__dict__)
# Special case: the given metaclass isn't a class,
@@ -729,33 +729,33 @@ class ClassPropertiesAndMethods(unittest.TestCase):
pass
self.assertIs(ANotMeta, type(A))
self.assertEqual(['ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
self.assertEqual(['ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
class B(metaclass=BNotMeta):
pass
self.assertIs(BNotMeta, type(B))
self.assertEqual(['BNotMeta', 'ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
self.assertEqual(['BNotMeta', 'ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
class C(A, B):
pass
self.assertIs(BNotMeta, type(C))
self.assertEqual(['BNotMeta', 'ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertEqual(['BNotMeta', 'ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
class C2(B, A):
pass
self.assertIs(BNotMeta, type(C2))
self.assertEqual(['BNotMeta', 'ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertEqual(['BNotMeta', 'ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
# This is a TypeError, because of a metaclass conflict:
# BNotMeta is neither a subclass, nor a superclass of type
@@ -767,25 +767,25 @@ class ClassPropertiesAndMethods(unittest.TestCase):
pass
self.assertIs(BNotMeta, type(E))
self.assertEqual(['BNotMeta', 'ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertEqual(['BNotMeta', 'ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
class F(object(), C):
pass
self.assertIs(BNotMeta, type(F))
self.assertEqual(['BNotMeta', 'ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertEqual(['BNotMeta', 'ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
class F2(C, object()):
pass
self.assertIs(BNotMeta, type(F2))
self.assertEqual(['BNotMeta', 'ANotMeta'], new_calls)
- new_calls[:] = []
+ new_calls.clear()
self.assertEqual(['BNotMeta', 'ANotMeta'], prepare_calls)
- prepare_calls[:] = []
+ prepare_calls.clear()
# TypeError: BNotMeta is neither a
# subclass, nor a superclass of int
@@ -1444,6 +1444,14 @@ order (MRO) for bases """
else:
self.fail("classmethod shouldn't accept keyword args")
+ cm = classmethod(f)
+ self.assertEqual(cm.__dict__, {})
+ cm.x = 42
+ self.assertEqual(cm.x, 42)
+ self.assertEqual(cm.__dict__, {"x" : 42})
+ del cm.x
+ self.assertFalse(hasattr(cm, "x"))
+
@support.impl_detail("the module 'xxsubtype' is internal")
def test_classmethods_in_c(self):
# Testing C-based class methods...
@@ -1491,6 +1499,13 @@ order (MRO) for bases """
self.assertEqual(d.goo(1), (1,))
self.assertEqual(d.foo(1), (d, 1))
self.assertEqual(D.foo(d, 1), (d, 1))
+ sm = staticmethod(None)
+ self.assertEqual(sm.__dict__, {})
+ sm.x = 42
+ self.assertEqual(sm.x, 42)
+ self.assertEqual(sm.__dict__, {"x" : 42})
+ del sm.x
+ self.assertFalse(hasattr(sm, "x"))
@support.impl_detail("the module 'xxsubtype' is internal")
def test_staticmethods_in_c(self):
@@ -1771,6 +1786,7 @@ order (MRO) for bases """
("__format__", format, format_impl, set(), {}),
("__floor__", math.floor, zero, set(), {}),
("__trunc__", math.trunc, zero, set(), {}),
+ ("__trunc__", int, zero, set(), {}),
("__ceil__", math.ceil, zero, set(), {}),
("__dir__", dir, empty_seq, set(), {}),
]
@@ -1816,12 +1832,7 @@ order (MRO) for bases """
for attr, obj in env.items():
setattr(X, attr, obj)
setattr(X, name, ErrDescr())
- try:
- runner(X())
- except MyException:
- pass
- else:
- self.fail("{0!r} didn't raise".format(name))
+ self.assertRaises(MyException, runner, X())
def test_specials(self):
# Testing special operators...
@@ -2258,9 +2269,6 @@ order (MRO) for bases """
# Two essentially featureless objects, just inheriting stuff from
# object.
self.assertEqual(dir(NotImplemented), dir(Ellipsis))
- if support.check_impl_detail():
- # None differs in PyPy: it has a __nonzero__
- self.assertEqual(dir(None), dir(Ellipsis))
# Nasty test case for proxied objects
class Wrapper(object):
@@ -4456,6 +4464,62 @@ order (MRO) for bases """
x.y = 42
self.assertEqual(x["y"], 42)
+ def test_slot_shadows_class_variable(self):
+ with self.assertRaises(ValueError) as cm:
+ class X:
+ __slots__ = ["foo"]
+ foo = None
+ m = str(cm.exception)
+ self.assertEqual("'foo' in __slots__ conflicts with class variable", m)
+
+ def test_set_doc(self):
+ class X:
+ "elephant"
+ X.__doc__ = "banana"
+ self.assertEqual(X.__doc__, "banana")
+ with self.assertRaises(TypeError) as cm:
+ type(list).__dict__["__doc__"].__set__(list, "blah")
+ self.assertIn("can't set list.__doc__", str(cm.exception))
+ with self.assertRaises(TypeError) as cm:
+ type(X).__dict__["__doc__"].__delete__(X)
+ self.assertIn("can't delete X.__doc__", str(cm.exception))
+ self.assertEqual(X.__doc__, "banana")
+
+ def test_qualname(self):
+ descriptors = [str.lower, complex.real, float.real, int.__add__]
+ types = ['method', 'member', 'getset', 'wrapper']
+
+ # make sure we have an example of each type of descriptor
+ for d, n in zip(descriptors, types):
+ self.assertEqual(type(d).__name__, n + '_descriptor')
+
+ for d in descriptors:
+ qualname = d.__objclass__.__qualname__ + '.' + d.__name__
+ self.assertEqual(d.__qualname__, qualname)
+
+ self.assertEqual(str.lower.__qualname__, 'str.lower')
+ self.assertEqual(complex.real.__qualname__, 'complex.real')
+ self.assertEqual(float.real.__qualname__, 'float.real')
+ self.assertEqual(int.__add__.__qualname__, 'int.__add__')
+
+ class X:
+ pass
+ with self.assertRaises(TypeError):
+ del X.__qualname__
+
+ self.assertRaises(TypeError, type.__dict__['__qualname__'].__set__,
+ str, 'Oink')
+
+ def test_qualname_dict(self):
+ ns = {'__qualname__': 'some.name'}
+ tp = type('Foo', (), ns)
+ self.assertEqual(tp.__qualname__, 'some.name')
+ self.assertNotIn('__qualname__', tp.__dict__)
+ self.assertEqual(ns, {'__qualname__': 'some.name'})
+
+ ns = {'__qualname__': 1}
+ self.assertRaises(TypeError, type, 'Foo', (), ns)
+
def test_cycle_through_dict(self):
# See bug #1469629
class X(dict):
@@ -4471,6 +4535,27 @@ order (MRO) for bases """
for o in gc.get_objects():
self.assertIsNot(type(o), X)
+ def test_object_new_and_init_with_parameters(self):
+ # See issue #1683368
+ class OverrideNeither:
+ pass
+ self.assertRaises(TypeError, OverrideNeither, 1)
+ self.assertRaises(TypeError, OverrideNeither, kw=1)
+ class OverrideNew:
+ def __new__(cls, foo, kw=0, *args, **kwds):
+ return object.__new__(cls, *args, **kwds)
+ class OverrideInit:
+ def __init__(self, foo, kw=0, *args, **kwargs):
+ return object.__init__(self, *args, **kwargs)
+ class OverrideBoth(OverrideNew, OverrideInit):
+ pass
+ for case in OverrideNew, OverrideInit, OverrideBoth:
+ case(1)
+ case(1, kw=2)
+ self.assertRaises(TypeError, case, 1, 2, 3)
+ self.assertRaises(TypeError, case, 1, 2, foo=3)
+
+
class DictProxyTests(unittest.TestCase):
def setUp(self):
class C(object):
@@ -4478,6 +4563,8 @@ class DictProxyTests(unittest.TestCase):
pass
self.C = C
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __local__')
def test_iter_keys(self):
# Testing dict-proxy keys...
it = self.C.__dict__.keys()
@@ -4485,8 +4572,10 @@ class DictProxyTests(unittest.TestCase):
keys = list(it)
keys.sort()
self.assertEqual(keys, ['__dict__', '__doc__', '__module__',
- '__weakref__', 'meth'])
+ '__weakref__', 'meth'])
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __local__')
def test_iter_values(self):
# Testing dict-proxy values...
it = self.C.__dict__.values()
@@ -4494,6 +4583,8 @@ class DictProxyTests(unittest.TestCase):
values = list(it)
self.assertEqual(len(values), 5)
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __local__')
def test_iter_items(self):
# Testing dict-proxy iteritems...
it = self.C.__dict__.items()
@@ -4501,7 +4592,7 @@ class DictProxyTests(unittest.TestCase):
keys = [item[0] for item in it]
keys.sort()
self.assertEqual(keys, ['__dict__', '__doc__', '__module__',
- '__weakref__', 'meth'])
+ '__weakref__', 'meth'])
def test_dict_type_with_metaclass(self):
# Testing type of __dict__ when metaclass set...
@@ -4515,19 +4606,14 @@ class DictProxyTests(unittest.TestCase):
self.assertEqual(type(C.__dict__), type(B.__dict__))
def test_repr(self):
- # Testing dict_proxy.__repr__
- def sorted_dict_repr(repr_):
- # Given the repr of a dict, sort the keys
- assert repr_.startswith('{')
- assert repr_.endswith('}')
- kvs = repr_[1:-1].split(', ')
- return '{' + ', '.join(sorted(kvs)) + '}'
- dict_ = {k: v for k, v in self.C.__dict__.items()}
- repr_ = repr(self.C.__dict__)
- self.assertTrue(repr_.startswith('dict_proxy('))
- self.assertTrue(repr_.endswith(')'))
- self.assertEqual(sorted_dict_repr(repr_[len('dict_proxy('):-len(')')]),
- sorted_dict_repr('{!r}'.format(dict_)))
+ # Testing mappingproxy.__repr__.
+ # We can't blindly compare with the repr of another dict as ordering
+ # of keys and values is arbitrary and may differ.
+ r = repr(self.C.__dict__)
+ self.assertTrue(r.startswith('mappingproxy('), r)
+ self.assertTrue(r.endswith(')'), r)
+ for k, v in self.C.__dict__.items():
+ self.assertIn('{!r}: {!r}'.format(k, v), r)
class PTypesLongInitTest(unittest.TestCase):
@@ -4552,10 +4638,38 @@ class PTypesLongInitTest(unittest.TestCase):
type.mro(tuple)
+class MiscTests(unittest.TestCase):
+ def test_type_lookup_mro_reference(self):
+ # Issue #14199: _PyType_Lookup() has to keep a strong reference to
+ # the type MRO because it may be modified during the lookup, if
+ # __bases__ is set during the lookup for example.
+ class MyKey(object):
+ def __hash__(self):
+ return hash('mykey')
+
+ def __eq__(self, other):
+ X.__bases__ = (Base2,)
+
+ class Base(object):
+ mykey = 'from Base'
+ mykey2 = 'from Base'
+
+ class Base2(object):
+ mykey = 'from Base2'
+ mykey2 = 'from Base2'
+
+ X = type('X', (Base,), {MyKey(): 5})
+ # mykey is read from Base
+ self.assertEqual(X.mykey, 'from Base')
+ # mykey2 is read from Base2 because MyKey.__eq__ has set __bases__
+ self.assertEqual(X.mykey2, 'from Base2')
+
+
def test_main():
# Run all local test cases, with PTypesLongInitTest first.
support.run_unittest(PTypesLongInitTest, OperatorsTest,
- ClassPropertiesAndMethods, DictProxyTests)
+ ClassPropertiesAndMethods, DictProxyTests,
+ MiscTests)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py
index 2db3d33..f495e18 100644
--- a/Lib/test/test_descrtut.py
+++ b/Lib/test/test_descrtut.py
@@ -170,6 +170,7 @@ You can get the information from the list type:
'__contains__',
'__delattr__',
'__delitem__',
+ '__dir__',
'__doc__',
'__eq__',
'__format__',
@@ -199,6 +200,8 @@ You can get the information from the list type:
'__str__',
'__subclasshook__',
'append',
+ 'clear',
+ 'copy',
'count',
'extend',
'index',
diff --git a/Lib/test/test_devpoll.py b/Lib/test/test_devpoll.py
new file mode 100644
index 0000000..bef4e18
--- /dev/null
+++ b/Lib/test/test_devpoll.py
@@ -0,0 +1,94 @@
+# Test case for the select.devpoll() function
+
+# Initial tests are copied as is from "test_poll.py"
+
+import os, select, random, unittest, sys
+from test.support import TESTFN, run_unittest
+
+try:
+ select.devpoll
+except AttributeError:
+ raise unittest.SkipTest("select.devpoll not defined -- skipping test_devpoll")
+
+
+def find_ready_matching(ready, flag):
+ match = []
+ for fd, mode in ready:
+ if mode & flag:
+ match.append(fd)
+ return match
+
+class DevPollTests(unittest.TestCase):
+
+ def test_devpoll1(self):
+ # Basic functional test of poll object
+ # Create a bunch of pipe and test that poll works with them.
+
+ p = select.devpoll()
+
+ NUM_PIPES = 12
+ MSG = b" This is a test."
+ MSG_LEN = len(MSG)
+ readers = []
+ writers = []
+ r2w = {}
+ w2r = {}
+
+ for i in range(NUM_PIPES):
+ rd, wr = os.pipe()
+ p.register(rd)
+ p.modify(rd, select.POLLIN)
+ p.register(wr, select.POLLOUT)
+ readers.append(rd)
+ writers.append(wr)
+ r2w[rd] = wr
+ w2r[wr] = rd
+
+ bufs = []
+
+ while writers:
+ ready = p.poll()
+ ready_writers = find_ready_matching(ready, select.POLLOUT)
+ if not ready_writers:
+ self.fail("no pipes ready for writing")
+ wr = random.choice(ready_writers)
+ os.write(wr, MSG)
+
+ ready = p.poll()
+ ready_readers = find_ready_matching(ready, select.POLLIN)
+ if not ready_readers:
+ self.fail("no pipes ready for reading")
+ self.assertEqual([w2r[wr]], ready_readers)
+ rd = ready_readers[0]
+ buf = os.read(rd, MSG_LEN)
+ self.assertEqual(len(buf), MSG_LEN)
+ bufs.append(buf)
+ os.close(r2w[rd]) ; os.close(rd)
+ p.unregister(r2w[rd])
+ p.unregister(rd)
+ writers.remove(r2w[rd])
+
+ self.assertEqual(bufs, [MSG] * NUM_PIPES)
+
+ def test_timeout_overflow(self):
+ pollster = select.devpoll()
+ w, r = os.pipe()
+ pollster.register(w)
+
+ pollster.poll(-1)
+ self.assertRaises(OverflowError, pollster.poll, -2)
+ self.assertRaises(OverflowError, pollster.poll, -1 << 31)
+ self.assertRaises(OverflowError, pollster.poll, -1 << 64)
+
+ pollster.poll(0)
+ pollster.poll(1)
+ pollster.poll(1 << 30)
+ self.assertRaises(OverflowError, pollster.poll, 1 << 31)
+ self.assertRaises(OverflowError, pollster.poll, 1 << 63)
+ self.assertRaises(OverflowError, pollster.poll, 1 << 64)
+
+def test_main():
+ run_unittest(DevPollTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
index 9666f9a..d1f3192 100644
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -2,7 +2,9 @@ import unittest
from test import support
import collections, random, string
+import collections.abc
import gc, weakref
+import pickle
class DictTest(unittest.TestCase):
@@ -327,6 +329,27 @@ class DictTest(unittest.TestCase):
self.assertEqual(hashed2.hash_count, 1)
self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1)
+ def test_setitem_atomic_at_resize(self):
+ class Hashed(object):
+ def __init__(self):
+ self.hash_count = 0
+ self.eq_count = 0
+ def __hash__(self):
+ self.hash_count += 1
+ return 42
+ def __eq__(self, other):
+ self.eq_count += 1
+ return id(self) == id(other)
+ hashed1 = Hashed()
+ # 5 items
+ y = {hashed1: 5, 0: 0, 1: 1, 2: 2, 3: 3}
+ hashed2 = Hashed()
+ # 6th item forces a resize
+ y[hashed2] = []
+ self.assertEqual(hashed1.hash_count, 1)
+ self.assertEqual(hashed2.hash_count, 1)
+ self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1)
+
def test_popitem(self):
# dict.popitem()
for copymode in -1, +1:
@@ -387,7 +410,7 @@ class DictTest(unittest.TestCase):
x.fail = True
self.assertRaises(Exc, d.pop, x)
- def test_mutatingiteration(self):
+ def test_mutating_iteration(self):
# changing dict size during iteration
d = {}
d[1] = 1
@@ -395,6 +418,32 @@ class DictTest(unittest.TestCase):
for i in d:
d[i+1] = 1
+ def test_mutating_lookup(self):
+ # changing dict during a lookup (issue #14417)
+ class NastyKey:
+ mutate_dict = None
+
+ def __init__(self, value):
+ self.value = value
+
+ def __hash__(self):
+ # hash collision!
+ return 1
+
+ def __eq__(self, other):
+ if NastyKey.mutate_dict:
+ mydict, key = NastyKey.mutate_dict
+ NastyKey.mutate_dict = None
+ del mydict[key]
+ return self.value == other.value
+
+ key1 = NastyKey(1)
+ key2 = NastyKey(2)
+ d = {key1: 1}
+ NastyKey.mutate_dict = (d, key1)
+ d[key2] = 2
+ self.assertEqual(d, {key2: 2})
+
def test_repr(self):
d = {}
self.assertEqual(repr(d), '{}')
@@ -784,6 +833,75 @@ class DictTest(unittest.TestCase):
pass
self._tracked(MyDict())
+ def test_iterator_pickling(self):
+ data = {1:"a", 2:"b", 3:"c"}
+ it = iter(data)
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(sorted(it), sorted(data))
+
+ it = pickle.loads(d)
+ try:
+ drop = next(it)
+ except StopIteration:
+ return
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ del data[drop]
+ self.assertEqual(sorted(it), sorted(data))
+
+ def test_itemiterator_pickling(self):
+ data = {1:"a", 2:"b", 3:"c"}
+ # dictviews aren't picklable, only their iterators
+ itorg = iter(data.items())
+ d = pickle.dumps(itorg)
+ it = pickle.loads(d)
+ # note that the type of type of the unpickled iterator
+ # is not necessarily the same as the original. It is
+ # merely an object supporting the iterator protocol, yielding
+ # the same objects as the original one.
+ # self.assertEqual(type(itorg), type(it))
+ self.assertTrue(isinstance(it, collections.abc.Iterator))
+ self.assertEqual(dict(it), data)
+
+ it = pickle.loads(d)
+ drop = next(it)
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ del data[drop[0]]
+ self.assertEqual(dict(it), data)
+
+ def test_valuesiterator_pickling(self):
+ data = {1:"a", 2:"b", 3:"c"}
+ # data.values() isn't picklable, only its iterator
+ it = iter(data.values())
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(sorted(list(it)), sorted(list(data.values())))
+
+ it = pickle.loads(d)
+ drop = next(it)
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ values = list(it) + [drop]
+ self.assertEqual(sorted(values), sorted(list(data.values())))
+
+ def test_instance_dict_getattr_str_subclass(self):
+ class Foo:
+ def __init__(self, msg):
+ self.msg = msg
+ f = Foo('123')
+ class _str(str):
+ pass
+ self.assertEqual(f.msg, getattr(f, _str('msg')))
+ self.assertEqual(f.msg, f.__dict__[_str('msg')])
+
+ def test_object_set_item_single_instance_non_str_key(self):
+ class Foo: pass
+ f = Foo()
+ f.__dict__[1] = 1
+ f.a = 'a'
+ self.assertEqual(f.__dict__, {1:1, 'a':'a'})
from test import mapping_tests
diff --git a/Lib/test/test_dictcomps.py b/Lib/test/test_dictcomps.py
index 5a19319..3c8b95c 100644
--- a/Lib/test/test_dictcomps.py
+++ b/Lib/test/test_dictcomps.py
@@ -84,8 +84,5 @@ class DictComprehensionTest(unittest.TestCase):
"exec")
-def test_main():
- support.run_unittest(__name__)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index 5c59eaa..1bcd693 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -1,11 +1,35 @@
# Minimal tests for dis module
from test.support import run_unittest, captured_stdout
+import difflib
import unittest
import sys
import dis
import io
+class _C:
+ def __init__(self, x):
+ self.x = x == 1
+
+dis_c_instance_method = """\
+ %-4d 0 LOAD_FAST 1 (x)
+ 3 LOAD_CONST 1 (1)
+ 6 COMPARE_OP 2 (==)
+ 9 LOAD_FAST 0 (self)
+ 12 STORE_ATTR 0 (x)
+ 15 LOAD_CONST 0 (None)
+ 18 RETURN_VALUE
+""" % (_C.__init__.__code__.co_firstlineno + 1,)
+
+dis_c_instance_method_bytes = """\
+ 0 LOAD_FAST 1 (1)
+ 3 LOAD_CONST 1 (1)
+ 6 COMPARE_OP 2 (==)
+ 9 LOAD_FAST 0 (0)
+ 12 STORE_ATTR 0 (0)
+ 15 LOAD_CONST 0 (0)
+ 18 RETURN_VALUE
+"""
def _f(a):
print(a)
@@ -14,7 +38,7 @@ def _f(a):
dis_f = """\
%-4d 0 LOAD_GLOBAL 0 (print)
3 LOAD_FAST 0 (a)
- 6 CALL_FUNCTION 1
+ 6 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
9 POP_TOP
%-4d 10 LOAD_CONST 1 (1)
@@ -23,6 +47,16 @@ dis_f = """\
_f.__code__.co_firstlineno + 2)
+dis_f_co_code = """\
+ 0 LOAD_GLOBAL 0 (0)
+ 3 LOAD_FAST 0 (0)
+ 6 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
+ 9 POP_TOP
+ 10 LOAD_CONST 1 (1)
+ 13 RETURN_VALUE
+"""
+
+
def bug708901():
for res in range(1,
10):
@@ -34,7 +68,7 @@ dis_bug708901 = """\
6 LOAD_CONST 1 (1)
%-4d 9 LOAD_CONST 2 (10)
- 12 CALL_FUNCTION 2
+ 12 CALL_FUNCTION 2 (2 positional, 0 keyword pair)
15 GET_ITER
>> 16 FOR_ITER 6 (to 25)
19 STORE_FAST 0 (res)
@@ -138,18 +172,27 @@ dis_compound_stmt_str = """\
"""
class DisTests(unittest.TestCase):
- def do_disassembly_test(self, func, expected):
+
+ def get_disassembly(self, func, lasti=-1, wrapper=True):
s = io.StringIO()
save_stdout = sys.stdout
sys.stdout = s
- dis.dis(func)
- sys.stdout = save_stdout
- got = s.getvalue()
+ try:
+ if wrapper:
+ dis.dis(func)
+ else:
+ dis.disassemble(func, lasti)
+ finally:
+ sys.stdout = save_stdout
# Trim trailing blanks (if any).
- lines = got.split('\n')
- lines = [line.rstrip() for line in lines]
- expected = expected.split("\n")
- import difflib
+ return [line.rstrip() for line in s.getvalue().splitlines()]
+
+ def get_disassemble_as_string(self, func, lasti=-1):
+ return '\n'.join(self.get_disassembly(func, lasti, False))
+
+ def do_disassembly_test(self, func, expected):
+ lines = self.get_disassembly(func)
+ expected = expected.splitlines()
if expected != lines:
self.fail(
"events did not match expectation:\n" +
@@ -157,7 +200,7 @@ class DisTests(unittest.TestCase):
lines)))
def test_opmap(self):
- self.assertEqual(dis.opmap["STOP_CODE"], 0)
+ self.assertEqual(dis.opmap["NOP"], 9)
self.assertIn(dis.opmap["LOAD_CONST"], dis.hasconst)
self.assertIn(dis.opmap["STORE_NAME"], dis.hasname)
@@ -211,6 +254,44 @@ class DisTests(unittest.TestCase):
self.do_disassembly_test(simple_stmt_str, dis_simple_stmt_str)
self.do_disassembly_test(compound_stmt_str, dis_compound_stmt_str)
+ def test_disassemble_bytes(self):
+ self.do_disassembly_test(_f.__code__.co_code, dis_f_co_code)
+
+ def test_disassemble_method(self):
+ self.do_disassembly_test(_C(1).__init__, dis_c_instance_method)
+
+ def test_disassemble_method_bytes(self):
+ method_bytecode = _C(1).__init__.__code__.co_code
+ self.do_disassembly_test(method_bytecode, dis_c_instance_method_bytes)
+
+ def test_dis_none(self):
+ try:
+ del sys.last_traceback
+ except AttributeError:
+ pass
+ self.assertRaises(RuntimeError, dis.dis, None)
+
+ def test_dis_object(self):
+ self.assertRaises(TypeError, dis.dis, object())
+
+ def test_dis_traceback(self):
+ try:
+ del sys.last_traceback
+ except AttributeError:
+ pass
+
+ try:
+ 1/0
+ except Exception as e:
+ tb = e.__traceback__
+ sys.last_traceback = tb
+
+ tb_dis = self.get_disassemble_as_string(tb.tb_frame.f_code, tb.tb_lasti)
+ self.do_disassembly_test(None, tb_dis)
+
+ def test_dis_object(self):
+ self.assertRaises(TypeError, dis.dis, object())
+
code_info_code_info = """\
Name: code_info
Filename: (.*)
@@ -258,6 +339,7 @@ Flags: OPTIMIZED, NEWLOCALS, VARARGS, VARKEYWORDS, GENERATOR
Constants:
0: None
1: <code object f at (.*), file "(.*)", line (.*)>
+ 2: 'tricky.<locals>.f'
Variable names:
0: x
1: y
@@ -364,6 +446,13 @@ class CodeInfoTests(unittest.TestCase):
dis.show_code(x)
self.assertRegex(output.getvalue(), expected+"\n")
+ def test_code_info_object(self):
+ self.assertRaises(TypeError, dis.code_info, object())
+
+ def test_pretty_flags_no_flags(self):
+ self.assertEqual(dis.pretty_flags(0), '0x0')
+
+
def test_main():
run_unittest(DisTests, CodeInfoTests)
diff --git a/Lib/test/test_doctest.py b/Lib/test/test_doctest.py
index a6c17cc..8f8c7c7 100644
--- a/Lib/test/test_doctest.py
+++ b/Lib/test/test_doctest.py
@@ -5,6 +5,7 @@ Test script for doctest.
from test import support
import doctest
import os
+import sys
# NOTE: There are some additional tests relating to interaction with
@@ -432,7 +433,7 @@ We'll simulate a __file__ attr that ends in pyc:
>>> tests = finder.find(sample_func)
>>> print(tests) # doctest: +ELLIPSIS
- [<DocTest sample_func from ...:17 (1 example)>]
+ [<DocTest sample_func from ...:18 (1 example)>]
The exact name depends on how test_doctest was invoked, so allow for
leading path components.
@@ -1745,226 +1746,227 @@ Run the debugger on the docstring, and then restore sys.stdin.
"""
-def test_pdb_set_trace():
- """Using pdb.set_trace from a doctest.
-
- You can use pdb.set_trace from a doctest. To do so, you must
- retrieve the set_trace function from the pdb module at the time
- you use it. The doctest module changes sys.stdout so that it can
- capture program output. It also temporarily replaces pdb.set_trace
- with a version that restores stdout. This is necessary for you to
- see debugger output.
-
- >>> doc = '''
- ... >>> x = 42
- ... >>> raise Exception('clé')
- ... Traceback (most recent call last):
- ... Exception: clé
- ... >>> import pdb; pdb.set_trace()
- ... '''
- >>> parser = doctest.DocTestParser()
- >>> test = parser.get_doctest(doc, {}, "foo-bar@baz", "foo-bar@baz.py", 0)
- >>> runner = doctest.DocTestRunner(verbose=False)
-
- To demonstrate this, we'll create a fake standard input that
- captures our debugger input:
-
- >>> import tempfile
- >>> real_stdin = sys.stdin
- >>> sys.stdin = _FakeInput([
- ... 'print(x)', # print data defined by the example
- ... 'continue', # stop debugging
- ... ''])
-
- >>> try: runner.run(test)
- ... finally: sys.stdin = real_stdin
- --Return--
- > <doctest foo-bar@baz[2]>(1)<module>()->None
- -> import pdb; pdb.set_trace()
- (Pdb) print(x)
- 42
- (Pdb) continue
- TestResults(failed=0, attempted=3)
-
- You can also put pdb.set_trace in a function called from a test:
-
- >>> def calls_set_trace():
- ... y=2
- ... import pdb; pdb.set_trace()
-
- >>> doc = '''
- ... >>> x=1
- ... >>> calls_set_trace()
- ... '''
- >>> test = parser.get_doctest(doc, globals(), "foo-bar@baz", "foo-bar@baz.py", 0)
- >>> real_stdin = sys.stdin
- >>> sys.stdin = _FakeInput([
- ... 'print(y)', # print data defined in the function
- ... 'up', # out of function
- ... 'print(x)', # print data defined by the example
- ... 'continue', # stop debugging
- ... ''])
-
- >>> try:
- ... runner.run(test)
- ... finally:
- ... sys.stdin = real_stdin
- --Return--
- > <doctest test.test_doctest.test_pdb_set_trace[8]>(3)calls_set_trace()->None
- -> import pdb; pdb.set_trace()
- (Pdb) print(y)
- 2
- (Pdb) up
- > <doctest foo-bar@baz[1]>(1)<module>()
- -> calls_set_trace()
- (Pdb) print(x)
- 1
- (Pdb) continue
- TestResults(failed=0, attempted=2)
-
- During interactive debugging, source code is shown, even for
- doctest examples:
-
- >>> doc = '''
- ... >>> def f(x):
- ... ... g(x*2)
- ... >>> def g(x):
- ... ... print(x+3)
- ... ... import pdb; pdb.set_trace()
- ... >>> f(3)
- ... '''
- >>> test = parser.get_doctest(doc, globals(), "foo-bar@baz", "foo-bar@baz.py", 0)
- >>> real_stdin = sys.stdin
- >>> sys.stdin = _FakeInput([
- ... 'list', # list source from example 2
- ... 'next', # return from g()
- ... 'list', # list source from example 1
- ... 'next', # return from f()
- ... 'list', # list source from example 3
- ... 'continue', # stop debugging
- ... ''])
- >>> try: runner.run(test)
- ... finally: sys.stdin = real_stdin
- ... # doctest: +NORMALIZE_WHITESPACE
- --Return--
- > <doctest foo-bar@baz[1]>(3)g()->None
- -> import pdb; pdb.set_trace()
- (Pdb) list
- 1 def g(x):
- 2 print(x+3)
- 3 -> import pdb; pdb.set_trace()
- [EOF]
- (Pdb) next
- --Return--
- > <doctest foo-bar@baz[0]>(2)f()->None
- -> g(x*2)
- (Pdb) list
- 1 def f(x):
- 2 -> g(x*2)
- [EOF]
- (Pdb) next
- --Return--
- > <doctest foo-bar@baz[2]>(1)<module>()->None
- -> f(3)
- (Pdb) list
- 1 -> f(3)
- [EOF]
- (Pdb) continue
- **********************************************************************
- File "foo-bar@baz.py", line 7, in foo-bar@baz
- Failed example:
- f(3)
- Expected nothing
- Got:
- 9
- TestResults(failed=1, attempted=3)
- """
-
-def test_pdb_set_trace_nested():
- """This illustrates more-demanding use of set_trace with nested functions.
-
- >>> class C(object):
- ... def calls_set_trace(self):
- ... y = 1
- ... import pdb; pdb.set_trace()
- ... self.f1()
- ... y = 2
- ... def f1(self):
- ... x = 1
- ... self.f2()
- ... x = 2
- ... def f2(self):
- ... z = 1
- ... z = 2
-
- >>> calls_set_trace = C().calls_set_trace
-
- >>> doc = '''
- ... >>> a = 1
- ... >>> calls_set_trace()
- ... '''
- >>> parser = doctest.DocTestParser()
- >>> runner = doctest.DocTestRunner(verbose=False)
- >>> test = parser.get_doctest(doc, globals(), "foo-bar@baz", "foo-bar@baz.py", 0)
- >>> real_stdin = sys.stdin
- >>> sys.stdin = _FakeInput([
- ... 'print(y)', # print data defined in the function
- ... 'step', 'step', 'step', 'step', 'step', 'step', 'print(z)',
- ... 'up', 'print(x)',
- ... 'up', 'print(y)',
- ... 'up', 'print(foo)',
- ... 'continue', # stop debugging
- ... ''])
-
- >>> try:
- ... runner.run(test)
- ... finally:
- ... sys.stdin = real_stdin
- ... # doctest: +REPORT_NDIFF
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(5)calls_set_trace()
- -> self.f1()
- (Pdb) print(y)
- 1
- (Pdb) step
- --Call--
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(7)f1()
- -> def f1(self):
- (Pdb) step
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(8)f1()
- -> x = 1
- (Pdb) step
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(9)f1()
- -> self.f2()
- (Pdb) step
- --Call--
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(11)f2()
- -> def f2(self):
- (Pdb) step
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(12)f2()
- -> z = 1
- (Pdb) step
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(13)f2()
- -> z = 2
- (Pdb) print(z)
- 1
- (Pdb) up
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(9)f1()
- -> self.f2()
- (Pdb) print(x)
- 1
- (Pdb) up
- > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(5)calls_set_trace()
- -> self.f1()
- (Pdb) print(y)
- 1
- (Pdb) up
- > <doctest foo-bar@baz[1]>(1)<module>()
- -> calls_set_trace()
- (Pdb) print(foo)
- *** NameError: name 'foo' is not defined
- (Pdb) continue
- TestResults(failed=0, attempted=2)
-"""
+if not hasattr(sys, 'gettrace') or not sys.gettrace():
+ def test_pdb_set_trace():
+ """Using pdb.set_trace from a doctest.
+
+ You can use pdb.set_trace from a doctest. To do so, you must
+ retrieve the set_trace function from the pdb module at the time
+ you use it. The doctest module changes sys.stdout so that it can
+ capture program output. It also temporarily replaces pdb.set_trace
+ with a version that restores stdout. This is necessary for you to
+ see debugger output.
+
+ >>> doc = '''
+ ... >>> x = 42
+ ... >>> raise Exception('clé')
+ ... Traceback (most recent call last):
+ ... Exception: clé
+ ... >>> import pdb; pdb.set_trace()
+ ... '''
+ >>> parser = doctest.DocTestParser()
+ >>> test = parser.get_doctest(doc, {}, "foo-bar@baz", "foo-bar@baz.py", 0)
+ >>> runner = doctest.DocTestRunner(verbose=False)
+
+ To demonstrate this, we'll create a fake standard input that
+ captures our debugger input:
+
+ >>> import tempfile
+ >>> real_stdin = sys.stdin
+ >>> sys.stdin = _FakeInput([
+ ... 'print(x)', # print data defined by the example
+ ... 'continue', # stop debugging
+ ... ''])
+
+ >>> try: runner.run(test)
+ ... finally: sys.stdin = real_stdin
+ --Return--
+ > <doctest foo-bar@baz[2]>(1)<module>()->None
+ -> import pdb; pdb.set_trace()
+ (Pdb) print(x)
+ 42
+ (Pdb) continue
+ TestResults(failed=0, attempted=3)
+
+ You can also put pdb.set_trace in a function called from a test:
+
+ >>> def calls_set_trace():
+ ... y=2
+ ... import pdb; pdb.set_trace()
+
+ >>> doc = '''
+ ... >>> x=1
+ ... >>> calls_set_trace()
+ ... '''
+ >>> test = parser.get_doctest(doc, globals(), "foo-bar@baz", "foo-bar@baz.py", 0)
+ >>> real_stdin = sys.stdin
+ >>> sys.stdin = _FakeInput([
+ ... 'print(y)', # print data defined in the function
+ ... 'up', # out of function
+ ... 'print(x)', # print data defined by the example
+ ... 'continue', # stop debugging
+ ... ''])
+
+ >>> try:
+ ... runner.run(test)
+ ... finally:
+ ... sys.stdin = real_stdin
+ --Return--
+ > <doctest test.test_doctest.test_pdb_set_trace[8]>(3)calls_set_trace()->None
+ -> import pdb; pdb.set_trace()
+ (Pdb) print(y)
+ 2
+ (Pdb) up
+ > <doctest foo-bar@baz[1]>(1)<module>()
+ -> calls_set_trace()
+ (Pdb) print(x)
+ 1
+ (Pdb) continue
+ TestResults(failed=0, attempted=2)
+
+ During interactive debugging, source code is shown, even for
+ doctest examples:
+
+ >>> doc = '''
+ ... >>> def f(x):
+ ... ... g(x*2)
+ ... >>> def g(x):
+ ... ... print(x+3)
+ ... ... import pdb; pdb.set_trace()
+ ... >>> f(3)
+ ... '''
+ >>> test = parser.get_doctest(doc, globals(), "foo-bar@baz", "foo-bar@baz.py", 0)
+ >>> real_stdin = sys.stdin
+ >>> sys.stdin = _FakeInput([
+ ... 'list', # list source from example 2
+ ... 'next', # return from g()
+ ... 'list', # list source from example 1
+ ... 'next', # return from f()
+ ... 'list', # list source from example 3
+ ... 'continue', # stop debugging
+ ... ''])
+ >>> try: runner.run(test)
+ ... finally: sys.stdin = real_stdin
+ ... # doctest: +NORMALIZE_WHITESPACE
+ --Return--
+ > <doctest foo-bar@baz[1]>(3)g()->None
+ -> import pdb; pdb.set_trace()
+ (Pdb) list
+ 1 def g(x):
+ 2 print(x+3)
+ 3 -> import pdb; pdb.set_trace()
+ [EOF]
+ (Pdb) next
+ --Return--
+ > <doctest foo-bar@baz[0]>(2)f()->None
+ -> g(x*2)
+ (Pdb) list
+ 1 def f(x):
+ 2 -> g(x*2)
+ [EOF]
+ (Pdb) next
+ --Return--
+ > <doctest foo-bar@baz[2]>(1)<module>()->None
+ -> f(3)
+ (Pdb) list
+ 1 -> f(3)
+ [EOF]
+ (Pdb) continue
+ **********************************************************************
+ File "foo-bar@baz.py", line 7, in foo-bar@baz
+ Failed example:
+ f(3)
+ Expected nothing
+ Got:
+ 9
+ TestResults(failed=1, attempted=3)
+ """
+
+ def test_pdb_set_trace_nested():
+ """This illustrates more-demanding use of set_trace with nested functions.
+
+ >>> class C(object):
+ ... def calls_set_trace(self):
+ ... y = 1
+ ... import pdb; pdb.set_trace()
+ ... self.f1()
+ ... y = 2
+ ... def f1(self):
+ ... x = 1
+ ... self.f2()
+ ... x = 2
+ ... def f2(self):
+ ... z = 1
+ ... z = 2
+
+ >>> calls_set_trace = C().calls_set_trace
+
+ >>> doc = '''
+ ... >>> a = 1
+ ... >>> calls_set_trace()
+ ... '''
+ >>> parser = doctest.DocTestParser()
+ >>> runner = doctest.DocTestRunner(verbose=False)
+ >>> test = parser.get_doctest(doc, globals(), "foo-bar@baz", "foo-bar@baz.py", 0)
+ >>> real_stdin = sys.stdin
+ >>> sys.stdin = _FakeInput([
+ ... 'print(y)', # print data defined in the function
+ ... 'step', 'step', 'step', 'step', 'step', 'step', 'print(z)',
+ ... 'up', 'print(x)',
+ ... 'up', 'print(y)',
+ ... 'up', 'print(foo)',
+ ... 'continue', # stop debugging
+ ... ''])
+
+ >>> try:
+ ... runner.run(test)
+ ... finally:
+ ... sys.stdin = real_stdin
+ ... # doctest: +REPORT_NDIFF
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(5)calls_set_trace()
+ -> self.f1()
+ (Pdb) print(y)
+ 1
+ (Pdb) step
+ --Call--
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(7)f1()
+ -> def f1(self):
+ (Pdb) step
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(8)f1()
+ -> x = 1
+ (Pdb) step
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(9)f1()
+ -> self.f2()
+ (Pdb) step
+ --Call--
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(11)f2()
+ -> def f2(self):
+ (Pdb) step
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(12)f2()
+ -> z = 1
+ (Pdb) step
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(13)f2()
+ -> z = 2
+ (Pdb) print(z)
+ 1
+ (Pdb) up
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(9)f1()
+ -> self.f2()
+ (Pdb) print(x)
+ 1
+ (Pdb) up
+ > <doctest test.test_doctest.test_pdb_set_trace_nested[0]>(5)calls_set_trace()
+ -> self.f1()
+ (Pdb) print(y)
+ 1
+ (Pdb) up
+ > <doctest foo-bar@baz[1]>(1)<module>()
+ -> calls_set_trace()
+ (Pdb) print(foo)
+ *** NameError: name 'foo' is not defined
+ (Pdb) continue
+ TestResults(failed=0, attempted=2)
+ """
def test_DocTestSuite():
"""DocTestSuite creates a unittest test suite from a doctest.
@@ -2566,7 +2568,7 @@ import sys, re, io
def test_coverage(coverdir):
trace = support.import_module('trace')
- tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix,],
+ tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,],
trace=0, count=1)
tracer.run('test_main()')
r = tracer.results()
diff --git a/Lib/test/test_dummy_thread.py b/Lib/test/test_dummy_thread.py
index c61078d..2fafe1d 100644
--- a/Lib/test/test_dummy_thread.py
+++ b/Lib/test/test_dummy_thread.py
@@ -35,8 +35,8 @@ class LockTests(unittest.TestCase):
"Lock object did not release properly.")
def test_improper_release(self):
- #Make sure release of an unlocked thread raises _thread.error
- self.assertRaises(_thread.error, self.lock.release)
+ #Make sure release of an unlocked thread raises RuntimeError
+ self.assertRaises(RuntimeError, self.lock.release)
def test_cond_acquire_success(self):
#Make sure the conditional acquiring of the lock works.
diff --git a/Lib/test/test_email.py b/Lib/test/test_email.py
deleted file mode 100644
index 5eebba5..0000000
--- a/Lib/test/test_email.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (C) 2001-2007 Python Software Foundation
-# email package unit tests
-
-# The specific tests now live in Lib/email/test
-from email.test.test_email import suite
-from email.test.test_email_codecs import suite as codecs_suite
-from test import support
-
-def test_main():
- support.run_unittest(suite())
- support.run_unittest(codecs_suite())
-
-if __name__ == '__main__':
- test_main()
diff --git a/Lib/test/test_email/__init__.py b/Lib/test/test_email/__init__.py
new file mode 100644
index 0000000..f206ace
--- /dev/null
+++ b/Lib/test/test_email/__init__.py
@@ -0,0 +1,150 @@
+import os
+import sys
+import unittest
+import test.support
+import email
+from email.message import Message
+from email._policybase import compat32
+from test.test_email import __file__ as landmark
+
+# Run all tests in package for '-m unittest test.test_email'
+def load_tests(loader, standard_tests, pattern):
+ this_dir = os.path.dirname(__file__)
+ if pattern is None:
+ pattern = "test*"
+ package_tests = loader.discover(start_dir=this_dir, pattern=pattern)
+ standard_tests.addTests(package_tests)
+ return standard_tests
+
+
+# used by regrtest and __main__.
+def test_main():
+ here = os.path.dirname(__file__)
+ # Unittest mucks with the path, so we have to save and restore
+ # it to keep regrtest happy.
+ savepath = sys.path[:]
+ test.support._run_suite(unittest.defaultTestLoader.discover(here))
+ sys.path[:] = savepath
+
+
+# helper code used by a number of test modules.
+
+def openfile(filename, *args, **kws):
+ path = os.path.join(os.path.dirname(landmark), 'data', filename)
+ return open(path, *args, **kws)
+
+
+# Base test class
+class TestEmailBase(unittest.TestCase):
+
+ maxDiff = None
+ # Currently the default policy is compat32. By setting that as the default
+ # here we make minimal changes in the test_email tests compared to their
+ # pre-3.3 state.
+ policy = compat32
+
+ def __init__(self, *args, **kw):
+ super().__init__(*args, **kw)
+ self.addTypeEqualityFunc(bytes, self.assertBytesEqual)
+
+ # Backward compatibility to minimize test_email test changes.
+ ndiffAssertEqual = unittest.TestCase.assertEqual
+
+ def _msgobj(self, filename):
+ with openfile(filename) as fp:
+ return email.message_from_file(fp, policy=self.policy)
+
+ def _str_msg(self, string, message=Message, policy=None):
+ if policy is None:
+ policy = self.policy
+ return email.message_from_string(string, message, policy=policy)
+
+ def _bytes_repr(self, b):
+ return [repr(x) for x in b.splitlines(keepends=True)]
+
+ def assertBytesEqual(self, first, second, msg):
+ """Our byte strings are really encoded strings; improve diff output"""
+ self.assertEqual(self._bytes_repr(first), self._bytes_repr(second))
+
+ def assertDefectsEqual(self, actual, expected):
+ self.assertEqual(len(actual), len(expected), actual)
+ for i in range(len(actual)):
+ self.assertIsInstance(actual[i], expected[i],
+ 'item {}'.format(i))
+
+
+def parameterize(cls):
+ """A test method parameterization class decorator.
+
+ Parameters are specified as the value of a class attribute that ends with
+ the string '_params'. Call the portion before '_params' the prefix. Then
+ a method to be parameterized must have the same prefix, the string
+ '_as_', and an arbitrary suffix.
+
+ The value of the _params attribute may be either a dictionary or a list.
+ The values in the dictionary and the elements of the list may either be
+ single values, or a list. If single values, they are turned into single
+ element tuples. However derived, the resulting sequence is passed via
+ *args to the parameterized test function.
+
+ In a _params dictioanry, the keys become part of the name of the generated
+ tests. In a _params list, the values in the list are converted into a
+ string by joining the string values of the elements of the tuple by '_' and
+ converting any blanks into '_'s, and this become part of the name.
+ The full name of a generated test is a 'test_' prefix, the portion of the
+ test function name after the '_as_' separator, plus an '_', plus the name
+ derived as explained above.
+
+ For example, if we have:
+
+ count_params = range(2)
+
+ def count_as_foo_arg(self, foo):
+ self.assertEqual(foo+1, myfunc(foo))
+
+ we will get parameterized test methods named:
+ test_foo_arg_0
+ test_foo_arg_1
+ test_foo_arg_2
+
+ Or we could have:
+
+ example_params = {'foo': ('bar', 1), 'bing': ('bang', 2)}
+
+ def example_as_myfunc_input(self, name, count):
+ self.assertEqual(name+str(count), myfunc(name, count))
+
+ and get:
+ test_myfunc_input_foo
+ test_myfunc_input_bing
+
+ Note: if and only if the generated test name is a valid identifier can it
+ be used to select the test individually from the unittest command line.
+
+ """
+ paramdicts = {}
+ for name, attr in cls.__dict__.items():
+ if name.endswith('_params'):
+ if not hasattr(attr, 'keys'):
+ d = {}
+ for x in attr:
+ if not hasattr(x, '__iter__'):
+ x = (x,)
+ n = '_'.join(str(v) for v in x).replace(' ', '_')
+ d[n] = x
+ attr = d
+ paramdicts[name[:-7] + '_as_'] = attr
+ testfuncs = {}
+ for name, attr in cls.__dict__.items():
+ for paramsname, paramsdict in paramdicts.items():
+ if name.startswith(paramsname):
+ testnameroot = 'test_' + name[len(paramsname):]
+ for paramname, params in paramsdict.items():
+ test = (lambda self, name=name, params=params:
+ getattr(self, name)(*params))
+ testname = testnameroot + '_' + paramname
+ test.__name__ = testname
+ testfuncs[testname] = test
+ for key, value in testfuncs.items():
+ setattr(cls, key, value)
+ return cls
diff --git a/Lib/test/test_email/__main__.py b/Lib/test/test_email/__main__.py
new file mode 100644
index 0000000..98af9ec
--- /dev/null
+++ b/Lib/test/test_email/__main__.py
@@ -0,0 +1,3 @@
+from test.test_email import test_main
+
+test_main()
diff --git a/Lib/test/test_email/data/PyBanner048.gif b/Lib/test/test_email/data/PyBanner048.gif
new file mode 100644
index 0000000..1a5c87f
--- /dev/null
+++ b/Lib/test/test_email/data/PyBanner048.gif
Binary files differ
diff --git a/Lib/test/test_email/data/audiotest.au b/Lib/test/test_email/data/audiotest.au
new file mode 100644
index 0000000..f76b050
--- /dev/null
+++ b/Lib/test/test_email/data/audiotest.au
Binary files differ
diff --git a/Lib/test/test_email/data/msg_01.txt b/Lib/test/test_email/data/msg_01.txt
new file mode 100644
index 0000000..7e33bcf
--- /dev/null
+++ b/Lib/test/test_email/data/msg_01.txt
@@ -0,0 +1,19 @@
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+
+
+Hi,
+
+Do you like this message?
+
+-Me
diff --git a/Lib/test/test_email/data/msg_02.txt b/Lib/test/test_email/data/msg_02.txt
new file mode 100644
index 0000000..43f2480
--- /dev/null
+++ b/Lib/test/test_email/data/msg_02.txt
@@ -0,0 +1,135 @@
+MIME-version: 1.0
+From: ppp-request@zzz.org
+Sender: ppp-admin@zzz.org
+To: ppp@zzz.org
+Subject: Ppp digest, Vol 1 #2 - 5 msgs
+Date: Fri, 20 Apr 2001 20:18:00 -0400 (EDT)
+X-Mailer: Mailman v2.0.4
+X-Mailman-Version: 2.0.4
+Content-Type: multipart/mixed; boundary="192.168.1.2.889.32614.987812255.500.21814"
+
+--192.168.1.2.889.32614.987812255.500.21814
+Content-type: text/plain; charset=us-ascii
+Content-description: Masthead (Ppp digest, Vol 1 #2)
+
+Send Ppp mailing list submissions to
+ ppp@zzz.org
+
+To subscribe or unsubscribe via the World Wide Web, visit
+ http://www.zzz.org/mailman/listinfo/ppp
+or, via email, send a message with subject or body 'help' to
+ ppp-request@zzz.org
+
+You can reach the person managing the list at
+ ppp-admin@zzz.org
+
+When replying, please edit your Subject line so it is more specific
+than "Re: Contents of Ppp digest..."
+
+
+--192.168.1.2.889.32614.987812255.500.21814
+Content-type: text/plain; charset=us-ascii
+Content-description: Today's Topics (5 msgs)
+
+Today's Topics:
+
+ 1. testing #1 (Barry A. Warsaw)
+ 2. testing #2 (Barry A. Warsaw)
+ 3. testing #3 (Barry A. Warsaw)
+ 4. testing #4 (Barry A. Warsaw)
+ 5. testing #5 (Barry A. Warsaw)
+
+--192.168.1.2.889.32614.987812255.500.21814
+Content-Type: multipart/digest; boundary="__--__--"
+
+--__--__--
+
+Message: 1
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+Date: Fri, 20 Apr 2001 20:16:13 -0400
+To: ppp@zzz.org
+From: barry@digicool.com (Barry A. Warsaw)
+Subject: [Ppp] testing #1
+Precedence: bulk
+
+
+hello
+
+
+--__--__--
+
+Message: 2
+Date: Fri, 20 Apr 2001 20:16:21 -0400
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+To: ppp@zzz.org
+From: barry@digicool.com (Barry A. Warsaw)
+Precedence: bulk
+
+
+hello
+
+
+--__--__--
+
+Message: 3
+Date: Fri, 20 Apr 2001 20:16:25 -0400
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+To: ppp@zzz.org
+From: barry@digicool.com (Barry A. Warsaw)
+Subject: [Ppp] testing #3
+Precedence: bulk
+
+
+hello
+
+
+--__--__--
+
+Message: 4
+Date: Fri, 20 Apr 2001 20:16:28 -0400
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+To: ppp@zzz.org
+From: barry@digicool.com (Barry A. Warsaw)
+Subject: [Ppp] testing #4
+Precedence: bulk
+
+
+hello
+
+
+--__--__--
+
+Message: 5
+Date: Fri, 20 Apr 2001 20:16:32 -0400
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+To: ppp@zzz.org
+From: barry@digicool.com (Barry A. Warsaw)
+Subject: [Ppp] testing #5
+Precedence: bulk
+
+
+hello
+
+
+
+
+--__--__----
+--192.168.1.2.889.32614.987812255.500.21814
+Content-type: text/plain; charset=us-ascii
+Content-description: Digest Footer
+
+_______________________________________________
+Ppp mailing list
+Ppp@zzz.org
+http://www.zzz.org/mailman/listinfo/ppp
+
+
+--192.168.1.2.889.32614.987812255.500.21814--
+
+End of Ppp Digest
+
diff --git a/Lib/test/test_email/data/msg_03.txt b/Lib/test/test_email/data/msg_03.txt
new file mode 100644
index 0000000..c748ebf
--- /dev/null
+++ b/Lib/test/test_email/data/msg_03.txt
@@ -0,0 +1,16 @@
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+
+
+Hi,
+
+Do you like this message?
+
+-Me
diff --git a/Lib/test/test_email/data/msg_04.txt b/Lib/test/test_email/data/msg_04.txt
new file mode 100644
index 0000000..1f633c4
--- /dev/null
+++ b/Lib/test/test_email/data/msg_04.txt
@@ -0,0 +1,37 @@
+Return-Path: <barry@python.org>
+Delivered-To: barry@python.org
+Received: by mail.python.org (Postfix, from userid 889)
+ id C2BF0D37C6; Tue, 11 Sep 2001 00:05:05 -0400 (EDT)
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary="h90VIIIKmx"
+Content-Transfer-Encoding: 7bit
+Message-ID: <15261.36209.358846.118674@anthem.python.org>
+From: barry@python.org (Barry A. Warsaw)
+To: barry@python.org
+Subject: a simple multipart
+Date: Tue, 11 Sep 2001 00:05:05 -0400
+X-Mailer: VM 6.95 under 21.4 (patch 4) "Artificial Intelligence" XEmacs Lucid
+X-Attribution: BAW
+X-Oblique-Strategy: Make a door into a window
+
+
+--h90VIIIKmx
+Content-Type: text/plain
+Content-Disposition: inline;
+ filename="msg.txt"
+Content-Transfer-Encoding: 7bit
+
+a simple kind of mirror
+to reflect upon our own
+
+--h90VIIIKmx
+Content-Type: text/plain
+Content-Disposition: inline;
+ filename="msg.txt"
+Content-Transfer-Encoding: 7bit
+
+a simple kind of mirror
+to reflect upon our own
+
+--h90VIIIKmx--
+
diff --git a/Lib/test/test_email/data/msg_05.txt b/Lib/test/test_email/data/msg_05.txt
new file mode 100644
index 0000000..87d5e9c
--- /dev/null
+++ b/Lib/test/test_email/data/msg_05.txt
@@ -0,0 +1,28 @@
+From: foo
+Subject: bar
+To: baz
+MIME-Version: 1.0
+Content-Type: multipart/report; report-type=delivery-status;
+ boundary="D1690A7AC1.996856090/mail.example.com"
+Message-Id: <20010803162810.0CA8AA7ACC@mail.example.com>
+
+This is a MIME-encapsulated message.
+
+--D1690A7AC1.996856090/mail.example.com
+Content-Type: text/plain
+
+Yadda yadda yadda
+
+--D1690A7AC1.996856090/mail.example.com
+
+Yadda yadda yadda
+
+--D1690A7AC1.996856090/mail.example.com
+Content-Type: message/rfc822
+
+From: nobody@python.org
+
+Yadda yadda yadda
+
+--D1690A7AC1.996856090/mail.example.com--
+
diff --git a/Lib/test/test_email/data/msg_06.txt b/Lib/test/test_email/data/msg_06.txt
new file mode 100644
index 0000000..69f3a47
--- /dev/null
+++ b/Lib/test/test_email/data/msg_06.txt
@@ -0,0 +1,33 @@
+Return-Path: <barry@python.org>
+Delivered-To: barry@python.org
+MIME-Version: 1.0
+Content-Type: message/rfc822
+Content-Description: forwarded message
+Content-Transfer-Encoding: 7bit
+Message-ID: <15265.9482.641338.555352@python.org>
+From: barry@zope.com (Barry A. Warsaw)
+Sender: barry@python.org
+To: barry@python.org
+Subject: forwarded message from Barry A. Warsaw
+Date: Thu, 13 Sep 2001 17:28:42 -0400
+X-Mailer: VM 6.95 under 21.4 (patch 4) "Artificial Intelligence" XEmacs Lucid
+X-Attribution: BAW
+X-Oblique-Strategy: Be dirty
+X-Url: http://barry.wooz.org
+
+MIME-Version: 1.0
+Content-Type: text/plain; charset=us-ascii
+Return-Path: <barry@python.org>
+Delivered-To: barry@python.org
+Message-ID: <15265.9468.713530.98441@python.org>
+From: barry@zope.com (Barry A. Warsaw)
+Sender: barry@python.org
+To: barry@python.org
+Subject: testing
+Date: Thu, 13 Sep 2001 17:28:28 -0400
+X-Mailer: VM 6.95 under 21.4 (patch 4) "Artificial Intelligence" XEmacs Lucid
+X-Attribution: BAW
+X-Oblique-Strategy: Spectrum analysis
+X-Url: http://barry.wooz.org
+
+
diff --git a/Lib/test/test_email/data/msg_07.txt b/Lib/test/test_email/data/msg_07.txt
new file mode 100644
index 0000000..721f3a0
--- /dev/null
+++ b/Lib/test/test_email/data/msg_07.txt
@@ -0,0 +1,83 @@
+MIME-Version: 1.0
+From: Barry <barry@digicool.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Here is your dingus fish
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+Hi there,
+
+This is the dingus fish.
+
+--BOUNDARY
+Content-Type: image/gif; name="dingusfish.gif"
+Content-Transfer-Encoding: base64
+content-disposition: attachment; filename="dingusfish.gif"
+
+R0lGODdhAAEAAfAAAP///wAAACwAAAAAAAEAAQAC/oSPqcvtD6OctNqLs968+w+G4kiW5omm6sq2
+7gvH8kzX9o3n+s73/g8MCofEovGITGICTKbyCV0FDNOo9SqpQqpOrJfXzTQj2vD3TGtqL+NtGQ2f
+qTXmxzuOd7WXdcc9DyjU53ewFni4s0fGhdiYaEhGBelICTNoV1j5NUnFcrmUqemjNifJVWpaOqaI
+oFq3SspZsSraE7sHq3jr1MZqWvi662vxV4tD+pvKW6aLDOCLyur8PDwbanyDeq0N3DctbQYeLDvR
+RY6t95m6UB0d3mwIrV7e2VGNvjjffukeJp4w7F65KecGFsTHQGAygOrgrWs1jt28Rc88KESYcGLA
+/obvTkH6p+CinWJiJmIMqXGQwH/y4qk0SYjgQTczT3ajKZGfuI0uJ4kkVI/DT5s3/ejkxI0aT4Y+
+YTYgWbImUaXk9nlLmnSh1qJiJFl0OpUqRK4oOy7NyRQtHWofhoYVxkwWXKUSn0YsS+fUV6lhqfYb
+6ayd3Z5qQdG1B7bvQzaJjwUV2lixMUZ7JVsOlfjWVr/3NB/uFvnySBN6Dcb6rGwaRM3wsormw5cC
+M9NxWy/bWdufudCvy8bOAjXjVVwta/uO21sE5RHBCzNFXtgq9ORtH4eYjVP4Yryo026nvkFmCeyA
+B29efV6ravCMK5JwWd5897Qrx7ll38o6iHDZ/rXPR//feevhF4l7wjUGX3xq1eeRfM4RSJGBIV1D
+z1gKPkfWag3mVBVvva1RlX5bAJTPR/2YqNtw/FkIYYEi/pIZiAdpcxpoHtmnYYoZtvhUftzdx5ZX
+JSKDW405zkGcZzzGZ6KEv4FI224oDmijlEf+xp6MJK5ojY/ASeVUR+wsKRuJ+XFZ5o7ZeEime8t1
+ouUsU6YjF5ZtUihhkGfCdFQLWQFJ3UXxmElfhQnR+eCdcDbkFZp6vTRmj56ApCihn5QGpaToNZmR
+n3NVSpZcQpZ2KEONusaiCsKAug0wkQbJSFO+PTSjneGxOuFjPlUk3ovWvdIerjUg9ZGIOtGq/qeX
+eCYrrCX+1UPsgTKGGRSbzd5q156d/gpfbJxe66eD5iQKrXj7RGgruGxs62qebBHUKS32CKluCiqZ
+qh+pmehmEb71noAUoe5e9Zm17S7773V10pjrtG4CmuurCV/n6zLK5turWNhqOvFXbjhZrMD0YhKe
+wR0zOyuvsh6MWrGoIuzvyWu5y1WIFAqmJselypxXh6dKLNOKEB98L88bS2rkNqqlKzCNJp9c0G0j
+Gzh0iRrCbHSXmPR643QS+4rWhgFmnSbSuXCjS0xAOWkU2UdLqyuUNfHSFdUouy3bm5i5GnDM3tG8
+doJ4r5tqu3pPbRSVfvs8uJzeNXhp3n4j/tZ42SwH7eaWUUOjc3qFV9453UHTXZfcLH+OeNs5g36x
+lBnHvTm7EbMbLeuaLncao8vWCXimfo1o+843Ak6y4ChNeGntvAYvfLK4ezmoyNIbNCLTCXO9ZV3A
+E8/s88RczPzDwI4Ob7XZyl7+9Miban29h+tJZPrE21wgvBphDfrrfPdCTPKJD/y98L1rZwHcV6Jq
+Zab0metpuNIX/qAFPoz171WUaUb4HAhBSzHuHfjzHb3kha/2Cctis/ORArVHNYfFyYRH2pYIRzic
+isVOfPWD1b6mRTqpCRBozzof6UZVvFXRxWIr3GGrEviGYgyPMfahheiSaLs/9QeFu7oZ/ndSY8DD
+ya9x+uPed+7mxN2IzIISBOMLFYWVqC3Pew1T2nFuuCiwZS5/v6II10i4t1OJcUH2U9zxKodHsGGv
+Oa+zkvNUYUOa/TCCRutF9MzDwdlUMJADTCGSbDQ5OV4PTamDoPEi6Ecc/RF5RWwkcdSXvSOaDWSn
+I9LlvubFTQpuc6JKXLcKeb+xdbKRBnwREemXyjg6ME65aJiOuBgrktzykfPLJBKR9ClMavJ62/Ff
+BlNIyod9yX9wcSXexnXFpvkrbXk64xsx5Db7wXKP5fSgsvwIMM/9631VLBfkmtbHRXpqmtei52hG
+pUwSlo+BASQoeILDOBgREECxBBh5/iYmNsQ9dIv5+OI++QkqdsJPc3uykz5fkM+OraeekcQF7X4n
+B5S67za5U967PmooGQhUXfF7afXyCD7ONdRe17QogYjVx38uLwtrS6nhTnm15LQUnu9E2uK6CNI/
+1HOABj0ESwOjut4FEpFQpdNAm4K2LHnDWHNcmKB2ioKBogysVZtMO2nSxUdZ8Yk2kJc7URioLVI0
+YgmtIwZj4LoeKemgnOnbUdGnzZ4Oa6scqiolBGqS6RgWNLu0RMhcaE6rhhU4hiuqFXPAG8fGwTPW
+FKeLMtdVmXLSs5YJGF/YeVm7rREMlY3UYE+yCxbaMXX8y15m5zVHq6GOKDMynzII/jdUHdyVqIy0
+ifX2+r/EgtZcvRzSb72gU9ui87M2VecjKildW/aFqaYhKoryUjfB/g4qtyVuc60xFDGmCxwjW+qu
+zjuwl2GkOWn66+3QiiEctvd04OVvcCVzjgT7lrkvjVGKKHmmlDUKowSeikb5kK/mJReuWOxONx+s
+ULsl+Lqb0CVn0SrVyJ6wt4t6yTeSCafhPhAf0OXn6L60UMxiLolFAtmN35S2Ob1lZpQ1r/n0Qb5D
+oQ1zJiRVDgF8N3Q8TYfbi3DyWCy3lT1nxyBs6FT3S2GOzWRlxwKvlRP0RPJA9SjxEy0UoEnkA+M4
+cnzLMJrBGWLFEaaUb5lvpqbq/loOaU5+DFuHPxo82/OZuM8FXG3oVNZhtWpMpb/0Xu5m/LfLhHZQ
+7yuVI0MqZ7NE43imC8jH3IwGZlbPm0xkJYs7+2U48hXTsFSMqgGDvai0kLxyynKNT/waj+q1c1tz
+GjOpPBgdCSq3UKZxCSsqFIY+O6JbAWGWcV1pwqLyj5sGqCF1xb1F3varUWqrJv6cN3PrUXzijtfZ
+FshpBL3Xwr4GIPvU2N8EjrJgS1zl21rbXQMXeXc5jjFyrhpCzijSv/RQtyPSzHCFMhlME95fHglt
+pRsX+dfSQjUeHAlpWzJ5iOo79Ldnaxai6bXTcGO3fp07ri7HLEmXXPlYi8bv/qVxvNcdra6m7Rlb
+6JBTb5fd66VhFRjGArh2n7R1rDW4P5NOT9K0I183T2scYkeZ3q/VFyLb09U9ajzXBS8Kgkhc4mBS
+kYY9cy3Vy9lUnuNJH8HGIclUilwnBtjUOH0gteGOZ4c/XNrhXLSYDyxfnD8z1pDy7rYRvDolhnbe
+UMzxCZUs40s6s7UIvBnLgc0+vKuOkIXeOrDymlp+Zxra4MZLBbVrqD/jTJ597pDmnw5c4+DbyB88
+9Cg9DodYcSuMZT/114pptqc/EuTjRPvH/z5slzI3tluOEBBLqOXLOX+0I5929tO97wkvl/atCz+y
+xJrdwteW2FNW/NSmBP+f/maYtVs/bYyBC7Ox3jsYZHL05CIrBa/nS+b3bHfiYm4Ueil1YZZSgAUI
+fFZ1dxUmeA2oQRQ3RuGXNGLFV9/XbGFGPV6kfzk1TBBCd+izc7q1H+OHMJwmaBX2IQNYVAKHYepV
+SSGCe6CnbYHHETKGNe43EDvFgZr0gB/nVHPHZ80VV1ojOiI3XDvYIkl4ayo4bxQIgrFXWTvBI0nH
+VElWMuw2aLUWCRHHf8ymVCHjFlJnOSojfevCYyyyZDH0IcvHhrsnQ5O1OsWzONuVVKIxSxiFZ/tR
+fKDAf6xFTnw4O9Qig2VCfW2hJQrmMOuHW0W3dLQmCMO2ccdUd/xyfflH/olTiHZVdGwb8nIwRzSE
+J15jFlOJuBZBZ4CiyHyd2IFylFlB+HgHhYabhWOGwYO1ZH/Og1dtQlFMk352CGRSIFTapnWQEUtN
+l4zv8S0aaCFDyGCBqDUxZYpxGHX01y/JuH1xhn7TOCnNCI4eKDs5WGX4R425F4vF1o3BJ4vO0otq
+I3rimI7jJY1jISqnBxknCIvruF83mF5wN4X7qGLIhR8A2Vg0yFERSIXn9Vv3GHy3Vj/WIkKddlYi
+yIMv2I/VMjTLpW7pt05SWIZR0RPyxpB4SIUM9lBPGBl0GC7oSEEwRYLe4pJpZY2P0zbI1n+Oc44w
+qY3PUnmF0ixjVpDD/mJ9wpOBGTVgXlaCaZiPcIWK5NiKBIiPdGaQ0TWGvAiG7nMchdZb7Vgf8zNi
+MuMyzRdy/lePe9iC4TRx7WhhOQI/QiSVNAmAa2lT/piFbuh7ofJoYSZzrSZ1bvmWw3eN2nKUPVky
+uPN5/VRfohRd0VYZoqhKIlU6TXYhJxmPUIloAwc1bPmHEpaZYZORHNlXUJM07hATwHR8MJYqkwWR
+WaIezFhxSFlc8/Fq82hEnpeRozg3ULhhr9lAGtVEkCg5ZNRuuVleBPaZadhG0ZgkyPmDOTOKzViM
+YgOcpukKqQcbjAWS0IleQ2ROjdh6A+md1qWdBRSX7iSYgFRTtRmBpJioieXJiHfJiMGIR9fJOn8I
+MSfXYhspn4ooSa2mSAj4n+8Bmg03fBJZoPOJgsVZRxu1oOMRPXYYjdqjihFaEoZpXBREanuJoRI6
+cibFinq4ngUKh/wQd/H5ofYCZ0HJXR62opZFaAT0iFIZo4DIiUojkjeqKiuoZirKo5Y1a7AWckGa
+BkuYoD5lpDK6eUs6CkDqpETwl1EqpfhJpVeKpVl6EgUAADs=
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_08.txt b/Lib/test/test_email/data/msg_08.txt
new file mode 100644
index 0000000..b563083
--- /dev/null
+++ b/Lib/test/test_email/data/msg_08.txt
@@ -0,0 +1,24 @@
+MIME-Version: 1.0
+From: Barry Warsaw <barry@zope.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Lyrics
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+
+--BOUNDARY
+Content-Type: text/html; charset="iso-8859-1"
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="iso-8859-2"
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="koi8-r"
+
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_09.txt b/Lib/test/test_email/data/msg_09.txt
new file mode 100644
index 0000000..575c4c2
--- /dev/null
+++ b/Lib/test/test_email/data/msg_09.txt
@@ -0,0 +1,24 @@
+MIME-Version: 1.0
+From: Barry Warsaw <barry@zope.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Lyrics
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+
+--BOUNDARY
+Content-Type: text/html; charset="iso-8859-1"
+
+
+--BOUNDARY
+Content-Type: text/plain
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="koi8-r"
+
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_10.txt b/Lib/test/test_email/data/msg_10.txt
new file mode 100644
index 0000000..0790396
--- /dev/null
+++ b/Lib/test/test_email/data/msg_10.txt
@@ -0,0 +1,39 @@
+MIME-Version: 1.0
+From: Barry Warsaw <barry@zope.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Lyrics
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+Content-Transfer-Encoding: 7bit
+
+This is a 7bit encoded message.
+
+--BOUNDARY
+Content-Type: text/html; charset="iso-8859-1"
+Content-Transfer-Encoding: Quoted-Printable
+
+=A1This is a Quoted Printable encoded message!
+
+--BOUNDARY
+Content-Type: text/plain; charset="iso-8859-1"
+Content-Transfer-Encoding: Base64
+
+VGhpcyBpcyBhIEJhc2U2NCBlbmNvZGVkIG1lc3NhZ2Uu
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="iso-8859-1"
+Content-Transfer-Encoding: Base64
+
+VGhpcyBpcyBhIEJhc2U2NCBlbmNvZGVkIG1lc3NhZ2UuCg==
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="iso-8859-1"
+
+This has no Content-Transfer-Encoding: header.
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_11.txt b/Lib/test/test_email/data/msg_11.txt
new file mode 100644
index 0000000..8f7f199
--- /dev/null
+++ b/Lib/test/test_email/data/msg_11.txt
@@ -0,0 +1,7 @@
+Content-Type: message/rfc822
+MIME-Version: 1.0
+Subject: The enclosing message
+
+Subject: An enclosed message
+
+Here is the body of the message.
diff --git a/Lib/test/test_email/data/msg_12.txt b/Lib/test/test_email/data/msg_12.txt
new file mode 100644
index 0000000..4bec8d9
--- /dev/null
+++ b/Lib/test/test_email/data/msg_12.txt
@@ -0,0 +1,36 @@
+MIME-Version: 1.0
+From: Barry Warsaw <barry@zope.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Lyrics
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+
+--BOUNDARY
+Content-Type: text/html; charset="iso-8859-1"
+
+
+--BOUNDARY
+Content-Type: multipart/mixed; boundary="ANOTHER"
+
+--ANOTHER
+Content-Type: text/plain; charset="iso-8859-2"
+
+
+--ANOTHER
+Content-Type: text/plain; charset="iso-8859-3"
+
+--ANOTHER--
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="koi8-r"
+
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_12a.txt b/Lib/test/test_email/data/msg_12a.txt
new file mode 100644
index 0000000..e94224e
--- /dev/null
+++ b/Lib/test/test_email/data/msg_12a.txt
@@ -0,0 +1,38 @@
+MIME-Version: 1.0
+From: Barry Warsaw <barry@zope.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Lyrics
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+
+--BOUNDARY
+Content-Type: text/html; charset="iso-8859-1"
+
+
+--BOUNDARY
+Content-Type: multipart/mixed; boundary="ANOTHER"
+
+--ANOTHER
+Content-Type: text/plain; charset="iso-8859-2"
+
+
+--ANOTHER
+Content-Type: text/plain; charset="iso-8859-3"
+
+
+--ANOTHER--
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="koi8-r"
+
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_13.txt b/Lib/test/test_email/data/msg_13.txt
new file mode 100644
index 0000000..8e6d52d
--- /dev/null
+++ b/Lib/test/test_email/data/msg_13.txt
@@ -0,0 +1,94 @@
+MIME-Version: 1.0
+From: Barry <barry@digicool.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Here is your dingus fish
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="OUTER"
+
+--OUTER
+Content-Type: text/plain; charset="us-ascii"
+
+A text/plain part
+
+--OUTER
+Content-Type: multipart/mixed; boundary=BOUNDARY
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+
+Hi there,
+
+This is the dingus fish.
+
+--BOUNDARY
+Content-Type: image/gif; name="dingusfish.gif"
+Content-Transfer-Encoding: base64
+content-disposition: attachment; filename="dingusfish.gif"
+
+R0lGODdhAAEAAfAAAP///wAAACwAAAAAAAEAAQAC/oSPqcvtD6OctNqLs968+w+G4kiW5omm6sq2
+7gvH8kzX9o3n+s73/g8MCofEovGITGICTKbyCV0FDNOo9SqpQqpOrJfXzTQj2vD3TGtqL+NtGQ2f
+qTXmxzuOd7WXdcc9DyjU53ewFni4s0fGhdiYaEhGBelICTNoV1j5NUnFcrmUqemjNifJVWpaOqaI
+oFq3SspZsSraE7sHq3jr1MZqWvi662vxV4tD+pvKW6aLDOCLyur8PDwbanyDeq0N3DctbQYeLDvR
+RY6t95m6UB0d3mwIrV7e2VGNvjjffukeJp4w7F65KecGFsTHQGAygOrgrWs1jt28Rc88KESYcGLA
+/obvTkH6p+CinWJiJmIMqXGQwH/y4qk0SYjgQTczT3ajKZGfuI0uJ4kkVI/DT5s3/ejkxI0aT4Y+
+YTYgWbImUaXk9nlLmnSh1qJiJFl0OpUqRK4oOy7NyRQtHWofhoYVxkwWXKUSn0YsS+fUV6lhqfYb
+6ayd3Z5qQdG1B7bvQzaJjwUV2lixMUZ7JVsOlfjWVr/3NB/uFvnySBN6Dcb6rGwaRM3wsormw5cC
+M9NxWy/bWdufudCvy8bOAjXjVVwta/uO21sE5RHBCzNFXtgq9ORtH4eYjVP4Yryo026nvkFmCeyA
+B29efV6ravCMK5JwWd5897Qrx7ll38o6iHDZ/rXPR//feevhF4l7wjUGX3xq1eeRfM4RSJGBIV1D
+z1gKPkfWag3mVBVvva1RlX5bAJTPR/2YqNtw/FkIYYEi/pIZiAdpcxpoHtmnYYoZtvhUftzdx5ZX
+JSKDW405zkGcZzzGZ6KEv4FI224oDmijlEf+xp6MJK5ojY/ASeVUR+wsKRuJ+XFZ5o7ZeEime8t1
+ouUsU6YjF5ZtUihhkGfCdFQLWQFJ3UXxmElfhQnR+eCdcDbkFZp6vTRmj56ApCihn5QGpaToNZmR
+n3NVSpZcQpZ2KEONusaiCsKAug0wkQbJSFO+PTSjneGxOuFjPlUk3ovWvdIerjUg9ZGIOtGq/qeX
+eCYrrCX+1UPsgTKGGRSbzd5q156d/gpfbJxe66eD5iQKrXj7RGgruGxs62qebBHUKS32CKluCiqZ
+qh+pmehmEb71noAUoe5e9Zm17S7773V10pjrtG4CmuurCV/n6zLK5turWNhqOvFXbjhZrMD0YhKe
+wR0zOyuvsh6MWrGoIuzvyWu5y1WIFAqmJselypxXh6dKLNOKEB98L88bS2rkNqqlKzCNJp9c0G0j
+Gzh0iRrCbHSXmPR643QS+4rWhgFmnSbSuXCjS0xAOWkU2UdLqyuUNfHSFdUouy3bm5i5GnDM3tG8
+doJ4r5tqu3pPbRSVfvs8uJzeNXhp3n4j/tZ42SwH7eaWUUOjc3qFV9453UHTXZfcLH+OeNs5g36x
+lBnHvTm7EbMbLeuaLncao8vWCXimfo1o+843Ak6y4ChNeGntvAYvfLK4ezmoyNIbNCLTCXO9ZV3A
+E8/s88RczPzDwI4Ob7XZyl7+9Miban29h+tJZPrE21wgvBphDfrrfPdCTPKJD/y98L1rZwHcV6Jq
+Zab0metpuNIX/qAFPoz171WUaUb4HAhBSzHuHfjzHb3kha/2Cctis/ORArVHNYfFyYRH2pYIRzic
+isVOfPWD1b6mRTqpCRBozzof6UZVvFXRxWIr3GGrEviGYgyPMfahheiSaLs/9QeFu7oZ/ndSY8DD
+ya9x+uPed+7mxN2IzIISBOMLFYWVqC3Pew1T2nFuuCiwZS5/v6II10i4t1OJcUH2U9zxKodHsGGv
+Oa+zkvNUYUOa/TCCRutF9MzDwdlUMJADTCGSbDQ5OV4PTamDoPEi6Ecc/RF5RWwkcdSXvSOaDWSn
+I9LlvubFTQpuc6JKXLcKeb+xdbKRBnwREemXyjg6ME65aJiOuBgrktzykfPLJBKR9ClMavJ62/Ff
+BlNIyod9yX9wcSXexnXFpvkrbXk64xsx5Db7wXKP5fSgsvwIMM/9631VLBfkmtbHRXpqmtei52hG
+pUwSlo+BASQoeILDOBgREECxBBh5/iYmNsQ9dIv5+OI++QkqdsJPc3uykz5fkM+OraeekcQF7X4n
+B5S67za5U967PmooGQhUXfF7afXyCD7ONdRe17QogYjVx38uLwtrS6nhTnm15LQUnu9E2uK6CNI/
+1HOABj0ESwOjut4FEpFQpdNAm4K2LHnDWHNcmKB2ioKBogysVZtMO2nSxUdZ8Yk2kJc7URioLVI0
+YgmtIwZj4LoeKemgnOnbUdGnzZ4Oa6scqiolBGqS6RgWNLu0RMhcaE6rhhU4hiuqFXPAG8fGwTPW
+FKeLMtdVmXLSs5YJGF/YeVm7rREMlY3UYE+yCxbaMXX8y15m5zVHq6GOKDMynzII/jdUHdyVqIy0
+ifX2+r/EgtZcvRzSb72gU9ui87M2VecjKildW/aFqaYhKoryUjfB/g4qtyVuc60xFDGmCxwjW+qu
+zjuwl2GkOWn66+3QiiEctvd04OVvcCVzjgT7lrkvjVGKKHmmlDUKowSeikb5kK/mJReuWOxONx+s
+ULsl+Lqb0CVn0SrVyJ6wt4t6yTeSCafhPhAf0OXn6L60UMxiLolFAtmN35S2Ob1lZpQ1r/n0Qb5D
+oQ1zJiRVDgF8N3Q8TYfbi3DyWCy3lT1nxyBs6FT3S2GOzWRlxwKvlRP0RPJA9SjxEy0UoEnkA+M4
+cnzLMJrBGWLFEaaUb5lvpqbq/loOaU5+DFuHPxo82/OZuM8FXG3oVNZhtWpMpb/0Xu5m/LfLhHZQ
+7yuVI0MqZ7NE43imC8jH3IwGZlbPm0xkJYs7+2U48hXTsFSMqgGDvai0kLxyynKNT/waj+q1c1tz
+GjOpPBgdCSq3UKZxCSsqFIY+O6JbAWGWcV1pwqLyj5sGqCF1xb1F3varUWqrJv6cN3PrUXzijtfZ
+FshpBL3Xwr4GIPvU2N8EjrJgS1zl21rbXQMXeXc5jjFyrhpCzijSv/RQtyPSzHCFMhlME95fHglt
+pRsX+dfSQjUeHAlpWzJ5iOo79Ldnaxai6bXTcGO3fp07ri7HLEmXXPlYi8bv/qVxvNcdra6m7Rlb
+6JBTb5fd66VhFRjGArh2n7R1rDW4P5NOT9K0I183T2scYkeZ3q/VFyLb09U9ajzXBS8Kgkhc4mBS
+kYY9cy3Vy9lUnuNJH8HGIclUilwnBtjUOH0gteGOZ4c/XNrhXLSYDyxfnD8z1pDy7rYRvDolhnbe
+UMzxCZUs40s6s7UIvBnLgc0+vKuOkIXeOrDymlp+Zxra4MZLBbVrqD/jTJ597pDmnw5c4+DbyB88
+9Cg9DodYcSuMZT/114pptqc/EuTjRPvH/z5slzI3tluOEBBLqOXLOX+0I5929tO97wkvl/atCz+y
+xJrdwteW2FNW/NSmBP+f/maYtVs/bYyBC7Ox3jsYZHL05CIrBa/nS+b3bHfiYm4Ueil1YZZSgAUI
+fFZ1dxUmeA2oQRQ3RuGXNGLFV9/XbGFGPV6kfzk1TBBCd+izc7q1H+OHMJwmaBX2IQNYVAKHYepV
+SSGCe6CnbYHHETKGNe43EDvFgZr0gB/nVHPHZ80VV1ojOiI3XDvYIkl4ayo4bxQIgrFXWTvBI0nH
+VElWMuw2aLUWCRHHf8ymVCHjFlJnOSojfevCYyyyZDH0IcvHhrsnQ5O1OsWzONuVVKIxSxiFZ/tR
+fKDAf6xFTnw4O9Qig2VCfW2hJQrmMOuHW0W3dLQmCMO2ccdUd/xyfflH/olTiHZVdGwb8nIwRzSE
+J15jFlOJuBZBZ4CiyHyd2IFylFlB+HgHhYabhWOGwYO1ZH/Og1dtQlFMk352CGRSIFTapnWQEUtN
+l4zv8S0aaCFDyGCBqDUxZYpxGHX01y/JuH1xhn7TOCnNCI4eKDs5WGX4R425F4vF1o3BJ4vO0otq
+I3rimI7jJY1jISqnBxknCIvruF83mF5wN4X7qGLIhR8A2Vg0yFERSIXn9Vv3GHy3Vj/WIkKddlYi
+yIMv2I/VMjTLpW7pt05SWIZR0RPyxpB4SIUM9lBPGBl0GC7oSEEwRYLe4pJpZY2P0zbI1n+Oc44w
+qY3PUnmF0ixjVpDD/mJ9wpOBGTVgXlaCaZiPcIWK5NiKBIiPdGaQ0TWGvAiG7nMchdZb7Vgf8zNi
+MuMyzRdy/lePe9iC4TRx7WhhOQI/QiSVNAmAa2lT/piFbuh7ofJoYSZzrSZ1bvmWw3eN2nKUPVky
+uPN5/VRfohRd0VYZoqhKIlU6TXYhJxmPUIloAwc1bPmHEpaZYZORHNlXUJM07hATwHR8MJYqkwWR
+WaIezFhxSFlc8/Fq82hEnpeRozg3ULhhr9lAGtVEkCg5ZNRuuVleBPaZadhG0ZgkyPmDOTOKzViM
+YgOcpukKqQcbjAWS0IleQ2ROjdh6A+md1qWdBRSX7iSYgFRTtRmBpJioieXJiHfJiMGIR9fJOn8I
+MSfXYhspn4ooSa2mSAj4n+8Bmg03fBJZoPOJgsVZRxu1oOMRPXYYjdqjihFaEoZpXBREanuJoRI6
+cibFinq4ngUKh/wQd/H5ofYCZ0HJXR62opZFaAT0iFIZo4DIiUojkjeqKiuoZirKo5Y1a7AWckGa
+BkuYoD5lpDK6eUs6CkDqpETwl1EqpfhJpVeKpVl6EgUAADs=
+
+--BOUNDARY--
+
+--OUTER--
diff --git a/Lib/test/test_email/data/msg_14.txt b/Lib/test/test_email/data/msg_14.txt
new file mode 100644
index 0000000..5d98d2f
--- /dev/null
+++ b/Lib/test/test_email/data/msg_14.txt
@@ -0,0 +1,23 @@
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Type: text; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+
+
+Hi,
+
+I'm sorry but I'm using a drainbread ISP, which although big and
+wealthy can't seem to generate standard compliant email. :(
+
+This message has a Content-Type: header with no subtype. I hope you
+can still read it.
+
+-Me
diff --git a/Lib/test/test_email/data/msg_15.txt b/Lib/test/test_email/data/msg_15.txt
new file mode 100644
index 0000000..0025624
--- /dev/null
+++ b/Lib/test/test_email/data/msg_15.txt
@@ -0,0 +1,52 @@
+Return-Path: <xx@xx.dk>
+Received: from fepD.post.tele.dk (195.41.46.149) by mail.groupcare.dk (LSMTP for Windows NT v1.1b) with SMTP id <0.0014F8A2@mail.groupcare.dk>; Mon, 30 Apr 2001 12:17:50 +0200
+User-Agent: Microsoft-Outlook-Express-Macintosh-Edition/5.02.2106
+Subject: XX
+From: xx@xx.dk
+To: XX
+Message-ID: <xxxx>
+Mime-version: 1.0
+Content-type: multipart/mixed;
+ boundary="MS_Mac_OE_3071477847_720252_MIME_Part"
+
+> Denne meddelelse er i MIME-format. Da dit postl
+
+--MS_Mac_OE_3071477847_720252_MIME_Part
+Content-type: multipart/alternative;
+ boundary="MS_Mac_OE_3071477847_720252_MIME_Part"
+
+
+--MS_Mac_OE_3071477847_720252_MIME_Part
+Content-type: text/plain; charset="ISO-8859-1"
+Content-transfer-encoding: quoted-printable
+
+Some removed test.
+
+--MS_Mac_OE_3071477847_720252_MIME_Part
+Content-type: text/html; charset="ISO-8859-1"
+Content-transfer-encoding: quoted-printable
+
+<HTML>
+<HEAD>
+<TITLE>Some removed HTML</TITLE>
+</HEAD>
+<BODY>
+Some removed text.
+</BODY>
+</HTML>
+
+
+--MS_Mac_OE_3071477847_720252_MIME_Part--
+
+
+--MS_Mac_OE_3071477847_720252_MIME_Part
+Content-type: image/gif; name="xx.gif";
+ x-mac-creator="6F676C65";
+ x-mac-type="47494666"
+Content-disposition: attachment
+Content-transfer-encoding: base64
+
+Some removed base64 encoded chars.
+
+--MS_Mac_OE_3071477847_720252_MIME_Part--
+
diff --git a/Lib/test/test_email/data/msg_16.txt b/Lib/test/test_email/data/msg_16.txt
new file mode 100644
index 0000000..56167e9
--- /dev/null
+++ b/Lib/test/test_email/data/msg_16.txt
@@ -0,0 +1,123 @@
+Return-Path: <>
+Delivered-To: scr-admin@socal-raves.org
+Received: from cougar.noc.ucla.edu (cougar.noc.ucla.edu [169.232.10.18])
+ by babylon.socal-raves.org (Postfix) with ESMTP id CCC2C51B84
+ for <scr-admin@socal-raves.org>; Sun, 23 Sep 2001 20:13:54 -0700 (PDT)
+Received: from sims-ms-daemon by cougar.noc.ucla.edu
+ (Sun Internet Mail Server sims.3.5.2000.03.23.18.03.p10)
+ id <0GK500B01D0B8Y@cougar.noc.ucla.edu> for scr-admin@socal-raves.org; Sun,
+ 23 Sep 2001 20:14:35 -0700 (PDT)
+Received: from cougar.noc.ucla.edu
+ (Sun Internet Mail Server sims.3.5.2000.03.23.18.03.p10)
+ id <0GK500B01D0B8X@cougar.noc.ucla.edu>; Sun, 23 Sep 2001 20:14:35 -0700 (PDT)
+Date: Sun, 23 Sep 2001 20:14:35 -0700 (PDT)
+From: Internet Mail Delivery <postmaster@ucla.edu>
+Subject: Delivery Notification: Delivery has failed
+To: scr-admin@socal-raves.org
+Message-id: <0GK500B04D0B8X@cougar.noc.ucla.edu>
+MIME-version: 1.0
+Sender: scr-owner@socal-raves.org
+Errors-To: scr-owner@socal-raves.org
+X-BeenThere: scr@socal-raves.org
+X-Mailman-Version: 2.1a3
+Precedence: bulk
+List-Help: <mailto:scr-request@socal-raves.org?subject=help>
+List-Post: <mailto:scr@socal-raves.org>
+List-Subscribe: <http://socal-raves.org/mailman/listinfo/scr>,
+ <mailto:scr-request@socal-raves.org?subject=subscribe>
+List-Id: SoCal-Raves <scr.socal-raves.org>
+List-Unsubscribe: <http://socal-raves.org/mailman/listinfo/scr>,
+ <mailto:scr-request@socal-raves.org?subject=unsubscribe>
+List-Archive: <http://socal-raves.org/mailman/private/scr/>
+Content-Type: multipart/report; boundary="Boundary_(ID_PGS2F2a+z+/jL7hupKgRhA)"
+
+
+--Boundary_(ID_PGS2F2a+z+/jL7hupKgRhA)
+Content-type: text/plain; charset=ISO-8859-1
+
+This report relates to a message you sent with the following header fields:
+
+ Message-id: <002001c144a6$8752e060$56104586@oxy.edu>
+ Date: Sun, 23 Sep 2001 20:10:55 -0700
+ From: "Ian T. Henry" <henryi@oxy.edu>
+ To: SoCal Raves <scr@socal-raves.org>
+ Subject: [scr] yeah for Ians!!
+
+Your message cannot be delivered to the following recipients:
+
+ Recipient address: jangel1@cougar.noc.ucla.edu
+ Reason: recipient reached disk quota
+
+
+--Boundary_(ID_PGS2F2a+z+/jL7hupKgRhA)
+Content-type: message/DELIVERY-STATUS
+
+Original-envelope-id: 0GK500B4HD0888@cougar.noc.ucla.edu
+Reporting-MTA: dns; cougar.noc.ucla.edu
+
+Action: failed
+Status: 5.0.0 (recipient reached disk quota)
+Original-recipient: rfc822;jangel1@cougar.noc.ucla.edu
+Final-recipient: rfc822;jangel1@cougar.noc.ucla.edu
+
+--Boundary_(ID_PGS2F2a+z+/jL7hupKgRhA)
+Content-type: MESSAGE/RFC822
+
+Return-path: scr-admin@socal-raves.org
+Received: from sims-ms-daemon by cougar.noc.ucla.edu
+ (Sun Internet Mail Server sims.3.5.2000.03.23.18.03.p10)
+ id <0GK500B01D0B8X@cougar.noc.ucla.edu>; Sun, 23 Sep 2001 20:14:35 -0700 (PDT)
+Received: from panther.noc.ucla.edu by cougar.noc.ucla.edu
+ (Sun Internet Mail Server sims.3.5.2000.03.23.18.03.p10)
+ with ESMTP id <0GK500B4GD0888@cougar.noc.ucla.edu> for jangel1@sims-ms-daemon;
+ Sun, 23 Sep 2001 20:14:33 -0700 (PDT)
+Received: from babylon.socal-raves.org
+ (ip-209-85-222-117.dreamhost.com [209.85.222.117])
+ by panther.noc.ucla.edu (8.9.1a/8.9.1) with ESMTP id UAA09793 for
+ <jangel1@ucla.edu>; Sun, 23 Sep 2001 20:14:32 -0700 (PDT)
+Received: from babylon (localhost [127.0.0.1]) by babylon.socal-raves.org
+ (Postfix) with ESMTP id D3B2951B70; Sun, 23 Sep 2001 20:13:47 -0700 (PDT)
+Received: by babylon.socal-raves.org (Postfix, from userid 60001)
+ id A611F51B82; Sun, 23 Sep 2001 20:13:46 -0700 (PDT)
+Received: from tiger.cc.oxy.edu (tiger.cc.oxy.edu [134.69.3.112])
+ by babylon.socal-raves.org (Postfix) with ESMTP id ADA7351B70 for
+ <scr@socal-raves.org>; Sun, 23 Sep 2001 20:13:44 -0700 (PDT)
+Received: from ent (n16h86.dhcp.oxy.edu [134.69.16.86])
+ by tiger.cc.oxy.edu (8.8.8/8.8.8) with SMTP id UAA08100 for
+ <scr@socal-raves.org>; Sun, 23 Sep 2001 20:14:24 -0700 (PDT)
+Date: Sun, 23 Sep 2001 20:10:55 -0700
+From: "Ian T. Henry" <henryi@oxy.edu>
+Subject: [scr] yeah for Ians!!
+Sender: scr-admin@socal-raves.org
+To: SoCal Raves <scr@socal-raves.org>
+Errors-to: scr-admin@socal-raves.org
+Message-id: <002001c144a6$8752e060$56104586@oxy.edu>
+MIME-version: 1.0
+X-Mailer: Microsoft Outlook Express 5.50.4522.1200
+Content-type: text/plain; charset=us-ascii
+Precedence: bulk
+Delivered-to: scr-post@babylon.socal-raves.org
+Delivered-to: scr@socal-raves.org
+X-Converted-To-Plain-Text: from multipart/alternative by demime 0.98e
+X-Converted-To-Plain-Text: Alternative section used was text/plain
+X-BeenThere: scr@socal-raves.org
+X-Mailman-Version: 2.1a3
+List-Help: <mailto:scr-request@socal-raves.org?subject=help>
+List-Post: <mailto:scr@socal-raves.org>
+List-Subscribe: <http://socal-raves.org/mailman/listinfo/scr>,
+ <mailto:scr-request@socal-raves.org?subject=subscribe>
+List-Id: SoCal-Raves <scr.socal-raves.org>
+List-Unsubscribe: <http://socal-raves.org/mailman/listinfo/scr>,
+ <mailto:scr-request@socal-raves.org?subject=unsubscribe>
+List-Archive: <http://socal-raves.org/mailman/private/scr/>
+
+I always love to find more Ian's that are over 3 years old!!
+
+Ian
+_______________________________________________
+For event info, list questions, or to unsubscribe, see http://www.socal-raves.org/
+
+
+
+--Boundary_(ID_PGS2F2a+z+/jL7hupKgRhA)--
+
diff --git a/Lib/test/test_email/data/msg_17.txt b/Lib/test/test_email/data/msg_17.txt
new file mode 100644
index 0000000..8d86e41
--- /dev/null
+++ b/Lib/test/test_email/data/msg_17.txt
@@ -0,0 +1,12 @@
+MIME-Version: 1.0
+From: Barry <barry@digicool.com>
+To: Dingus Lovers <cravindogs@cravindogs.com>
+Subject: Here is your dingus fish
+Date: Fri, 20 Apr 2001 19:35:02 -0400
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+Hi there,
+
+This is the dingus fish.
+
+[Non-text (image/gif) part of message omitted, filename dingusfish.gif]
diff --git a/Lib/test/test_email/data/msg_18.txt b/Lib/test/test_email/data/msg_18.txt
new file mode 100644
index 0000000..f9f4904
--- /dev/null
+++ b/Lib/test/test_email/data/msg_18.txt
@@ -0,0 +1,6 @@
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+X-Foobar-Spoink-Defrobnit: wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"
+
diff --git a/Lib/test/test_email/data/msg_19.txt b/Lib/test/test_email/data/msg_19.txt
new file mode 100644
index 0000000..49bf7fc
--- /dev/null
+++ b/Lib/test/test_email/data/msg_19.txt
@@ -0,0 +1,43 @@
+Send Ppp mailing list submissions to
+ ppp@zzz.org
+
+To subscribe or unsubscribe via the World Wide Web, visit
+ http://www.zzz.org/mailman/listinfo/ppp
+or, via email, send a message with subject or body 'help' to
+ ppp-request@zzz.org
+
+You can reach the person managing the list at
+ ppp-admin@zzz.org
+
+When replying, please edit your Subject line so it is more specific
+than "Re: Contents of Ppp digest..."
+
+Today's Topics:
+
+ 1. testing #1 (Barry A. Warsaw)
+ 2. testing #2 (Barry A. Warsaw)
+ 3. testing #3 (Barry A. Warsaw)
+ 4. testing #4 (Barry A. Warsaw)
+ 5. testing #5 (Barry A. Warsaw)
+
+hello
+
+
+hello
+
+
+hello
+
+
+hello
+
+
+hello
+
+
+
+_______________________________________________
+Ppp mailing list
+Ppp@zzz.org
+http://www.zzz.org/mailman/listinfo/ppp
+
diff --git a/Lib/test/test_email/data/msg_20.txt b/Lib/test/test_email/data/msg_20.txt
new file mode 100644
index 0000000..1a6a887
--- /dev/null
+++ b/Lib/test/test_email/data/msg_20.txt
@@ -0,0 +1,22 @@
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Cc: ccc@zzz.org
+CC: ddd@zzz.org
+cc: eee@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+
+
+Hi,
+
+Do you like this message?
+
+-Me
diff --git a/Lib/test/test_email/data/msg_21.txt b/Lib/test/test_email/data/msg_21.txt
new file mode 100644
index 0000000..23590b2
--- /dev/null
+++ b/Lib/test/test_email/data/msg_21.txt
@@ -0,0 +1,20 @@
+From: aperson@dom.ain
+To: bperson@dom.ain
+Subject: Test
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+MIME message
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+One
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+Two
+--BOUNDARY--
+End of MIME message
diff --git a/Lib/test/test_email/data/msg_22.txt b/Lib/test/test_email/data/msg_22.txt
new file mode 100644
index 0000000..af9de5f
--- /dev/null
+++ b/Lib/test/test_email/data/msg_22.txt
@@ -0,0 +1,46 @@
+Mime-Version: 1.0
+Message-Id: <a05001902b7f1c33773e9@[134.84.183.138]>
+Date: Tue, 16 Oct 2001 13:59:25 +0300
+To: a@example.com
+From: b@example.com
+Content-Type: multipart/mixed; boundary="============_-1208892523==_============"
+
+--============_-1208892523==_============
+Content-Type: text/plain; charset="us-ascii" ; format="flowed"
+
+Text text text.
+--============_-1208892523==_============
+Content-Id: <a05001902b7f1c33773e9@[134.84.183.138].0.0>
+Content-Type: image/jpeg; name="wibble.JPG"
+ ; x-mac-type="4A504547"
+ ; x-mac-creator="474B4F4E"
+Content-Disposition: attachment; filename="wibble.JPG"
+Content-Transfer-Encoding: base64
+
+/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB
+AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/wAALCAXABIEBAREA
+g6bCjjw/pIZSjO6FWFpldjySOmCNrO7DBZibUXhTwtCixw+GtAijVdqxxaPp0aKvmGXa
+qrbBQvms0mAMeYS/3iTV1dG0hHaRNK01XblnWxtVdjkHLMIgTyqnk9VB7CrP2KzIINpa
+4O7I+zxYO9WV8jZg71Zlb+8rMDkEirAVQFAUAKAFAAAUAYAUDgADgY6DjpRtXj5RxjHA
+4wQRj0wQCMdCAewpaKKK/9k=
+--============_-1208892523==_============
+Content-Id: <a05001902b7f1c33773e9@[134.84.183.138].0.1>
+Content-Type: image/jpeg; name="wibble2.JPG"
+ ; x-mac-type="4A504547"
+ ; x-mac-creator="474B4F4E"
+Content-Disposition: attachment; filename="wibble2.JPG"
+Content-Transfer-Encoding: base64
+
+/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB
+AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/wAALCAXABJ0BAREA
+/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQA
+W6NFJJBEkU10kKGTcWMDwxuU+0JHvk8qAtOpNwqSR0n8c3BlDyXHlqsUltHEiTvdXLxR
+7vMiGDNJAJWkAMk8ZkCFp5G2oo5W++INrbQtNfTQxJAuXlupz9oS4d5Y1W+E2XlWZJJE
+Y7LWYQxTLE1zuMbfBPxw8X2fibVdIbSbI6nLZxX635t9TjtYreWR7WGKJTLJFFKSlozO
+0ShxIXM43uC3/9k=
+--============_-1208892523==_============
+Content-Type: text/plain; charset="us-ascii" ; format="flowed"
+
+Text text text.
+--============_-1208892523==_============--
+
diff --git a/Lib/test/test_email/data/msg_23.txt b/Lib/test/test_email/data/msg_23.txt
new file mode 100644
index 0000000..bb2e8ec
--- /dev/null
+++ b/Lib/test/test_email/data/msg_23.txt
@@ -0,0 +1,8 @@
+From: aperson@dom.ain
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+--BOUNDARY
+Content-Type: text/plain
+
+A message part
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_24.txt b/Lib/test/test_email/data/msg_24.txt
new file mode 100644
index 0000000..4e52339
--- /dev/null
+++ b/Lib/test/test_email/data/msg_24.txt
@@ -0,0 +1,10 @@
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_25.txt b/Lib/test/test_email/data/msg_25.txt
new file mode 100644
index 0000000..9e35275
--- /dev/null
+++ b/Lib/test/test_email/data/msg_25.txt
@@ -0,0 +1,117 @@
+From MAILER-DAEMON Fri Apr 06 16:46:09 2001
+Received: from [204.245.199.98] (helo=zinfandel.lacita.com)
+ by www.linux.org.uk with esmtp (Exim 3.13 #1)
+ id 14lYR6-0008Iv-00
+ for linuxuser-admin@www.linux.org.uk; Fri, 06 Apr 2001 16:46:09 +0100
+Received: from localhost (localhost) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with internal id JAB03225; Fri, 6 Apr 2001 09:23:06 -0800 (GMT-0800)
+Date: Fri, 6 Apr 2001 09:23:06 -0800 (GMT-0800)
+From: Mail Delivery Subsystem <MAILER-DAEMON@zinfandel.lacita.com>
+Subject: Returned mail: Too many hops 19 (17 max): from <linuxuser-admin@www.linux.org.uk> via [199.164.235.226], to <scoffman@wellpartner.com>
+Message-Id: <200104061723.JAB03225@zinfandel.lacita.com>
+To: <linuxuser-admin@www.linux.org.uk>
+To: postmaster@zinfandel.lacita.com
+MIME-Version: 1.0
+Content-Type: multipart/report; report-type=delivery-status;
+ bo
+Auto-Submitted: auto-generated (failure)
+
+This is a MIME-encapsulated message
+
+--JAB03225.986577786/zinfandel.lacita.com
+
+The original message was received at Fri, 6 Apr 2001 09:23:03 -0800 (GMT-0800)
+from [199.164.235.226]
+
+ ----- The following addresses have delivery notifications -----
+<scoffman@wellpartner.com> (unrecoverable error)
+
+ ----- Transcript of session follows -----
+554 Too many hops 19 (17 max): from <linuxuser-admin@www.linux.org.uk> via [199.164.235.226], to <scoffman@wellpartner.com>
+
+--JAB03225.986577786/zinfandel.lacita.com
+Content-Type: message/delivery-status
+
+Reporting-MTA: dns; zinfandel.lacita.com
+Received-From-MTA: dns; [199.164.235.226]
+Arrival-Date: Fri, 6 Apr 2001 09:23:03 -0800 (GMT-0800)
+
+Final-Recipient: rfc822; scoffman@wellpartner.com
+Action: failed
+Status: 5.4.6
+Last-Attempt-Date: Fri, 6 Apr 2001 09:23:06 -0800 (GMT-0800)
+
+--JAB03225.986577786/zinfandel.lacita.com
+Content-Type: text/rfc822-headers
+
+Return-Path: linuxuser-admin@www.linux.org.uk
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03225 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:23:03 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03221 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:22:18 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03217 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:21:37 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03213 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:20:56 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03209 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:20:15 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03205 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:19:33 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03201 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:18:52 -0800 (GMT-0800)
+Received: from zinfandel.lacita.com ([204.245.199.98])
+ by
+ fo
+Received: from ns1.wellpartner.net ([199.164.235.226]) by zinfandel.lacita.com (8.7.3/8.6.10-MT4.00) with ESMTP id JAA03197 for <scoffman@wellpartner.com>; Fri, 6 Apr 2001 09:17:54 -0800 (GMT-0800)
+Received: from www.linux.org.uk (parcelfarce.linux.theplanet.co.uk [195.92.249.252])
+ by
+ fo
+Received: from localhost.localdomain
+ ([
+ by
+ id
+Received: from [212.1.130.11] (helo=s1.uklinux.net ident=root)
+ by
+ id
+ fo
+Received: from server (ppp-2-22.cvx4.telinco.net [212.1.149.22])
+ by
+ fo
+From: Daniel James <daniel@linuxuser.co.uk>
+Organization: LinuxUser
+To: linuxuser@www.linux.org.uk
+X-Mailer: KMail [version 1.1.99]
+Content-Type: text/plain;
+ c
+MIME-Version: 1.0
+Message-Id: <01040616033903.00962@server>
+Content-Transfer-Encoding: 8bit
+Subject: [LinuxUser] bulletin no. 45
+Sender: linuxuser-admin@www.linux.org.uk
+Errors-To: linuxuser-admin@www.linux.org.uk
+X-BeenThere: linuxuser@www.linux.org.uk
+X-Mailman-Version: 2.0.3
+Precedence: bulk
+List-Help: <mailto:linuxuser-request@www.linux.org.uk?subject=help>
+List-Post: <mailto:linuxuser@www.linux.org.uk>
+List-Subscribe: <http://www.linux.org.uk/mailman/listinfo/linuxuser>,
+ <m
+List-Id: bulletins from LinuxUser magazine <linuxuser.www.linux.org.uk>
+List-Unsubscribe: <http://www.linux.org.uk/mailman/listinfo/linuxuser>,
+ <m
+List-Archive: <http://www.linux.org.uk/pipermail/linuxuser/>
+Date: Fri, 6 Apr 2001 16:03:39 +0100
+
+--JAB03225.986577786/zinfandel.lacita.com--
+
+
diff --git a/Lib/test/test_email/data/msg_26.txt b/Lib/test/test_email/data/msg_26.txt
new file mode 100644
index 0000000..58efaa9
--- /dev/null
+++ b/Lib/test/test_email/data/msg_26.txt
@@ -0,0 +1,46 @@
+Received: from xcar [192.168.0.2] by jeeves.wooster.local
+ (SMTPD32-7.07 EVAL) id AFF92F0214; Sun, 12 May 2002 08:55:37 +0100
+Date: Sun, 12 May 2002 08:56:15 +0100
+From: Father Time <father.time@xcar.wooster.local>
+To: timbo@jeeves.wooster.local
+Subject: IMAP file test
+Message-ID: <6df65d354b.father.time@rpc.wooster.local>
+X-Organization: Home
+User-Agent: Messenger-Pro/2.50a (MsgServe/1.50) (RISC-OS/4.02) POPstar/2.03
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary="1618492860--2051301190--113853680"
+Status: R
+X-UIDL: 319998302
+
+This message is in MIME format which your mailer apparently does not support.
+You either require a newer version of your software which supports MIME, or
+a separate MIME decoding utility. Alternatively, ask the sender of this
+message to resend it in a different format.
+
+--1618492860--2051301190--113853680
+Content-Type: text/plain; charset=us-ascii
+
+Simple email with attachment.
+
+
+--1618492860--2051301190--113853680
+Content-Type: application/riscos; name="clock.bmp,69c"; type=BMP;
+ load=&fff69c4b; exec=&355dd4d1; access=&03
+Content-Disposition: attachment; filename="clock.bmp"
+Content-Transfer-Encoding: base64
+
+Qk12AgAAAAAAAHYAAAAoAAAAIAAAACAAAAABAAQAAAAAAAAAAADXDQAA1w0AAAAAAAAA
+AAAAAAAAAAAAiAAAiAAAAIiIAIgAAACIAIgAiIgAALu7uwCIiIgAERHdACLuIgAz//8A
+zAAAAN0R3QDu7iIA////AAAAAAAAAAAAAAAAAAAAAAAAAAi3AAAAAAAAADeAAAAAAAAA
+C3ADMzMzMANwAAAAAAAAAAAHMAAAAANwAAAAAAAAAACAMAd3zPfwAwgAAAAAAAAIAwd/
+f8x/f3AwgAAAAAAAgDB0x/f3//zPAwgAAAAAAAcHfM9////8z/AwAAAAAAiwd/f3////
+////A4AAAAAAcEx/f///////zAMAAAAAiwfM9////3///8zwOAAAAAcHf3////B/////
+8DAAAAALB/f3///wd3d3//AwAAAABwTPf//wCQAAD/zAMAAAAAsEx/f///B////8wDAA
+AAAHB39////wf/////AwAAAACwf39///8H/////wMAAAAIcHfM9///B////M8DgAAAAA
+sHTH///wf///xAMAAAAACHB3f3//8H////cDgAAAAAALB3zH//D//M9wMAAAAAAAgLB0
+z39///xHAwgAAAAAAAgLB3d3RHd3cDCAAAAAAAAAgLAHd0R3cAMIAAAAAAAAgAgLcAAA
+AAMwgAgAAAAACDAAAAu7t7cwAAgDgAAAAABzcIAAAAAAAAgDMwAAAAAAN7uwgAAAAAgH
+MzMAAAAACH97tzAAAAALu3c3gAAAAAAL+7tzDABAu7f7cAAAAAAACA+3MA7EQAv/sIAA
+AAAAAAAIAAAAAAAAAIAAAAAA
+
+--1618492860--2051301190--113853680--
diff --git a/Lib/test/test_email/data/msg_27.txt b/Lib/test/test_email/data/msg_27.txt
new file mode 100644
index 0000000..d019176
--- /dev/null
+++ b/Lib/test/test_email/data/msg_27.txt
@@ -0,0 +1,15 @@
+Return-Path: <aperson@dom.ain>
+Received: by mail.dom.ain (Postfix, from userid 889)
+ id B9D0AD35DB; Tue, 4 Jun 2002 21:46:59 -0400 (EDT)
+Message-ID: <15613.28051.707126.569693@dom.ain>
+Date: Tue, 4 Jun 2002 21:46:59 -0400
+MIME-Version: 1.0
+Content-Type: text/plain; charset=us-ascii
+Content-Transfer-Encoding: 7bit
+Subject: bug demonstration
+ 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+ more text
+From: aperson@dom.ain (Anne P. Erson)
+To: bperson@dom.ain (Barney P. Erson)
+
+test
diff --git a/Lib/test/test_email/data/msg_28.txt b/Lib/test/test_email/data/msg_28.txt
new file mode 100644
index 0000000..1e4824c
--- /dev/null
+++ b/Lib/test/test_email/data/msg_28.txt
@@ -0,0 +1,25 @@
+From: aperson@dom.ain
+MIME-Version: 1.0
+Content-Type: multipart/digest; boundary=BOUNDARY
+
+--BOUNDARY
+Content-Type: message/rfc822
+
+Content-Type: text/plain; charset=us-ascii
+To: aa@bb.org
+From: cc@dd.org
+Subject: ee
+
+message 1
+
+--BOUNDARY
+Content-Type: message/rfc822
+
+Content-Type: text/plain; charset=us-ascii
+To: aa@bb.org
+From: cc@dd.org
+Subject: ee
+
+message 2
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_29.txt b/Lib/test/test_email/data/msg_29.txt
new file mode 100644
index 0000000..1fab561
--- /dev/null
+++ b/Lib/test/test_email/data/msg_29.txt
@@ -0,0 +1,22 @@
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Type: text/plain; charset=us-ascii;
+ title*0*="us-ascii'en'This%20is%20even%20more%20";
+ title*1*="%2A%2A%2Afun%2A%2A%2A%20";
+ title*2="isn't it!"
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+
+
+Hi,
+
+Do you like this message?
+
+-Me
diff --git a/Lib/test/test_email/data/msg_30.txt b/Lib/test/test_email/data/msg_30.txt
new file mode 100644
index 0000000..4334bb6
--- /dev/null
+++ b/Lib/test/test_email/data/msg_30.txt
@@ -0,0 +1,23 @@
+From: aperson@dom.ain
+MIME-Version: 1.0
+Content-Type: multipart/digest; boundary=BOUNDARY
+
+--BOUNDARY
+
+Content-Type: text/plain; charset=us-ascii
+To: aa@bb.org
+From: cc@dd.org
+Subject: ee
+
+message 1
+
+--BOUNDARY
+
+Content-Type: text/plain; charset=us-ascii
+To: aa@bb.org
+From: cc@dd.org
+Subject: ee
+
+message 2
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_31.txt b/Lib/test/test_email/data/msg_31.txt
new file mode 100644
index 0000000..1e58e56
--- /dev/null
+++ b/Lib/test/test_email/data/msg_31.txt
@@ -0,0 +1,15 @@
+From: aperson@dom.ain
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary=BOUNDARY_
+
+--BOUNDARY
+Content-Type: text/plain
+
+message 1
+
+--BOUNDARY
+Content-Type: text/plain
+
+message 2
+
+--BOUNDARY--
diff --git a/Lib/test/test_email/data/msg_32.txt b/Lib/test/test_email/data/msg_32.txt
new file mode 100644
index 0000000..07ec5af
--- /dev/null
+++ b/Lib/test/test_email/data/msg_32.txt
@@ -0,0 +1,14 @@
+Delivered-To: freebsd-isp@freebsd.org
+Date: Tue, 26 Sep 2000 12:23:03 -0500
+From: Anne Person <aperson@example.com>
+To: Barney Dude <bdude@example.com>
+Subject: Re: Limiting Perl CPU Utilization...
+Mime-Version: 1.0
+Content-Type: text/plain; charset*=ansi-x3.4-1968''us-ascii
+Content-Disposition: inline
+User-Agent: Mutt/1.3.8i
+Sender: owner-freebsd-isp@FreeBSD.ORG
+Precedence: bulk
+X-Loop: FreeBSD.org
+
+Some message.
diff --git a/Lib/test/test_email/data/msg_33.txt b/Lib/test/test_email/data/msg_33.txt
new file mode 100644
index 0000000..042787a
--- /dev/null
+++ b/Lib/test/test_email/data/msg_33.txt
@@ -0,0 +1,29 @@
+Delivered-To: freebsd-isp@freebsd.org
+Date: Wed, 27 Sep 2000 11:11:09 -0500
+From: Anne Person <aperson@example.com>
+To: Barney Dude <bdude@example.com>
+Subject: Re: Limiting Perl CPU Utilization...
+Mime-Version: 1.0
+Content-Type: multipart/signed; micalg*=ansi-x3.4-1968''pgp-md5;
+ protocol*=ansi-x3.4-1968''application%2Fpgp-signature;
+ boundary*="ansi-x3.4-1968''EeQfGwPcQSOJBaQU"
+Content-Disposition: inline
+Sender: owner-freebsd-isp@FreeBSD.ORG
+Precedence: bulk
+X-Loop: FreeBSD.org
+
+
+--EeQfGwPcQSOJBaQU
+Content-Type: text/plain; charset*=ansi-x3.4-1968''us-ascii
+Content-Disposition: inline
+Content-Transfer-Encoding: quoted-printable
+
+part 1
+
+--EeQfGwPcQSOJBaQU
+Content-Type: text/plain
+Content-Disposition: inline
+
+part 2
+
+--EeQfGwPcQSOJBaQU--
diff --git a/Lib/test/test_email/data/msg_34.txt b/Lib/test/test_email/data/msg_34.txt
new file mode 100644
index 0000000..055dfea
--- /dev/null
+++ b/Lib/test/test_email/data/msg_34.txt
@@ -0,0 +1,19 @@
+From: aperson@dom.ain
+To: bperson@dom.ain
+Content-Type: multipart/digest; boundary=XYZ
+
+--XYZ
+Content-Type: text/plain
+
+
+This is a text plain part that is counter to recommended practice in
+RFC 2046, $5.1.5, but is not illegal
+
+--XYZ
+
+From: cperson@dom.ain
+To: dperson@dom.ain
+
+A submessage
+
+--XYZ--
diff --git a/Lib/test/test_email/data/msg_35.txt b/Lib/test/test_email/data/msg_35.txt
new file mode 100644
index 0000000..be7d5a2
--- /dev/null
+++ b/Lib/test/test_email/data/msg_35.txt
@@ -0,0 +1,4 @@
+From: aperson@dom.ain
+To: bperson@dom.ain
+Subject: here's something interesting
+counter to RFC 2822, there's no separating newline here
diff --git a/Lib/test/test_email/data/msg_36.txt b/Lib/test/test_email/data/msg_36.txt
new file mode 100644
index 0000000..5632c30
--- /dev/null
+++ b/Lib/test/test_email/data/msg_36.txt
@@ -0,0 +1,40 @@
+Mime-Version: 1.0
+Content-Type: Multipart/Mixed; Boundary="NextPart"
+To: IETF-Announce:;
+From: Internet-Drafts@ietf.org
+Subject: I-D ACTION:draft-ietf-mboned-mix-00.txt
+Date: Tue, 22 Dec 1998 16:55:06 -0500
+
+--NextPart
+
+Blah blah blah
+
+--NextPart
+Content-Type: Multipart/Alternative; Boundary="OtherAccess"
+
+--OtherAccess
+Content-Type: Message/External-body;
+ access-type="mail-server";
+ server="mailserv@ietf.org"
+
+Content-Type: text/plain
+Content-ID: <19981222151406.I-D@ietf.org>
+
+ENCODING mime
+FILE /internet-drafts/draft-ietf-mboned-mix-00.txt
+
+--OtherAccess
+Content-Type: Message/External-body;
+ name="draft-ietf-mboned-mix-00.txt";
+ site="ftp.ietf.org";
+ access-type="anon-ftp";
+ directory="internet-drafts"
+
+Content-Type: text/plain
+Content-ID: <19981222151406.I-D@ietf.org>
+
+
+--OtherAccess--
+
+--NextPart--
+
diff --git a/Lib/test/test_email/data/msg_37.txt b/Lib/test/test_email/data/msg_37.txt
new file mode 100644
index 0000000..038d34a
--- /dev/null
+++ b/Lib/test/test_email/data/msg_37.txt
@@ -0,0 +1,22 @@
+Content-Type: multipart/mixed; boundary=ABCDE
+
+--ABCDE
+Content-Type: text/x-one
+
+Blah
+
+--ABCDE
+--ABCDE
+Content-Type: text/x-two
+
+Blah
+
+--ABCDE
+--ABCDE
+--ABCDE
+--ABCDE
+Content-Type: text/x-two
+
+Blah
+
+--ABCDE--
diff --git a/Lib/test/test_email/data/msg_38.txt b/Lib/test/test_email/data/msg_38.txt
new file mode 100644
index 0000000..006df81
--- /dev/null
+++ b/Lib/test/test_email/data/msg_38.txt
@@ -0,0 +1,101 @@
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary="----- =_aaaaaaaaaa0"
+
+------- =_aaaaaaaaaa0
+Content-Type: multipart/mixed; boundary="----- =_aaaaaaaaaa1"
+Content-ID: <20592.1022586929.1@example.com>
+
+------- =_aaaaaaaaaa1
+Content-Type: multipart/alternative; boundary="----- =_aaaaaaaaaa2"
+Content-ID: <20592.1022586929.2@example.com>
+
+------- =_aaaaaaaaaa2
+Content-Type: text/plain
+Content-ID: <20592.1022586929.3@example.com>
+Content-Description: very tricky
+Content-Transfer-Encoding: 7bit
+
+
+Unlike the test test_nested-multiples-with-internal-boundary, this
+piece of text not only contains the outer boundary tags
+------- =_aaaaaaaaaa1
+and
+------- =_aaaaaaaaaa0
+but puts them at the start of a line! And, to be even nastier, it
+even includes a couple of end tags, such as this one:
+
+------- =_aaaaaaaaaa1--
+
+and this one, which is from a multipart we haven't even seen yet!
+
+------- =_aaaaaaaaaa4--
+
+This will, I'm sure, cause much breakage of MIME parsers. But, as
+far as I can tell, it's perfectly legal. I have not yet ever seen
+a case of this in the wild, but I've seen *similar* things.
+
+
+------- =_aaaaaaaaaa2
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.4@example.com>
+Content-Description: patch2
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa2--
+
+------- =_aaaaaaaaaa1
+Content-Type: multipart/alternative; boundary="----- =_aaaaaaaaaa3"
+Content-ID: <20592.1022586929.6@example.com>
+
+------- =_aaaaaaaaaa3
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.7@example.com>
+Content-Description: patch3
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa3
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.8@example.com>
+Content-Description: patch4
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa3--
+
+------- =_aaaaaaaaaa1
+Content-Type: multipart/alternative; boundary="----- =_aaaaaaaaaa4"
+Content-ID: <20592.1022586929.10@example.com>
+
+------- =_aaaaaaaaaa4
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.11@example.com>
+Content-Description: patch5
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa4
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.12@example.com>
+Content-Description: patch6
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa4--
+
+------- =_aaaaaaaaaa1--
+
+------- =_aaaaaaaaaa0
+Content-Type: text/plain; charset="us-ascii"
+Content-ID: <20592.1022586929.15@example.com>
+
+--
+It's never too late to have a happy childhood.
+
+------- =_aaaaaaaaaa0--
diff --git a/Lib/test/test_email/data/msg_39.txt b/Lib/test/test_email/data/msg_39.txt
new file mode 100644
index 0000000..124b269
--- /dev/null
+++ b/Lib/test/test_email/data/msg_39.txt
@@ -0,0 +1,83 @@
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary="----- =_aaaaaaaaaa0"
+
+------- =_aaaaaaaaaa0
+Content-Type: multipart/mixed; boundary="----- =_aaaaaaaaaa1"
+Content-ID: <20592.1022586929.1@example.com>
+
+------- =_aaaaaaaaaa1
+Content-Type: multipart/alternative; boundary="----- =_aaaaaaaaaa1"
+Content-ID: <20592.1022586929.2@example.com>
+
+------- =_aaaaaaaaaa1
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.3@example.com>
+Content-Description: patch1
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa1
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.4@example.com>
+Content-Description: patch2
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa1--
+
+------- =_aaaaaaaaaa1
+Content-Type: multipart/alternative; boundary="----- =_aaaaaaaaaa1"
+Content-ID: <20592.1022586929.6@example.com>
+
+------- =_aaaaaaaaaa1
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.7@example.com>
+Content-Description: patch3
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa1
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.8@example.com>
+Content-Description: patch4
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa1--
+
+------- =_aaaaaaaaaa1
+Content-Type: multipart/alternative; boundary="----- =_aaaaaaaaaa1"
+Content-ID: <20592.1022586929.10@example.com>
+
+------- =_aaaaaaaaaa1
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.11@example.com>
+Content-Description: patch5
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa1
+Content-Type: application/octet-stream
+Content-ID: <20592.1022586929.12@example.com>
+Content-Description: patch6
+Content-Transfer-Encoding: base64
+
+XXX
+
+------- =_aaaaaaaaaa1--
+
+------- =_aaaaaaaaaa1--
+
+------- =_aaaaaaaaaa0
+Content-Type: text/plain; charset="us-ascii"
+Content-ID: <20592.1022586929.15@example.com>
+
+--
+It's never too late to have a happy childhood.
+
+------- =_aaaaaaaaaa0--
diff --git a/Lib/test/test_email/data/msg_40.txt b/Lib/test/test_email/data/msg_40.txt
new file mode 100644
index 0000000..1435fa1
--- /dev/null
+++ b/Lib/test/test_email/data/msg_40.txt
@@ -0,0 +1,10 @@
+MIME-Version: 1.0
+Content-Type: text/html; boundary="--961284236552522269"
+
+----961284236552522269
+Content-Type: text/html;
+Content-Transfer-Encoding: 7Bit
+
+<html></html>
+
+----961284236552522269--
diff --git a/Lib/test/test_email/data/msg_41.txt b/Lib/test/test_email/data/msg_41.txt
new file mode 100644
index 0000000..76cdd1c
--- /dev/null
+++ b/Lib/test/test_email/data/msg_41.txt
@@ -0,0 +1,8 @@
+From: "Allison Dunlap" <xxx@example.com>
+To: yyy@example.com
+Subject: 64423
+Date: Sun, 11 Jul 2004 16:09:27 -0300
+MIME-Version: 1.0
+Content-Type: multipart/alternative;
+
+Blah blah blah
diff --git a/Lib/test/test_email/data/msg_42.txt b/Lib/test/test_email/data/msg_42.txt
new file mode 100644
index 0000000..a75f8f4
--- /dev/null
+++ b/Lib/test/test_email/data/msg_42.txt
@@ -0,0 +1,20 @@
+Content-Type: multipart/mixed; boundary="AAA"
+From: Mail Delivery Subsystem <xxx@example.com>
+To: yyy@example.com
+
+This is a MIME-encapsulated message
+
+--AAA
+
+Stuff
+
+--AAA
+Content-Type: message/rfc822
+
+From: webmaster@python.org
+To: zzz@example.com
+Content-Type: multipart/mixed; boundary="BBB"
+
+--BBB--
+
+--AAA--
diff --git a/Lib/test/test_email/data/msg_43.txt b/Lib/test/test_email/data/msg_43.txt
new file mode 100644
index 0000000..797d12c
--- /dev/null
+++ b/Lib/test/test_email/data/msg_43.txt
@@ -0,0 +1,217 @@
+From SRS0=aO/p=ON=bag.python.org=None@bounce2.pobox.com Fri Nov 26 21:40:36 2004
+X-VM-v5-Data: ([nil nil nil nil nil nil nil nil nil]
+ [nil nil nil nil nil nil nil "MAILER DAEMON <>" "MAILER DAEMON <>" nil nil "Banned file: auto__mail.python.bat in mail from you" "^From:" nil nil nil nil "Banned file: auto__mail.python.bat in mail from you" nil nil nil nil nil nil nil]
+ nil)
+MIME-Version: 1.0
+Message-Id: <edab.7804f5cb8070@python.org>
+Content-Type: multipart/report; report-type=delivery-status;
+ charset=utf-8;
+ boundary="----------=_1101526904-1956-5"
+X-Virus-Scanned: by XS4ALL Virus Scanner
+X-UIDL: 4\G!!!<c"!UV["!M7C!!
+From: MAILER DAEMON <>
+To: <webmaster@python.org>
+Subject: Banned file: auto__mail.python.bat in mail from you
+Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+This is a multi-part message in MIME format...
+
+------------=_1101526904-1956-5
+Content-Type: text/plain; charset="utf-8"
+Content-Disposition: inline
+Content-Transfer-Encoding: 7bit
+
+BANNED FILENAME ALERT
+
+Your message to: xxxxxxx@dot.ca.gov, xxxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxx@dot.ca.gov, xxxxxx@dot.ca.gov, xxxxxxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxxxx@dot.ca.gov, xxxxxxx@dot.ca.gov, xxxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxx@dot.ca.gov, xxx@dot.ca.gov, xxxxxxx@dot.ca.gov, xxxxxxx@dot.ca.gov, xxxxxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxxx@dot.ca.gov, xxx@dot.ca.gov, xxxxxxxx@dot.ca.gov, xxxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxxx@dot.ca.gov, xxxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxxxxxxxx@dot.ca.gov, xxxxxxx@dot.ca.gov, xxxxxxxxxxxxxxx@dot.ca.gov, xxxxxxxxxxxxx@dot.ca.gov, xxxx@dot.ca.gov, xxxxxxxx@dot.ca.gov, xxxxxxxxxx@dot.ca.gov, xxxxxxxxxxxxxxxxxx@dot.ca.gov
+was blocked by our Spam Firewall. The email you sent with the following subject has NOT BEEN DELIVERED:
+
+Subject: Delivery_failure_notice
+
+An attachment in that mail was of a file type that the Spam Firewall is set to block.
+
+
+
+------------=_1101526904-1956-5
+Content-Type: message/delivery-status
+Content-Disposition: inline
+Content-Transfer-Encoding: 7bit
+Content-Description: Delivery error report
+
+Reporting-MTA: dns; sacspam01.dot.ca.gov
+Received-From-MTA: smtp; sacspam01.dot.ca.gov ([127.0.0.1])
+Arrival-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+Final-Recipient: rfc822; xxxxxxx@dot.ca.gov
+Action: failed
+Status: 5.7.1
+Diagnostic-Code: smtp; 550 5.7.1 Message content rejected, id=01956-02-2 - BANNED: auto__mail.python.bat
+Last-Attempt-Date: Fri, 26 Nov 2004 19:41:44 -0800 (PST)
+
+------------=_1101526904-1956-5
+Content-Type: text/rfc822-headers
+Content-Disposition: inline
+Content-Transfer-Encoding: 7bit
+Content-Description: Undelivered-message headers
+
+Received: from kgsav.org (ppp-70-242-162-63.dsl.spfdmo.swbell.net [70.242.162.63])
+ by sacspam01.dot.ca.gov (Spam Firewall) with SMTP
+ id A232AD03DE3A; Fri, 26 Nov 2004 19:41:35 -0800 (PST)
+From: webmaster@python.org
+To: xxxxx@dot.ca.gov
+Date: Sat, 27 Nov 2004 03:35:30 UTC
+Subject: Delivery_failure_notice
+Importance: Normal
+X-Priority: 3 (Normal)
+X-MSMail-Priority: Normal
+Message-ID: <edab.7804f5cb8070@python.org>
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary="====67bd2b7a5.f99f7"
+Content-Transfer-Encoding: 7bit
+
+------------=_1101526904-1956-5--
+
diff --git a/Lib/test/test_email/data/msg_44.txt b/Lib/test/test_email/data/msg_44.txt
new file mode 100644
index 0000000..15a2252
--- /dev/null
+++ b/Lib/test/test_email/data/msg_44.txt
@@ -0,0 +1,33 @@
+Return-Path: <barry@python.org>
+Delivered-To: barry@python.org
+Received: by mail.python.org (Postfix, from userid 889)
+ id C2BF0D37C6; Tue, 11 Sep 2001 00:05:05 -0400 (EDT)
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary="h90VIIIKmx"
+Content-Transfer-Encoding: 7bit
+Message-ID: <15261.36209.358846.118674@anthem.python.org>
+From: barry@python.org (Barry A. Warsaw)
+To: barry@python.org
+Subject: a simple multipart
+Date: Tue, 11 Sep 2001 00:05:05 -0400
+X-Mailer: VM 6.95 under 21.4 (patch 4) "Artificial Intelligence" XEmacs Lucid
+X-Attribution: BAW
+X-Oblique-Strategy: Make a door into a window
+
+
+--h90VIIIKmx
+Content-Type: text/plain; name="msg.txt"
+Content-Transfer-Encoding: 7bit
+
+a simple kind of mirror
+to reflect upon our own
+
+--h90VIIIKmx
+Content-Type: text/plain; name="msg.txt"
+Content-Transfer-Encoding: 7bit
+
+a simple kind of mirror
+to reflect upon our own
+
+--h90VIIIKmx--
+
diff --git a/Lib/test/test_email/data/msg_45.txt b/Lib/test/test_email/data/msg_45.txt
new file mode 100644
index 0000000..58fde95
--- /dev/null
+++ b/Lib/test/test_email/data/msg_45.txt
@@ -0,0 +1,33 @@
+From: <foo@bar.baz>
+To: <baz@bar.foo>
+Subject: test
+X-Long-Line: Some really long line contains a lot of text and thus has to be rewrapped because it is some
+ really long
+ line
+MIME-Version: 1.0
+Content-Type: multipart/signed; boundary="borderline";
+ protocol="application/pgp-signature"; micalg=pgp-sha1
+
+This is an OpenPGP/MIME signed message (RFC 2440 and 3156)
+--borderline
+Content-Type: text/plain
+X-Long-Line: Another really long line contains a lot of text and thus has to be rewrapped because it is another
+ really long
+ line
+
+This is the signed contents.
+
+--borderline
+Content-Type: application/pgp-signature; name="signature.asc"
+Content-Description: OpenPGP digital signature
+Content-Disposition: attachment; filename="signature.asc"
+
+-----BEGIN PGP SIGNATURE-----
+Version: GnuPG v2.0.6 (GNU/Linux)
+
+iD8DBQFG03voRhp6o4m9dFsRApSZAKCCAN3IkJlVRg6NvAiMHlvvIuMGPQCeLZtj
+FGwfnRHFBFO/S4/DKysm0lI=
+=t7+s
+-----END PGP SIGNATURE-----
+
+--borderline--
diff --git a/Lib/test/test_email/data/msg_46.txt b/Lib/test/test_email/data/msg_46.txt
new file mode 100644
index 0000000..1e22c4f
--- /dev/null
+++ b/Lib/test/test_email/data/msg_46.txt
@@ -0,0 +1,23 @@
+Return-Path: <sender@example.net>
+Delivery-Date: Mon, 08 Feb 2010 14:05:16 +0100
+Received: from example.org (example.org [64.5.53.58])
+ by example.net (node=mxbap2) with ESMTP (Nemesis)
+ id UNIQUE for someone@example.com; Mon, 08 Feb 2010 14:05:16 +0100
+Date: Mon, 01 Feb 2010 12:21:16 +0100
+From: "Sender" <sender@example.net>
+To: <someone@example.com>
+Subject: GroupwiseForwardingTest
+Mime-Version: 1.0
+Content-Type: message/rfc822
+
+Return-path: <sender@example.net>
+Message-ID: <4B66B890.4070408@teconcept.de>
+Date: Mon, 01 Feb 2010 12:18:40 +0100
+From: "Dr. Sender" <sender@example.net>
+MIME-Version: 1.0
+To: "Recipient" <recipient@example.com>
+Subject: GroupwiseForwardingTest
+Content-Type: text/plain; charset=ISO-8859-15
+Content-Transfer-Encoding: 7bit
+
+Testing email forwarding with Groupwise 1.2.2010
diff --git a/Lib/test/test_email/test__encoded_words.py b/Lib/test/test_email/test__encoded_words.py
new file mode 100644
index 0000000..14395fe
--- /dev/null
+++ b/Lib/test/test_email/test__encoded_words.py
@@ -0,0 +1,187 @@
+import unittest
+from email import _encoded_words as _ew
+from email import errors
+from test.test_email import TestEmailBase
+
+
+class TestDecodeQ(TestEmailBase):
+
+ def _test(self, source, ex_result, ex_defects=[]):
+ result, defects = _ew.decode_q(source)
+ self.assertEqual(result, ex_result)
+ self.assertDefectsEqual(defects, ex_defects)
+
+ def test_no_encoded(self):
+ self._test(b'foobar', b'foobar')
+
+ def test_spaces(self):
+ self._test(b'foo=20bar=20', b'foo bar ')
+ self._test(b'foo_bar_', b'foo bar ')
+
+ def test_run_of_encoded(self):
+ self._test(b'foo=20=20=21=2Cbar', b'foo !,bar')
+
+
+class TestDecodeB(TestEmailBase):
+
+ def _test(self, source, ex_result, ex_defects=[]):
+ result, defects = _ew.decode_b(source)
+ self.assertEqual(result, ex_result)
+ self.assertDefectsEqual(defects, ex_defects)
+
+ def test_simple(self):
+ self._test(b'Zm9v', b'foo')
+
+ def test_missing_padding(self):
+ self._test(b'dmk', b'vi', [errors.InvalidBase64PaddingDefect])
+
+ def test_invalid_character(self):
+ self._test(b'dm\x01k===', b'vi', [errors.InvalidBase64CharactersDefect])
+
+ def test_invalid_character_and_bad_padding(self):
+ self._test(b'dm\x01k', b'vi', [errors.InvalidBase64CharactersDefect,
+ errors.InvalidBase64PaddingDefect])
+
+
+class TestDecode(TestEmailBase):
+
+ def test_wrong_format_input_raises(self):
+ with self.assertRaises(ValueError):
+ _ew.decode('=?badone?=')
+ with self.assertRaises(ValueError):
+ _ew.decode('=?')
+ with self.assertRaises(ValueError):
+ _ew.decode('')
+
+ def _test(self, source, result, charset='us-ascii', lang='', defects=[]):
+ res, char, l, d = _ew.decode(source)
+ self.assertEqual(res, result)
+ self.assertEqual(char, charset)
+ self.assertEqual(l, lang)
+ self.assertDefectsEqual(d, defects)
+
+ def test_simple_q(self):
+ self._test('=?us-ascii?q?foo?=', 'foo')
+
+ def test_simple_b(self):
+ self._test('=?us-ascii?b?dmk=?=', 'vi')
+
+ def test_q_case_ignored(self):
+ self._test('=?us-ascii?Q?foo?=', 'foo')
+
+ def test_b_case_ignored(self):
+ self._test('=?us-ascii?B?dmk=?=', 'vi')
+
+ def test_non_trivial_q(self):
+ self._test('=?latin-1?q?=20F=fcr=20Elise=20?=', ' Für Elise ', 'latin-1')
+
+ def test_q_escpaed_bytes_preserved(self):
+ self._test(b'=?us-ascii?q?=20\xACfoo?='.decode('us-ascii',
+ 'surrogateescape'),
+ ' \uDCACfoo',
+ defects = [errors.UndecodableBytesDefect])
+
+ def test_b_undecodable_bytes_ignored_with_defect(self):
+ self._test(b'=?us-ascii?b?dm\xACk?='.decode('us-ascii',
+ 'surrogateescape'),
+ 'vi',
+ defects = [
+ errors.InvalidBase64CharactersDefect,
+ errors.InvalidBase64PaddingDefect])
+
+ def test_b_invalid_bytes_ignored_with_defect(self):
+ self._test('=?us-ascii?b?dm\x01k===?=',
+ 'vi',
+ defects = [errors.InvalidBase64CharactersDefect])
+
+ def test_b_invalid_bytes_incorrect_padding(self):
+ self._test('=?us-ascii?b?dm\x01k?=',
+ 'vi',
+ defects = [
+ errors.InvalidBase64CharactersDefect,
+ errors.InvalidBase64PaddingDefect])
+
+ def test_b_padding_defect(self):
+ self._test('=?us-ascii?b?dmk?=',
+ 'vi',
+ defects = [errors.InvalidBase64PaddingDefect])
+
+ def test_nonnull_lang(self):
+ self._test('=?us-ascii*jive?q?test?=', 'test', lang='jive')
+
+ def test_unknown_8bit_charset(self):
+ self._test('=?unknown-8bit?q?foo=ACbar?=',
+ b'foo\xacbar'.decode('ascii', 'surrogateescape'),
+ charset = 'unknown-8bit',
+ defects = [])
+
+ def test_unknown_charset(self):
+ self._test('=?foobar?q?foo=ACbar?=',
+ b'foo\xacbar'.decode('ascii', 'surrogateescape'),
+ charset = 'foobar',
+ # XXX Should this be a new Defect instead?
+ defects = [errors.CharsetError])
+
+
+class TestEncodeQ(TestEmailBase):
+
+ def _test(self, src, expected):
+ self.assertEqual(_ew.encode_q(src), expected)
+
+ def test_all_safe(self):
+ self._test(b'foobar', 'foobar')
+
+ def test_spaces(self):
+ self._test(b'foo bar ', 'foo_bar_')
+
+ def test_run_of_encodables(self):
+ self._test(b'foo ,,bar', 'foo__=2C=2Cbar')
+
+
+class TestEncodeB(TestEmailBase):
+
+ def test_simple(self):
+ self.assertEqual(_ew.encode_b(b'foo'), 'Zm9v')
+
+ def test_padding(self):
+ self.assertEqual(_ew.encode_b(b'vi'), 'dmk=')
+
+
+class TestEncode(TestEmailBase):
+
+ def test_q(self):
+ self.assertEqual(_ew.encode('foo', 'utf-8', 'q'), '=?utf-8?q?foo?=')
+
+ def test_b(self):
+ self.assertEqual(_ew.encode('foo', 'utf-8', 'b'), '=?utf-8?b?Zm9v?=')
+
+ def test_auto_q(self):
+ self.assertEqual(_ew.encode('foo', 'utf-8'), '=?utf-8?q?foo?=')
+
+ def test_auto_q_if_short_mostly_safe(self):
+ self.assertEqual(_ew.encode('vi.', 'utf-8'), '=?utf-8?q?vi=2E?=')
+
+ def test_auto_b_if_enough_unsafe(self):
+ self.assertEqual(_ew.encode('.....', 'utf-8'), '=?utf-8?b?Li4uLi4=?=')
+
+ def test_auto_b_if_long_unsafe(self):
+ self.assertEqual(_ew.encode('vi.vi.vi.vi.vi.', 'utf-8'),
+ '=?utf-8?b?dmkudmkudmkudmkudmku?=')
+
+ def test_auto_q_if_long_mostly_safe(self):
+ self.assertEqual(_ew.encode('vi vi vi.vi ', 'utf-8'),
+ '=?utf-8?q?vi_vi_vi=2Evi_?=')
+
+ def test_utf8_default(self):
+ self.assertEqual(_ew.encode('foo'), '=?utf-8?q?foo?=')
+
+ def test_lang(self):
+ self.assertEqual(_ew.encode('foo', lang='jive'), '=?utf-8*jive?q?foo?=')
+
+ def test_unknown_8bit(self):
+ self.assertEqual(_ew.encode('foo\uDCACbar', charset='unknown-8bit'),
+ '=?unknown-8bit?q?foo=ACbar?=')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test__header_value_parser.py b/Lib/test/test_email/test__header_value_parser.py
new file mode 100644
index 0000000..6101e19
--- /dev/null
+++ b/Lib/test/test_email/test__header_value_parser.py
@@ -0,0 +1,2552 @@
+import string
+import unittest
+from email import _header_value_parser as parser
+from email import errors
+from email import policy
+from test.test_email import TestEmailBase, parameterize
+
+class TestTokens(TestEmailBase):
+
+ # EWWhiteSpaceTerminal
+
+ def test_EWWhiteSpaceTerminal(self):
+ x = parser.EWWhiteSpaceTerminal(' \t', 'fws')
+ self.assertEqual(x, ' \t')
+ self.assertEqual(str(x), '')
+ self.assertEqual(x.value, '')
+ self.assertEqual(x.encoded, ' \t')
+
+ # UnstructuredTokenList
+
+ def test_undecodable_bytes_error_preserved(self):
+ badstr = b"le pouf c\xaflebre".decode('ascii', 'surrogateescape')
+ unst = parser.get_unstructured(badstr)
+ self.assertDefectsEqual(unst.all_defects, [errors.UndecodableBytesDefect])
+ parts = list(unst.parts)
+ self.assertDefectsEqual(parts[0].all_defects, [])
+ self.assertDefectsEqual(parts[1].all_defects, [])
+ self.assertDefectsEqual(parts[2].all_defects, [errors.UndecodableBytesDefect])
+
+
+class TestParserMixin:
+
+ def _assert_results(self, tl, rest, string, value, defects, remainder,
+ comments=None):
+ self.assertEqual(str(tl), string)
+ self.assertEqual(tl.value, value)
+ self.assertDefectsEqual(tl.all_defects, defects)
+ self.assertEqual(rest, remainder)
+ if comments is not None:
+ self.assertEqual(tl.comments, comments)
+
+ def _test_get_x(self, method, source, string, value, defects,
+ remainder, comments=None):
+ tl, rest = method(source)
+ self._assert_results(tl, rest, string, value, defects, remainder,
+ comments=None)
+ return tl
+
+ def _test_parse_x(self, method, input, string, value, defects,
+ comments=None):
+ tl = method(input)
+ self._assert_results(tl, '', string, value, defects, '', comments)
+ return tl
+
+
+class TestParser(TestParserMixin, TestEmailBase):
+
+ # _wsp_splitter
+
+ rfc_printable_ascii = bytes(range(33, 127)).decode('ascii')
+ rfc_atext_chars = (string.ascii_letters + string.digits +
+ "!#$%&\'*+-/=?^_`{}|~")
+ rfc_dtext_chars = rfc_printable_ascii.translate(str.maketrans('','',r'\[]'))
+
+ def test__wsp_splitter_one_word(self):
+ self.assertEqual(parser._wsp_splitter('foo', 1), ['foo'])
+
+ def test__wsp_splitter_two_words(self):
+ self.assertEqual(parser._wsp_splitter('foo def', 1),
+ ['foo', ' ', 'def'])
+
+ def test__wsp_splitter_ws_runs(self):
+ self.assertEqual(parser._wsp_splitter('foo \t def jik', 1),
+ ['foo', ' \t ', 'def jik'])
+
+
+ # get_fws
+
+ def test_get_fws_only(self):
+ fws = self._test_get_x(parser.get_fws, ' \t ', ' \t ', ' ', [], '')
+ self.assertEqual(fws.token_type, 'fws')
+
+ def test_get_fws_space(self):
+ self._test_get_x(parser.get_fws, ' foo', ' ', ' ', [], 'foo')
+
+ def test_get_fws_ws_run(self):
+ self._test_get_x(parser.get_fws, ' \t foo ', ' \t ', ' ', [], 'foo ')
+
+ # get_encoded_word
+
+ def test_get_encoded_word_missing_start_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_encoded_word('abc')
+
+ def test_get_encoded_word_missing_end_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_encoded_word('=?abc')
+
+ def test_get_encoded_word_missing_middle_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_encoded_word('=?abc?=')
+
+ def test_get_encoded_word_valid_ew(self):
+ self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?this_is_a_test?= bird',
+ 'this is a test',
+ 'this is a test',
+ [],
+ ' bird')
+
+ def test_get_encoded_word_internal_spaces(self):
+ self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?this is a test?= bird',
+ 'this is a test',
+ 'this is a test',
+ [errors.InvalidHeaderDefect],
+ ' bird')
+
+ def test_get_encoded_word_gets_first(self):
+ self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?first?= =?utf-8?q?second?=',
+ 'first',
+ 'first',
+ [],
+ ' =?utf-8?q?second?=')
+
+ def test_get_encoded_word_gets_first_even_if_no_space(self):
+ self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?first?==?utf-8?q?second?=',
+ 'first',
+ 'first',
+ [],
+ '=?utf-8?q?second?=')
+
+ def test_get_encoded_word_sets_extra_attributes(self):
+ ew = self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii*jive?q?first_second?=',
+ 'first second',
+ 'first second',
+ [],
+ '')
+ self.assertEqual(ew.encoded, '=?us-ascii*jive?q?first_second?=')
+ self.assertEqual(ew.charset, 'us-ascii')
+ self.assertEqual(ew.lang, 'jive')
+
+ def test_get_encoded_word_lang_default_is_blank(self):
+ ew = self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?first_second?=',
+ 'first second',
+ 'first second',
+ [],
+ '')
+ self.assertEqual(ew.encoded, '=?us-ascii?q?first_second?=')
+ self.assertEqual(ew.charset, 'us-ascii')
+ self.assertEqual(ew.lang, '')
+
+ def test_get_encoded_word_non_printable_defect(self):
+ self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?first\x02second?=',
+ 'first\x02second',
+ 'first\x02second',
+ [errors.NonPrintableDefect],
+ '')
+
+ def test_get_encoded_word_leading_internal_space(self):
+ self._test_get_x(parser.get_encoded_word,
+ '=?us-ascii?q?=20foo?=',
+ ' foo',
+ ' foo',
+ [],
+ '')
+
+ # get_unstructured
+
+ def _get_unst(self, value):
+ token = parser.get_unstructured(value)
+ return token, ''
+
+ def test_get_unstructured_null(self):
+ self._test_get_x(self._get_unst, '', '', '', [], '')
+
+ def test_get_unstructured_one_word(self):
+ self._test_get_x(self._get_unst, 'foo', 'foo', 'foo', [], '')
+
+ def test_get_unstructured_normal_phrase(self):
+ self._test_get_x(self._get_unst, 'foo bar bird',
+ 'foo bar bird',
+ 'foo bar bird',
+ [],
+ '')
+
+ def test_get_unstructured_normal_phrase_with_whitespace(self):
+ self._test_get_x(self._get_unst, 'foo \t bar bird',
+ 'foo \t bar bird',
+ 'foo bar bird',
+ [],
+ '')
+
+ def test_get_unstructured_leading_whitespace(self):
+ self._test_get_x(self._get_unst, ' foo bar',
+ ' foo bar',
+ ' foo bar',
+ [],
+ '')
+
+ def test_get_unstructured_trailing_whitespace(self):
+ self._test_get_x(self._get_unst, 'foo bar ',
+ 'foo bar ',
+ 'foo bar ',
+ [],
+ '')
+
+ def test_get_unstructured_leading_and_trailing_whitespace(self):
+ self._test_get_x(self._get_unst, ' foo bar ',
+ ' foo bar ',
+ ' foo bar ',
+ [],
+ '')
+
+ def test_get_unstructured_one_valid_ew_no_ws(self):
+ self._test_get_x(self._get_unst, '=?us-ascii?q?bar?=',
+ 'bar',
+ 'bar',
+ [],
+ '')
+
+ def test_get_unstructured_one_ew_trailing_ws(self):
+ self._test_get_x(self._get_unst, '=?us-ascii?q?bar?= ',
+ 'bar ',
+ 'bar ',
+ [],
+ '')
+
+ def test_get_unstructured_one_valid_ew_trailing_text(self):
+ self._test_get_x(self._get_unst, '=?us-ascii?q?bar?= bird',
+ 'bar bird',
+ 'bar bird',
+ [],
+ '')
+
+ def test_get_unstructured_phrase_with_ew_in_middle_of_text(self):
+ self._test_get_x(self._get_unst, 'foo =?us-ascii?q?bar?= bird',
+ 'foo bar bird',
+ 'foo bar bird',
+ [],
+ '')
+
+ def test_get_unstructured_phrase_with_two_ew(self):
+ self._test_get_x(self._get_unst,
+ 'foo =?us-ascii?q?bar?= =?us-ascii?q?bird?=',
+ 'foo barbird',
+ 'foo barbird',
+ [],
+ '')
+
+ def test_get_unstructured_phrase_with_two_ew_trailing_ws(self):
+ self._test_get_x(self._get_unst,
+ 'foo =?us-ascii?q?bar?= =?us-ascii?q?bird?= ',
+ 'foo barbird ',
+ 'foo barbird ',
+ [],
+ '')
+
+ def test_get_unstructured_phrase_with_ew_with_leading_ws(self):
+ self._test_get_x(self._get_unst,
+ ' =?us-ascii?q?bar?=',
+ ' bar',
+ ' bar',
+ [],
+ '')
+
+ def test_get_unstructured_phrase_with_two_ew_extra_ws(self):
+ self._test_get_x(self._get_unst,
+ 'foo =?us-ascii?q?bar?= \t =?us-ascii?q?bird?=',
+ 'foo barbird',
+ 'foo barbird',
+ [],
+ '')
+
+ def test_get_unstructured_two_ew_extra_ws_trailing_text(self):
+ self._test_get_x(self._get_unst,
+ '=?us-ascii?q?test?= =?us-ascii?q?foo?= val',
+ 'testfoo val',
+ 'testfoo val',
+ [],
+ '')
+
+ def test_get_unstructured_ew_with_internal_ws(self):
+ self._test_get_x(self._get_unst,
+ '=?iso-8859-1?q?hello=20world?=',
+ 'hello world',
+ 'hello world',
+ [],
+ '')
+
+ def test_get_unstructured_ew_with_internal_leading_ws(self):
+ self._test_get_x(self._get_unst,
+ ' =?us-ascii?q?=20test?= =?us-ascii?q?=20foo?= val',
+ ' test foo val',
+ ' test foo val',
+ [],
+ '')
+
+ def test_get_unstructured_invaild_ew(self):
+ self._test_get_x(self._get_unst,
+ '=?test val',
+ '=?test val',
+ '=?test val',
+ [],
+ '')
+
+ def test_get_unstructured_undecodable_bytes(self):
+ self._test_get_x(self._get_unst,
+ b'test \xACfoo val'.decode('ascii', 'surrogateescape'),
+ 'test \uDCACfoo val',
+ 'test \uDCACfoo val',
+ [errors.UndecodableBytesDefect],
+ '')
+
+ def test_get_unstructured_undecodable_bytes_in_EW(self):
+ self._test_get_x(self._get_unst,
+ (b'=?us-ascii?q?=20test?= =?us-ascii?q?=20\xACfoo?='
+ b' val').decode('ascii', 'surrogateescape'),
+ ' test \uDCACfoo val',
+ ' test \uDCACfoo val',
+ [errors.UndecodableBytesDefect]*2,
+ '')
+
+ def test_get_unstructured_missing_base64_padding(self):
+ self._test_get_x(self._get_unst,
+ '=?utf-8?b?dmk?=',
+ 'vi',
+ 'vi',
+ [errors.InvalidBase64PaddingDefect],
+ '')
+
+ def test_get_unstructured_invalid_base64_character(self):
+ self._test_get_x(self._get_unst,
+ '=?utf-8?b?dm\x01k===?=',
+ 'vi',
+ 'vi',
+ [errors.InvalidBase64CharactersDefect],
+ '')
+
+ def test_get_unstructured_invalid_base64_character_and_bad_padding(self):
+ self._test_get_x(self._get_unst,
+ '=?utf-8?b?dm\x01k?=',
+ 'vi',
+ 'vi',
+ [errors.InvalidBase64CharactersDefect,
+ errors.InvalidBase64PaddingDefect],
+ '')
+
+ def test_get_unstructured_no_whitespace_between_ews(self):
+ self._test_get_x(self._get_unst,
+ '=?utf-8?q?foo?==?utf-8?q?bar?=',
+ 'foobar',
+ 'foobar',
+ [errors.InvalidHeaderDefect],
+ '')
+
+ # get_qp_ctext
+
+ def test_get_qp_ctext_only(self):
+ ptext = self._test_get_x(parser.get_qp_ctext,
+ 'foobar', 'foobar', ' ', [], '')
+ self.assertEqual(ptext.token_type, 'ptext')
+
+ def test_get_qp_ctext_all_printables(self):
+ with_qp = self.rfc_printable_ascii.replace('\\', '\\\\')
+ with_qp = with_qp. replace('(', r'\(')
+ with_qp = with_qp.replace(')', r'\)')
+ ptext = self._test_get_x(parser.get_qp_ctext,
+ with_qp, self.rfc_printable_ascii, ' ', [], '')
+
+ def test_get_qp_ctext_two_words_gets_first(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo de', 'foo', ' ', [], ' de')
+
+ def test_get_qp_ctext_following_wsp_preserved(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo \t\tde', 'foo', ' ', [], ' \t\tde')
+
+ def test_get_qp_ctext_up_to_close_paren_only(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo)', 'foo', ' ', [], ')')
+
+ def test_get_qp_ctext_wsp_before_close_paren_preserved(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo )', 'foo', ' ', [], ' )')
+
+ def test_get_qp_ctext_close_paren_mid_word(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo)bar', 'foo', ' ', [], ')bar')
+
+ def test_get_qp_ctext_up_to_open_paren_only(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo(', 'foo', ' ', [], '(')
+
+ def test_get_qp_ctext_wsp_before_open_paren_preserved(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo (', 'foo', ' ', [], ' (')
+
+ def test_get_qp_ctext_open_paren_mid_word(self):
+ self._test_get_x(parser.get_qp_ctext,
+ 'foo(bar', 'foo', ' ', [], '(bar')
+
+ def test_get_qp_ctext_non_printables(self):
+ ptext = self._test_get_x(parser.get_qp_ctext,
+ 'foo\x00bar)', 'foo\x00bar', ' ',
+ [errors.NonPrintableDefect], ')')
+ self.assertEqual(ptext.defects[0].non_printables[0], '\x00')
+
+ # get_qcontent
+
+ def test_get_qcontent_only(self):
+ ptext = self._test_get_x(parser.get_qcontent,
+ 'foobar', 'foobar', 'foobar', [], '')
+ self.assertEqual(ptext.token_type, 'ptext')
+
+ def test_get_qcontent_all_printables(self):
+ with_qp = self.rfc_printable_ascii.replace('\\', '\\\\')
+ with_qp = with_qp. replace('"', r'\"')
+ ptext = self._test_get_x(parser.get_qcontent, with_qp,
+ self.rfc_printable_ascii,
+ self.rfc_printable_ascii, [], '')
+
+ def test_get_qcontent_two_words_gets_first(self):
+ self._test_get_x(parser.get_qcontent,
+ 'foo de', 'foo', 'foo', [], ' de')
+
+ def test_get_qcontent_following_wsp_preserved(self):
+ self._test_get_x(parser.get_qcontent,
+ 'foo \t\tde', 'foo', 'foo', [], ' \t\tde')
+
+ def test_get_qcontent_up_to_dquote_only(self):
+ self._test_get_x(parser.get_qcontent,
+ 'foo"', 'foo', 'foo', [], '"')
+
+ def test_get_qcontent_wsp_before_close_paren_preserved(self):
+ self._test_get_x(parser.get_qcontent,
+ 'foo "', 'foo', 'foo', [], ' "')
+
+ def test_get_qcontent_close_paren_mid_word(self):
+ self._test_get_x(parser.get_qcontent,
+ 'foo"bar', 'foo', 'foo', [], '"bar')
+
+ def test_get_qcontent_non_printables(self):
+ ptext = self._test_get_x(parser.get_qcontent,
+ 'foo\x00fg"', 'foo\x00fg', 'foo\x00fg',
+ [errors.NonPrintableDefect], '"')
+ self.assertEqual(ptext.defects[0].non_printables[0], '\x00')
+
+ # get_atext
+
+ def test_get_atext_only(self):
+ atext = self._test_get_x(parser.get_atext,
+ 'foobar', 'foobar', 'foobar', [], '')
+ self.assertEqual(atext.token_type, 'atext')
+
+ def test_get_atext_all_atext(self):
+ atext = self._test_get_x(parser.get_atext, self.rfc_atext_chars,
+ self.rfc_atext_chars,
+ self.rfc_atext_chars, [], '')
+
+ def test_get_atext_two_words_gets_first(self):
+ self._test_get_x(parser.get_atext,
+ 'foo bar', 'foo', 'foo', [], ' bar')
+
+ def test_get_atext_following_wsp_preserved(self):
+ self._test_get_x(parser.get_atext,
+ 'foo \t\tbar', 'foo', 'foo', [], ' \t\tbar')
+
+ def test_get_atext_up_to_special(self):
+ self._test_get_x(parser.get_atext,
+ 'foo@bar', 'foo', 'foo', [], '@bar')
+
+ def test_get_atext_non_printables(self):
+ atext = self._test_get_x(parser.get_atext,
+ 'foo\x00bar(', 'foo\x00bar', 'foo\x00bar',
+ [errors.NonPrintableDefect], '(')
+ self.assertEqual(atext.defects[0].non_printables[0], '\x00')
+
+ # get_bare_quoted_string
+
+ def test_get_bare_quoted_string_only(self):
+ bqs = self._test_get_x(parser.get_bare_quoted_string,
+ '"foo"', '"foo"', 'foo', [], '')
+ self.assertEqual(bqs.token_type, 'bare-quoted-string')
+
+ def test_get_bare_quoted_string_must_start_with_dquote(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_bare_quoted_string('foo"')
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_bare_quoted_string(' "foo"')
+
+ def test_get_bare_quoted_string_following_wsp_preserved(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '"foo"\t bar', '"foo"', 'foo', [], '\t bar')
+
+ def test_get_bare_quoted_string_multiple_words(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '"foo bar moo"', '"foo bar moo"', 'foo bar moo', [], '')
+
+ def test_get_bare_quoted_string_multiple_words_wsp_preserved(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '" foo moo\t"', '" foo moo\t"', ' foo moo\t', [], '')
+
+ def test_get_bare_quoted_string_end_dquote_mid_word(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '"foo"bar', '"foo"', 'foo', [], 'bar')
+
+ def test_get_bare_quoted_string_quoted_dquote(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ r'"foo\"in"a', r'"foo\"in"', 'foo"in', [], 'a')
+
+ def test_get_bare_quoted_string_non_printables(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '"a\x01a"', '"a\x01a"', 'a\x01a',
+ [errors.NonPrintableDefect], '')
+
+ def test_get_bare_quoted_string_no_end_dquote(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '"foo', '"foo"', 'foo',
+ [errors.InvalidHeaderDefect], '')
+ self._test_get_x(parser.get_bare_quoted_string,
+ '"foo ', '"foo "', 'foo ',
+ [errors.InvalidHeaderDefect], '')
+
+ def test_get_bare_quoted_string_empty_quotes(self):
+ self._test_get_x(parser.get_bare_quoted_string,
+ '""', '""', '', [], '')
+
+ # get_comment
+
+ def test_get_comment_only(self):
+ comment = self._test_get_x(parser.get_comment,
+ '(comment)', '(comment)', ' ', [], '', ['comment'])
+ self.assertEqual(comment.token_type, 'comment')
+
+ def test_get_comment_must_start_with_paren(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_comment('foo"')
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_comment(' (foo"')
+
+ def test_get_comment_following_wsp_preserved(self):
+ self._test_get_x(parser.get_comment,
+ '(comment) \t', '(comment)', ' ', [], ' \t', ['comment'])
+
+ def test_get_comment_multiple_words(self):
+ self._test_get_x(parser.get_comment,
+ '(foo bar) \t', '(foo bar)', ' ', [], ' \t', ['foo bar'])
+
+ def test_get_comment_multiple_words_wsp_preserved(self):
+ self._test_get_x(parser.get_comment,
+ '( foo bar\t ) \t', '( foo bar\t )', ' ', [], ' \t',
+ [' foo bar\t '])
+
+ def test_get_comment_end_paren_mid_word(self):
+ self._test_get_x(parser.get_comment,
+ '(foo)bar', '(foo)', ' ', [], 'bar', ['foo'])
+
+ def test_get_comment_quoted_parens(self):
+ self._test_get_x(parser.get_comment,
+ '(foo\) \(\)bar)', '(foo\) \(\)bar)', ' ', [], '', ['foo) ()bar'])
+
+ def test_get_comment_non_printable(self):
+ self._test_get_x(parser.get_comment,
+ '(foo\x7Fbar)', '(foo\x7Fbar)', ' ',
+ [errors.NonPrintableDefect], '', ['foo\x7Fbar'])
+
+ def test_get_comment_no_end_paren(self):
+ self._test_get_x(parser.get_comment,
+ '(foo bar', '(foo bar)', ' ',
+ [errors.InvalidHeaderDefect], '', ['foo bar'])
+ self._test_get_x(parser.get_comment,
+ '(foo bar ', '(foo bar )', ' ',
+ [errors.InvalidHeaderDefect], '', ['foo bar '])
+
+ def test_get_comment_nested_comment(self):
+ comment = self._test_get_x(parser.get_comment,
+ '(foo(bar))', '(foo(bar))', ' ', [], '', ['foo(bar)'])
+ self.assertEqual(comment[1].content, 'bar')
+
+ def test_get_comment_nested_comment_wsp(self):
+ comment = self._test_get_x(parser.get_comment,
+ '(foo ( bar ) )', '(foo ( bar ) )', ' ', [], '', ['foo ( bar ) '])
+ self.assertEqual(comment[2].content, ' bar ')
+
+ def test_get_comment_empty_comment(self):
+ self._test_get_x(parser.get_comment,
+ '()', '()', ' ', [], '', [''])
+
+ def test_get_comment_multiple_nesting(self):
+ comment = self._test_get_x(parser.get_comment,
+ '(((((foo)))))', '(((((foo)))))', ' ', [], '', ['((((foo))))'])
+ for i in range(4, 0, -1):
+ self.assertEqual(comment[0].content, '('*(i-1)+'foo'+')'*(i-1))
+ comment = comment[0]
+ self.assertEqual(comment.content, 'foo')
+
+ def test_get_comment_missing_end_of_nesting(self):
+ self._test_get_x(parser.get_comment,
+ '(((((foo)))', '(((((foo)))))', ' ',
+ [errors.InvalidHeaderDefect]*2, '', ['((((foo))))'])
+
+ def test_get_comment_qs_in_nested_comment(self):
+ comment = self._test_get_x(parser.get_comment,
+ '(foo (b\)))', '(foo (b\)))', ' ', [], '', ['foo (b\))'])
+ self.assertEqual(comment[2].content, 'b)')
+
+ # get_cfws
+
+ def test_get_cfws_only_ws(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ ' \t \t', ' \t \t', ' ', [], '', [])
+ self.assertEqual(cfws.token_type, 'cfws')
+
+ def test_get_cfws_only_comment(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ '(foo)', '(foo)', ' ', [], '', ['foo'])
+ self.assertEqual(cfws[0].content, 'foo')
+
+ def test_get_cfws_only_mixed(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ ' (foo ) ( bar) ', ' (foo ) ( bar) ', ' ', [], '',
+ ['foo ', ' bar'])
+ self.assertEqual(cfws[1].content, 'foo ')
+ self.assertEqual(cfws[3].content, ' bar')
+
+ def test_get_cfws_ends_at_non_leader(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ '(foo) bar', '(foo) ', ' ', [], 'bar', ['foo'])
+ self.assertEqual(cfws[0].content, 'foo')
+
+ def test_get_cfws_ends_at_non_printable(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ '(foo) \x07', '(foo) ', ' ', [], '\x07', ['foo'])
+ self.assertEqual(cfws[0].content, 'foo')
+
+ def test_get_cfws_non_printable_in_comment(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ '(foo \x07) "test"', '(foo \x07) ', ' ',
+ [errors.NonPrintableDefect], '"test"', ['foo \x07'])
+ self.assertEqual(cfws[0].content, 'foo \x07')
+
+ def test_get_cfws_header_ends_in_comment(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ ' (foo ', ' (foo )', ' ',
+ [errors.InvalidHeaderDefect], '', ['foo '])
+ self.assertEqual(cfws[1].content, 'foo ')
+
+ def test_get_cfws_multiple_nested_comments(self):
+ cfws = self._test_get_x(parser.get_cfws,
+ '(foo (bar)) ((a)(a))', '(foo (bar)) ((a)(a))', ' ', [],
+ '', ['foo (bar)', '(a)(a)'])
+ self.assertEqual(cfws[0].comments, ['foo (bar)'])
+ self.assertEqual(cfws[2].comments, ['(a)(a)'])
+
+ # get_quoted_string
+
+ def test_get_quoted_string_only(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ '"bob"', '"bob"', 'bob', [], '')
+ self.assertEqual(qs.token_type, 'quoted-string')
+ self.assertEqual(qs.quoted_value, '"bob"')
+ self.assertEqual(qs.content, 'bob')
+
+ def test_get_quoted_string_with_wsp(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ '\t "bob" ', '\t "bob" ', ' bob ', [], '')
+ self.assertEqual(qs.quoted_value, ' "bob" ')
+ self.assertEqual(qs.content, 'bob')
+
+ def test_get_quoted_string_with_comments_and_wsp(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (foo) "bob"(bar)', ' (foo) "bob"(bar)', ' bob ', [], '')
+ self.assertEqual(qs[0][1].content, 'foo')
+ self.assertEqual(qs[2][0].content, 'bar')
+ self.assertEqual(qs.content, 'bob')
+ self.assertEqual(qs.quoted_value, ' "bob" ')
+
+ def test_get_quoted_string_with_multiple_comments(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (foo) (bar) "bob"(bird)', ' (foo) (bar) "bob"(bird)', ' bob ',
+ [], '')
+ self.assertEqual(qs[0].comments, ['foo', 'bar'])
+ self.assertEqual(qs[2].comments, ['bird'])
+ self.assertEqual(qs.content, 'bob')
+ self.assertEqual(qs.quoted_value, ' "bob" ')
+
+ def test_get_quoted_string_non_printable_in_comment(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (\x0A) "bob"', ' (\x0A) "bob"', ' bob',
+ [errors.NonPrintableDefect], '')
+ self.assertEqual(qs[0].comments, ['\x0A'])
+ self.assertEqual(qs.content, 'bob')
+ self.assertEqual(qs.quoted_value, ' "bob"')
+
+ def test_get_quoted_string_non_printable_in_qcontent(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (a) "a\x0B"', ' (a) "a\x0B"', ' a\x0B',
+ [errors.NonPrintableDefect], '')
+ self.assertEqual(qs[0].comments, ['a'])
+ self.assertEqual(qs.content, 'a\x0B')
+ self.assertEqual(qs.quoted_value, ' "a\x0B"')
+
+ def test_get_quoted_string_internal_ws(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (a) "foo bar "', ' (a) "foo bar "', ' foo bar ',
+ [], '')
+ self.assertEqual(qs[0].comments, ['a'])
+ self.assertEqual(qs.content, 'foo bar ')
+ self.assertEqual(qs.quoted_value, ' "foo bar "')
+
+ def test_get_quoted_string_header_ends_in_comment(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (a) "bob" (a', ' (a) "bob" (a)', ' bob ',
+ [errors.InvalidHeaderDefect], '')
+ self.assertEqual(qs[0].comments, ['a'])
+ self.assertEqual(qs[2].comments, ['a'])
+ self.assertEqual(qs.content, 'bob')
+ self.assertEqual(qs.quoted_value, ' "bob" ')
+
+ def test_get_quoted_string_header_ends_in_qcontent(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ ' (a) "bob', ' (a) "bob"', ' bob',
+ [errors.InvalidHeaderDefect], '')
+ self.assertEqual(qs[0].comments, ['a'])
+ self.assertEqual(qs.content, 'bob')
+ self.assertEqual(qs.quoted_value, ' "bob"')
+
+ def test_get_quoted_string_no_quoted_string(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_quoted_string(' (ab) xyz')
+
+ def test_get_quoted_string_qs_ends_at_noncfws(self):
+ qs = self._test_get_x(parser.get_quoted_string,
+ '\t "bob" fee', '\t "bob" ', ' bob ', [], 'fee')
+ self.assertEqual(qs.content, 'bob')
+ self.assertEqual(qs.quoted_value, ' "bob" ')
+
+ # get_atom
+
+ def test_get_atom_only(self):
+ atom = self._test_get_x(parser.get_atom,
+ 'bob', 'bob', 'bob', [], '')
+ self.assertEqual(atom.token_type, 'atom')
+
+ def test_get_atom_with_wsp(self):
+ self._test_get_x(parser.get_atom,
+ '\t bob ', '\t bob ', ' bob ', [], '')
+
+ def test_get_atom_with_comments_and_wsp(self):
+ atom = self._test_get_x(parser.get_atom,
+ ' (foo) bob(bar)', ' (foo) bob(bar)', ' bob ', [], '')
+ self.assertEqual(atom[0][1].content, 'foo')
+ self.assertEqual(atom[2][0].content, 'bar')
+
+ def test_get_atom_with_multiple_comments(self):
+ atom = self._test_get_x(parser.get_atom,
+ ' (foo) (bar) bob(bird)', ' (foo) (bar) bob(bird)', ' bob ',
+ [], '')
+ self.assertEqual(atom[0].comments, ['foo', 'bar'])
+ self.assertEqual(atom[2].comments, ['bird'])
+
+ def test_get_atom_non_printable_in_comment(self):
+ atom = self._test_get_x(parser.get_atom,
+ ' (\x0A) bob', ' (\x0A) bob', ' bob',
+ [errors.NonPrintableDefect], '')
+ self.assertEqual(atom[0].comments, ['\x0A'])
+
+ def test_get_atom_non_printable_in_atext(self):
+ atom = self._test_get_x(parser.get_atom,
+ ' (a) a\x0B', ' (a) a\x0B', ' a\x0B',
+ [errors.NonPrintableDefect], '')
+ self.assertEqual(atom[0].comments, ['a'])
+
+ def test_get_atom_header_ends_in_comment(self):
+ atom = self._test_get_x(parser.get_atom,
+ ' (a) bob (a', ' (a) bob (a)', ' bob ',
+ [errors.InvalidHeaderDefect], '')
+ self.assertEqual(atom[0].comments, ['a'])
+ self.assertEqual(atom[2].comments, ['a'])
+
+ def test_get_atom_no_atom(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_atom(' (ab) ')
+
+ def test_get_atom_no_atom_before_special(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_atom(' (ab) @')
+
+ def test_get_atom_atom_ends_at_special(self):
+ atom = self._test_get_x(parser.get_atom,
+ ' (foo) bob(bar) @bang', ' (foo) bob(bar) ', ' bob ', [], '@bang')
+ self.assertEqual(atom[0].comments, ['foo'])
+ self.assertEqual(atom[2].comments, ['bar'])
+
+ def test_get_atom_atom_ends_at_noncfws(self):
+ atom = self._test_get_x(parser.get_atom,
+ 'bob fred', 'bob ', 'bob ', [], 'fred')
+
+ # get_dot_atom_text
+
+ def test_get_dot_atom_text(self):
+ dot_atom_text = self._test_get_x(parser.get_dot_atom_text,
+ 'foo.bar.bang', 'foo.bar.bang', 'foo.bar.bang', [], '')
+ self.assertEqual(dot_atom_text.token_type, 'dot-atom-text')
+ self.assertEqual(len(dot_atom_text), 5)
+
+ def test_get_dot_atom_text_lone_atom_is_valid(self):
+ dot_atom_text = self._test_get_x(parser.get_dot_atom_text,
+ 'foo', 'foo', 'foo', [], '')
+
+ def test_get_dot_atom_text_raises_on_leading_dot(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom_text('.foo.bar')
+
+ def test_get_dot_atom_text_raises_on_trailing_dot(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom_text('foo.bar.')
+
+ def test_get_dot_atom_text_raises_on_leading_non_atext(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom_text(' foo.bar')
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom_text('@foo.bar')
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom_text('"foo.bar"')
+
+ def test_get_dot_atom_text_trailing_text_preserved(self):
+ dot_atom_text = self._test_get_x(parser.get_dot_atom_text,
+ 'foo@bar', 'foo', 'foo', [], '@bar')
+
+ def test_get_dot_atom_text_trailing_ws_preserved(self):
+ dot_atom_text = self._test_get_x(parser.get_dot_atom_text,
+ 'foo .bar', 'foo', 'foo', [], ' .bar')
+
+ # get_dot_atom
+
+ def test_get_dot_atom_only(self):
+ dot_atom = self._test_get_x(parser.get_dot_atom,
+ 'foo.bar.bing', 'foo.bar.bing', 'foo.bar.bing', [], '')
+ self.assertEqual(dot_atom.token_type, 'dot-atom')
+ self.assertEqual(len(dot_atom), 1)
+
+ def test_get_dot_atom_with_wsp(self):
+ self._test_get_x(parser.get_dot_atom,
+ '\t foo.bar.bing ', '\t foo.bar.bing ', ' foo.bar.bing ', [], '')
+
+ def test_get_dot_atom_with_comments_and_wsp(self):
+ self._test_get_x(parser.get_dot_atom,
+ ' (sing) foo.bar.bing (here) ', ' (sing) foo.bar.bing (here) ',
+ ' foo.bar.bing ', [], '')
+
+ def test_get_dot_atom_space_ends_dot_atom(self):
+ self._test_get_x(parser.get_dot_atom,
+ ' (sing) foo.bar .bing (here) ', ' (sing) foo.bar ',
+ ' foo.bar ', [], '.bing (here) ')
+
+ def test_get_dot_atom_no_atom_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom(' (foo) ')
+
+ def test_get_dot_atom_leading_dot_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom(' (foo) .bar')
+
+ def test_get_dot_atom_two_dots_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom('bar..bang')
+
+ def test_get_dot_atom_trailing_dot_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_dot_atom(' (foo) bar.bang. foo')
+
+ # get_word (if this were black box we'd repeat all the qs/atom tests)
+
+ def test_get_word_atom_yields_atom(self):
+ word = self._test_get_x(parser.get_word,
+ ' (foo) bar (bang) :ah', ' (foo) bar (bang) ', ' bar ', [], ':ah')
+ self.assertEqual(word.token_type, 'atom')
+ self.assertEqual(word[0].token_type, 'cfws')
+
+ def test_get_word_qs_yields_qs(self):
+ word = self._test_get_x(parser.get_word,
+ '"bar " (bang) ah', '"bar " (bang) ', 'bar ', [], 'ah')
+ self.assertEqual(word.token_type, 'quoted-string')
+ self.assertEqual(word[0].token_type, 'bare-quoted-string')
+ self.assertEqual(word[0].value, 'bar ')
+ self.assertEqual(word.content, 'bar ')
+
+ def test_get_word_ends_at_dot(self):
+ self._test_get_x(parser.get_word,
+ 'foo.', 'foo', 'foo', [], '.')
+
+ # get_phrase
+
+ def test_get_phrase_simple(self):
+ phrase = self._test_get_x(parser.get_phrase,
+ '"Fred A. Johnson" is his name, oh.',
+ '"Fred A. Johnson" is his name',
+ 'Fred A. Johnson is his name',
+ [],
+ ', oh.')
+ self.assertEqual(phrase.token_type, 'phrase')
+
+ def test_get_phrase_complex(self):
+ phrase = self._test_get_x(parser.get_phrase,
+ ' (A) bird (in (my|your)) "hand " is messy\t<>\t',
+ ' (A) bird (in (my|your)) "hand " is messy\t',
+ ' bird hand is messy ',
+ [],
+ '<>\t')
+ self.assertEqual(phrase[0][0].comments, ['A'])
+ self.assertEqual(phrase[0][2].comments, ['in (my|your)'])
+
+ def test_get_phrase_obsolete(self):
+ phrase = self._test_get_x(parser.get_phrase,
+ 'Fred A.(weird).O Johnson',
+ 'Fred A.(weird).O Johnson',
+ 'Fred A. .O Johnson',
+ [errors.ObsoleteHeaderDefect]*3,
+ '')
+ self.assertEqual(len(phrase), 7)
+ self.assertEqual(phrase[3].comments, ['weird'])
+
+ def test_get_phrase_pharse_must_start_with_word(self):
+ phrase = self._test_get_x(parser.get_phrase,
+ '(even weirder).name',
+ '(even weirder).name',
+ ' .name',
+ [errors.InvalidHeaderDefect] + [errors.ObsoleteHeaderDefect]*2,
+ '')
+ self.assertEqual(len(phrase), 3)
+ self.assertEqual(phrase[0].comments, ['even weirder'])
+
+ def test_get_phrase_ending_with_obsolete(self):
+ phrase = self._test_get_x(parser.get_phrase,
+ 'simple phrase.(with trailing comment):boo',
+ 'simple phrase.(with trailing comment)',
+ 'simple phrase. ',
+ [errors.ObsoleteHeaderDefect]*2,
+ ':boo')
+ self.assertEqual(len(phrase), 4)
+ self.assertEqual(phrase[3].comments, ['with trailing comment'])
+
+ def get_phrase_cfws_only_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_phrase(' (foo) ')
+
+ # get_local_part
+
+ def test_get_local_part_simple(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ 'dinsdale@python.org', 'dinsdale', 'dinsdale', [], '@python.org')
+ self.assertEqual(local_part.token_type, 'local-part')
+ self.assertEqual(local_part.local_part, 'dinsdale')
+
+ def test_get_local_part_with_dot(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ 'Fred.A.Johnson@python.org',
+ 'Fred.A.Johnson',
+ 'Fred.A.Johnson',
+ [],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson')
+
+ def test_get_local_part_with_whitespace(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' Fred.A.Johnson @python.org',
+ ' Fred.A.Johnson ',
+ ' Fred.A.Johnson ',
+ [],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson')
+
+ def test_get_local_part_with_cfws(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' (foo) Fred.A.Johnson (bar (bird)) @python.org',
+ ' (foo) Fred.A.Johnson (bar (bird)) ',
+ ' Fred.A.Johnson ',
+ [],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson')
+ self.assertEqual(local_part[0][0].comments, ['foo'])
+ self.assertEqual(local_part[0][2].comments, ['bar (bird)'])
+
+ def test_get_local_part_simple_quoted(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ '"dinsdale"@python.org', '"dinsdale"', '"dinsdale"', [], '@python.org')
+ self.assertEqual(local_part.token_type, 'local-part')
+ self.assertEqual(local_part.local_part, 'dinsdale')
+
+ def test_get_local_part_with_quoted_dot(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ '"Fred.A.Johnson"@python.org',
+ '"Fred.A.Johnson"',
+ '"Fred.A.Johnson"',
+ [],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson')
+
+ def test_get_local_part_quoted_with_whitespace(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' "Fred A. Johnson" @python.org',
+ ' "Fred A. Johnson" ',
+ ' "Fred A. Johnson" ',
+ [],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred A. Johnson')
+
+ def test_get_local_part_quoted_with_cfws(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' (foo) " Fred A. Johnson " (bar (bird)) @python.org',
+ ' (foo) " Fred A. Johnson " (bar (bird)) ',
+ ' " Fred A. Johnson " ',
+ [],
+ '@python.org')
+ self.assertEqual(local_part.local_part, ' Fred A. Johnson ')
+ self.assertEqual(local_part[0][0].comments, ['foo'])
+ self.assertEqual(local_part[0][2].comments, ['bar (bird)'])
+
+
+ def test_get_local_part_simple_obsolete(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ 'Fred. A.Johnson@python.org',
+ 'Fred. A.Johnson',
+ 'Fred. A.Johnson',
+ [errors.ObsoleteHeaderDefect],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson')
+
+ def test_get_local_part_complex_obsolete_1(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' (foo )Fred (bar).(bird) A.(sheep)Johnson."and dogs "@python.org',
+ ' (foo )Fred (bar).(bird) A.(sheep)Johnson."and dogs "',
+ ' Fred . A. Johnson.and dogs ',
+ [errors.ObsoleteHeaderDefect],
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson.and dogs ')
+
+ def test_get_local_part_complex_obsolete_invalid(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' (foo )Fred (bar).(bird) A.(sheep)Johnson "and dogs"@python.org',
+ ' (foo )Fred (bar).(bird) A.(sheep)Johnson "and dogs"',
+ ' Fred . A. Johnson and dogs',
+ [errors.InvalidHeaderDefect]*2,
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'Fred.A.Johnson and dogs')
+
+ def test_get_local_part_no_part_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_local_part(' (foo) ')
+
+ def test_get_local_part_special_instead_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_local_part(' (foo) @python.org')
+
+ def test_get_local_part_trailing_dot(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' borris.@python.org',
+ ' borris.',
+ ' borris.',
+ [errors.InvalidHeaderDefect]*2,
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'borris.')
+
+ def test_get_local_part_trailing_dot_with_ws(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' borris. @python.org',
+ ' borris. ',
+ ' borris. ',
+ [errors.InvalidHeaderDefect]*2,
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'borris.')
+
+ def test_get_local_part_leading_dot(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ '.borris@python.org',
+ '.borris',
+ '.borris',
+ [errors.InvalidHeaderDefect]*2,
+ '@python.org')
+ self.assertEqual(local_part.local_part, '.borris')
+
+ def test_get_local_part_leading_dot_after_ws(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' .borris@python.org',
+ ' .borris',
+ ' .borris',
+ [errors.InvalidHeaderDefect]*2,
+ '@python.org')
+ self.assertEqual(local_part.local_part, '.borris')
+
+ def test_get_local_part_double_dot_raises(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ ' borris.(foo).natasha@python.org',
+ ' borris.(foo).natasha',
+ ' borris. .natasha',
+ [errors.InvalidHeaderDefect]*2,
+ '@python.org')
+ self.assertEqual(local_part.local_part, 'borris..natasha')
+
+ def test_get_local_part_quoted_strings_in_atom_list(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ '""example" example"@example.com',
+ '""example" example"',
+ 'example example',
+ [errors.InvalidHeaderDefect]*3,
+ '@example.com')
+ self.assertEqual(local_part.local_part, 'example example')
+
+ def test_get_local_part_valid_and_invalid_qp_in_atom_list(self):
+ local_part = self._test_get_x(parser.get_local_part,
+ r'"\\"example\\" example"@example.com',
+ r'"\\"example\\" example"',
+ r'\example\\ example',
+ [errors.InvalidHeaderDefect]*5,
+ '@example.com')
+ self.assertEqual(local_part.local_part, r'\example\\ example')
+
+ def test_get_local_part_unicode_defect(self):
+ # Currently this only happens when parsing unicode, not when parsing
+ # stuff that was originally binary.
+ local_part = self._test_get_x(parser.get_local_part,
+ 'exámple@example.com',
+ 'exámple',
+ 'exámple',
+ [errors.NonASCIILocalPartDefect],
+ '@example.com')
+ self.assertEqual(local_part.local_part, 'exámple')
+
+ # get_dtext
+
+ def test_get_dtext_only(self):
+ dtext = self._test_get_x(parser.get_dtext,
+ 'foobar', 'foobar', 'foobar', [], '')
+ self.assertEqual(dtext.token_type, 'ptext')
+
+ def test_get_dtext_all_dtext(self):
+ dtext = self._test_get_x(parser.get_dtext, self.rfc_dtext_chars,
+ self.rfc_dtext_chars,
+ self.rfc_dtext_chars, [], '')
+
+ def test_get_dtext_two_words_gets_first(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo bar', 'foo', 'foo', [], ' bar')
+
+ def test_get_dtext_following_wsp_preserved(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo \t\tbar', 'foo', 'foo', [], ' \t\tbar')
+
+ def test_get_dtext_non_printables(self):
+ dtext = self._test_get_x(parser.get_dtext,
+ 'foo\x00bar]', 'foo\x00bar', 'foo\x00bar',
+ [errors.NonPrintableDefect], ']')
+ self.assertEqual(dtext.defects[0].non_printables[0], '\x00')
+
+ def test_get_dtext_with_qp(self):
+ ptext = self._test_get_x(parser.get_dtext,
+ r'foo\]\[\\bar\b\e\l\l',
+ r'foo][\barbell',
+ r'foo][\barbell',
+ [errors.ObsoleteHeaderDefect],
+ '')
+
+ def test_get_dtext_up_to_close_bracket_only(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo]', 'foo', 'foo', [], ']')
+
+ def test_get_dtext_wsp_before_close_bracket_preserved(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo ]', 'foo', 'foo', [], ' ]')
+
+ def test_get_dtext_close_bracket_mid_word(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo]bar', 'foo', 'foo', [], ']bar')
+
+ def test_get_dtext_up_to_open_bracket_only(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo[', 'foo', 'foo', [], '[')
+
+ def test_get_dtext_wsp_before_open_bracket_preserved(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo [', 'foo', 'foo', [], ' [')
+
+ def test_get_dtext_open_bracket_mid_word(self):
+ self._test_get_x(parser.get_dtext,
+ 'foo[bar', 'foo', 'foo', [], '[bar')
+
+ # get_domain_literal
+
+ def test_get_domain_literal_only(self):
+ domain_literal = domain_literal = self._test_get_x(parser.get_domain_literal,
+ '[127.0.0.1]',
+ '[127.0.0.1]',
+ '[127.0.0.1]',
+ [],
+ '')
+ self.assertEqual(domain_literal.token_type, 'domain-literal')
+ self.assertEqual(domain_literal.domain, '[127.0.0.1]')
+ self.assertEqual(domain_literal.ip, '127.0.0.1')
+
+ def test_get_domain_literal_with_internal_ws(self):
+ domain_literal = self._test_get_x(parser.get_domain_literal,
+ '[ 127.0.0.1\t ]',
+ '[ 127.0.0.1\t ]',
+ '[ 127.0.0.1 ]',
+ [],
+ '')
+ self.assertEqual(domain_literal.domain, '[127.0.0.1]')
+ self.assertEqual(domain_literal.ip, '127.0.0.1')
+
+ def test_get_domain_literal_with_surrounding_cfws(self):
+ domain_literal = self._test_get_x(parser.get_domain_literal,
+ '(foo)[ 127.0.0.1] (bar)',
+ '(foo)[ 127.0.0.1] (bar)',
+ ' [ 127.0.0.1] ',
+ [],
+ '')
+ self.assertEqual(domain_literal.domain, '[127.0.0.1]')
+ self.assertEqual(domain_literal.ip, '127.0.0.1')
+
+ def test_get_domain_literal_no_start_char_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_domain_literal('(foo) ')
+
+ def test_get_domain_literal_no_start_char_before_special_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_domain_literal('(foo) @')
+
+ def test_get_domain_literal_bad_dtext_char_before_special_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_domain_literal('(foo) [abc[@')
+
+ # get_domain
+
+ def test_get_domain_regular_domain_only(self):
+ domain = self._test_get_x(parser.get_domain,
+ 'example.com',
+ 'example.com',
+ 'example.com',
+ [],
+ '')
+ self.assertEqual(domain.token_type, 'domain')
+ self.assertEqual(domain.domain, 'example.com')
+
+ def test_get_domain_domain_literal_only(self):
+ domain = self._test_get_x(parser.get_domain,
+ '[127.0.0.1]',
+ '[127.0.0.1]',
+ '[127.0.0.1]',
+ [],
+ '')
+ self.assertEqual(domain.token_type, 'domain')
+ self.assertEqual(domain.domain, '[127.0.0.1]')
+
+ def test_get_domain_with_cfws(self):
+ domain = self._test_get_x(parser.get_domain,
+ '(foo) example.com(bar)\t',
+ '(foo) example.com(bar)\t',
+ ' example.com ',
+ [],
+ '')
+ self.assertEqual(domain.domain, 'example.com')
+
+ def test_get_domain_domain_literal_with_cfws(self):
+ domain = self._test_get_x(parser.get_domain,
+ '(foo)[127.0.0.1]\t(bar)',
+ '(foo)[127.0.0.1]\t(bar)',
+ ' [127.0.0.1] ',
+ [],
+ '')
+ self.assertEqual(domain.domain, '[127.0.0.1]')
+
+ def test_get_domain_domain_with_cfws_ends_at_special(self):
+ domain = self._test_get_x(parser.get_domain,
+ '(foo)example.com\t(bar), next',
+ '(foo)example.com\t(bar)',
+ ' example.com ',
+ [],
+ ', next')
+ self.assertEqual(domain.domain, 'example.com')
+
+ def test_get_domain_domain_literal_with_cfws_ends_at_special(self):
+ domain = self._test_get_x(parser.get_domain,
+ '(foo)[127.0.0.1]\t(bar), next',
+ '(foo)[127.0.0.1]\t(bar)',
+ ' [127.0.0.1] ',
+ [],
+ ', next')
+ self.assertEqual(domain.domain, '[127.0.0.1]')
+
+ def test_get_domain_obsolete(self):
+ domain = self._test_get_x(parser.get_domain,
+ '(foo) example . (bird)com(bar)\t',
+ '(foo) example . (bird)com(bar)\t',
+ ' example . com ',
+ [errors.ObsoleteHeaderDefect],
+ '')
+ self.assertEqual(domain.domain, 'example.com')
+
+ def test_get_domain_no_non_cfws_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_domain(" (foo)\t")
+
+ def test_get_domain_no_atom_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_domain(" (foo)\t, broken")
+
+
+ # get_addr_spec
+
+ def test_get_addr_spec_normal(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ [],
+ '')
+ self.assertEqual(addr_spec.token_type, 'addr-spec')
+ self.assertEqual(addr_spec.local_part, 'dinsdale')
+ self.assertEqual(addr_spec.domain, 'example.com')
+ self.assertEqual(addr_spec.addr_spec, 'dinsdale@example.com')
+
+ def test_get_addr_spec_with_doamin_literal(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ 'dinsdale@[127.0.0.1]',
+ 'dinsdale@[127.0.0.1]',
+ 'dinsdale@[127.0.0.1]',
+ [],
+ '')
+ self.assertEqual(addr_spec.local_part, 'dinsdale')
+ self.assertEqual(addr_spec.domain, '[127.0.0.1]')
+ self.assertEqual(addr_spec.addr_spec, 'dinsdale@[127.0.0.1]')
+
+ def test_get_addr_spec_with_cfws(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ '(foo) dinsdale(bar)@ (bird) example.com (bog)',
+ '(foo) dinsdale(bar)@ (bird) example.com (bog)',
+ ' dinsdale@example.com ',
+ [],
+ '')
+ self.assertEqual(addr_spec.local_part, 'dinsdale')
+ self.assertEqual(addr_spec.domain, 'example.com')
+ self.assertEqual(addr_spec.addr_spec, 'dinsdale@example.com')
+
+ def test_get_addr_spec_with_qouoted_string_and_cfws(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ '(foo) "roy a bug"(bar)@ (bird) example.com (bog)',
+ '(foo) "roy a bug"(bar)@ (bird) example.com (bog)',
+ ' "roy a bug"@example.com ',
+ [],
+ '')
+ self.assertEqual(addr_spec.local_part, 'roy a bug')
+ self.assertEqual(addr_spec.domain, 'example.com')
+ self.assertEqual(addr_spec.addr_spec, '"roy a bug"@example.com')
+
+ def test_get_addr_spec_ends_at_special(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ '(foo) "roy a bug"(bar)@ (bird) example.com (bog) , next',
+ '(foo) "roy a bug"(bar)@ (bird) example.com (bog) ',
+ ' "roy a bug"@example.com ',
+ [],
+ ', next')
+ self.assertEqual(addr_spec.local_part, 'roy a bug')
+ self.assertEqual(addr_spec.domain, 'example.com')
+ self.assertEqual(addr_spec.addr_spec, '"roy a bug"@example.com')
+
+ def test_get_addr_spec_quoted_strings_in_atom_list(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ '""example" example"@example.com',
+ '""example" example"@example.com',
+ 'example example@example.com',
+ [errors.InvalidHeaderDefect]*3,
+ '')
+ self.assertEqual(addr_spec.local_part, 'example example')
+ self.assertEqual(addr_spec.domain, 'example.com')
+ self.assertEqual(addr_spec.addr_spec, '"example example"@example.com')
+
+ def test_get_addr_spec_dot_atom(self):
+ addr_spec = self._test_get_x(parser.get_addr_spec,
+ 'star.a.star@example.com',
+ 'star.a.star@example.com',
+ 'star.a.star@example.com',
+ [],
+ '')
+ self.assertEqual(addr_spec.local_part, 'star.a.star')
+ self.assertEqual(addr_spec.domain, 'example.com')
+ self.assertEqual(addr_spec.addr_spec, 'star.a.star@example.com')
+
+ # get_obs_route
+
+ def test_get_obs_route_simple(self):
+ obs_route = self._test_get_x(parser.get_obs_route,
+ '@example.com, @two.example.com:',
+ '@example.com, @two.example.com:',
+ '@example.com, @two.example.com:',
+ [],
+ '')
+ self.assertEqual(obs_route.token_type, 'obs-route')
+ self.assertEqual(obs_route.domains, ['example.com', 'two.example.com'])
+
+ def test_get_obs_route_complex(self):
+ obs_route = self._test_get_x(parser.get_obs_route,
+ '(foo),, (blue)@example.com (bar),@two.(foo) example.com (bird):',
+ '(foo),, (blue)@example.com (bar),@two.(foo) example.com (bird):',
+ ' ,, @example.com ,@two. example.com :',
+ [errors.ObsoleteHeaderDefect], # This is the obs-domain
+ '')
+ self.assertEqual(obs_route.token_type, 'obs-route')
+ self.assertEqual(obs_route.domains, ['example.com', 'two.example.com'])
+
+ def test_get_obs_route_no_route_before_end_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_obs_route('(foo) @example.com,')
+
+ def test_get_obs_route_no_route_before_special_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_obs_route('(foo) [abc],')
+
+ def test_get_obs_route_no_route_before_special_raises2(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_obs_route('(foo) @example.com [abc],')
+
+ # get_angle_addr
+
+ def test_get_angle_addr_simple(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(angle_addr.token_type, 'angle-addr')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_empty(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<>',
+ '<>',
+ '<>',
+ [errors.InvalidHeaderDefect],
+ '')
+ self.assertEqual(angle_addr.token_type, 'angle-addr')
+ self.assertIsNone(angle_addr.local_part)
+ self.assertIsNone(angle_addr.domain)
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, '<>')
+
+ def test_get_angle_addr_with_cfws(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ ' (foo) <dinsdale@example.com>(bar)',
+ ' (foo) <dinsdale@example.com>(bar)',
+ ' <dinsdale@example.com> ',
+ [],
+ '')
+ self.assertEqual(angle_addr.token_type, 'angle-addr')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_qs_and_domain_literal(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<"Fred Perfect"@[127.0.0.1]>',
+ '<"Fred Perfect"@[127.0.0.1]>',
+ '<"Fred Perfect"@[127.0.0.1]>',
+ [],
+ '')
+ self.assertEqual(angle_addr.local_part, 'Fred Perfect')
+ self.assertEqual(angle_addr.domain, '[127.0.0.1]')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, '"Fred Perfect"@[127.0.0.1]')
+
+ def test_get_angle_addr_internal_cfws(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<(foo) dinsdale@example.com(bar)>',
+ '<(foo) dinsdale@example.com(bar)>',
+ '< dinsdale@example.com >',
+ [],
+ '')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_obs_route(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '(foo)<@example.com, (bird) @two.example.com: dinsdale@example.com> (bar) ',
+ '(foo)<@example.com, (bird) @two.example.com: dinsdale@example.com> (bar) ',
+ ' <@example.com, @two.example.com: dinsdale@example.com> ',
+ [errors.ObsoleteHeaderDefect],
+ '')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertEqual(angle_addr.route, ['example.com', 'two.example.com'])
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_missing_closing_angle(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<dinsdale@example.com',
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ [errors.InvalidHeaderDefect],
+ '')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_missing_closing_angle_with_cfws(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<dinsdale@example.com (foo)',
+ '<dinsdale@example.com (foo)>',
+ '<dinsdale@example.com >',
+ [errors.InvalidHeaderDefect],
+ '')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_ends_at_special(self):
+ angle_addr = self._test_get_x(parser.get_angle_addr,
+ '<dinsdale@example.com> (foo), next',
+ '<dinsdale@example.com> (foo)',
+ '<dinsdale@example.com> ',
+ [],
+ ', next')
+ self.assertEqual(angle_addr.local_part, 'dinsdale')
+ self.assertEqual(angle_addr.domain, 'example.com')
+ self.assertIsNone(angle_addr.route)
+ self.assertEqual(angle_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_angle_addr_no_angle_raise(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_angle_addr('(foo) ')
+
+ def test_get_angle_addr_no_angle_before_special_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_angle_addr('(foo) , next')
+
+ def test_get_angle_addr_no_angle_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_angle_addr('bar')
+
+ def test_get_angle_addr_special_after_angle_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_angle_addr('(foo) <, bar')
+
+ # get_display_name This is phrase but with a different value.
+
+ def test_get_display_name_simple(self):
+ display_name = self._test_get_x(parser.get_display_name,
+ 'Fred A Johnson',
+ 'Fred A Johnson',
+ 'Fred A Johnson',
+ [],
+ '')
+ self.assertEqual(display_name.token_type, 'display-name')
+ self.assertEqual(display_name.display_name, 'Fred A Johnson')
+
+ def test_get_display_name_complex1(self):
+ display_name = self._test_get_x(parser.get_display_name,
+ '"Fred A. Johnson" is his name, oh.',
+ '"Fred A. Johnson" is his name',
+ '"Fred A. Johnson is his name"',
+ [],
+ ', oh.')
+ self.assertEqual(display_name.token_type, 'display-name')
+ self.assertEqual(display_name.display_name, 'Fred A. Johnson is his name')
+
+ def test_get_display_name_complex2(self):
+ display_name = self._test_get_x(parser.get_display_name,
+ ' (A) bird (in (my|your)) "hand " is messy\t<>\t',
+ ' (A) bird (in (my|your)) "hand " is messy\t',
+ ' "bird hand is messy" ',
+ [],
+ '<>\t')
+ self.assertEqual(display_name[0][0].comments, ['A'])
+ self.assertEqual(display_name[0][2].comments, ['in (my|your)'])
+ self.assertEqual(display_name.display_name, 'bird hand is messy')
+
+ def test_get_display_name_obsolete(self):
+ display_name = self._test_get_x(parser.get_display_name,
+ 'Fred A.(weird).O Johnson',
+ 'Fred A.(weird).O Johnson',
+ '"Fred A. .O Johnson"',
+ [errors.ObsoleteHeaderDefect]*3,
+ '')
+ self.assertEqual(len(display_name), 7)
+ self.assertEqual(display_name[3].comments, ['weird'])
+ self.assertEqual(display_name.display_name, 'Fred A. .O Johnson')
+
+ def test_get_display_name_pharse_must_start_with_word(self):
+ display_name = self._test_get_x(parser.get_display_name,
+ '(even weirder).name',
+ '(even weirder).name',
+ ' ".name"',
+ [errors.InvalidHeaderDefect] + [errors.ObsoleteHeaderDefect]*2,
+ '')
+ self.assertEqual(len(display_name), 3)
+ self.assertEqual(display_name[0].comments, ['even weirder'])
+ self.assertEqual(display_name.display_name, '.name')
+
+ def test_get_display_name_ending_with_obsolete(self):
+ display_name = self._test_get_x(parser.get_display_name,
+ 'simple phrase.(with trailing comment):boo',
+ 'simple phrase.(with trailing comment)',
+ '"simple phrase." ',
+ [errors.ObsoleteHeaderDefect]*2,
+ ':boo')
+ self.assertEqual(len(display_name), 4)
+ self.assertEqual(display_name[3].comments, ['with trailing comment'])
+ self.assertEqual(display_name.display_name, 'simple phrase.')
+
+ # get_name_addr
+
+ def test_get_name_addr_angle_addr_only(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(name_addr.token_type, 'name-addr')
+ self.assertIsNone(name_addr.display_name)
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertIsNone(name_addr.route)
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_atom_name(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ 'Dinsdale <dinsdale@example.com>',
+ 'Dinsdale <dinsdale@example.com>',
+ 'Dinsdale <dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(name_addr.token_type, 'name-addr')
+ self.assertEqual(name_addr.display_name, 'Dinsdale')
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertIsNone(name_addr.route)
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_atom_name_with_cfws(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ '(foo) Dinsdale (bar) <dinsdale@example.com> (bird)',
+ '(foo) Dinsdale (bar) <dinsdale@example.com> (bird)',
+ ' Dinsdale <dinsdale@example.com> ',
+ [],
+ '')
+ self.assertEqual(name_addr.display_name, 'Dinsdale')
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertIsNone(name_addr.route)
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_name_with_cfws_and_dots(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ '(foo) Roy.A.Bear (bar) <dinsdale@example.com> (bird)',
+ '(foo) Roy.A.Bear (bar) <dinsdale@example.com> (bird)',
+ ' "Roy.A.Bear" <dinsdale@example.com> ',
+ [errors.ObsoleteHeaderDefect]*2,
+ '')
+ self.assertEqual(name_addr.display_name, 'Roy.A.Bear')
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertIsNone(name_addr.route)
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_qs_name(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ '"Roy.A.Bear" <dinsdale@example.com>',
+ '"Roy.A.Bear" <dinsdale@example.com>',
+ '"Roy.A.Bear" <dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(name_addr.display_name, 'Roy.A.Bear')
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertIsNone(name_addr.route)
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_with_route(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ '"Roy.A.Bear" <@two.example.com: dinsdale@example.com>',
+ '"Roy.A.Bear" <@two.example.com: dinsdale@example.com>',
+ '"Roy.A.Bear" <@two.example.com: dinsdale@example.com>',
+ [errors.ObsoleteHeaderDefect],
+ '')
+ self.assertEqual(name_addr.display_name, 'Roy.A.Bear')
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertEqual(name_addr.route, ['two.example.com'])
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_ends_at_special(self):
+ name_addr = self._test_get_x(parser.get_name_addr,
+ '"Roy.A.Bear" <dinsdale@example.com>, next',
+ '"Roy.A.Bear" <dinsdale@example.com>',
+ '"Roy.A.Bear" <dinsdale@example.com>',
+ [],
+ ', next')
+ self.assertEqual(name_addr.display_name, 'Roy.A.Bear')
+ self.assertEqual(name_addr.local_part, 'dinsdale')
+ self.assertEqual(name_addr.domain, 'example.com')
+ self.assertIsNone(name_addr.route)
+ self.assertEqual(name_addr.addr_spec, 'dinsdale@example.com')
+
+ def test_get_name_addr_no_content_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_name_addr(' (foo) ')
+
+ def test_get_name_addr_no_content_before_special_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_name_addr(' (foo) ,')
+
+ def test_get_name_addr_no_angle_after_display_name_raises(self):
+ with self.assertRaises(errors.HeaderParseError):
+ parser.get_name_addr('foo bar')
+
+ # get_mailbox
+
+ def test_get_mailbox_addr_spec_only(self):
+ mailbox = self._test_get_x(parser.get_mailbox,
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ [],
+ '')
+ self.assertEqual(mailbox.token_type, 'mailbox')
+ self.assertIsNone(mailbox.display_name)
+ self.assertEqual(mailbox.local_part, 'dinsdale')
+ self.assertEqual(mailbox.domain, 'example.com')
+ self.assertIsNone(mailbox.route)
+ self.assertEqual(mailbox.addr_spec, 'dinsdale@example.com')
+
+ def test_get_mailbox_angle_addr_only(self):
+ mailbox = self._test_get_x(parser.get_mailbox,
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ '<dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(mailbox.token_type, 'mailbox')
+ self.assertIsNone(mailbox.display_name)
+ self.assertEqual(mailbox.local_part, 'dinsdale')
+ self.assertEqual(mailbox.domain, 'example.com')
+ self.assertIsNone(mailbox.route)
+ self.assertEqual(mailbox.addr_spec, 'dinsdale@example.com')
+
+ def test_get_mailbox_name_addr(self):
+ mailbox = self._test_get_x(parser.get_mailbox,
+ '"Roy A. Bear" <dinsdale@example.com>',
+ '"Roy A. Bear" <dinsdale@example.com>',
+ '"Roy A. Bear" <dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(mailbox.token_type, 'mailbox')
+ self.assertEqual(mailbox.display_name, 'Roy A. Bear')
+ self.assertEqual(mailbox.local_part, 'dinsdale')
+ self.assertEqual(mailbox.domain, 'example.com')
+ self.assertIsNone(mailbox.route)
+ self.assertEqual(mailbox.addr_spec, 'dinsdale@example.com')
+
+ def test_get_mailbox_ends_at_special(self):
+ mailbox = self._test_get_x(parser.get_mailbox,
+ '"Roy A. Bear" <dinsdale@example.com>, rest',
+ '"Roy A. Bear" <dinsdale@example.com>',
+ '"Roy A. Bear" <dinsdale@example.com>',
+ [],
+ ', rest')
+ self.assertEqual(mailbox.token_type, 'mailbox')
+ self.assertEqual(mailbox.display_name, 'Roy A. Bear')
+ self.assertEqual(mailbox.local_part, 'dinsdale')
+ self.assertEqual(mailbox.domain, 'example.com')
+ self.assertIsNone(mailbox.route)
+ self.assertEqual(mailbox.addr_spec, 'dinsdale@example.com')
+
+ def test_get_mailbox_quoted_strings_in_atom_list(self):
+ mailbox = self._test_get_x(parser.get_mailbox,
+ '""example" example"@example.com',
+ '""example" example"@example.com',
+ 'example example@example.com',
+ [errors.InvalidHeaderDefect]*3,
+ '')
+ self.assertEqual(mailbox.local_part, 'example example')
+ self.assertEqual(mailbox.domain, 'example.com')
+ self.assertEqual(mailbox.addr_spec, '"example example"@example.com')
+
+ # get_mailbox_list
+
+ def test_get_mailbox_list_single_addr(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ [],
+ '')
+ self.assertEqual(mailbox_list.token_type, 'mailbox-list')
+ self.assertEqual(len(mailbox_list.mailboxes), 1)
+ mailbox = mailbox_list.mailboxes[0]
+ self.assertIsNone(mailbox.display_name)
+ self.assertEqual(mailbox.local_part, 'dinsdale')
+ self.assertEqual(mailbox.domain, 'example.com')
+ self.assertIsNone(mailbox.route)
+ self.assertEqual(mailbox.addr_spec, 'dinsdale@example.com')
+ self.assertEqual(mailbox_list.mailboxes,
+ mailbox_list.all_mailboxes)
+
+ def test_get_mailbox_list_two_simple_addr(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ 'dinsdale@example.com, dinsdale@test.example.com',
+ 'dinsdale@example.com, dinsdale@test.example.com',
+ 'dinsdale@example.com, dinsdale@test.example.com',
+ [],
+ '')
+ self.assertEqual(mailbox_list.token_type, 'mailbox-list')
+ self.assertEqual(len(mailbox_list.mailboxes), 2)
+ self.assertEqual(mailbox_list.mailboxes[0].addr_spec,
+ 'dinsdale@example.com')
+ self.assertEqual(mailbox_list.mailboxes[1].addr_spec,
+ 'dinsdale@test.example.com')
+ self.assertEqual(mailbox_list.mailboxes,
+ mailbox_list.all_mailboxes)
+
+ def test_get_mailbox_list_two_name_addr(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ ('"Roy A. Bear" <dinsdale@example.com>,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ [],
+ '')
+ self.assertEqual(len(mailbox_list.mailboxes), 2)
+ self.assertEqual(mailbox_list.mailboxes[0].addr_spec,
+ 'dinsdale@example.com')
+ self.assertEqual(mailbox_list.mailboxes[0].display_name,
+ 'Roy A. Bear')
+ self.assertEqual(mailbox_list.mailboxes[1].addr_spec,
+ 'dinsdale@test.example.com')
+ self.assertEqual(mailbox_list.mailboxes[1].display_name,
+ 'Fred Flintstone')
+ self.assertEqual(mailbox_list.mailboxes,
+ mailbox_list.all_mailboxes)
+
+ def test_get_mailbox_list_two_complex(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ ('(foo) "Roy A. Bear" <dinsdale@example.com>(bar),'
+ ' "Fred Flintstone" <dinsdale@test.(bird)example.com>'),
+ ('(foo) "Roy A. Bear" <dinsdale@example.com>(bar),'
+ ' "Fred Flintstone" <dinsdale@test.(bird)example.com>'),
+ (' "Roy A. Bear" <dinsdale@example.com> ,'
+ ' "Fred Flintstone" <dinsdale@test. example.com>'),
+ [errors.ObsoleteHeaderDefect],
+ '')
+ self.assertEqual(len(mailbox_list.mailboxes), 2)
+ self.assertEqual(mailbox_list.mailboxes[0].addr_spec,
+ 'dinsdale@example.com')
+ self.assertEqual(mailbox_list.mailboxes[0].display_name,
+ 'Roy A. Bear')
+ self.assertEqual(mailbox_list.mailboxes[1].addr_spec,
+ 'dinsdale@test.example.com')
+ self.assertEqual(mailbox_list.mailboxes[1].display_name,
+ 'Fred Flintstone')
+ self.assertEqual(mailbox_list.mailboxes,
+ mailbox_list.all_mailboxes)
+
+ def test_get_mailbox_list_unparseable_mailbox_null(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ ('"Roy A. Bear"[] dinsdale@example.com,'
+ ' "Fred Flintstone" <dinsdale@test.(bird)example.com>'),
+ ('"Roy A. Bear"[] dinsdale@example.com,'
+ ' "Fred Flintstone" <dinsdale@test.(bird)example.com>'),
+ ('"Roy A. Bear"[] dinsdale@example.com,'
+ ' "Fred Flintstone" <dinsdale@test. example.com>'),
+ [errors.InvalidHeaderDefect, # the 'extra' text after the local part
+ errors.InvalidHeaderDefect, # the local part with no angle-addr
+ errors.ObsoleteHeaderDefect, # period in extra text (example.com)
+ errors.ObsoleteHeaderDefect], # (bird) in valid address.
+ '')
+ self.assertEqual(len(mailbox_list.mailboxes), 1)
+ self.assertEqual(len(mailbox_list.all_mailboxes), 2)
+ self.assertEqual(mailbox_list.all_mailboxes[0].token_type,
+ 'invalid-mailbox')
+ self.assertIsNone(mailbox_list.all_mailboxes[0].display_name)
+ self.assertEqual(mailbox_list.all_mailboxes[0].local_part,
+ 'Roy A. Bear')
+ self.assertIsNone(mailbox_list.all_mailboxes[0].domain)
+ self.assertEqual(mailbox_list.all_mailboxes[0].addr_spec,
+ '"Roy A. Bear"')
+ self.assertIs(mailbox_list.all_mailboxes[1],
+ mailbox_list.mailboxes[0])
+ self.assertEqual(mailbox_list.mailboxes[0].addr_spec,
+ 'dinsdale@test.example.com')
+ self.assertEqual(mailbox_list.mailboxes[0].display_name,
+ 'Fred Flintstone')
+
+ def test_get_mailbox_list_junk_after_valid_address(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ ('"Roy A. Bear" <dinsdale@example.com>@@,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>@@,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>@@,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ [errors.InvalidHeaderDefect],
+ '')
+ self.assertEqual(len(mailbox_list.mailboxes), 1)
+ self.assertEqual(len(mailbox_list.all_mailboxes), 2)
+ self.assertEqual(mailbox_list.all_mailboxes[0].addr_spec,
+ 'dinsdale@example.com')
+ self.assertEqual(mailbox_list.all_mailboxes[0].display_name,
+ 'Roy A. Bear')
+ self.assertEqual(mailbox_list.all_mailboxes[0].token_type,
+ 'invalid-mailbox')
+ self.assertIs(mailbox_list.all_mailboxes[1],
+ mailbox_list.mailboxes[0])
+ self.assertEqual(mailbox_list.mailboxes[0].addr_spec,
+ 'dinsdale@test.example.com')
+ self.assertEqual(mailbox_list.mailboxes[0].display_name,
+ 'Fred Flintstone')
+
+ def test_get_mailbox_list_empty_list_element(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ ('"Roy A. Bear" <dinsdale@example.com>, (bird),,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>, (bird),,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>, ,,'
+ ' "Fred Flintstone" <dinsdale@test.example.com>'),
+ [errors.ObsoleteHeaderDefect]*2,
+ '')
+ self.assertEqual(len(mailbox_list.mailboxes), 2)
+ self.assertEqual(mailbox_list.all_mailboxes,
+ mailbox_list.mailboxes)
+ self.assertEqual(mailbox_list.all_mailboxes[0].addr_spec,
+ 'dinsdale@example.com')
+ self.assertEqual(mailbox_list.all_mailboxes[0].display_name,
+ 'Roy A. Bear')
+ self.assertEqual(mailbox_list.mailboxes[1].addr_spec,
+ 'dinsdale@test.example.com')
+ self.assertEqual(mailbox_list.mailboxes[1].display_name,
+ 'Fred Flintstone')
+
+ def test_get_mailbox_list_only_empty_elements(self):
+ mailbox_list = self._test_get_x(parser.get_mailbox_list,
+ '(foo),, (bar)',
+ '(foo),, (bar)',
+ ' ,, ',
+ [errors.ObsoleteHeaderDefect]*3,
+ '')
+ self.assertEqual(len(mailbox_list.mailboxes), 0)
+ self.assertEqual(mailbox_list.all_mailboxes,
+ mailbox_list.mailboxes)
+
+ # get_group_list
+
+ def test_get_group_list_cfws_only(self):
+ group_list = self._test_get_x(parser.get_group_list,
+ '(hidden);',
+ '(hidden)',
+ ' ',
+ [],
+ ';')
+ self.assertEqual(group_list.token_type, 'group-list')
+ self.assertEqual(len(group_list.mailboxes), 0)
+ self.assertEqual(group_list.mailboxes,
+ group_list.all_mailboxes)
+
+ def test_get_group_list_mailbox_list(self):
+ group_list = self._test_get_x(parser.get_group_list,
+ 'dinsdale@example.org, "Fred A. Bear" <dinsdale@example.org>',
+ 'dinsdale@example.org, "Fred A. Bear" <dinsdale@example.org>',
+ 'dinsdale@example.org, "Fred A. Bear" <dinsdale@example.org>',
+ [],
+ '')
+ self.assertEqual(group_list.token_type, 'group-list')
+ self.assertEqual(len(group_list.mailboxes), 2)
+ self.assertEqual(group_list.mailboxes,
+ group_list.all_mailboxes)
+ self.assertEqual(group_list.mailboxes[1].display_name,
+ 'Fred A. Bear')
+
+ def test_get_group_list_obs_group_list(self):
+ group_list = self._test_get_x(parser.get_group_list,
+ ', (foo),,(bar)',
+ ', (foo),,(bar)',
+ ', ,, ',
+ [errors.ObsoleteHeaderDefect],
+ '')
+ self.assertEqual(group_list.token_type, 'group-list')
+ self.assertEqual(len(group_list.mailboxes), 0)
+ self.assertEqual(group_list.mailboxes,
+ group_list.all_mailboxes)
+
+ def test_get_group_list_comment_only_invalid(self):
+ group_list = self._test_get_x(parser.get_group_list,
+ '(bar)',
+ '(bar)',
+ ' ',
+ [errors.InvalidHeaderDefect],
+ '')
+ self.assertEqual(group_list.token_type, 'group-list')
+ self.assertEqual(len(group_list.mailboxes), 0)
+ self.assertEqual(group_list.mailboxes,
+ group_list.all_mailboxes)
+
+ # get_group
+
+ def test_get_group_empty(self):
+ group = self._test_get_x(parser.get_group,
+ 'Monty Python:;',
+ 'Monty Python:;',
+ 'Monty Python:;',
+ [],
+ '')
+ self.assertEqual(group.token_type, 'group')
+ self.assertEqual(group.display_name, 'Monty Python')
+ self.assertEqual(len(group.mailboxes), 0)
+ self.assertEqual(group.mailboxes,
+ group.all_mailboxes)
+
+ def test_get_group_null_addr_spec(self):
+ group = self._test_get_x(parser.get_group,
+ 'foo: <>;',
+ 'foo: <>;',
+ 'foo: <>;',
+ [errors.InvalidHeaderDefect],
+ '')
+ self.assertEqual(group.display_name, 'foo')
+ self.assertEqual(len(group.mailboxes), 0)
+ self.assertEqual(len(group.all_mailboxes), 1)
+ self.assertEqual(group.all_mailboxes[0].value, '<>')
+
+ def test_get_group_cfws_only(self):
+ group = self._test_get_x(parser.get_group,
+ 'Monty Python: (hidden);',
+ 'Monty Python: (hidden);',
+ 'Monty Python: ;',
+ [],
+ '')
+ self.assertEqual(group.token_type, 'group')
+ self.assertEqual(group.display_name, 'Monty Python')
+ self.assertEqual(len(group.mailboxes), 0)
+ self.assertEqual(group.mailboxes,
+ group.all_mailboxes)
+
+ def test_get_group_single_mailbox(self):
+ group = self._test_get_x(parser.get_group,
+ 'Monty Python: "Fred A. Bear" <dinsdale@example.com>;',
+ 'Monty Python: "Fred A. Bear" <dinsdale@example.com>;',
+ 'Monty Python: "Fred A. Bear" <dinsdale@example.com>;',
+ [],
+ '')
+ self.assertEqual(group.token_type, 'group')
+ self.assertEqual(group.display_name, 'Monty Python')
+ self.assertEqual(len(group.mailboxes), 1)
+ self.assertEqual(group.mailboxes,
+ group.all_mailboxes)
+ self.assertEqual(group.mailboxes[0].addr_spec,
+ 'dinsdale@example.com')
+
+ def test_get_group_mixed_list(self):
+ group = self._test_get_x(parser.get_group,
+ ('Monty Python: "Fred A. Bear" <dinsdale@example.com>,'
+ '(foo) Roger <ping@exampele.com>, x@test.example.com;'),
+ ('Monty Python: "Fred A. Bear" <dinsdale@example.com>,'
+ '(foo) Roger <ping@exampele.com>, x@test.example.com;'),
+ ('Monty Python: "Fred A. Bear" <dinsdale@example.com>,'
+ ' Roger <ping@exampele.com>, x@test.example.com;'),
+ [],
+ '')
+ self.assertEqual(group.token_type, 'group')
+ self.assertEqual(group.display_name, 'Monty Python')
+ self.assertEqual(len(group.mailboxes), 3)
+ self.assertEqual(group.mailboxes,
+ group.all_mailboxes)
+ self.assertEqual(group.mailboxes[0].display_name,
+ 'Fred A. Bear')
+ self.assertEqual(group.mailboxes[1].display_name,
+ 'Roger')
+ self.assertEqual(group.mailboxes[2].local_part, 'x')
+
+ def test_get_group_one_invalid(self):
+ group = self._test_get_x(parser.get_group,
+ ('Monty Python: "Fred A. Bear" <dinsdale@example.com>,'
+ '(foo) Roger ping@exampele.com, x@test.example.com;'),
+ ('Monty Python: "Fred A. Bear" <dinsdale@example.com>,'
+ '(foo) Roger ping@exampele.com, x@test.example.com;'),
+ ('Monty Python: "Fred A. Bear" <dinsdale@example.com>,'
+ ' Roger ping@exampele.com, x@test.example.com;'),
+ [errors.InvalidHeaderDefect, # non-angle addr makes local part invalid
+ errors.InvalidHeaderDefect], # and its not obs-local either: no dots.
+ '')
+ self.assertEqual(group.token_type, 'group')
+ self.assertEqual(group.display_name, 'Monty Python')
+ self.assertEqual(len(group.mailboxes), 2)
+ self.assertEqual(len(group.all_mailboxes), 3)
+ self.assertEqual(group.mailboxes[0].display_name,
+ 'Fred A. Bear')
+ self.assertEqual(group.mailboxes[1].local_part, 'x')
+ self.assertIsNone(group.all_mailboxes[1].display_name)
+
+ # get_address
+
+ def test_get_address_simple(self):
+ address = self._test_get_x(parser.get_address,
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ [],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 1)
+ self.assertEqual(address.mailboxes,
+ address.all_mailboxes)
+ self.assertEqual(address.mailboxes[0].domain,
+ 'example.com')
+ self.assertEqual(address[0].token_type,
+ 'mailbox')
+
+ def test_get_address_complex(self):
+ address = self._test_get_x(parser.get_address,
+ '(foo) "Fred A. Bear" <(bird)dinsdale@example.com>',
+ '(foo) "Fred A. Bear" <(bird)dinsdale@example.com>',
+ ' "Fred A. Bear" < dinsdale@example.com>',
+ [],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 1)
+ self.assertEqual(address.mailboxes,
+ address.all_mailboxes)
+ self.assertEqual(address.mailboxes[0].display_name,
+ 'Fred A. Bear')
+ self.assertEqual(address[0].token_type,
+ 'mailbox')
+
+ def test_get_address_empty_group(self):
+ address = self._test_get_x(parser.get_address,
+ 'Monty Python:;',
+ 'Monty Python:;',
+ 'Monty Python:;',
+ [],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 0)
+ self.assertEqual(address.mailboxes,
+ address.all_mailboxes)
+ self.assertEqual(address[0].token_type,
+ 'group')
+ self.assertEqual(address[0].display_name,
+ 'Monty Python')
+
+ def test_get_address_group(self):
+ address = self._test_get_x(parser.get_address,
+ 'Monty Python: x@example.com, y@example.com;',
+ 'Monty Python: x@example.com, y@example.com;',
+ 'Monty Python: x@example.com, y@example.com;',
+ [],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 2)
+ self.assertEqual(address.mailboxes,
+ address.all_mailboxes)
+ self.assertEqual(address[0].token_type,
+ 'group')
+ self.assertEqual(address[0].display_name,
+ 'Monty Python')
+ self.assertEqual(address.mailboxes[0].local_part, 'x')
+
+ def test_get_address_quoted_local_part(self):
+ address = self._test_get_x(parser.get_address,
+ '"foo bar"@example.com',
+ '"foo bar"@example.com',
+ '"foo bar"@example.com',
+ [],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 1)
+ self.assertEqual(address.mailboxes,
+ address.all_mailboxes)
+ self.assertEqual(address.mailboxes[0].domain,
+ 'example.com')
+ self.assertEqual(address.mailboxes[0].local_part,
+ 'foo bar')
+ self.assertEqual(address[0].token_type, 'mailbox')
+
+ def test_get_address_ends_at_special(self):
+ address = self._test_get_x(parser.get_address,
+ 'dinsdale@example.com, next',
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ [],
+ ', next')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 1)
+ self.assertEqual(address.mailboxes,
+ address.all_mailboxes)
+ self.assertEqual(address.mailboxes[0].domain,
+ 'example.com')
+ self.assertEqual(address[0].token_type, 'mailbox')
+
+ def test_get_address_invalid_mailbox_invalid(self):
+ address = self._test_get_x(parser.get_address,
+ 'ping example.com, next',
+ 'ping example.com',
+ 'ping example.com',
+ [errors.InvalidHeaderDefect, # addr-spec with no domain
+ errors.InvalidHeaderDefect, # invalid local-part
+ errors.InvalidHeaderDefect, # missing .s in local-part
+ ],
+ ', next')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 0)
+ self.assertEqual(len(address.all_mailboxes), 1)
+ self.assertIsNone(address.all_mailboxes[0].domain)
+ self.assertEqual(address.all_mailboxes[0].local_part, 'ping example.com')
+ self.assertEqual(address[0].token_type, 'invalid-mailbox')
+
+ def test_get_address_quoted_strings_in_atom_list(self):
+ address = self._test_get_x(parser.get_address,
+ '""example" example"@example.com',
+ '""example" example"@example.com',
+ 'example example@example.com',
+ [errors.InvalidHeaderDefect]*3,
+ '')
+ self.assertEqual(address.all_mailboxes[0].local_part, 'example example')
+ self.assertEqual(address.all_mailboxes[0].domain, 'example.com')
+ self.assertEqual(address.all_mailboxes[0].addr_spec, '"example example"@example.com')
+
+
+ # get_address_list
+
+ def test_get_address_list_mailboxes_simple(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ 'dinsdale@example.com',
+ [],
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 1)
+ self.assertEqual(address_list.mailboxes,
+ address_list.all_mailboxes)
+ self.assertEqual([str(x) for x in address_list.mailboxes],
+ [str(x) for x in address_list.addresses])
+ self.assertEqual(address_list.mailboxes[0].domain, 'example.com')
+ self.assertEqual(address_list[0].token_type, 'address')
+ self.assertIsNone(address_list[0].display_name)
+
+ def test_get_address_list_mailboxes_two_simple(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ 'foo@example.com, "Fred A. Bar" <bar@example.com>',
+ 'foo@example.com, "Fred A. Bar" <bar@example.com>',
+ 'foo@example.com, "Fred A. Bar" <bar@example.com>',
+ [],
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 2)
+ self.assertEqual(address_list.mailboxes,
+ address_list.all_mailboxes)
+ self.assertEqual([str(x) for x in address_list.mailboxes],
+ [str(x) for x in address_list.addresses])
+ self.assertEqual(address_list.mailboxes[0].local_part, 'foo')
+ self.assertEqual(address_list.mailboxes[1].display_name, "Fred A. Bar")
+
+ def test_get_address_list_mailboxes_complex(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ ('"Roy A. Bear" <dinsdale@example.com>, '
+ '(ping) Foo <x@example.com>,'
+ 'Nobody Is. Special <y@(bird)example.(bad)com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>, '
+ '(ping) Foo <x@example.com>,'
+ 'Nobody Is. Special <y@(bird)example.(bad)com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>, '
+ 'Foo <x@example.com>,'
+ '"Nobody Is. Special" <y@example. com>'),
+ [errors.ObsoleteHeaderDefect, # period in Is.
+ errors.ObsoleteHeaderDefect], # cfws in domain
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 3)
+ self.assertEqual(address_list.mailboxes,
+ address_list.all_mailboxes)
+ self.assertEqual([str(x) for x in address_list.mailboxes],
+ [str(x) for x in address_list.addresses])
+ self.assertEqual(address_list.mailboxes[0].domain, 'example.com')
+ self.assertEqual(address_list.mailboxes[0].token_type, 'mailbox')
+ self.assertEqual(address_list.addresses[0].token_type, 'address')
+ self.assertEqual(address_list.mailboxes[1].local_part, 'x')
+ self.assertEqual(address_list.mailboxes[2].display_name,
+ 'Nobody Is. Special')
+
+ def test_get_address_list_mailboxes_invalid_addresses(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ ('"Roy A. Bear" <dinsdale@example.com>, '
+ '(ping) Foo x@example.com[],'
+ 'Nobody Is. Special <(bird)example.(bad)com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>, '
+ '(ping) Foo x@example.com[],'
+ 'Nobody Is. Special <(bird)example.(bad)com>'),
+ ('"Roy A. Bear" <dinsdale@example.com>, '
+ 'Foo x@example.com[],'
+ '"Nobody Is. Special" < example. com>'),
+ [errors.InvalidHeaderDefect, # invalid address in list
+ errors.InvalidHeaderDefect, # 'Foo x' local part invalid.
+ errors.InvalidHeaderDefect, # Missing . in 'Foo x' local part
+ errors.ObsoleteHeaderDefect, # period in 'Is.' disp-name phrase
+ errors.InvalidHeaderDefect, # no domain part in addr-spec
+ errors.ObsoleteHeaderDefect], # addr-spec has comment in it
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 1)
+ self.assertEqual(len(address_list.all_mailboxes), 3)
+ self.assertEqual([str(x) for x in address_list.all_mailboxes],
+ [str(x) for x in address_list.addresses])
+ self.assertEqual(address_list.mailboxes[0].domain, 'example.com')
+ self.assertEqual(address_list.mailboxes[0].token_type, 'mailbox')
+ self.assertEqual(address_list.addresses[0].token_type, 'address')
+ self.assertEqual(address_list.addresses[1].token_type, 'address')
+ self.assertEqual(len(address_list.addresses[0].mailboxes), 1)
+ self.assertEqual(len(address_list.addresses[1].mailboxes), 0)
+ self.assertEqual(len(address_list.addresses[1].mailboxes), 0)
+ self.assertEqual(
+ address_list.addresses[1].all_mailboxes[0].local_part, 'Foo x')
+ self.assertEqual(
+ address_list.addresses[2].all_mailboxes[0].display_name,
+ "Nobody Is. Special")
+
+ def test_get_address_list_group_empty(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ 'Monty Python: ;',
+ 'Monty Python: ;',
+ 'Monty Python: ;',
+ [],
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 0)
+ self.assertEqual(address_list.mailboxes,
+ address_list.all_mailboxes)
+ self.assertEqual(len(address_list.addresses), 1)
+ self.assertEqual(address_list.addresses[0].token_type, 'address')
+ self.assertEqual(address_list.addresses[0].display_name, 'Monty Python')
+ self.assertEqual(len(address_list.addresses[0].mailboxes), 0)
+
+ def test_get_address_list_group_simple(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ 'Monty Python: dinsdale@example.com;',
+ 'Monty Python: dinsdale@example.com;',
+ 'Monty Python: dinsdale@example.com;',
+ [],
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 1)
+ self.assertEqual(address_list.mailboxes,
+ address_list.all_mailboxes)
+ self.assertEqual(address_list.mailboxes[0].domain, 'example.com')
+ self.assertEqual(address_list.addresses[0].display_name,
+ 'Monty Python')
+ self.assertEqual(address_list.addresses[0].mailboxes[0].domain,
+ 'example.com')
+
+ def test_get_address_list_group_and_mailboxes(self):
+ address_list = self._test_get_x(parser.get_address_list,
+ ('Monty Python: dinsdale@example.com, "Fred" <flint@example.com>;, '
+ 'Abe <x@example.com>, Bee <y@example.com>'),
+ ('Monty Python: dinsdale@example.com, "Fred" <flint@example.com>;, '
+ 'Abe <x@example.com>, Bee <y@example.com>'),
+ ('Monty Python: dinsdale@example.com, "Fred" <flint@example.com>;, '
+ 'Abe <x@example.com>, Bee <y@example.com>'),
+ [],
+ '')
+ self.assertEqual(address_list.token_type, 'address-list')
+ self.assertEqual(len(address_list.mailboxes), 4)
+ self.assertEqual(address_list.mailboxes,
+ address_list.all_mailboxes)
+ self.assertEqual(len(address_list.addresses), 3)
+ self.assertEqual(address_list.mailboxes[0].local_part, 'dinsdale')
+ self.assertEqual(address_list.addresses[0].display_name,
+ 'Monty Python')
+ self.assertEqual(address_list.addresses[0].mailboxes[0].domain,
+ 'example.com')
+ self.assertEqual(address_list.addresses[0].mailboxes[1].local_part,
+ 'flint')
+ self.assertEqual(address_list.addresses[1].mailboxes[0].local_part,
+ 'x')
+ self.assertEqual(address_list.addresses[2].mailboxes[0].local_part,
+ 'y')
+ self.assertEqual(str(address_list.addresses[1]),
+ str(address_list.mailboxes[2]))
+
+
+@parameterize
+class Test_parse_mime_version(TestParserMixin, TestEmailBase):
+
+ def mime_version_as_value(self,
+ value,
+ tl_str,
+ tl_value,
+ major,
+ minor,
+ defects):
+ mime_version = self._test_parse_x(parser.parse_mime_version,
+ value, tl_str, tl_value, defects)
+ self.assertEqual(mime_version.major, major)
+ self.assertEqual(mime_version.minor, minor)
+
+ mime_version_params = {
+
+ 'rfc_2045_1': (
+ '1.0',
+ '1.0',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ 'RFC_2045_2': (
+ '1.0 (produced by MetaSend Vx.x)',
+ '1.0 (produced by MetaSend Vx.x)',
+ '1.0 ',
+ 1,
+ 0,
+ []),
+
+ 'RFC_2045_3': (
+ '(produced by MetaSend Vx.x) 1.0',
+ '(produced by MetaSend Vx.x) 1.0',
+ ' 1.0',
+ 1,
+ 0,
+ []),
+
+ 'RFC_2045_4': (
+ '1.(produced by MetaSend Vx.x)0',
+ '1.(produced by MetaSend Vx.x)0',
+ '1. 0',
+ 1,
+ 0,
+ []),
+
+ 'empty': (
+ '',
+ '',
+ '',
+ None,
+ None,
+ [errors.HeaderMissingRequiredValue]),
+
+ }
+
+
+
+class TestFolding(TestEmailBase):
+
+ policy = policy.default
+
+ def _test(self, tl, folded, policy=policy):
+ self.assertEqual(tl.fold(policy=policy), folded, tl.ppstr())
+
+ def test_simple_unstructured_no_folds(self):
+ self._test(parser.get_unstructured("This is a test"),
+ "This is a test\n")
+
+ def test_simple_unstructured_folded(self):
+ self._test(parser.get_unstructured("This is also a test, but this "
+ "time there are enough words (and even some "
+ "symbols) to make it wrap; at least in theory."),
+ "This is also a test, but this time there are enough "
+ "words (and even some\n"
+ " symbols) to make it wrap; at least in theory.\n")
+
+ def test_unstructured_with_unicode_no_folds(self):
+ self._test(parser.get_unstructured("hübsch kleiner beißt"),
+ "=?utf-8?q?h=C3=BCbsch_kleiner_bei=C3=9Ft?=\n")
+
+ def test_one_ew_on_each_of_two_wrapped_lines(self):
+ self._test(parser.get_unstructured("Mein kleiner Kaktus ist sehr "
+ "hübsch. Es hat viele Stacheln "
+ "und oft beißt mich."),
+ "Mein kleiner Kaktus ist sehr =?utf-8?q?h=C3=BCbsch=2E?= "
+ "Es hat viele Stacheln\n"
+ " und oft =?utf-8?q?bei=C3=9Ft?= mich.\n")
+
+ def test_ews_combined_before_wrap(self):
+ self._test(parser.get_unstructured("Mein Kaktus ist hübsch. "
+ "Es beißt mich. "
+ "And that's all I'm sayin."),
+ "Mein Kaktus ist =?utf-8?q?h=C3=BCbsch=2E__Es_bei=C3=9Ft?= "
+ "mich. And that's\n"
+ " all I'm sayin.\n")
+
+ # XXX Need test of an encoded word so long that it needs to be wrapped
+
+ def test_simple_address(self):
+ self._test(parser.get_address_list("abc <xyz@example.com>")[0],
+ "abc <xyz@example.com>\n")
+
+ def test_address_list_folding_at_commas(self):
+ self._test(parser.get_address_list('abc <xyz@example.com>, '
+ '"Fred Blunt" <sharp@example.com>, '
+ '"J.P.Cool" <hot@example.com>, '
+ '"K<>y" <key@example.com>, '
+ 'Firesale <cheap@example.com>, '
+ '<end@example.com>')[0],
+ 'abc <xyz@example.com>, "Fred Blunt" <sharp@example.com>,\n'
+ ' "J.P.Cool" <hot@example.com>, "K<>y" <key@example.com>,\n'
+ ' Firesale <cheap@example.com>, <end@example.com>\n')
+
+ def test_address_list_with_unicode_names(self):
+ self._test(parser.get_address_list(
+ 'Hübsch Kaktus <beautiful@example.com>, '
+ 'beißt beißt <biter@example.com>')[0],
+ '=?utf-8?q?H=C3=BCbsch?= Kaktus <beautiful@example.com>,\n'
+ ' =?utf-8?q?bei=C3=9Ft_bei=C3=9Ft?= <biter@example.com>\n')
+
+ def test_address_list_with_unicode_names_in_quotes(self):
+ self._test(parser.get_address_list(
+ '"Hübsch Kaktus" <beautiful@example.com>, '
+ '"beißt" beißt <biter@example.com>')[0],
+ '=?utf-8?q?H=C3=BCbsch?= Kaktus <beautiful@example.com>,\n'
+ ' =?utf-8?q?bei=C3=9Ft_bei=C3=9Ft?= <biter@example.com>\n')
+
+ # XXX Need tests with comments on various sides of a unicode token,
+ # and with unicode tokens in the comments. Spaces inside the quotes
+ # currently don't do the right thing.
+
+ def test_initial_whitespace_splitting(self):
+ body = parser.get_unstructured(' ' + 'x'*77)
+ header = parser.Header([
+ parser.HeaderLabel([parser.ValueTerminal('test:', 'atext')]),
+ parser.CFWSList([parser.WhiteSpaceTerminal(' ', 'fws')]), body])
+ self._test(header, 'test: \n ' + 'x'*77 + '\n')
+
+ def test_whitespace_splitting(self):
+ self._test(parser.get_unstructured('xxx ' + 'y'*77),
+ 'xxx \n ' + 'y'*77 + '\n')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_asian_codecs.py b/Lib/test/test_email/test_asian_codecs.py
new file mode 100644
index 0000000..089269f
--- /dev/null
+++ b/Lib/test/test_email/test_asian_codecs.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2002-2006 Python Software Foundation
+# Contact: email-sig@python.org
+# email package unit tests for (optional) Asian codecs
+
+import unittest
+from test.support import run_unittest
+
+from test.test_email.test_email import TestEmailBase
+from email.charset import Charset
+from email.header import Header, decode_header
+from email.message import Message
+
+# We're compatible with Python 2.3, but it doesn't have the built-in Asian
+# codecs, so we have to skip all these tests.
+try:
+ str(b'foo', 'euc-jp')
+except LookupError:
+ raise unittest.SkipTest
+
+
+
+class TestEmailAsianCodecs(TestEmailBase):
+ def test_japanese_codecs(self):
+ eq = self.ndiffAssertEqual
+ jcode = "euc-jp"
+ gcode = "iso-8859-1"
+ j = Charset(jcode)
+ g = Charset(gcode)
+ h = Header("Hello World!")
+ jhello = str(b'\xa5\xcf\xa5\xed\xa1\xbc\xa5\xef\xa1\xbc'
+ b'\xa5\xeb\xa5\xc9\xa1\xaa', jcode)
+ ghello = str(b'Gr\xfc\xdf Gott!', gcode)
+ h.append(jhello, j)
+ h.append(ghello, g)
+ # BAW: This used to -- and maybe should -- fold the two iso-8859-1
+ # chunks into a single encoded word. However it doesn't violate the
+ # standard to have them as two encoded chunks and maybe it's
+ # reasonable <wink> for each .append() call to result in a separate
+ # encoded word.
+ eq(h.encode(), """\
+Hello World! =?iso-2022-jp?b?GyRCJU8lbSE8JW8hPCVrJUkhKhsoQg==?=
+ =?iso-8859-1?q?Gr=FC=DF_Gott!?=""")
+ eq(decode_header(h.encode()),
+ [(b'Hello World! ', None),
+ (b'\x1b$B%O%m!<%o!<%k%I!*\x1b(B', 'iso-2022-jp'),
+ (b'Gr\xfc\xdf Gott!', gcode)])
+ subject_bytes = (b'test-ja \xa4\xd8\xc5\xea\xb9\xc6\xa4\xb5'
+ b'\xa4\xec\xa4\xbf\xa5\xe1\xa1\xbc\xa5\xeb\xa4\xcf\xbb\xca\xb2'
+ b'\xf1\xbc\xd4\xa4\xce\xbe\xb5\xc7\xa7\xa4\xf2\xc2\xd4\xa4\xc3'
+ b'\xa4\xc6\xa4\xa4\xa4\xde\xa4\xb9')
+ subject = str(subject_bytes, jcode)
+ h = Header(subject, j, header_name="Subject")
+ # test a very long header
+ enc = h.encode()
+ # TK: splitting point may differ by codec design and/or Header encoding
+ eq(enc , """\
+=?iso-2022-jp?b?dGVzdC1qYSAbJEIkWEVqOUYkNSRsJD8lYSE8JWskTztKGyhC?=
+ =?iso-2022-jp?b?GyRCMnE8VCROPjVHJyRyQlQkQyRGJCQkXiQ5GyhC?=""")
+ # TK: full decode comparison
+ eq(str(h).encode(jcode), subject_bytes)
+
+ def test_payload_encoding_utf8(self):
+ jhello = str(b'\xa5\xcf\xa5\xed\xa1\xbc\xa5\xef\xa1\xbc'
+ b'\xa5\xeb\xa5\xc9\xa1\xaa', 'euc-jp')
+ msg = Message()
+ msg.set_payload(jhello, 'utf-8')
+ ustr = msg.get_payload(decode=True).decode(msg.get_content_charset())
+ self.assertEqual(jhello, ustr)
+
+ def test_payload_encoding(self):
+ jcode = 'euc-jp'
+ jhello = str(b'\xa5\xcf\xa5\xed\xa1\xbc\xa5\xef\xa1\xbc'
+ b'\xa5\xeb\xa5\xc9\xa1\xaa', jcode)
+ msg = Message()
+ msg.set_payload(jhello, jcode)
+ ustr = msg.get_payload(decode=True).decode(msg.get_content_charset())
+ self.assertEqual(jhello, ustr)
+
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_defect_handling.py b/Lib/test/test_email/test_defect_handling.py
new file mode 100644
index 0000000..305432d
--- /dev/null
+++ b/Lib/test/test_email/test_defect_handling.py
@@ -0,0 +1,320 @@
+import textwrap
+import unittest
+import contextlib
+from email import policy
+from email import errors
+from test.test_email import TestEmailBase
+
+
+class TestDefectsBase:
+
+ policy = policy.default
+ raise_expected = False
+
+ @contextlib.contextmanager
+ def _raise_point(self, defect):
+ yield
+
+ def test_same_boundary_inner_outer(self):
+ source = textwrap.dedent("""\
+ Subject: XX
+ From: xx@xx.dk
+ To: XX
+ Mime-version: 1.0
+ Content-type: multipart/mixed;
+ boundary="MS_Mac_OE_3071477847_720252_MIME_Part"
+
+ --MS_Mac_OE_3071477847_720252_MIME_Part
+ Content-type: multipart/alternative;
+ boundary="MS_Mac_OE_3071477847_720252_MIME_Part"
+
+ --MS_Mac_OE_3071477847_720252_MIME_Part
+ Content-type: text/plain; charset="ISO-8859-1"
+ Content-transfer-encoding: quoted-printable
+
+ text
+
+ --MS_Mac_OE_3071477847_720252_MIME_Part
+ Content-type: text/html; charset="ISO-8859-1"
+ Content-transfer-encoding: quoted-printable
+
+ <HTML></HTML>
+
+ --MS_Mac_OE_3071477847_720252_MIME_Part--
+
+ --MS_Mac_OE_3071477847_720252_MIME_Part
+ Content-type: image/gif; name="xx.gif";
+ Content-disposition: attachment
+ Content-transfer-encoding: base64
+
+ Some removed base64 encoded chars.
+
+ --MS_Mac_OE_3071477847_720252_MIME_Part--
+
+ """)
+ # XXX better would be to actually detect the duplicate.
+ with self._raise_point(errors.StartBoundaryNotFoundDefect):
+ msg = self._str_msg(source)
+ if self.raise_expected: return
+ inner = msg.get_payload(0)
+ self.assertTrue(hasattr(inner, 'defects'))
+ self.assertEqual(len(self.get_defects(inner)), 1)
+ self.assertTrue(isinstance(self.get_defects(inner)[0],
+ errors.StartBoundaryNotFoundDefect))
+
+ def test_multipart_no_boundary(self):
+ source = textwrap.dedent("""\
+ Date: Fri, 6 Apr 2001 09:23:06 -0800 (GMT-0800)
+ From: foobar
+ Subject: broken mail
+ MIME-Version: 1.0
+ Content-Type: multipart/report; report-type=delivery-status;
+
+ --JAB03225.986577786/zinfandel.lacita.com
+
+ One part
+
+ --JAB03225.986577786/zinfandel.lacita.com
+ Content-Type: message/delivery-status
+
+ Header: Another part
+
+ --JAB03225.986577786/zinfandel.lacita.com--
+ """)
+ with self._raise_point(errors.NoBoundaryInMultipartDefect):
+ msg = self._str_msg(source)
+ if self.raise_expected: return
+ self.assertTrue(isinstance(msg.get_payload(), str))
+ self.assertEqual(len(self.get_defects(msg)), 2)
+ self.assertTrue(isinstance(self.get_defects(msg)[0],
+ errors.NoBoundaryInMultipartDefect))
+ self.assertTrue(isinstance(self.get_defects(msg)[1],
+ errors.MultipartInvariantViolationDefect))
+
+ multipart_msg = textwrap.dedent("""\
+ Date: Wed, 14 Nov 2007 12:56:23 GMT
+ From: foo@bar.invalid
+ To: foo@bar.invalid
+ Subject: Content-Transfer-Encoding: base64 and multipart
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed;
+ boundary="===============3344438784458119861=="{}
+
+ --===============3344438784458119861==
+ Content-Type: text/plain
+
+ Test message
+
+ --===============3344438784458119861==
+ Content-Type: application/octet-stream
+ Content-Transfer-Encoding: base64
+
+ YWJj
+
+ --===============3344438784458119861==--
+ """)
+
+ def test_multipart_invalid_cte(self):
+ with self._raise_point(
+ errors.InvalidMultipartContentTransferEncodingDefect):
+ msg = self._str_msg(
+ self.multipart_msg.format(
+ "\nContent-Transfer-Encoding: base64"))
+ if self.raise_expected: return
+ self.assertEqual(len(self.get_defects(msg)), 1)
+ self.assertIsInstance(self.get_defects(msg)[0],
+ errors.InvalidMultipartContentTransferEncodingDefect)
+
+ def test_multipart_no_cte_no_defect(self):
+ if self.raise_expected: return
+ msg = self._str_msg(self.multipart_msg.format(''))
+ self.assertEqual(len(self.get_defects(msg)), 0)
+
+ def test_multipart_valid_cte_no_defect(self):
+ if self.raise_expected: return
+ for cte in ('7bit', '8bit', 'BINary'):
+ msg = self._str_msg(
+ self.multipart_msg.format("\nContent-Transfer-Encoding: "+cte))
+ self.assertEqual(len(self.get_defects(msg)), 0, "cte="+cte)
+
+ def test_lying_multipart(self):
+ source = textwrap.dedent("""\
+ From: "Allison Dunlap" <xxx@example.com>
+ To: yyy@example.com
+ Subject: 64423
+ Date: Sun, 11 Jul 2004 16:09:27 -0300
+ MIME-Version: 1.0
+ Content-Type: multipart/alternative;
+
+ Blah blah blah
+ """)
+ with self._raise_point(errors.NoBoundaryInMultipartDefect):
+ msg = self._str_msg(source)
+ if self.raise_expected: return
+ self.assertTrue(hasattr(msg, 'defects'))
+ self.assertEqual(len(self.get_defects(msg)), 2)
+ self.assertTrue(isinstance(self.get_defects(msg)[0],
+ errors.NoBoundaryInMultipartDefect))
+ self.assertTrue(isinstance(self.get_defects(msg)[1],
+ errors.MultipartInvariantViolationDefect))
+
+ def test_missing_start_boundary(self):
+ source = textwrap.dedent("""\
+ Content-Type: multipart/mixed; boundary="AAA"
+ From: Mail Delivery Subsystem <xxx@example.com>
+ To: yyy@example.com
+
+ --AAA
+
+ Stuff
+
+ --AAA
+ Content-Type: message/rfc822
+
+ From: webmaster@python.org
+ To: zzz@example.com
+ Content-Type: multipart/mixed; boundary="BBB"
+
+ --BBB--
+
+ --AAA--
+
+ """)
+ # The message structure is:
+ #
+ # multipart/mixed
+ # text/plain
+ # message/rfc822
+ # multipart/mixed [*]
+ #
+ # [*] This message is missing its start boundary
+ with self._raise_point(errors.StartBoundaryNotFoundDefect):
+ outer = self._str_msg(source)
+ if self.raise_expected: return
+ bad = outer.get_payload(1).get_payload(0)
+ self.assertEqual(len(self.get_defects(bad)), 1)
+ self.assertTrue(isinstance(self.get_defects(bad)[0],
+ errors.StartBoundaryNotFoundDefect))
+
+ def test_first_line_is_continuation_header(self):
+ with self._raise_point(errors.FirstHeaderLineIsContinuationDefect):
+ msg = self._str_msg(' Line 1\nSubject: test\n\nbody')
+ if self.raise_expected: return
+ self.assertEqual(msg.keys(), ['Subject'])
+ self.assertEqual(msg.get_payload(), 'body')
+ self.assertEqual(len(self.get_defects(msg)), 1)
+ self.assertDefectsEqual(self.get_defects(msg),
+ [errors.FirstHeaderLineIsContinuationDefect])
+ self.assertEqual(self.get_defects(msg)[0].line, ' Line 1\n')
+
+ def test_missing_header_body_separator(self):
+ # Our heuristic if we see a line that doesn't look like a header (no
+ # leading whitespace but no ':') is to assume that the blank line that
+ # separates the header from the body is missing, and to stop parsing
+ # headers and start parsing the body.
+ with self._raise_point(errors.MissingHeaderBodySeparatorDefect):
+ msg = self._str_msg('Subject: test\nnot a header\nTo: abc\n\nb\n')
+ if self.raise_expected: return
+ self.assertEqual(msg.keys(), ['Subject'])
+ self.assertEqual(msg.get_payload(), 'not a header\nTo: abc\n\nb\n')
+ self.assertDefectsEqual(self.get_defects(msg),
+ [errors.MissingHeaderBodySeparatorDefect])
+
+ def test_bad_padding_in_base64_payload(self):
+ source = textwrap.dedent("""\
+ Subject: test
+ MIME-Version: 1.0
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: base64
+
+ dmk
+ """)
+ msg = self._str_msg(source)
+ with self._raise_point(errors.InvalidBase64PaddingDefect):
+ payload = msg.get_payload(decode=True)
+ if self.raise_expected: return
+ self.assertEqual(payload, b'vi')
+ self.assertDefectsEqual(self.get_defects(msg),
+ [errors.InvalidBase64PaddingDefect])
+
+ def test_invalid_chars_in_base64_payload(self):
+ source = textwrap.dedent("""\
+ Subject: test
+ MIME-Version: 1.0
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: base64
+
+ dm\x01k===
+ """)
+ msg = self._str_msg(source)
+ with self._raise_point(errors.InvalidBase64CharactersDefect):
+ payload = msg.get_payload(decode=True)
+ if self.raise_expected: return
+ self.assertEqual(payload, b'vi')
+ self.assertDefectsEqual(self.get_defects(msg),
+ [errors.InvalidBase64CharactersDefect])
+
+ def test_missing_ending_boundary(self):
+ source = textwrap.dedent("""\
+ To: 1@harrydomain4.com
+ Subject: Fwd: 1
+ MIME-Version: 1.0
+ Content-Type: multipart/alternative;
+ boundary="------------000101020201080900040301"
+
+ --------------000101020201080900040301
+ Content-Type: text/plain; charset=ISO-8859-1
+ Content-Transfer-Encoding: 7bit
+
+ Alternative 1
+
+ --------------000101020201080900040301
+ Content-Type: text/html; charset=ISO-8859-1
+ Content-Transfer-Encoding: 7bit
+
+ Alternative 2
+
+ """)
+ with self._raise_point(errors.CloseBoundaryNotFoundDefect):
+ msg = self._str_msg(source)
+ if self.raise_expected: return
+ self.assertEqual(len(msg.get_payload()), 2)
+ self.assertEqual(msg.get_payload(1).get_payload(), 'Alternative 2\n')
+ self.assertDefectsEqual(self.get_defects(msg),
+ [errors.CloseBoundaryNotFoundDefect])
+
+
+class TestDefectDetection(TestDefectsBase, TestEmailBase):
+
+ def get_defects(self, obj):
+ return obj.defects
+
+
+class TestDefectCapture(TestDefectsBase, TestEmailBase):
+
+ class CapturePolicy(policy.EmailPolicy):
+ captured = None
+ def register_defect(self, obj, defect):
+ self.captured.append(defect)
+
+ def setUp(self):
+ self.policy = self.CapturePolicy(captured=list())
+
+ def get_defects(self, obj):
+ return self.policy.captured
+
+
+class TestDefectRaising(TestDefectsBase, TestEmailBase):
+
+ policy = TestDefectsBase.policy
+ policy = policy.clone(raise_on_defect=True)
+ raise_expected = True
+
+ @contextlib.contextmanager
+ def _raise_point(self, defect):
+ with self.assertRaises(defect):
+ yield
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py
new file mode 100644
index 0000000..593d27b
--- /dev/null
+++ b/Lib/test/test_email/test_email.py
@@ -0,0 +1,5009 @@
+# Copyright (C) 2001-2010 Python Software Foundation
+# Contact: email-sig@python.org
+# email package unit tests
+
+import os
+import re
+import sys
+import time
+import base64
+import difflib
+import unittest
+import warnings
+import textwrap
+
+from io import StringIO, BytesIO
+from itertools import chain
+
+import email
+import email.policy
+
+from email.charset import Charset
+from email.header import Header, decode_header, make_header
+from email.parser import Parser, HeaderParser
+from email.generator import Generator, DecodedGenerator, BytesGenerator
+from email.message import Message
+from email.mime.application import MIMEApplication
+from email.mime.audio import MIMEAudio
+from email.mime.text import MIMEText
+from email.mime.image import MIMEImage
+from email.mime.base import MIMEBase
+from email.mime.message import MIMEMessage
+from email.mime.multipart import MIMEMultipart
+from email import utils
+from email import errors
+from email import encoders
+from email import iterators
+from email import base64mime
+from email import quoprimime
+
+from test.support import run_unittest, unlink
+from test.test_email import openfile, TestEmailBase
+
+NL = '\n'
+EMPTYSTRING = ''
+SPACE = ' '
+
+
+# Test various aspects of the Message class's API
+class TestMessageAPI(TestEmailBase):
+ def test_get_all(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_20.txt')
+ eq(msg.get_all('cc'), ['ccc@zzz.org', 'ddd@zzz.org', 'eee@zzz.org'])
+ eq(msg.get_all('xx', 'n/a'), 'n/a')
+
+ def test_getset_charset(self):
+ eq = self.assertEqual
+ msg = Message()
+ eq(msg.get_charset(), None)
+ charset = Charset('iso-8859-1')
+ msg.set_charset(charset)
+ eq(msg['mime-version'], '1.0')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg['content-type'], 'text/plain; charset="iso-8859-1"')
+ eq(msg.get_param('charset'), 'iso-8859-1')
+ eq(msg['content-transfer-encoding'], 'quoted-printable')
+ eq(msg.get_charset().input_charset, 'iso-8859-1')
+ # Remove the charset
+ msg.set_charset(None)
+ eq(msg.get_charset(), None)
+ eq(msg['content-type'], 'text/plain')
+ # Try adding a charset when there's already MIME headers present
+ msg = Message()
+ msg['MIME-Version'] = '2.0'
+ msg['Content-Type'] = 'text/x-weird'
+ msg['Content-Transfer-Encoding'] = 'quinted-puntable'
+ msg.set_charset(charset)
+ eq(msg['mime-version'], '2.0')
+ eq(msg['content-type'], 'text/x-weird; charset="iso-8859-1"')
+ eq(msg['content-transfer-encoding'], 'quinted-puntable')
+
+ def test_set_charset_from_string(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_charset('us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_set_payload_with_charset(self):
+ msg = Message()
+ charset = Charset('iso-8859-1')
+ msg.set_payload('This is a string payload', charset)
+ self.assertEqual(msg.get_charset().input_charset, 'iso-8859-1')
+
+ def test_get_charsets(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_08.txt')
+ charsets = msg.get_charsets()
+ eq(charsets, [None, 'us-ascii', 'iso-8859-1', 'iso-8859-2', 'koi8-r'])
+
+ msg = self._msgobj('msg_09.txt')
+ charsets = msg.get_charsets('dingbat')
+ eq(charsets, ['dingbat', 'us-ascii', 'iso-8859-1', 'dingbat',
+ 'koi8-r'])
+
+ msg = self._msgobj('msg_12.txt')
+ charsets = msg.get_charsets()
+ eq(charsets, [None, 'us-ascii', 'iso-8859-1', None, 'iso-8859-2',
+ 'iso-8859-3', 'us-ascii', 'koi8-r'])
+
+ def test_get_filename(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_04.txt')
+ filenames = [p.get_filename() for p in msg.get_payload()]
+ eq(filenames, ['msg.txt', 'msg.txt'])
+
+ msg = self._msgobj('msg_07.txt')
+ subpart = msg.get_payload(1)
+ eq(subpart.get_filename(), 'dingusfish.gif')
+
+ def test_get_filename_with_name_parameter(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_44.txt')
+ filenames = [p.get_filename() for p in msg.get_payload()]
+ eq(filenames, ['msg.txt', 'msg.txt'])
+
+ def test_get_boundary(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_07.txt')
+ # No quotes!
+ eq(msg.get_boundary(), 'BOUNDARY')
+
+ def test_set_boundary(self):
+ eq = self.assertEqual
+ # This one has no existing boundary parameter, but the Content-Type:
+ # header appears fifth.
+ msg = self._msgobj('msg_01.txt')
+ msg.set_boundary('BOUNDARY')
+ header, value = msg.items()[4]
+ eq(header.lower(), 'content-type')
+ eq(value, 'text/plain; charset="us-ascii"; boundary="BOUNDARY"')
+ # This one has a Content-Type: header, with a boundary, stuck in the
+ # middle of its headers. Make sure the order is preserved; it should
+ # be fifth.
+ msg = self._msgobj('msg_04.txt')
+ msg.set_boundary('BOUNDARY')
+ header, value = msg.items()[4]
+ eq(header.lower(), 'content-type')
+ eq(value, 'multipart/mixed; boundary="BOUNDARY"')
+ # And this one has no Content-Type: header at all.
+ msg = self._msgobj('msg_03.txt')
+ self.assertRaises(errors.HeaderParseError,
+ msg.set_boundary, 'BOUNDARY')
+
+ def test_make_boundary(self):
+ msg = MIMEMultipart('form-data')
+ # Note that when the boundary gets created is an implementation
+ # detail and might change.
+ self.assertEqual(msg.items()[0][1], 'multipart/form-data')
+ # Trigger creation of boundary
+ msg.as_string()
+ self.assertEqual(msg.items()[0][1][:33],
+ 'multipart/form-data; boundary="==')
+ # XXX: there ought to be tests of the uniqueness of the boundary, too.
+
+ def test_message_rfc822_only(self):
+ # Issue 7970: message/rfc822 not in multipart parsed by
+ # HeaderParser caused an exception when flattened.
+ with openfile('msg_46.txt') as fp:
+ msgdata = fp.read()
+ parser = HeaderParser()
+ msg = parser.parsestr(msgdata)
+ out = StringIO()
+ gen = Generator(out, True, 0)
+ gen.flatten(msg, False)
+ self.assertEqual(out.getvalue(), msgdata)
+
+ def test_byte_message_rfc822_only(self):
+ # Make sure new bytes header parser also passes this.
+ with openfile('msg_46.txt', 'rb') as fp:
+ msgdata = fp.read()
+ parser = email.parser.BytesHeaderParser()
+ msg = parser.parsebytes(msgdata)
+ out = BytesIO()
+ gen = email.generator.BytesGenerator(out)
+ gen.flatten(msg)
+ self.assertEqual(out.getvalue(), msgdata)
+
+ def test_get_decoded_payload(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_10.txt')
+ # The outer message is a multipart
+ eq(msg.get_payload(decode=True), None)
+ # Subpart 1 is 7bit encoded
+ eq(msg.get_payload(0).get_payload(decode=True),
+ b'This is a 7bit encoded message.\n')
+ # Subpart 2 is quopri
+ eq(msg.get_payload(1).get_payload(decode=True),
+ b'\xa1This is a Quoted Printable encoded message!\n')
+ # Subpart 3 is base64
+ eq(msg.get_payload(2).get_payload(decode=True),
+ b'This is a Base64 encoded message.')
+ # Subpart 4 is base64 with a trailing newline, which
+ # used to be stripped (issue 7143).
+ eq(msg.get_payload(3).get_payload(decode=True),
+ b'This is a Base64 encoded message.\n')
+ # Subpart 5 has no Content-Transfer-Encoding: header.
+ eq(msg.get_payload(4).get_payload(decode=True),
+ b'This has no Content-Transfer-Encoding: header.\n')
+
+ def test_get_decoded_uu_payload(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_payload('begin 666 -\n+:&5L;&\\@=V]R;&0 \n \nend\n')
+ for cte in ('x-uuencode', 'uuencode', 'uue', 'x-uue'):
+ msg['content-transfer-encoding'] = cte
+ eq(msg.get_payload(decode=True), b'hello world')
+ # Now try some bogus data
+ msg.set_payload('foo')
+ eq(msg.get_payload(decode=True), b'foo')
+
+ def test_get_payload_n_raises_on_non_multipart(self):
+ msg = Message()
+ self.assertRaises(TypeError, msg.get_payload, 1)
+
+ def test_decoded_generator(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_07.txt')
+ with openfile('msg_17.txt') as fp:
+ text = fp.read()
+ s = StringIO()
+ g = DecodedGenerator(s)
+ g.flatten(msg)
+ eq(s.getvalue(), text)
+
+ def test__contains__(self):
+ msg = Message()
+ msg['From'] = 'Me'
+ msg['to'] = 'You'
+ # Check for case insensitivity
+ self.assertTrue('from' in msg)
+ self.assertTrue('From' in msg)
+ self.assertTrue('FROM' in msg)
+ self.assertTrue('to' in msg)
+ self.assertTrue('To' in msg)
+ self.assertTrue('TO' in msg)
+
+ def test_as_string(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_01.txt')
+ with openfile('msg_01.txt') as fp:
+ text = fp.read()
+ eq(text, str(msg))
+ fullrepr = msg.as_string(unixfrom=True)
+ lines = fullrepr.split('\n')
+ self.assertTrue(lines[0].startswith('From '))
+ eq(text, NL.join(lines[1:]))
+
+ # test_headerregistry.TestContentTypeHeader.bad_params
+ def test_bad_param(self):
+ msg = email.message_from_string("Content-Type: blarg; baz; boo\n")
+ self.assertEqual(msg.get_param('baz'), '')
+
+ def test_missing_filename(self):
+ msg = email.message_from_string("From: foo\n")
+ self.assertEqual(msg.get_filename(), None)
+
+ def test_bogus_filename(self):
+ msg = email.message_from_string(
+ "Content-Disposition: blarg; filename\n")
+ self.assertEqual(msg.get_filename(), '')
+
+ def test_missing_boundary(self):
+ msg = email.message_from_string("From: foo\n")
+ self.assertEqual(msg.get_boundary(), None)
+
+ def test_get_params(self):
+ eq = self.assertEqual
+ msg = email.message_from_string(
+ 'X-Header: foo=one; bar=two; baz=three\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', 'one'), ('bar', 'two'), ('baz', 'three')])
+ msg = email.message_from_string(
+ 'X-Header: foo; bar=one; baz=two\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', ''), ('bar', 'one'), ('baz', 'two')])
+ eq(msg.get_params(), None)
+ msg = email.message_from_string(
+ 'X-Header: foo; bar="one"; baz=two\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', ''), ('bar', 'one'), ('baz', 'two')])
+
+ # test_headerregistry.TestContentTypeHeader.spaces_around_param_equals
+ def test_get_param_liberal(self):
+ msg = Message()
+ msg['Content-Type'] = 'Content-Type: Multipart/mixed; boundary = "CPIMSSMTPC06p5f3tG"'
+ self.assertEqual(msg.get_param('boundary'), 'CPIMSSMTPC06p5f3tG')
+
+ def test_get_param(self):
+ eq = self.assertEqual
+ msg = email.message_from_string(
+ "X-Header: foo=one; bar=two; baz=three\n")
+ eq(msg.get_param('bar', header='x-header'), 'two')
+ eq(msg.get_param('quuz', header='x-header'), None)
+ eq(msg.get_param('quuz'), None)
+ msg = email.message_from_string(
+ 'X-Header: foo; bar="one"; baz=two\n')
+ eq(msg.get_param('foo', header='x-header'), '')
+ eq(msg.get_param('bar', header='x-header'), 'one')
+ eq(msg.get_param('baz', header='x-header'), 'two')
+ # XXX: We are not RFC-2045 compliant! We cannot parse:
+ # msg["Content-Type"] = 'text/plain; weird="hey; dolly? [you] @ <\\"home\\">?"'
+ # msg.get_param("weird")
+ # yet.
+
+ # test_headerregistry.TestContentTypeHeader.spaces_around_semis
+ def test_get_param_funky_continuation_lines(self):
+ msg = self._msgobj('msg_22.txt')
+ self.assertEqual(msg.get_payload(1).get_param('name'), 'wibble.JPG')
+
+ # test_headerregistry.TestContentTypeHeader.semis_inside_quotes
+ def test_get_param_with_semis_in_quotes(self):
+ msg = email.message_from_string(
+ 'Content-Type: image/pjpeg; name="Jim&amp;&amp;Jill"\n')
+ self.assertEqual(msg.get_param('name'), 'Jim&amp;&amp;Jill')
+ self.assertEqual(msg.get_param('name', unquote=False),
+ '"Jim&amp;&amp;Jill"')
+
+ # test_headerregistry.TestContentTypeHeader.quotes_inside_rfc2231_value
+ def test_get_param_with_quotes(self):
+ msg = email.message_from_string(
+ 'Content-Type: foo; bar*0="baz\\"foobar"; bar*1="\\"baz"')
+ self.assertEqual(msg.get_param('bar'), 'baz"foobar"baz')
+ msg = email.message_from_string(
+ "Content-Type: foo; bar*0=\"baz\\\"foobar\"; bar*1=\"\\\"baz\"")
+ self.assertEqual(msg.get_param('bar'), 'baz"foobar"baz')
+
+ def test_field_containment(self):
+ unless = self.assertTrue
+ msg = email.message_from_string('Header: exists')
+ unless('header' in msg)
+ unless('Header' in msg)
+ unless('HEADER' in msg)
+ self.assertFalse('headerx' in msg)
+
+ def test_set_param(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_param('charset', 'iso-2022-jp')
+ eq(msg.get_param('charset'), 'iso-2022-jp')
+ msg.set_param('importance', 'high value')
+ eq(msg.get_param('importance'), 'high value')
+ eq(msg.get_param('importance', unquote=False), '"high value"')
+ eq(msg.get_params(), [('text/plain', ''),
+ ('charset', 'iso-2022-jp'),
+ ('importance', 'high value')])
+ eq(msg.get_params(unquote=False), [('text/plain', ''),
+ ('charset', '"iso-2022-jp"'),
+ ('importance', '"high value"')])
+ msg.set_param('charset', 'iso-9999-xx', header='X-Jimmy')
+ eq(msg.get_param('charset', header='X-Jimmy'), 'iso-9999-xx')
+
+ def test_del_param(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_05.txt')
+ eq(msg.get_params(),
+ [('multipart/report', ''), ('report-type', 'delivery-status'),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com')])
+ old_val = msg.get_param("report-type")
+ msg.del_param("report-type")
+ eq(msg.get_params(),
+ [('multipart/report', ''),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com')])
+ msg.set_param("report-type", old_val)
+ eq(msg.get_params(),
+ [('multipart/report', ''),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com'),
+ ('report-type', old_val)])
+
+ def test_del_param_on_other_header(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment', filename='bud.gif')
+ msg.del_param('filename', 'content-disposition')
+ self.assertEqual(msg['content-disposition'], 'attachment')
+
+ def test_del_param_on_nonexistent_header(self):
+ msg = Message()
+ msg.del_param('filename', 'content-disposition')
+
+ def test_del_nonexistent_param(self):
+ msg = Message()
+ msg.add_header('Content-Type', 'text/plain', charset='utf-8')
+ existing_header = msg['Content-Type']
+ msg.del_param('foobar', header='Content-Type')
+ self.assertEqual(msg['Content-Type'], 'text/plain; charset="utf-8"')
+
+ def test_set_type(self):
+ eq = self.assertEqual
+ msg = Message()
+ self.assertRaises(ValueError, msg.set_type, 'text')
+ msg.set_type('text/plain')
+ eq(msg['content-type'], 'text/plain')
+ msg.set_param('charset', 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+ msg.set_type('text/html')
+ eq(msg['content-type'], 'text/html; charset="us-ascii"')
+
+ def test_set_type_on_other_header(self):
+ msg = Message()
+ msg['X-Content-Type'] = 'text/plain'
+ msg.set_type('application/octet-stream', 'X-Content-Type')
+ self.assertEqual(msg['x-content-type'], 'application/octet-stream')
+
+ def test_get_content_type_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_type_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_type(), 'message/rfc822')
+
+ def test_get_content_type_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_type(),
+ 'message/rfc822')
+
+ def test_get_content_type_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_type(),
+ 'message/rfc822')
+
+ def test_get_content_type_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_type_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_maintype_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_maintype_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_maintype_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_subtype_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_subtype_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_subtype_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_maintype_error(self):
+ msg = Message()
+ msg['Content-Type'] = 'no-slash-in-this-string'
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_subtype_error(self):
+ msg = Message()
+ msg['Content-Type'] = 'no-slash-in-this-string'
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_replace_header(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.add_header('First', 'One')
+ msg.add_header('Second', 'Two')
+ msg.add_header('Third', 'Three')
+ eq(msg.keys(), ['First', 'Second', 'Third'])
+ eq(msg.values(), ['One', 'Two', 'Three'])
+ msg.replace_header('Second', 'Twenty')
+ eq(msg.keys(), ['First', 'Second', 'Third'])
+ eq(msg.values(), ['One', 'Twenty', 'Three'])
+ msg.add_header('First', 'Eleven')
+ msg.replace_header('First', 'One Hundred')
+ eq(msg.keys(), ['First', 'Second', 'Third', 'First'])
+ eq(msg.values(), ['One Hundred', 'Twenty', 'Three', 'Eleven'])
+ self.assertRaises(KeyError, msg.replace_header, 'Fourth', 'Missing')
+
+ # test_defect_handling:test_invalid_chars_in_base64_payload
+ def test_broken_base64_payload(self):
+ x = 'AwDp0P7//y6LwKEAcPa/6Q=9'
+ msg = Message()
+ msg['content-type'] = 'audio/x-midi'
+ msg['content-transfer-encoding'] = 'base64'
+ msg.set_payload(x)
+ self.assertEqual(msg.get_payload(decode=True),
+ (b'\x03\x00\xe9\xd0\xfe\xff\xff.\x8b\xc0'
+ b'\xa1\x00p\xf6\xbf\xe9\x0f'))
+ self.assertIsInstance(msg.defects[0],
+ errors.InvalidBase64CharactersDefect)
+
+ def test_broken_unicode_payload(self):
+ # This test improves coverage but is not a compliance test.
+ # The behavior in this situation is currently undefined by the API.
+ x = 'this is a br\xf6ken thing to do'
+ msg = Message()
+ msg['content-type'] = 'text/plain'
+ msg['content-transfer-encoding'] = '8bit'
+ msg.set_payload(x)
+ self.assertEqual(msg.get_payload(decode=True),
+ bytes(x, 'raw-unicode-escape'))
+
+ def test_questionable_bytes_payload(self):
+ # This test improves coverage but is not a compliance test,
+ # since it involves poking inside the black box.
+ x = 'this is a quéstionable thing to do'.encode('utf-8')
+ msg = Message()
+ msg['content-type'] = 'text/plain; charset="utf-8"'
+ msg['content-transfer-encoding'] = '8bit'
+ msg._payload = x
+ self.assertEqual(msg.get_payload(decode=True), x)
+
+ # Issue 1078919
+ def test_ascii_add_header(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename='bud.gif')
+ self.assertEqual('attachment; filename="bud.gif"',
+ msg['Content-Disposition'])
+
+ def test_noascii_add_header(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename="Fußballer.ppt")
+ self.assertEqual(
+ 'attachment; filename*=utf-8\'\'Fu%C3%9Fballer.ppt',
+ msg['Content-Disposition'])
+
+ def test_nonascii_add_header_via_triple(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename=('iso-8859-1', '', 'Fußballer.ppt'))
+ self.assertEqual(
+ 'attachment; filename*=iso-8859-1\'\'Fu%DFballer.ppt',
+ msg['Content-Disposition'])
+
+ def test_ascii_add_header_with_tspecial(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename="windows [filename].ppt")
+ self.assertEqual(
+ 'attachment; filename="windows [filename].ppt"',
+ msg['Content-Disposition'])
+
+ def test_nonascii_add_header_with_tspecial(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename="Fußballer [filename].ppt")
+ self.assertEqual(
+ "attachment; filename*=utf-8''Fu%C3%9Fballer%20%5Bfilename%5D.ppt",
+ msg['Content-Disposition'])
+
+ def test_add_header_with_name_only_param(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'inline', foo_bar=None)
+ self.assertEqual("inline; foo-bar", msg['Content-Disposition'])
+
+ def test_add_header_with_no_value(self):
+ msg = Message()
+ msg.add_header('X-Status', None)
+ self.assertEqual('', msg['X-Status'])
+
+ # Issue 5871: reject an attempt to embed a header inside a header value
+ # (header injection attack).
+ def test_embeded_header_via_Header_rejected(self):
+ msg = Message()
+ msg['Dummy'] = Header('dummy\nX-Injected-Header: test')
+ self.assertRaises(errors.HeaderParseError, msg.as_string)
+
+ def test_embeded_header_via_string_rejected(self):
+ msg = Message()
+ msg['Dummy'] = 'dummy\nX-Injected-Header: test'
+ self.assertRaises(errors.HeaderParseError, msg.as_string)
+
+ def test_unicode_header_defaults_to_utf8_encoding(self):
+ # Issue 14291
+ m = MIMEText('abc\n')
+ m['Subject'] = 'É test'
+ self.assertEqual(str(m),textwrap.dedent("""\
+ Content-Type: text/plain; charset="us-ascii"
+ MIME-Version: 1.0
+ Content-Transfer-Encoding: 7bit
+ Subject: =?utf-8?q?=C3=89_test?=
+
+ abc
+ """))
+
+ def test_unicode_body_defaults_to_utf8_encoding(self):
+ # Issue 14291
+ m = MIMEText('É testabc\n')
+ self.assertEqual(str(m),textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ MIME-Version: 1.0
+ Content-Transfer-Encoding: base64
+
+ w4kgdGVzdGFiYwo=
+ """))
+
+
+# Test the email.encoders module
+class TestEncoders(unittest.TestCase):
+
+ def test_EncodersEncode_base64(self):
+ with openfile('PyBanner048.gif', 'rb') as fp:
+ bindata = fp.read()
+ mimed = email.mime.image.MIMEImage(bindata)
+ base64ed = mimed.get_payload()
+ # the transfer-encoded body lines should all be <=76 characters
+ lines = base64ed.split('\n')
+ self.assertLessEqual(max([ len(x) for x in lines ]), 76)
+
+ def test_encode_empty_payload(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_charset('us-ascii')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_default_cte(self):
+ eq = self.assertEqual
+ # 7bit data and the default us-ascii _charset
+ msg = MIMEText('hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+ # Similar, but with 8bit data
+ msg = MIMEText('hello \xf8 world')
+ eq(msg['content-transfer-encoding'], 'base64')
+ # And now with a different charset
+ msg = MIMEText('hello \xf8 world', _charset='iso-8859-1')
+ eq(msg['content-transfer-encoding'], 'quoted-printable')
+
+ def test_encode7or8bit(self):
+ # Make sure a charset whose input character set is 8bit but
+ # whose output character set is 7bit gets a transfer-encoding
+ # of 7bit.
+ eq = self.assertEqual
+ msg = MIMEText('文', _charset='euc-jp')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_qp_encode_latin1(self):
+ msg = MIMEText('\xe1\xf6\n', 'text', 'ISO-8859-1')
+ self.assertEqual(str(msg), textwrap.dedent("""\
+ MIME-Version: 1.0
+ Content-Type: text/text; charset="iso-8859-1"
+ Content-Transfer-Encoding: quoted-printable
+
+ =E1=F6
+ """))
+
+ def test_qp_encode_non_latin1(self):
+ # Issue 16948
+ msg = MIMEText('\u017c\n', 'text', 'ISO-8859-2')
+ self.assertEqual(str(msg), textwrap.dedent("""\
+ MIME-Version: 1.0
+ Content-Type: text/text; charset="iso-8859-2"
+ Content-Transfer-Encoding: quoted-printable
+
+ =BF
+ """))
+
+
+# Test long header wrapping
+class TestLongHeaders(TestEmailBase):
+
+ maxDiff = None
+
+ def test_split_long_continuation(self):
+ eq = self.ndiffAssertEqual
+ msg = email.message_from_string("""\
+Subject: bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text
+
+test
+""")
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+Subject: bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text
+
+test
+""")
+
+ def test_another_long_almost_unsplittable_header(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text"""
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text""")
+ h = Header(hstr.replace('\t', ' '))
+ eq(h.encode(), """\
+bug demonstration
+ 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+ more text""")
+
+ def test_long_nonstring(self):
+ eq = self.ndiffAssertEqual
+ g = Charset("iso-8859-1")
+ cz = Charset("iso-8859-2")
+ utf8 = Charset("utf-8")
+ g_head = (b'Die Mieter treten hier ein werden mit einem Foerderband '
+ b'komfortabel den Korridor entlang, an s\xfcdl\xfcndischen '
+ b'Wandgem\xe4lden vorbei, gegen die rotierenden Klingen '
+ b'bef\xf6rdert. ')
+ cz_head = (b'Finan\xe8ni metropole se hroutily pod tlakem jejich '
+ b'd\xf9vtipu.. ')
+ utf8_head = ('\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f'
+ '\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00'
+ '\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c'
+ '\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067'
+ '\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das '
+ 'Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder '
+ 'die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066'
+ '\u3044\u307e\u3059\u3002')
+ h = Header(g_head, g, header_name='Subject')
+ h.append(cz_head, cz)
+ h.append(utf8_head, utf8)
+ msg = Message()
+ msg['Subject'] = h
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+Subject: =?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerderb?=
+ =?iso-8859-1?q?and_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndischen?=
+ =?iso-8859-1?q?_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Klingen_bef?=
+ =?iso-8859-1?q?=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se_hrouti?=
+ =?iso-8859-2?q?ly_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= =?utf-8?b?5q2j56K6?=
+ =?utf-8?b?44Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb44KT44CC5LiA?=
+ =?utf-8?b?6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go44Gv44Gn44Gf44KJ?=
+ =?utf-8?b?44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBpc3QgZGFzIE51bnN0dWNr?=
+ =?utf-8?b?IGdpdCB1bmQgU2xvdGVybWV5ZXI/IEphISBCZWloZXJodW5kIGRhcyBPZGVyIGRp?=
+ =?utf-8?b?ZSBGbGlwcGVyd2FsZHQgZ2Vyc3B1dC7jgI3jgajoqIDjgaPjgabjgYTjgb7jgZk=?=
+ =?utf-8?b?44CC?=
+
+""")
+ eq(h.encode(maxlinelen=76), """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerde?=
+ =?iso-8859-1?q?rband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndis?=
+ =?iso-8859-1?q?chen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Klinge?=
+ =?iso-8859-1?q?n_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se?=
+ =?iso-8859-2?q?_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?=
+ =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb?=
+ =?utf-8?b?44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go?=
+ =?utf-8?b?44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBp?=
+ =?utf-8?b?c3QgZGFzIE51bnN0dWNrIGdpdCB1bmQgU2xvdGVybWV5ZXI/IEphISBCZWlo?=
+ =?utf-8?b?ZXJodW5kIGRhcyBPZGVyIGRpZSBGbGlwcGVyd2FsZHQgZ2Vyc3B1dC7jgI0=?=
+ =?utf-8?b?44Go6KiA44Gj44Gm44GE44G+44GZ44CC?=""")
+
+ def test_long_header_encode(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_long_header_encode_with_tab_continuation_is_just_a_hint(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit',
+ continuation_ws='\t')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_long_header_encode_with_tab_continuation(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals";\t'
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit',
+ continuation_ws='\t')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+\tspooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_header_encode_with_different_output_charset(self):
+ h = Header('文', 'euc-jp')
+ self.assertEqual(h.encode(), "=?iso-2022-jp?b?GyRCSjgbKEI=?=")
+
+ def test_long_header_encode_with_different_output_charset(self):
+ h = Header(b'test-ja \xa4\xd8\xc5\xea\xb9\xc6\xa4\xb5\xa4\xec\xa4'
+ b'\xbf\xa5\xe1\xa1\xbc\xa5\xeb\xa4\xcf\xbb\xca\xb2\xf1\xbc\xd4'
+ b'\xa4\xce\xbe\xb5\xc7\xa7\xa4\xf2\xc2\xd4\xa4\xc3\xa4\xc6\xa4'
+ b'\xa4\xa4\xde\xa4\xb9'.decode('euc-jp'), 'euc-jp')
+ res = """\
+=?iso-2022-jp?b?dGVzdC1qYSAbJEIkWEVqOUYkNSRsJD8lYSE8JWskTztKMnE8VCROPjUbKEI=?=
+ =?iso-2022-jp?b?GyRCRyckckJUJEMkRiQkJF4kORsoQg==?="""
+ self.assertEqual(h.encode(), res)
+
+ def test_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = MIMEText('')
+ # It'd be great if we could use add_header() here, but that doesn't
+ # guarantee an order of the parameters.
+ msg['X-Foobar-Spoink-Defrobnit'] = (
+ 'wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), '''\
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+X-Foobar-Spoink-Defrobnit: wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"
+
+''')
+
+ def test_no_semis_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'test@dom.ain'
+ msg['References'] = SPACE.join('<%d@dom.ain>' % i for i in range(10))
+ msg.set_payload('Test')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+From: test@dom.ain
+References: <0@dom.ain> <1@dom.ain> <2@dom.ain> <3@dom.ain> <4@dom.ain>
+ <5@dom.ain> <6@dom.ain> <7@dom.ain> <8@dom.ain> <9@dom.ain>
+
+Test""")
+
+ def test_last_split_chunk_does_not_fit(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Subject: the first part of this is short, but_the_second'
+ '_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line'
+ '_all_by_itself')
+ eq(h.encode(), """\
+Subject: the first part of this is short,
+ but_the_second_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself""")
+
+ def test_splittable_leading_char_followed_by_overlong_unsplitable(self):
+ eq = self.ndiffAssertEqual
+ h = Header(', but_the_second'
+ '_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line'
+ '_all_by_itself')
+ eq(h.encode(), """\
+,
+ but_the_second_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself""")
+
+ def test_multiple_splittable_leading_char_followed_by_overlong_unsplitable(self):
+ eq = self.ndiffAssertEqual
+ h = Header(', , but_the_second'
+ '_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line'
+ '_all_by_itself')
+ eq(h.encode(), """\
+, ,
+ but_the_second_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself""")
+
+ def test_trailing_splitable_on_overlong_unsplitable(self):
+ eq = self.ndiffAssertEqual
+ h = Header('this_part_does_not_fit_within_maxlinelen_and_thus_should_'
+ 'be_on_a_line_all_by_itself;')
+ eq(h.encode(), "this_part_does_not_fit_within_maxlinelen_and_thus_should_"
+ "be_on_a_line_all_by_itself;")
+
+ def test_trailing_splitable_on_overlong_unsplitable_with_leading_splitable(self):
+ eq = self.ndiffAssertEqual
+ h = Header('; '
+ 'this_part_does_not_fit_within_maxlinelen_and_thus_should_'
+ 'be_on_a_line_all_by_itself; ')
+ eq(h.encode(), """\
+;
+ this_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself; """)
+
+ def test_long_header_with_multiple_sequential_split_chars(self):
+ eq = self.ndiffAssertEqual
+ h = Header('This is a long line that has two whitespaces in a row. '
+ 'This used to cause truncation of the header when folded')
+ eq(h.encode(), """\
+This is a long line that has two whitespaces in a row. This used to cause
+ truncation of the header when folded""")
+
+ def test_splitter_split_on_punctuation_only_if_fws(self):
+ eq = self.ndiffAssertEqual
+ h = Header('thisverylongheaderhas;semicolons;and,commas,but'
+ 'they;arenotlegal;fold,points')
+ eq(h.encode(), "thisverylongheaderhas;semicolons;and,commas,butthey;"
+ "arenotlegal;fold,points")
+
+ def test_leading_splittable_in_the_middle_just_before_overlong_last_part(self):
+ eq = self.ndiffAssertEqual
+ h = Header('this is a test where we need to have more than one line '
+ 'before; our final line that is just too big to fit;; '
+ 'this_part_does_not_fit_within_maxlinelen_and_thus_should_'
+ 'be_on_a_line_all_by_itself;')
+ eq(h.encode(), """\
+this is a test where we need to have more than one line before;
+ our final line that is just too big to fit;;
+ this_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself;""")
+
+ def test_overlong_last_part_followed_by_split_point(self):
+ eq = self.ndiffAssertEqual
+ h = Header('this_part_does_not_fit_within_maxlinelen_and_thus_should_'
+ 'be_on_a_line_all_by_itself ')
+ eq(h.encode(), "this_part_does_not_fit_within_maxlinelen_and_thus_"
+ "should_be_on_a_line_all_by_itself ")
+
+ def test_multiline_with_overlong_parts_separated_by_two_split_points(self):
+ eq = self.ndiffAssertEqual
+ h = Header('this_is_a__test_where_we_need_to_have_more_than_one_line_'
+ 'before_our_final_line_; ; '
+ 'this_part_does_not_fit_within_maxlinelen_and_thus_should_'
+ 'be_on_a_line_all_by_itself; ')
+ eq(h.encode(), """\
+this_is_a__test_where_we_need_to_have_more_than_one_line_before_our_final_line_;
+ ;
+ this_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself; """)
+
+ def test_multiline_with_overlong_last_part_followed_by_split_point(self):
+ eq = self.ndiffAssertEqual
+ h = Header('this is a test where we need to have more than one line '
+ 'before our final line; ; '
+ 'this_part_does_not_fit_within_maxlinelen_and_thus_should_'
+ 'be_on_a_line_all_by_itself; ')
+ eq(h.encode(), """\
+this is a test where we need to have more than one line before our final line;
+ ;
+ this_part_does_not_fit_within_maxlinelen_and_thus_should_be_on_a_line_all_by_itself; """)
+
+ def test_long_header_with_whitespace_runs(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'test@dom.ain'
+ msg['References'] = SPACE.join(['<foo@dom.ain> '] * 10)
+ msg.set_payload('Test')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+From: test@dom.ain
+References: <foo@dom.ain> <foo@dom.ain> <foo@dom.ain> <foo@dom.ain>
+ <foo@dom.ain> <foo@dom.ain> <foo@dom.ain> <foo@dom.ain>
+ <foo@dom.ain> <foo@dom.ain>\x20\x20
+
+Test""")
+
+ def test_long_run_with_semi_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'test@dom.ain'
+ msg['References'] = SPACE.join(['<foo@dom.ain>'] * 10) + '; abc'
+ msg.set_payload('Test')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+From: test@dom.ain
+References: <foo@dom.ain> <foo@dom.ain> <foo@dom.ain> <foo@dom.ain>
+ <foo@dom.ain> <foo@dom.ain> <foo@dom.ain> <foo@dom.ain> <foo@dom.ain>
+ <foo@dom.ain>; abc
+
+Test""")
+
+ def test_splitter_split_on_punctuation_only_if_fws(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'test@dom.ain'
+ msg['References'] = ('thisverylongheaderhas;semicolons;and,commas,but'
+ 'they;arenotlegal;fold,points')
+ msg.set_payload('Test')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ # XXX the space after the header should not be there.
+ eq(sfp.getvalue(), """\
+From: test@dom.ain
+References:\x20
+ thisverylongheaderhas;semicolons;and,commas,butthey;arenotlegal;fold,points
+
+Test""")
+
+ def test_no_split_long_header(self):
+ eq = self.ndiffAssertEqual
+ hstr = 'References: ' + 'x' * 80
+ h = Header(hstr)
+ # These come on two lines because Headers are really field value
+ # classes and don't really know about their field names.
+ eq(h.encode(), """\
+References:
+ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx""")
+ h = Header('x' * 80)
+ eq(h.encode(), 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
+
+ def test_splitting_multiple_long_lines(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+from babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin@babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin@babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin@babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+"""
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+from babylon.socal-raves.org (localhost [127.0.0.1]);
+ by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+ for <mailman-admin@babylon.socal-raves.org>;
+ Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]);
+ by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+ for <mailman-admin@babylon.socal-raves.org>;
+ Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]);
+ by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+ for <mailman-admin@babylon.socal-raves.org>;
+ Sat, 2 Feb 2002 17:00:06 -0800 (PST)""")
+
+ def test_splitting_first_line_only_is_long(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] helo=cthulhu.gerg.ca)
+\tby kronos.mems-exchange.org with esmtp (Exim 4.05)
+\tid 17k4h5-00034i-00
+\tfor test@mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400"""
+ h = Header(hstr, maxlinelen=78, header_name='Received',
+ continuation_ws='\t')
+ eq(h.encode(), """\
+from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93]
+ helo=cthulhu.gerg.ca)
+\tby kronos.mems-exchange.org with esmtp (Exim 4.05)
+\tid 17k4h5-00034i-00
+\tfor test@mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""")
+
+ def test_long_8bit_header(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ h = Header('Britische Regierung gibt', 'iso-8859-1',
+ header_name='Subject')
+ h.append('gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte')
+ eq(h.encode(maxlinelen=76), """\
+=?iso-8859-1?q?Britische_Regierung_gibt_gr=FCnes_Licht_f=FCr_Offs?=
+ =?iso-8859-1?q?hore-Windkraftprojekte?=""")
+ msg['Subject'] = h
+ eq(msg.as_string(maxheaderlen=76), """\
+Subject: =?iso-8859-1?q?Britische_Regierung_gibt_gr=FCnes_Licht_f=FCr_Offs?=
+ =?iso-8859-1?q?hore-Windkraftprojekte?=
+
+""")
+ eq(msg.as_string(maxheaderlen=0), """\
+Subject: =?iso-8859-1?q?Britische_Regierung_gibt_gr=FCnes_Licht_f=FCr_Offshore-Windkraftprojekte?=
+
+""")
+
+ def test_long_8bit_header_no_charset(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ header_string = ('Britische Regierung gibt gr\xfcnes Licht '
+ 'f\xfcr Offshore-Windkraftprojekte '
+ '<a-very-long-address@example.com>')
+ msg['Reply-To'] = header_string
+ eq(msg.as_string(maxheaderlen=78), """\
+Reply-To: =?utf-8?q?Britische_Regierung_gibt_gr=C3=BCnes_Licht_f=C3=BCr_Offs?=
+ =?utf-8?q?hore-Windkraftprojekte_=3Ca-very-long-address=40example=2Ecom=3E?=
+
+""")
+ msg = Message()
+ msg['Reply-To'] = Header(header_string,
+ header_name='Reply-To')
+ eq(msg.as_string(maxheaderlen=78), """\
+Reply-To: =?utf-8?q?Britische_Regierung_gibt_gr=C3=BCnes_Licht_f=C3=BCr_Offs?=
+ =?utf-8?q?hore-Windkraftprojekte_=3Ca-very-long-address=40example=2Ecom=3E?=
+
+""")
+
+ def test_long_to_header(self):
+ eq = self.ndiffAssertEqual
+ to = ('"Someone Test #A" <someone@eecs.umich.edu>,'
+ '<someone@eecs.umich.edu>, '
+ '"Someone Test #B" <someone@umich.edu>, '
+ '"Someone Test #C" <someone@eecs.umich.edu>, '
+ '"Someone Test #D" <someone@eecs.umich.edu>')
+ msg = Message()
+ msg['To'] = to
+ eq(msg.as_string(maxheaderlen=78), '''\
+To: "Someone Test #A" <someone@eecs.umich.edu>,<someone@eecs.umich.edu>,
+ "Someone Test #B" <someone@umich.edu>,
+ "Someone Test #C" <someone@eecs.umich.edu>,
+ "Someone Test #D" <someone@eecs.umich.edu>
+
+''')
+
+ def test_long_line_after_append(self):
+ eq = self.ndiffAssertEqual
+ s = 'This is an example of string which has almost the limit of header length.'
+ h = Header(s)
+ h.append('Add another line.')
+ eq(h.encode(maxlinelen=76), """\
+This is an example of string which has almost the limit of header length.
+ Add another line.""")
+
+ def test_shorter_line_with_append(self):
+ eq = self.ndiffAssertEqual
+ s = 'This is a shorter line.'
+ h = Header(s)
+ h.append('Add another sentence. (Surprise?)')
+ eq(h.encode(),
+ 'This is a shorter line. Add another sentence. (Surprise?)')
+
+ def test_long_field_name(self):
+ eq = self.ndiffAssertEqual
+ fn = 'X-Very-Very-Very-Long-Header-Name'
+ gs = ('Die Mieter treten hier ein werden mit einem Foerderband '
+ 'komfortabel den Korridor entlang, an s\xfcdl\xfcndischen '
+ 'Wandgem\xe4lden vorbei, gegen die rotierenden Klingen '
+ 'bef\xf6rdert. ')
+ h = Header(gs, 'iso-8859-1', header_name=fn)
+ # BAW: this seems broken because the first line is too long
+ eq(h.encode(maxlinelen=76), """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_e?=
+ =?iso-8859-1?q?in_werden_mit_einem_Foerderband_komfortabel_den_Korridor_e?=
+ =?iso-8859-1?q?ntlang=2C_an_s=FCdl=FCndischen_Wandgem=E4lden_vorbei=2C_ge?=
+ =?iso-8859-1?q?gen_die_rotierenden_Klingen_bef=F6rdert=2E_?=""")
+
+ def test_long_received_header(self):
+ h = ('from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) '
+ 'by hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; '
+ 'Wed, 05 Mar 2003 18:10:18 -0700')
+ msg = Message()
+ msg['Received-1'] = Header(h, continuation_ws='\t')
+ msg['Received-2'] = h
+ # This should be splitting on spaces not semicolons.
+ self.ndiffAssertEqual(msg.as_string(maxheaderlen=78), """\
+Received-1: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by
+ hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP;
+ Wed, 05 Mar 2003 18:10:18 -0700
+Received-2: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by
+ hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP;
+ Wed, 05 Mar 2003 18:10:18 -0700
+
+""")
+
+ def test_string_headerinst_eq(self):
+ h = ('<15975.17901.207240.414604@sgigritzmann1.mathematik.'
+ 'tu-muenchen.de> (David Bremner\'s message of '
+ '"Thu, 6 Mar 2003 13:58:21 +0100")')
+ msg = Message()
+ msg['Received-1'] = Header(h, header_name='Received-1',
+ continuation_ws='\t')
+ msg['Received-2'] = h
+ # XXX The space after the ':' should not be there.
+ self.ndiffAssertEqual(msg.as_string(maxheaderlen=78), """\
+Received-1:\x20
+ <15975.17901.207240.414604@sgigritzmann1.mathematik.tu-muenchen.de> (David
+ Bremner's message of \"Thu, 6 Mar 2003 13:58:21 +0100\")
+Received-2:\x20
+ <15975.17901.207240.414604@sgigritzmann1.mathematik.tu-muenchen.de> (David
+ Bremner's message of \"Thu, 6 Mar 2003 13:58:21 +0100\")
+
+""")
+
+ def test_long_unbreakable_lines_with_continuation(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ t = """\
+iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp"""
+ msg['Face-1'] = t
+ msg['Face-2'] = Header(t, header_name='Face-2')
+ msg['Face-3'] = ' ' + t
+ # XXX This splitting is all wrong. It the first value line should be
+ # snug against the field name or the space after the header not there.
+ eq(msg.as_string(maxheaderlen=78), """\
+Face-1:\x20
+ iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+Face-2:\x20
+ iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+Face-3:\x20
+ iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+
+""")
+
+ def test_another_long_multiline_header(self):
+ eq = self.ndiffAssertEqual
+ m = ('Received: from siimage.com '
+ '([172.25.1.3]) by zima.siliconimage.com with '
+ 'Microsoft SMTPSVC(5.0.2195.4905); '
+ 'Wed, 16 Oct 2002 07:41:11 -0700')
+ msg = email.message_from_string(m)
+ eq(msg.as_string(maxheaderlen=78), '''\
+Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with
+ Microsoft SMTPSVC(5.0.2195.4905); Wed, 16 Oct 2002 07:41:11 -0700
+
+''')
+
+ def test_long_lines_with_different_header(self):
+ eq = self.ndiffAssertEqual
+ h = ('List-Unsubscribe: '
+ '<http://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,'
+ ' <mailto:spamassassin-talk-request@lists.sourceforge.net'
+ '?subject=unsubscribe>')
+ msg = Message()
+ msg['List'] = h
+ msg['List'] = Header(h, header_name='List')
+ eq(msg.as_string(maxheaderlen=78), """\
+List: List-Unsubscribe:
+ <http://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request@lists.sourceforge.net?subject=unsubscribe>
+List: List-Unsubscribe:
+ <http://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request@lists.sourceforge.net?subject=unsubscribe>
+
+""")
+
+ def test_long_rfc2047_header_with_embedded_fws(self):
+ h = Header(textwrap.dedent("""\
+ We're going to pretend this header is in a non-ascii character set
+ \tto see if line wrapping with encoded words and embedded
+ folding white space works"""),
+ charset='utf-8',
+ header_name='Test')
+ self.assertEqual(h.encode()+'\n', textwrap.dedent("""\
+ =?utf-8?q?We=27re_going_to_pretend_this_header_is_in_a_non-ascii_chara?=
+ =?utf-8?q?cter_set?=
+ =?utf-8?q?_to_see_if_line_wrapping_with_encoded_words_and_embedded?=
+ =?utf-8?q?_folding_white_space_works?=""")+'\n')
+
+
+
+# Test mangling of "From " lines in the body of a message
+class TestFromMangling(unittest.TestCase):
+ def setUp(self):
+ self.msg = Message()
+ self.msg['From'] = 'aaa@bbb.org'
+ self.msg.set_payload("""\
+From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_mangled_from(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=True)
+ g.flatten(self.msg)
+ self.assertEqual(s.getvalue(), """\
+From: aaa@bbb.org
+
+>From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_dont_mangle_from(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=False)
+ g.flatten(self.msg)
+ self.assertEqual(s.getvalue(), """\
+From: aaa@bbb.org
+
+From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_mangle_from_in_preamble_and_epilog(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=True)
+ msg = email.message_from_string(textwrap.dedent("""\
+ From: foo@bar.com
+ Mime-Version: 1.0
+ Content-Type: multipart/mixed; boundary=XXX
+
+ From somewhere unknown
+
+ --XXX
+ Content-Type: text/plain
+
+ foo
+
+ --XXX--
+
+ From somewhere unknowable
+ """))
+ g.flatten(msg)
+ self.assertEqual(len([1 for x in s.getvalue().split('\n')
+ if x.startswith('>From ')]), 2)
+
+ def test_mangled_from_with_bad_bytes(self):
+ source = textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ MIME-Version: 1.0
+ Content-Transfer-Encoding: 8bit
+ From: aaa@bbb.org
+
+ """).encode('utf-8')
+ msg = email.message_from_bytes(source + b'From R\xc3\xb6lli\n')
+ b = BytesIO()
+ g = BytesGenerator(b, mangle_from_=True)
+ g.flatten(msg)
+ self.assertEqual(b.getvalue(), source + b'>From R\xc3\xb6lli\n')
+
+
+# Test the basic MIMEAudio class
+class TestMIMEAudio(unittest.TestCase):
+ def setUp(self):
+ with openfile('audiotest.au', 'rb') as fp:
+ self._audiodata = fp.read()
+ self._au = MIMEAudio(self._audiodata)
+
+ def test_guess_minor_type(self):
+ self.assertEqual(self._au.get_content_type(), 'audio/basic')
+
+ def test_encoding(self):
+ payload = self._au.get_payload()
+ self.assertEqual(base64.decodebytes(bytes(payload, 'ascii')),
+ self._audiodata)
+
+ def test_checkSetMinor(self):
+ au = MIMEAudio(self._audiodata, 'fish')
+ self.assertEqual(au.get_content_type(), 'audio/fish')
+
+ def test_add_header(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ self._au.add_header('Content-Disposition', 'attachment',
+ filename='audiotest.au')
+ eq(self._au['content-disposition'],
+ 'attachment; filename="audiotest.au"')
+ eq(self._au.get_params(header='content-disposition'),
+ [('attachment', ''), ('filename', 'audiotest.au')])
+ eq(self._au.get_param('filename', header='content-disposition'),
+ 'audiotest.au')
+ missing = []
+ eq(self._au.get_param('attachment', header='content-disposition'), '')
+ unless(self._au.get_param('foo', failobj=missing,
+ header='content-disposition') is missing)
+ # Try some missing stuff
+ unless(self._au.get_param('foobar', missing) is missing)
+ unless(self._au.get_param('attachment', missing,
+ header='foobar') is missing)
+
+
+
+# Test the basic MIMEImage class
+class TestMIMEImage(unittest.TestCase):
+ def setUp(self):
+ with openfile('PyBanner048.gif', 'rb') as fp:
+ self._imgdata = fp.read()
+ self._im = MIMEImage(self._imgdata)
+
+ def test_guess_minor_type(self):
+ self.assertEqual(self._im.get_content_type(), 'image/gif')
+
+ def test_encoding(self):
+ payload = self._im.get_payload()
+ self.assertEqual(base64.decodebytes(bytes(payload, 'ascii')),
+ self._imgdata)
+
+ def test_checkSetMinor(self):
+ im = MIMEImage(self._imgdata, 'fish')
+ self.assertEqual(im.get_content_type(), 'image/fish')
+
+ def test_add_header(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ self._im.add_header('Content-Disposition', 'attachment',
+ filename='dingusfish.gif')
+ eq(self._im['content-disposition'],
+ 'attachment; filename="dingusfish.gif"')
+ eq(self._im.get_params(header='content-disposition'),
+ [('attachment', ''), ('filename', 'dingusfish.gif')])
+ eq(self._im.get_param('filename', header='content-disposition'),
+ 'dingusfish.gif')
+ missing = []
+ eq(self._im.get_param('attachment', header='content-disposition'), '')
+ unless(self._im.get_param('foo', failobj=missing,
+ header='content-disposition') is missing)
+ # Try some missing stuff
+ unless(self._im.get_param('foobar', missing) is missing)
+ unless(self._im.get_param('attachment', missing,
+ header='foobar') is missing)
+
+
+
+# Test the basic MIMEApplication class
+class TestMIMEApplication(unittest.TestCase):
+ def test_headers(self):
+ eq = self.assertEqual
+ msg = MIMEApplication(b'\xfa\xfb\xfc\xfd\xfe\xff')
+ eq(msg.get_content_type(), 'application/octet-stream')
+ eq(msg['content-transfer-encoding'], 'base64')
+
+ def test_body(self):
+ eq = self.assertEqual
+ bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff'
+ msg = MIMEApplication(bytesdata)
+ # whitespace in the cte encoded block is RFC-irrelevant.
+ eq(msg.get_payload().strip(), '+vv8/f7/')
+ eq(msg.get_payload(decode=True), bytesdata)
+
+ def test_binary_body_with_encode_7or8bit(self):
+ # Issue 17171.
+ bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff'
+ msg = MIMEApplication(bytesdata, _encoder=encoders.encode_7or8bit)
+ # Treated as a string, this will be invalid code points.
+ self.assertEqual(msg.get_payload(), '\uFFFD' * len(bytesdata))
+ self.assertEqual(msg.get_payload(decode=True), bytesdata)
+ self.assertEqual(msg['Content-Transfer-Encoding'], '8bit')
+ s = BytesIO()
+ g = BytesGenerator(s)
+ g.flatten(msg)
+ wireform = s.getvalue()
+ msg2 = email.message_from_bytes(wireform)
+ self.assertEqual(msg.get_payload(), '\uFFFD' * len(bytesdata))
+ self.assertEqual(msg2.get_payload(decode=True), bytesdata)
+ self.assertEqual(msg2['Content-Transfer-Encoding'], '8bit')
+
+ def test_binary_body_with_encode_noop(self):
+ # Issue 16564: This does not produce an RFC valid message, since to be
+ # valid it should have a CTE of binary. But the below works in
+ # Python2, and is documented as working this way.
+ bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff'
+ msg = MIMEApplication(bytesdata, _encoder=encoders.encode_noop)
+ # Treated as a string, this will be invalid code points.
+ self.assertEqual(msg.get_payload(), '\uFFFD' * len(bytesdata))
+ self.assertEqual(msg.get_payload(decode=True), bytesdata)
+ s = BytesIO()
+ g = BytesGenerator(s)
+ g.flatten(msg)
+ wireform = s.getvalue()
+ msg2 = email.message_from_bytes(wireform)
+ self.assertEqual(msg.get_payload(), '\uFFFD' * len(bytesdata))
+ self.assertEqual(msg2.get_payload(decode=True), bytesdata)
+
+
+# Test the basic MIMEText class
+class TestMIMEText(unittest.TestCase):
+ def setUp(self):
+ self._msg = MIMEText('hello there')
+
+ def test_types(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ eq(self._msg.get_content_type(), 'text/plain')
+ eq(self._msg.get_param('charset'), 'us-ascii')
+ missing = []
+ unless(self._msg.get_param('foobar', missing) is missing)
+ unless(self._msg.get_param('charset', missing, header='foobar')
+ is missing)
+
+ def test_payload(self):
+ self.assertEqual(self._msg.get_payload(), 'hello there')
+ self.assertTrue(not self._msg.is_multipart())
+
+ def test_charset(self):
+ eq = self.assertEqual
+ msg = MIMEText('hello there', _charset='us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_7bit_input(self):
+ eq = self.assertEqual
+ msg = MIMEText('hello there', _charset='us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_7bit_input_no_charset(self):
+ eq = self.assertEqual
+ msg = MIMEText('hello there')
+ eq(msg.get_charset(), 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+ self.assertTrue('hello there' in msg.as_string())
+
+ def test_utf8_input(self):
+ teststr = '\u043a\u0438\u0440\u0438\u043b\u0438\u0446\u0430'
+ eq = self.assertEqual
+ msg = MIMEText(teststr, _charset='utf-8')
+ eq(msg.get_charset().output_charset, 'utf-8')
+ eq(msg['content-type'], 'text/plain; charset="utf-8"')
+ eq(msg.get_payload(decode=True), teststr.encode('utf-8'))
+
+ @unittest.skip("can't fix because of backward compat in email5, "
+ "will fix in email6")
+ def test_utf8_input_no_charset(self):
+ teststr = '\u043a\u0438\u0440\u0438\u043b\u0438\u0446\u0430'
+ self.assertRaises(UnicodeEncodeError, MIMEText, teststr)
+
+
+
+# Test complicated multipart/* messages
+class TestMultipart(TestEmailBase):
+ def setUp(self):
+ with openfile('PyBanner048.gif', 'rb') as fp:
+ data = fp.read()
+ container = MIMEBase('multipart', 'mixed', boundary='BOUNDARY')
+ image = MIMEImage(data, name='dingusfish.gif')
+ image.add_header('content-disposition', 'attachment',
+ filename='dingusfish.gif')
+ intro = MIMEText('''\
+Hi there,
+
+This is the dingus fish.
+''')
+ container.attach(intro)
+ container.attach(image)
+ container['From'] = 'Barry <barry@digicool.com>'
+ container['To'] = 'Dingus Lovers <cravindogs@cravindogs.com>'
+ container['Subject'] = 'Here is your dingus fish'
+
+ now = 987809702.54848599
+ timetuple = time.localtime(now)
+ if timetuple[-1] == 0:
+ tzsecs = time.timezone
+ else:
+ tzsecs = time.altzone
+ if tzsecs > 0:
+ sign = '-'
+ else:
+ sign = '+'
+ tzoffset = ' %s%04d' % (sign, tzsecs / 36)
+ container['Date'] = time.strftime(
+ '%a, %d %b %Y %H:%M:%S',
+ time.localtime(now)) + tzoffset
+ self._msg = container
+ self._im = image
+ self._txt = intro
+
+ def test_hierarchy(self):
+ # convenience
+ eq = self.assertEqual
+ unless = self.assertTrue
+ raises = self.assertRaises
+ # tests
+ m = self._msg
+ unless(m.is_multipart())
+ eq(m.get_content_type(), 'multipart/mixed')
+ eq(len(m.get_payload()), 2)
+ raises(IndexError, m.get_payload, 2)
+ m0 = m.get_payload(0)
+ m1 = m.get_payload(1)
+ unless(m0 is self._txt)
+ unless(m1 is self._im)
+ eq(m.get_payload(), [m0, m1])
+ unless(not m0.is_multipart())
+ unless(not m1.is_multipart())
+
+ def test_empty_multipart_idempotent(self):
+ text = """\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+
+--BOUNDARY
+
+
+--BOUNDARY--
+"""
+ msg = Parser().parsestr(text)
+ self.ndiffAssertEqual(text, msg.as_string())
+
+ def test_no_parts_in_a_multipart_with_none_epilogue(self):
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.set_boundary('BOUNDARY')
+ self.ndiffAssertEqual(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+
+--BOUNDARY--''')
+
+ def test_no_parts_in_a_multipart_with_empty_epilogue(self):
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.preamble = ''
+ outer.epilogue = ''
+ outer.set_boundary('BOUNDARY')
+ self.ndiffAssertEqual(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+
+--BOUNDARY
+
+--BOUNDARY--
+''')
+
+ def test_one_part_in_a_multipart(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.set_boundary('BOUNDARY')
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+ def test_seq_parts_in_a_multipart_with_empty_preamble(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.preamble = ''
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_none_preamble(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.preamble = None
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_none_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.epilogue = None
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_empty_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.epilogue = ''
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--
+''')
+
+
+ def test_seq_parts_in_a_multipart_with_nl_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson@dom.ain'
+ outer['From'] = 'bperson@dom.ain'
+ outer.epilogue = '\n'
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson@dom.ain
+From: bperson@dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--
+
+''')
+
+ def test_message_external_body(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_36.txt')
+ eq(len(msg.get_payload()), 2)
+ msg1 = msg.get_payload(1)
+ eq(msg1.get_content_type(), 'multipart/alternative')
+ eq(len(msg1.get_payload()), 2)
+ for subpart in msg1.get_payload():
+ eq(subpart.get_content_type(), 'message/external-body')
+ eq(len(subpart.get_payload()), 1)
+ subsubpart = subpart.get_payload(0)
+ eq(subsubpart.get_content_type(), 'text/plain')
+
+ def test_double_boundary(self):
+ # msg_37.txt is a multipart that contains two dash-boundary's in a
+ # row. Our interpretation of RFC 2046 calls for ignoring the second
+ # and subsequent boundaries.
+ msg = self._msgobj('msg_37.txt')
+ self.assertEqual(len(msg.get_payload()), 3)
+
+ def test_nested_inner_contains_outer_boundary(self):
+ eq = self.ndiffAssertEqual
+ # msg_38.txt has an inner part that contains outer boundaries. My
+ # interpretation of RFC 2046 (based on sections 5.1 and 5.1.2) say
+ # these are illegal and should be interpreted as unterminated inner
+ # parts.
+ msg = self._msgobj('msg_38.txt')
+ sfp = StringIO()
+ iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/mixed
+ multipart/mixed
+ multipart/alternative
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+""")
+
+ def test_nested_with_same_boundary(self):
+ eq = self.ndiffAssertEqual
+ # msg 39.txt is similarly evil in that it's got inner parts that use
+ # the same boundary as outer parts. Again, I believe the way this is
+ # parsed is closest to the spirit of RFC 2046
+ msg = self._msgobj('msg_39.txt')
+ sfp = StringIO()
+ iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/mixed
+ multipart/mixed
+ multipart/alternative
+ application/octet-stream
+ application/octet-stream
+ text/plain
+""")
+
+ def test_boundary_in_non_multipart(self):
+ msg = self._msgobj('msg_40.txt')
+ self.assertEqual(msg.as_string(), '''\
+MIME-Version: 1.0
+Content-Type: text/html; boundary="--961284236552522269"
+
+----961284236552522269
+Content-Type: text/html;
+Content-Transfer-Encoding: 7Bit
+
+<html></html>
+
+----961284236552522269--
+''')
+
+ def test_boundary_with_leading_space(self):
+ eq = self.assertEqual
+ msg = email.message_from_string('''\
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary=" XXXX"
+
+-- XXXX
+Content-Type: text/plain
+
+
+-- XXXX
+Content-Type: text/plain
+
+-- XXXX--
+''')
+ self.assertTrue(msg.is_multipart())
+ eq(msg.get_boundary(), ' XXXX')
+ eq(len(msg.get_payload()), 2)
+
+ def test_boundary_without_trailing_newline(self):
+ m = Parser().parsestr("""\
+Content-Type: multipart/mixed; boundary="===============0012394164=="
+MIME-Version: 1.0
+
+--===============0012394164==
+Content-Type: image/file1.jpg
+MIME-Version: 1.0
+Content-Transfer-Encoding: base64
+
+YXNkZg==
+--===============0012394164==--""")
+ self.assertEqual(m.get_payload(0).get_payload(), 'YXNkZg==')
+
+
+
+# Test some badly formatted messages
+class TestNonConformant(TestEmailBase):
+
+ def test_parse_missing_minor_type(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_14.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+
+ # test_defect_handling
+ def test_same_boundary_inner_outer(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_15.txt')
+ # XXX We can probably eventually do better
+ inner = msg.get_payload(0)
+ unless(hasattr(inner, 'defects'))
+ self.assertEqual(len(inner.defects), 1)
+ unless(isinstance(inner.defects[0],
+ errors.StartBoundaryNotFoundDefect))
+
+ # test_defect_handling
+ def test_multipart_no_boundary(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_25.txt')
+ unless(isinstance(msg.get_payload(), str))
+ self.assertEqual(len(msg.defects), 2)
+ unless(isinstance(msg.defects[0],
+ errors.NoBoundaryInMultipartDefect))
+ unless(isinstance(msg.defects[1],
+ errors.MultipartInvariantViolationDefect))
+
+ multipart_msg = textwrap.dedent("""\
+ Date: Wed, 14 Nov 2007 12:56:23 GMT
+ From: foo@bar.invalid
+ To: foo@bar.invalid
+ Subject: Content-Transfer-Encoding: base64 and multipart
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed;
+ boundary="===============3344438784458119861=="{}
+
+ --===============3344438784458119861==
+ Content-Type: text/plain
+
+ Test message
+
+ --===============3344438784458119861==
+ Content-Type: application/octet-stream
+ Content-Transfer-Encoding: base64
+
+ YWJj
+
+ --===============3344438784458119861==--
+ """)
+
+ # test_defect_handling
+ def test_multipart_invalid_cte(self):
+ msg = self._str_msg(
+ self.multipart_msg.format("\nContent-Transfer-Encoding: base64"))
+ self.assertEqual(len(msg.defects), 1)
+ self.assertIsInstance(msg.defects[0],
+ errors.InvalidMultipartContentTransferEncodingDefect)
+
+ # test_defect_handling
+ def test_multipart_no_cte_no_defect(self):
+ msg = self._str_msg(self.multipart_msg.format(''))
+ self.assertEqual(len(msg.defects), 0)
+
+ # test_defect_handling
+ def test_multipart_valid_cte_no_defect(self):
+ for cte in ('7bit', '8bit', 'BINary'):
+ msg = self._str_msg(
+ self.multipart_msg.format(
+ "\nContent-Transfer-Encoding: {}".format(cte)))
+ self.assertEqual(len(msg.defects), 0)
+
+ # test_headerregistry.TestContentTyopeHeader invalid_1 and invalid_2.
+ def test_invalid_content_type(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ msg = Message()
+ # RFC 2045, $5.2 says invalid yields text/plain
+ msg['Content-Type'] = 'text'
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_content_type(), 'text/plain')
+ # Clear the old value and try something /really/ invalid
+ del msg['content-type']
+ msg['Content-Type'] = 'foo'
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_content_type(), 'text/plain')
+ # Still, make sure that the message is idempotently generated
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(msg)
+ neq(s.getvalue(), 'Content-Type: foo\n\n')
+
+ def test_no_start_boundary(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_31.txt')
+ eq(msg.get_payload(), """\
+--BOUNDARY
+Content-Type: text/plain
+
+message 1
+
+--BOUNDARY
+Content-Type: text/plain
+
+message 2
+
+--BOUNDARY--
+""")
+
+ def test_no_separating_blank_line(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_35.txt')
+ eq(msg.as_string(), """\
+From: aperson@dom.ain
+To: bperson@dom.ain
+Subject: here's something interesting
+
+counter to RFC 2822, there's no separating newline here
+""")
+
+ # test_defect_handling
+ def test_lying_multipart(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_41.txt')
+ unless(hasattr(msg, 'defects'))
+ self.assertEqual(len(msg.defects), 2)
+ unless(isinstance(msg.defects[0],
+ errors.NoBoundaryInMultipartDefect))
+ unless(isinstance(msg.defects[1],
+ errors.MultipartInvariantViolationDefect))
+
+ # test_defect_handling
+ def test_missing_start_boundary(self):
+ outer = self._msgobj('msg_42.txt')
+ # The message structure is:
+ #
+ # multipart/mixed
+ # text/plain
+ # message/rfc822
+ # multipart/mixed [*]
+ #
+ # [*] This message is missing its start boundary
+ bad = outer.get_payload(1).get_payload(0)
+ self.assertEqual(len(bad.defects), 1)
+ self.assertTrue(isinstance(bad.defects[0],
+ errors.StartBoundaryNotFoundDefect))
+
+ # test_defect_handling
+ def test_first_line_is_continuation_header(self):
+ eq = self.assertEqual
+ m = ' Line 1\nSubject: test\n\nbody'
+ msg = email.message_from_string(m)
+ eq(msg.keys(), ['Subject'])
+ eq(msg.get_payload(), 'body')
+ eq(len(msg.defects), 1)
+ self.assertDefectsEqual(msg.defects,
+ [errors.FirstHeaderLineIsContinuationDefect])
+ eq(msg.defects[0].line, ' Line 1\n')
+
+ # test_defect_handling
+ def test_missing_header_body_separator(self):
+ # Our heuristic if we see a line that doesn't look like a header (no
+ # leading whitespace but no ':') is to assume that the blank line that
+ # separates the header from the body is missing, and to stop parsing
+ # headers and start parsing the body.
+ msg = self._str_msg('Subject: test\nnot a header\nTo: abc\n\nb\n')
+ self.assertEqual(msg.keys(), ['Subject'])
+ self.assertEqual(msg.get_payload(), 'not a header\nTo: abc\n\nb\n')
+ self.assertDefectsEqual(msg.defects,
+ [errors.MissingHeaderBodySeparatorDefect])
+
+
+# Test RFC 2047 header encoding and decoding
+class TestRFC2047(TestEmailBase):
+ def test_rfc2047_multiline(self):
+ eq = self.assertEqual
+ s = """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz
+ foo bar =?mac-iceland?q?r=8Aksm=9Arg=8Cs?="""
+ dh = decode_header(s)
+ eq(dh, [
+ (b'Re: ', None),
+ (b'r\x8aksm\x9arg\x8cs', 'mac-iceland'),
+ (b' baz foo bar ', None),
+ (b'r\x8aksm\x9arg\x8cs', 'mac-iceland')])
+ header = make_header(dh)
+ eq(str(header),
+ 'Re: r\xe4ksm\xf6rg\xe5s baz foo bar r\xe4ksm\xf6rg\xe5s')
+ self.ndiffAssertEqual(header.encode(maxlinelen=76), """\
+Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar =?mac-iceland?q?r=8Aksm?=
+ =?mac-iceland?q?=9Arg=8Cs?=""")
+
+ def test_whitespace_keeper_unicode(self):
+ eq = self.assertEqual
+ s = '=?ISO-8859-1?Q?Andr=E9?= Pirard <pirard@dom.ain>'
+ dh = decode_header(s)
+ eq(dh, [(b'Andr\xe9', 'iso-8859-1'),
+ (b' Pirard <pirard@dom.ain>', None)])
+ header = str(make_header(dh))
+ eq(header, 'Andr\xe9 Pirard <pirard@dom.ain>')
+
+ def test_whitespace_keeper_unicode_2(self):
+ eq = self.assertEqual
+ s = 'The =?iso-8859-1?b?cXVpY2sgYnJvd24gZm94?= jumped over the =?iso-8859-1?b?bGF6eSBkb2c=?='
+ dh = decode_header(s)
+ eq(dh, [(b'The ', None), (b'quick brown fox', 'iso-8859-1'),
+ (b' jumped over the ', None), (b'lazy dog', 'iso-8859-1')])
+ hu = str(make_header(dh))
+ eq(hu, 'The quick brown fox jumped over the lazy dog')
+
+ def test_rfc2047_missing_whitespace(self):
+ s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [(b'Sm', None), (b'\xf6', 'iso-8859-1'),
+ (b'rg', None), (b'\xe5', 'iso-8859-1'),
+ (b'sbord', None)])
+
+ def test_rfc2047_with_whitespace(self):
+ s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [(b'Sm ', None), (b'\xf6', 'iso-8859-1'),
+ (b' rg ', None), (b'\xe5', 'iso-8859-1'),
+ (b' sbord', None)])
+
+ def test_rfc2047_B_bad_padding(self):
+ s = '=?iso-8859-1?B?%s?='
+ data = [ # only test complete bytes
+ ('dm==', b'v'), ('dm=', b'v'), ('dm', b'v'),
+ ('dmk=', b'vi'), ('dmk', b'vi')
+ ]
+ for q, a in data:
+ dh = decode_header(s % q)
+ self.assertEqual(dh, [(a, 'iso-8859-1')])
+
+ def test_rfc2047_Q_invalid_digits(self):
+ # issue 10004.
+ s = '=?iso-8659-1?Q?andr=e9=zz?='
+ self.assertEqual(decode_header(s),
+ [(b'andr\xe9=zz', 'iso-8659-1')])
+
+ def test_rfc2047_rfc2047_1(self):
+ # 1st testcase at end of rfc2047
+ s = '(=?ISO-8859-1?Q?a?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'a', 'iso-8859-1'), (b')', None)])
+
+ def test_rfc2047_rfc2047_2(self):
+ # 2nd testcase at end of rfc2047
+ s = '(=?ISO-8859-1?Q?a?= b)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'a', 'iso-8859-1'), (b' b)', None)])
+
+ def test_rfc2047_rfc2047_3(self):
+ # 3rd testcase at end of rfc2047
+ s = '(=?ISO-8859-1?Q?a?= =?ISO-8859-1?Q?b?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'ab', 'iso-8859-1'), (b')', None)])
+
+ def test_rfc2047_rfc2047_4(self):
+ # 4th testcase at end of rfc2047
+ s = '(=?ISO-8859-1?Q?a?= =?ISO-8859-1?Q?b?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'ab', 'iso-8859-1'), (b')', None)])
+
+ def test_rfc2047_rfc2047_5a(self):
+ # 5th testcase at end of rfc2047 newline is \r\n
+ s = '(=?ISO-8859-1?Q?a?=\r\n =?ISO-8859-1?Q?b?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'ab', 'iso-8859-1'), (b')', None)])
+
+ def test_rfc2047_rfc2047_5b(self):
+ # 5th testcase at end of rfc2047 newline is \n
+ s = '(=?ISO-8859-1?Q?a?=\n =?ISO-8859-1?Q?b?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'ab', 'iso-8859-1'), (b')', None)])
+
+ def test_rfc2047_rfc2047_6(self):
+ # 6th testcase at end of rfc2047
+ s = '(=?ISO-8859-1?Q?a_b?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'a b', 'iso-8859-1'), (b')', None)])
+
+ def test_rfc2047_rfc2047_7(self):
+ # 7th testcase at end of rfc2047
+ s = '(=?ISO-8859-1?Q?a?= =?ISO-8859-2?Q?_b?=)'
+ self.assertEqual(decode_header(s),
+ [(b'(', None), (b'a', 'iso-8859-1'), (b' b', 'iso-8859-2'),
+ (b')', None)])
+ self.assertEqual(make_header(decode_header(s)).encode(), s.lower())
+ self.assertEqual(str(make_header(decode_header(s))), '(a b)')
+
+ def test_multiline_header(self):
+ s = '=?windows-1252?q?=22M=FCller_T=22?=\r\n <T.Mueller@xxx.com>'
+ self.assertEqual(decode_header(s),
+ [(b'"M\xfcller T"', 'windows-1252'),
+ (b'<T.Mueller@xxx.com>', None)])
+ self.assertEqual(make_header(decode_header(s)).encode(),
+ ''.join(s.splitlines()))
+ self.assertEqual(str(make_header(decode_header(s))),
+ '"Müller T" <T.Mueller@xxx.com>')
+
+
+# Test the MIMEMessage class
+class TestMIMEMessage(TestEmailBase):
+ def setUp(self):
+ with openfile('msg_11.txt') as fp:
+ self._text = fp.read()
+
+ def test_type_error(self):
+ self.assertRaises(TypeError, MIMEMessage, 'a plain string')
+
+ def test_valid_argument(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ subject = 'A sub-message'
+ m = Message()
+ m['Subject'] = subject
+ r = MIMEMessage(m)
+ eq(r.get_content_type(), 'message/rfc822')
+ payload = r.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ subpart = payload[0]
+ unless(subpart is m)
+ eq(subpart['subject'], subject)
+
+ def test_bad_multipart(self):
+ eq = self.assertEqual
+ msg1 = Message()
+ msg1['Subject'] = 'subpart 1'
+ msg2 = Message()
+ msg2['Subject'] = 'subpart 2'
+ r = MIMEMessage(msg1)
+ self.assertRaises(errors.MultipartConversionError, r.attach, msg2)
+
+ def test_generate(self):
+ # First craft the message to be encapsulated
+ m = Message()
+ m['Subject'] = 'An enclosed message'
+ m.set_payload('Here is the body of the message.\n')
+ r = MIMEMessage(m)
+ r['Subject'] = 'The enclosing message'
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(r)
+ self.assertEqual(s.getvalue(), """\
+Content-Type: message/rfc822
+MIME-Version: 1.0
+Subject: The enclosing message
+
+Subject: An enclosed message
+
+Here is the body of the message.
+""")
+
+ def test_parse_message_rfc822(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ msg = self._msgobj('msg_11.txt')
+ eq(msg.get_content_type(), 'message/rfc822')
+ payload = msg.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ submsg = payload[0]
+ self.assertTrue(isinstance(submsg, Message))
+ eq(submsg['subject'], 'An enclosed message')
+ eq(submsg.get_payload(), 'Here is the body of the message.\n')
+
+ def test_dsn(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ # msg 16 is a Delivery Status Notification, see RFC 1894
+ msg = self._msgobj('msg_16.txt')
+ eq(msg.get_content_type(), 'multipart/report')
+ unless(msg.is_multipart())
+ eq(len(msg.get_payload()), 3)
+ # Subpart 1 is a text/plain, human readable section
+ subpart = msg.get_payload(0)
+ eq(subpart.get_content_type(), 'text/plain')
+ eq(subpart.get_payload(), """\
+This report relates to a message you sent with the following header fields:
+
+ Message-id: <002001c144a6$8752e060$56104586@oxy.edu>
+ Date: Sun, 23 Sep 2001 20:10:55 -0700
+ From: "Ian T. Henry" <henryi@oxy.edu>
+ To: SoCal Raves <scr@socal-raves.org>
+ Subject: [scr] yeah for Ians!!
+
+Your message cannot be delivered to the following recipients:
+
+ Recipient address: jangel1@cougar.noc.ucla.edu
+ Reason: recipient reached disk quota
+
+""")
+ # Subpart 2 contains the machine parsable DSN information. It
+ # consists of two blocks of headers, represented by two nested Message
+ # objects.
+ subpart = msg.get_payload(1)
+ eq(subpart.get_content_type(), 'message/delivery-status')
+ eq(len(subpart.get_payload()), 2)
+ # message/delivery-status should treat each block as a bunch of
+ # headers, i.e. a bunch of Message objects.
+ dsn1 = subpart.get_payload(0)
+ unless(isinstance(dsn1, Message))
+ eq(dsn1['original-envelope-id'], '0GK500B4HD0888@cougar.noc.ucla.edu')
+ eq(dsn1.get_param('dns', header='reporting-mta'), '')
+ # Try a missing one <wink>
+ eq(dsn1.get_param('nsd', header='reporting-mta'), None)
+ dsn2 = subpart.get_payload(1)
+ unless(isinstance(dsn2, Message))
+ eq(dsn2['action'], 'failed')
+ eq(dsn2.get_params(header='original-recipient'),
+ [('rfc822', ''), ('jangel1@cougar.noc.ucla.edu', '')])
+ eq(dsn2.get_param('rfc822', header='final-recipient'), '')
+ # Subpart 3 is the original message
+ subpart = msg.get_payload(2)
+ eq(subpart.get_content_type(), 'message/rfc822')
+ payload = subpart.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ subsubpart = payload[0]
+ unless(isinstance(subsubpart, Message))
+ eq(subsubpart.get_content_type(), 'text/plain')
+ eq(subsubpart['message-id'],
+ '<002001c144a6$8752e060$56104586@oxy.edu>')
+
+ def test_epilogue(self):
+ eq = self.ndiffAssertEqual
+ with openfile('msg_21.txt') as fp:
+ text = fp.read()
+ msg = Message()
+ msg['From'] = 'aperson@dom.ain'
+ msg['To'] = 'bperson@dom.ain'
+ msg['Subject'] = 'Test'
+ msg.preamble = 'MIME message'
+ msg.epilogue = 'End of MIME message\n'
+ msg1 = MIMEText('One')
+ msg2 = MIMEText('Two')
+ msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY')
+ msg.attach(msg1)
+ msg.attach(msg2)
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), text)
+
+ def test_no_nl_preamble(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'aperson@dom.ain'
+ msg['To'] = 'bperson@dom.ain'
+ msg['Subject'] = 'Test'
+ msg.preamble = 'MIME message'
+ msg.epilogue = ''
+ msg1 = MIMEText('One')
+ msg2 = MIMEText('Two')
+ msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY')
+ msg.attach(msg1)
+ msg.attach(msg2)
+ eq(msg.as_string(), """\
+From: aperson@dom.ain
+To: bperson@dom.ain
+Subject: Test
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+MIME message
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+One
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+Two
+--BOUNDARY--
+""")
+
+ def test_default_type(self):
+ eq = self.assertEqual
+ with openfile('msg_30.txt') as fp:
+ msg = email.message_from_file(fp)
+ container1 = msg.get_payload(0)
+ eq(container1.get_default_type(), 'message/rfc822')
+ eq(container1.get_content_type(), 'message/rfc822')
+ container2 = msg.get_payload(1)
+ eq(container2.get_default_type(), 'message/rfc822')
+ eq(container2.get_content_type(), 'message/rfc822')
+ container1a = container1.get_payload(0)
+ eq(container1a.get_default_type(), 'text/plain')
+ eq(container1a.get_content_type(), 'text/plain')
+ container2a = container2.get_payload(0)
+ eq(container2a.get_default_type(), 'text/plain')
+ eq(container2a.get_content_type(), 'text/plain')
+
+ def test_default_type_with_explicit_container_type(self):
+ eq = self.assertEqual
+ with openfile('msg_28.txt') as fp:
+ msg = email.message_from_file(fp)
+ container1 = msg.get_payload(0)
+ eq(container1.get_default_type(), 'message/rfc822')
+ eq(container1.get_content_type(), 'message/rfc822')
+ container2 = msg.get_payload(1)
+ eq(container2.get_default_type(), 'message/rfc822')
+ eq(container2.get_content_type(), 'message/rfc822')
+ container1a = container1.get_payload(0)
+ eq(container1a.get_default_type(), 'text/plain')
+ eq(container1a.get_content_type(), 'text/plain')
+ container2a = container2.get_payload(0)
+ eq(container2a.get_default_type(), 'text/plain')
+ eq(container2a.get_content_type(), 'text/plain')
+
+ def test_default_type_non_parsed(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ # Set up container
+ container = MIMEMultipart('digest', 'BOUNDARY')
+ container.epilogue = ''
+ # Set up subparts
+ subpart1a = MIMEText('message 1\n')
+ subpart2a = MIMEText('message 2\n')
+ subpart1 = MIMEMessage(subpart1a)
+ subpart2 = MIMEMessage(subpart2a)
+ container.attach(subpart1)
+ container.attach(subpart2)
+ eq(subpart1.get_content_type(), 'message/rfc822')
+ eq(subpart1.get_default_type(), 'message/rfc822')
+ eq(subpart2.get_content_type(), 'message/rfc822')
+ eq(subpart2.get_default_type(), 'message/rfc822')
+ neq(container.as_string(0), '''\
+Content-Type: multipart/digest; boundary="BOUNDARY"
+MIME-Version: 1.0
+
+--BOUNDARY
+Content-Type: message/rfc822
+MIME-Version: 1.0
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 1
+
+--BOUNDARY
+Content-Type: message/rfc822
+MIME-Version: 1.0
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 2
+
+--BOUNDARY--
+''')
+ del subpart1['content-type']
+ del subpart1['mime-version']
+ del subpart2['content-type']
+ del subpart2['mime-version']
+ eq(subpart1.get_content_type(), 'message/rfc822')
+ eq(subpart1.get_default_type(), 'message/rfc822')
+ eq(subpart2.get_content_type(), 'message/rfc822')
+ eq(subpart2.get_default_type(), 'message/rfc822')
+ neq(container.as_string(0), '''\
+Content-Type: multipart/digest; boundary="BOUNDARY"
+MIME-Version: 1.0
+
+--BOUNDARY
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 1
+
+--BOUNDARY
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 2
+
+--BOUNDARY--
+''')
+
+ def test_mime_attachments_in_constructor(self):
+ eq = self.assertEqual
+ text1 = MIMEText('')
+ text2 = MIMEText('')
+ msg = MIMEMultipart(_subparts=(text1, text2))
+ eq(len(msg.get_payload()), 2)
+ eq(msg.get_payload(0), text1)
+ eq(msg.get_payload(1), text2)
+
+ def test_default_multipart_constructor(self):
+ msg = MIMEMultipart()
+ self.assertTrue(msg.is_multipart())
+
+
+# A general test of parser->model->generator idempotency. IOW, read a message
+# in, parse it into a message object tree, then without touching the tree,
+# regenerate the plain text. The original text and the transformed text
+# should be identical. Note: that we ignore the Unix-From since that may
+# contain a changed date.
+class TestIdempotent(TestEmailBase):
+
+ linesep = '\n'
+
+ def _msgobj(self, filename):
+ with openfile(filename) as fp:
+ data = fp.read()
+ msg = email.message_from_string(data)
+ return msg, data
+
+ def _idempotent(self, msg, text, unixfrom=False):
+ eq = self.ndiffAssertEqual
+ s = StringIO()
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg, unixfrom=unixfrom)
+ eq(text, s.getvalue())
+
+ def test_parse_text_message(self):
+ eq = self.assertEqual
+ msg, text = self._msgobj('msg_01.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_params()[1], ('charset', 'us-ascii'))
+ eq(msg.get_param('charset'), 'us-ascii')
+ eq(msg.preamble, None)
+ eq(msg.epilogue, None)
+ self._idempotent(msg, text)
+
+ def test_parse_untyped_message(self):
+ eq = self.assertEqual
+ msg, text = self._msgobj('msg_03.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_params(), None)
+ eq(msg.get_param('charset'), None)
+ self._idempotent(msg, text)
+
+ def test_simple_multipart(self):
+ msg, text = self._msgobj('msg_04.txt')
+ self._idempotent(msg, text)
+
+ def test_MIME_digest(self):
+ msg, text = self._msgobj('msg_02.txt')
+ self._idempotent(msg, text)
+
+ def test_long_header(self):
+ msg, text = self._msgobj('msg_27.txt')
+ self._idempotent(msg, text)
+
+ def test_MIME_digest_with_part_headers(self):
+ msg, text = self._msgobj('msg_28.txt')
+ self._idempotent(msg, text)
+
+ def test_mixed_with_image(self):
+ msg, text = self._msgobj('msg_06.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_report(self):
+ msg, text = self._msgobj('msg_05.txt')
+ self._idempotent(msg, text)
+
+ def test_dsn(self):
+ msg, text = self._msgobj('msg_16.txt')
+ self._idempotent(msg, text)
+
+ def test_preamble_epilogue(self):
+ msg, text = self._msgobj('msg_21.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_one_part(self):
+ msg, text = self._msgobj('msg_23.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_no_parts(self):
+ msg, text = self._msgobj('msg_24.txt')
+ self._idempotent(msg, text)
+
+ def test_no_start_boundary(self):
+ msg, text = self._msgobj('msg_31.txt')
+ self._idempotent(msg, text)
+
+ def test_rfc2231_charset(self):
+ msg, text = self._msgobj('msg_32.txt')
+ self._idempotent(msg, text)
+
+ def test_more_rfc2231_parameters(self):
+ msg, text = self._msgobj('msg_33.txt')
+ self._idempotent(msg, text)
+
+ def test_text_plain_in_a_multipart_digest(self):
+ msg, text = self._msgobj('msg_34.txt')
+ self._idempotent(msg, text)
+
+ def test_nested_multipart_mixeds(self):
+ msg, text = self._msgobj('msg_12a.txt')
+ self._idempotent(msg, text)
+
+ def test_message_external_body_idempotent(self):
+ msg, text = self._msgobj('msg_36.txt')
+ self._idempotent(msg, text)
+
+ def test_message_delivery_status(self):
+ msg, text = self._msgobj('msg_43.txt')
+ self._idempotent(msg, text, unixfrom=True)
+
+ def test_message_signed_idempotent(self):
+ msg, text = self._msgobj('msg_45.txt')
+ self._idempotent(msg, text)
+
+ def test_content_type(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ # Get a message object and reset the seek pointer for other tests
+ msg, text = self._msgobj('msg_05.txt')
+ eq(msg.get_content_type(), 'multipart/report')
+ # Test the Content-Type: parameters
+ params = {}
+ for pk, pv in msg.get_params():
+ params[pk] = pv
+ eq(params['report-type'], 'delivery-status')
+ eq(params['boundary'], 'D1690A7AC1.996856090/mail.example.com')
+ eq(msg.preamble, 'This is a MIME-encapsulated message.' + self.linesep)
+ eq(msg.epilogue, self.linesep)
+ eq(len(msg.get_payload()), 3)
+ # Make sure the subparts are what we expect
+ msg1 = msg.get_payload(0)
+ eq(msg1.get_content_type(), 'text/plain')
+ eq(msg1.get_payload(), 'Yadda yadda yadda' + self.linesep)
+ msg2 = msg.get_payload(1)
+ eq(msg2.get_content_type(), 'text/plain')
+ eq(msg2.get_payload(), 'Yadda yadda yadda' + self.linesep)
+ msg3 = msg.get_payload(2)
+ eq(msg3.get_content_type(), 'message/rfc822')
+ self.assertTrue(isinstance(msg3, Message))
+ payload = msg3.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ msg4 = payload[0]
+ unless(isinstance(msg4, Message))
+ eq(msg4.get_payload(), 'Yadda yadda yadda' + self.linesep)
+
+ def test_parser(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ msg, text = self._msgobj('msg_06.txt')
+ # Check some of the outer headers
+ eq(msg.get_content_type(), 'message/rfc822')
+ # Make sure the payload is a list of exactly one sub-Message, and that
+ # that submessage has a type of text/plain
+ payload = msg.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ msg1 = payload[0]
+ self.assertTrue(isinstance(msg1, Message))
+ eq(msg1.get_content_type(), 'text/plain')
+ self.assertTrue(isinstance(msg1.get_payload(), str))
+ eq(msg1.get_payload(), self.linesep)
+
+
+
+# Test various other bits of the package's functionality
+class TestMiscellaneous(TestEmailBase):
+ def test_message_from_string(self):
+ with openfile('msg_01.txt') as fp:
+ text = fp.read()
+ msg = email.message_from_string(text)
+ s = StringIO()
+ # Don't wrap/continue long headers since we're trying to test
+ # idempotency.
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ self.assertEqual(text, s.getvalue())
+
+ def test_message_from_file(self):
+ with openfile('msg_01.txt') as fp:
+ text = fp.read()
+ fp.seek(0)
+ msg = email.message_from_file(fp)
+ s = StringIO()
+ # Don't wrap/continue long headers since we're trying to test
+ # idempotency.
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ self.assertEqual(text, s.getvalue())
+
+ def test_message_from_string_with_class(self):
+ unless = self.assertTrue
+ with openfile('msg_01.txt') as fp:
+ text = fp.read()
+
+ # Create a subclass
+ class MyMessage(Message):
+ pass
+
+ msg = email.message_from_string(text, MyMessage)
+ unless(isinstance(msg, MyMessage))
+ # Try something more complicated
+ with openfile('msg_02.txt') as fp:
+ text = fp.read()
+ msg = email.message_from_string(text, MyMessage)
+ for subpart in msg.walk():
+ unless(isinstance(subpart, MyMessage))
+
+ def test_message_from_file_with_class(self):
+ unless = self.assertTrue
+ # Create a subclass
+ class MyMessage(Message):
+ pass
+
+ with openfile('msg_01.txt') as fp:
+ msg = email.message_from_file(fp, MyMessage)
+ unless(isinstance(msg, MyMessage))
+ # Try something more complicated
+ with openfile('msg_02.txt') as fp:
+ msg = email.message_from_file(fp, MyMessage)
+ for subpart in msg.walk():
+ unless(isinstance(subpart, MyMessage))
+
+ def test_custom_message_does_not_require_arguments(self):
+ class MyMessage(Message):
+ def __init__(self):
+ super().__init__()
+ msg = self._str_msg("Subject: test\n\ntest", MyMessage)
+ self.assertTrue(isinstance(msg, MyMessage))
+
+ def test__all__(self):
+ module = __import__('email')
+ self.assertEqual(sorted(module.__all__), [
+ 'base64mime', 'charset', 'encoders', 'errors', 'feedparser',
+ 'generator', 'header', 'iterators', 'message',
+ 'message_from_binary_file', 'message_from_bytes',
+ 'message_from_file', 'message_from_string', 'mime', 'parser',
+ 'quoprimime', 'utils',
+ ])
+
+ def test_formatdate(self):
+ now = time.time()
+ self.assertEqual(utils.parsedate(utils.formatdate(now))[:6],
+ time.gmtime(now)[:6])
+
+ def test_formatdate_localtime(self):
+ now = time.time()
+ self.assertEqual(
+ utils.parsedate(utils.formatdate(now, localtime=True))[:6],
+ time.localtime(now)[:6])
+
+ def test_formatdate_usegmt(self):
+ now = time.time()
+ self.assertEqual(
+ utils.formatdate(now, localtime=False),
+ time.strftime('%a, %d %b %Y %H:%M:%S -0000', time.gmtime(now)))
+ self.assertEqual(
+ utils.formatdate(now, localtime=False, usegmt=True),
+ time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(now)))
+
+ # parsedate and parsedate_tz will become deprecated interfaces someday
+ def test_parsedate_returns_None_for_invalid_strings(self):
+ self.assertIsNone(utils.parsedate(''))
+ self.assertIsNone(utils.parsedate_tz(''))
+ self.assertIsNone(utils.parsedate('0'))
+ self.assertIsNone(utils.parsedate_tz('0'))
+ self.assertIsNone(utils.parsedate('A Complete Waste of Time'))
+ self.assertIsNone(utils.parsedate_tz('A Complete Waste of Time'))
+ # Not a part of the spec but, but this has historically worked:
+ self.assertIsNone(utils.parsedate(None))
+ self.assertIsNone(utils.parsedate_tz(None))
+
+ def test_parsedate_compact(self):
+ # The FWS after the comma is optional
+ self.assertEqual(utils.parsedate('Wed,3 Apr 2002 14:58:26 +0800'),
+ utils.parsedate('Wed, 3 Apr 2002 14:58:26 +0800'))
+
+ def test_parsedate_no_dayofweek(self):
+ eq = self.assertEqual
+ eq(utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'),
+ (2003, 2, 25, 13, 47, 26, 0, 1, -1, -28800))
+
+ def test_parsedate_compact_no_dayofweek(self):
+ eq = self.assertEqual
+ eq(utils.parsedate_tz('5 Feb 2003 13:47:26 -0800'),
+ (2003, 2, 5, 13, 47, 26, 0, 1, -1, -28800))
+
+ def test_parsedate_no_space_before_positive_offset(self):
+ self.assertEqual(utils.parsedate_tz('Wed, 3 Apr 2002 14:58:26+0800'),
+ (2002, 4, 3, 14, 58, 26, 0, 1, -1, 28800))
+
+ def test_parsedate_no_space_before_negative_offset(self):
+ # Issue 1155362: we already handled '+' for this case.
+ self.assertEqual(utils.parsedate_tz('Wed, 3 Apr 2002 14:58:26-0800'),
+ (2002, 4, 3, 14, 58, 26, 0, 1, -1, -28800))
+
+
+ def test_parsedate_accepts_time_with_dots(self):
+ eq = self.assertEqual
+ eq(utils.parsedate_tz('5 Feb 2003 13.47.26 -0800'),
+ (2003, 2, 5, 13, 47, 26, 0, 1, -1, -28800))
+ eq(utils.parsedate_tz('5 Feb 2003 13.47 -0800'),
+ (2003, 2, 5, 13, 47, 0, 0, 1, -1, -28800))
+
+ def test_parsedate_acceptable_to_time_functions(self):
+ eq = self.assertEqual
+ timetup = utils.parsedate('5 Feb 2003 13:47:26 -0800')
+ t = int(time.mktime(timetup))
+ eq(time.localtime(t)[:6], timetup[:6])
+ eq(int(time.strftime('%Y', timetup)), 2003)
+ timetup = utils.parsedate_tz('5 Feb 2003 13:47:26 -0800')
+ t = int(time.mktime(timetup[:9]))
+ eq(time.localtime(t)[:6], timetup[:6])
+ eq(int(time.strftime('%Y', timetup[:9])), 2003)
+
+ def test_mktime_tz(self):
+ self.assertEqual(utils.mktime_tz((1970, 1, 1, 0, 0, 0,
+ -1, -1, -1, 0)), 0)
+ self.assertEqual(utils.mktime_tz((1970, 1, 1, 0, 0, 0,
+ -1, -1, -1, 1234)), -1234)
+
+ def test_parsedate_y2k(self):
+ """Test for parsing a date with a two-digit year.
+
+ Parsing a date with a two-digit year should return the correct
+ four-digit year. RFC822 allows two-digit years, but RFC2822 (which
+ obsoletes RFC822) requires four-digit years.
+
+ """
+ self.assertEqual(utils.parsedate_tz('25 Feb 03 13:47:26 -0800'),
+ utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'))
+ self.assertEqual(utils.parsedate_tz('25 Feb 71 13:47:26 -0800'),
+ utils.parsedate_tz('25 Feb 1971 13:47:26 -0800'))
+
+ def test_parseaddr_empty(self):
+ self.assertEqual(utils.parseaddr('<>'), ('', ''))
+ self.assertEqual(utils.formataddr(utils.parseaddr('<>')), '')
+
+ def test_noquote_dump(self):
+ self.assertEqual(
+ utils.formataddr(('A Silly Person', 'person@dom.ain')),
+ 'A Silly Person <person@dom.ain>')
+
+ def test_escape_dump(self):
+ self.assertEqual(
+ utils.formataddr(('A (Very) Silly Person', 'person@dom.ain')),
+ r'"A (Very) Silly Person" <person@dom.ain>')
+ self.assertEqual(
+ utils.parseaddr(r'"A \(Very\) Silly Person" <person@dom.ain>'),
+ ('A (Very) Silly Person', 'person@dom.ain'))
+ a = r'A \(Special\) Person'
+ b = 'person@dom.ain'
+ self.assertEqual(utils.parseaddr(utils.formataddr((a, b))), (a, b))
+
+ def test_escape_backslashes(self):
+ self.assertEqual(
+ utils.formataddr(('Arthur \Backslash\ Foobar', 'person@dom.ain')),
+ r'"Arthur \\Backslash\\ Foobar" <person@dom.ain>')
+ a = r'Arthur \Backslash\ Foobar'
+ b = 'person@dom.ain'
+ self.assertEqual(utils.parseaddr(utils.formataddr((a, b))), (a, b))
+
+ def test_quotes_unicode_names(self):
+ # issue 1690608. email.utils.formataddr() should be rfc2047 aware.
+ name = "H\u00e4ns W\u00fcrst"
+ addr = 'person@dom.ain'
+ utf8_base64 = "=?utf-8?b?SMOkbnMgV8O8cnN0?= <person@dom.ain>"
+ latin1_quopri = "=?iso-8859-1?q?H=E4ns_W=FCrst?= <person@dom.ain>"
+ self.assertEqual(utils.formataddr((name, addr)), utf8_base64)
+ self.assertEqual(utils.formataddr((name, addr), 'iso-8859-1'),
+ latin1_quopri)
+
+ def test_accepts_any_charset_like_object(self):
+ # issue 1690608. email.utils.formataddr() should be rfc2047 aware.
+ name = "H\u00e4ns W\u00fcrst"
+ addr = 'person@dom.ain'
+ utf8_base64 = "=?utf-8?b?SMOkbnMgV8O8cnN0?= <person@dom.ain>"
+ foobar = "FOOBAR"
+ class CharsetMock:
+ def header_encode(self, string):
+ return foobar
+ mock = CharsetMock()
+ mock_expected = "%s <%s>" % (foobar, addr)
+ self.assertEqual(utils.formataddr((name, addr), mock), mock_expected)
+ self.assertEqual(utils.formataddr((name, addr), Charset('utf-8')),
+ utf8_base64)
+
+ def test_invalid_charset_like_object_raises_error(self):
+ # issue 1690608. email.utils.formataddr() should be rfc2047 aware.
+ name = "H\u00e4ns W\u00fcrst"
+ addr = 'person@dom.ain'
+ # A object without a header_encode method:
+ bad_charset = object()
+ self.assertRaises(AttributeError, utils.formataddr, (name, addr),
+ bad_charset)
+
+ def test_unicode_address_raises_error(self):
+ # issue 1690608. email.utils.formataddr() should be rfc2047 aware.
+ addr = 'pers\u00f6n@dom.in'
+ self.assertRaises(UnicodeError, utils.formataddr, (None, addr))
+ self.assertRaises(UnicodeError, utils.formataddr, ("Name", addr))
+
+ def test_name_with_dot(self):
+ x = 'John X. Doe <jxd@example.com>'
+ y = '"John X. Doe" <jxd@example.com>'
+ a, b = ('John X. Doe', 'jxd@example.com')
+ self.assertEqual(utils.parseaddr(x), (a, b))
+ self.assertEqual(utils.parseaddr(y), (a, b))
+ # formataddr() quotes the name if there's a dot in it
+ self.assertEqual(utils.formataddr((a, b)), y)
+
+ def test_parseaddr_preserves_quoted_pairs_in_addresses(self):
+ # issue 10005. Note that in the third test the second pair of
+ # backslashes is not actually a quoted pair because it is not inside a
+ # comment or quoted string: the address being parsed has a quoted
+ # string containing a quoted backslash, followed by 'example' and two
+ # backslashes, followed by another quoted string containing a space and
+ # the word 'example'. parseaddr copies those two backslashes
+ # literally. Per rfc5322 this is not technically correct since a \ may
+ # not appear in an address outside of a quoted string. It is probably
+ # a sensible Postel interpretation, though.
+ eq = self.assertEqual
+ eq(utils.parseaddr('""example" example"@example.com'),
+ ('', '""example" example"@example.com'))
+ eq(utils.parseaddr('"\\"example\\" example"@example.com'),
+ ('', '"\\"example\\" example"@example.com'))
+ eq(utils.parseaddr('"\\\\"example\\\\" example"@example.com'),
+ ('', '"\\\\"example\\\\" example"@example.com'))
+
+ def test_parseaddr_preserves_spaces_in_local_part(self):
+ # issue 9286. A normal RFC5322 local part should not contain any
+ # folding white space, but legacy local parts can (they are a sequence
+ # of atoms, not dotatoms). On the other hand we strip whitespace from
+ # before the @ and around dots, on the assumption that the whitespace
+ # around the punctuation is a mistake in what would otherwise be
+ # an RFC5322 local part. Leading whitespace is, usual, stripped as well.
+ self.assertEqual(('', "merwok wok@xample.com"),
+ utils.parseaddr("merwok wok@xample.com"))
+ self.assertEqual(('', "merwok wok@xample.com"),
+ utils.parseaddr("merwok wok@xample.com"))
+ self.assertEqual(('', "merwok wok@xample.com"),
+ utils.parseaddr(" merwok wok @xample.com"))
+ self.assertEqual(('', 'merwok"wok" wok@xample.com'),
+ utils.parseaddr('merwok"wok" wok@xample.com'))
+ self.assertEqual(('', 'merwok.wok.wok@xample.com'),
+ utils.parseaddr('merwok. wok . wok@xample.com'))
+
+ def test_formataddr_does_not_quote_parens_in_quoted_string(self):
+ addr = ("'foo@example.com' (foo@example.com)",
+ 'foo@example.com')
+ addrstr = ('"\'foo@example.com\' '
+ '(foo@example.com)" <foo@example.com>')
+ self.assertEqual(utils.parseaddr(addrstr), addr)
+ self.assertEqual(utils.formataddr(addr), addrstr)
+
+
+ def test_multiline_from_comment(self):
+ x = """\
+Foo
+\tBar <foo@example.com>"""
+ self.assertEqual(utils.parseaddr(x), ('Foo Bar', 'foo@example.com'))
+
+ def test_quote_dump(self):
+ self.assertEqual(
+ utils.formataddr(('A Silly; Person', 'person@dom.ain')),
+ r'"A Silly; Person" <person@dom.ain>')
+
+ def test_charset_richcomparisons(self):
+ eq = self.assertEqual
+ ne = self.assertNotEqual
+ cset1 = Charset()
+ cset2 = Charset()
+ eq(cset1, 'us-ascii')
+ eq(cset1, 'US-ASCII')
+ eq(cset1, 'Us-AsCiI')
+ eq('us-ascii', cset1)
+ eq('US-ASCII', cset1)
+ eq('Us-AsCiI', cset1)
+ ne(cset1, 'usascii')
+ ne(cset1, 'USASCII')
+ ne(cset1, 'UsAsCiI')
+ ne('usascii', cset1)
+ ne('USASCII', cset1)
+ ne('UsAsCiI', cset1)
+ eq(cset1, cset2)
+ eq(cset2, cset1)
+
+ def test_getaddresses(self):
+ eq = self.assertEqual
+ eq(utils.getaddresses(['aperson@dom.ain (Al Person)',
+ 'Bud Person <bperson@dom.ain>']),
+ [('Al Person', 'aperson@dom.ain'),
+ ('Bud Person', 'bperson@dom.ain')])
+
+ def test_getaddresses_nasty(self):
+ eq = self.assertEqual
+ eq(utils.getaddresses(['foo: ;']), [('', '')])
+ eq(utils.getaddresses(
+ ['[]*-- =~$']),
+ [('', ''), ('', ''), ('', '*--')])
+ eq(utils.getaddresses(
+ ['foo: ;', '"Jason R. Mastaler" <jason@dom.ain>']),
+ [('', ''), ('Jason R. Mastaler', 'jason@dom.ain')])
+
+ def test_getaddresses_embedded_comment(self):
+ """Test proper handling of a nested comment"""
+ eq = self.assertEqual
+ addrs = utils.getaddresses(['User ((nested comment)) <foo@bar.com>'])
+ eq(addrs[0][1], 'foo@bar.com')
+
+ def test_utils_quote_unquote(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.add_header('content-disposition', 'attachment',
+ filename='foo\\wacky"name')
+ eq(msg.get_filename(), 'foo\\wacky"name')
+
+ def test_get_body_encoding_with_bogus_charset(self):
+ charset = Charset('not a charset')
+ self.assertEqual(charset.get_body_encoding(), 'base64')
+
+ def test_get_body_encoding_with_uppercase_charset(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg['Content-Type'] = 'text/plain; charset=UTF-8'
+ eq(msg['content-type'], 'text/plain; charset=UTF-8')
+ charsets = msg.get_charsets()
+ eq(len(charsets), 1)
+ eq(charsets[0], 'utf-8')
+ charset = Charset(charsets[0])
+ eq(charset.get_body_encoding(), 'base64')
+ msg.set_payload(b'hello world', charset=charset)
+ eq(msg.get_payload(), 'aGVsbG8gd29ybGQ=\n')
+ eq(msg.get_payload(decode=True), b'hello world')
+ eq(msg['content-transfer-encoding'], 'base64')
+ # Try another one
+ msg = Message()
+ msg['Content-Type'] = 'text/plain; charset="US-ASCII"'
+ charsets = msg.get_charsets()
+ eq(len(charsets), 1)
+ eq(charsets[0], 'us-ascii')
+ charset = Charset(charsets[0])
+ eq(charset.get_body_encoding(), encoders.encode_7or8bit)
+ msg.set_payload('hello world', charset=charset)
+ eq(msg.get_payload(), 'hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_charsets_case_insensitive(self):
+ lc = Charset('us-ascii')
+ uc = Charset('US-ASCII')
+ self.assertEqual(lc.get_body_encoding(), uc.get_body_encoding())
+
+ def test_partial_falls_inside_message_delivery_status(self):
+ eq = self.ndiffAssertEqual
+ # The Parser interface provides chunks of data to FeedParser in 8192
+ # byte gulps. SF bug #1076485 found one of those chunks inside
+ # message/delivery-status header block, which triggered an
+ # unreadline() of NeedMoreData.
+ msg = self._msgobj('msg_43.txt')
+ sfp = StringIO()
+ iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/report
+ text/plain
+ message/delivery-status
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/rfc822-headers
+""")
+
+ def test_make_msgid_domain(self):
+ self.assertEqual(
+ email.utils.make_msgid(domain='testdomain-string')[-19:],
+ '@testdomain-string>')
+
+
+# Test the iterator/generators
+class TestIterators(TestEmailBase):
+ def test_body_line_iterator(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ # First a simple non-multipart message
+ msg = self._msgobj('msg_01.txt')
+ it = iterators.body_line_iterator(msg)
+ lines = list(it)
+ eq(len(lines), 6)
+ neq(EMPTYSTRING.join(lines), msg.get_payload())
+ # Now a more complicated multipart
+ msg = self._msgobj('msg_02.txt')
+ it = iterators.body_line_iterator(msg)
+ lines = list(it)
+ eq(len(lines), 43)
+ with openfile('msg_19.txt') as fp:
+ neq(EMPTYSTRING.join(lines), fp.read())
+
+ def test_typed_subpart_iterator(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_04.txt')
+ it = iterators.typed_subpart_iterator(msg, 'text')
+ lines = []
+ subparts = 0
+ for subpart in it:
+ subparts += 1
+ lines.append(subpart.get_payload())
+ eq(subparts, 2)
+ eq(EMPTYSTRING.join(lines), """\
+a simple kind of mirror
+to reflect upon our own
+a simple kind of mirror
+to reflect upon our own
+""")
+
+ def test_typed_subpart_iterator_default_type(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_03.txt')
+ it = iterators.typed_subpart_iterator(msg, 'text', 'plain')
+ lines = []
+ subparts = 0
+ for subpart in it:
+ subparts += 1
+ lines.append(subpart.get_payload())
+ eq(subparts, 1)
+ eq(EMPTYSTRING.join(lines), """\
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_pushCR_LF(self):
+ '''FeedParser BufferedSubFile.push() assumed it received complete
+ line endings. A CR ending one push() followed by a LF starting
+ the next push() added an empty line.
+ '''
+ imt = [
+ ("a\r \n", 2),
+ ("b", 0),
+ ("c\n", 1),
+ ("", 0),
+ ("d\r\n", 1),
+ ("e\r", 0),
+ ("\nf", 1),
+ ("\r\n", 1),
+ ]
+ from email.feedparser import BufferedSubFile, NeedMoreData
+ bsf = BufferedSubFile()
+ om = []
+ nt = 0
+ for il, n in imt:
+ bsf.push(il)
+ nt += n
+ n1 = 0
+ while True:
+ ol = bsf.readline()
+ if ol == NeedMoreData:
+ break
+ om.append(ol)
+ n1 += 1
+ self.assertTrue(n == n1)
+ self.assertTrue(len(om) == nt)
+ self.assertTrue(''.join([il for il, n in imt]) == ''.join(om))
+
+
+
+class TestParsers(TestEmailBase):
+
+ def test_header_parser(self):
+ eq = self.assertEqual
+ # Parse only the headers of a complex multipart MIME document
+ with openfile('msg_02.txt') as fp:
+ msg = HeaderParser().parse(fp)
+ eq(msg['from'], 'ppp-request@zzz.org')
+ eq(msg['to'], 'ppp@zzz.org')
+ eq(msg.get_content_type(), 'multipart/mixed')
+ self.assertFalse(msg.is_multipart())
+ self.assertTrue(isinstance(msg.get_payload(), str))
+
+ def test_bytes_header_parser(self):
+ eq = self.assertEqual
+ # Parse only the headers of a complex multipart MIME document
+ with openfile('msg_02.txt', 'rb') as fp:
+ msg = email.parser.BytesHeaderParser().parse(fp)
+ eq(msg['from'], 'ppp-request@zzz.org')
+ eq(msg['to'], 'ppp@zzz.org')
+ eq(msg.get_content_type(), 'multipart/mixed')
+ self.assertFalse(msg.is_multipart())
+ self.assertTrue(isinstance(msg.get_payload(), str))
+ self.assertTrue(isinstance(msg.get_payload(decode=True), bytes))
+
+ def test_whitespace_continuation(self):
+ eq = self.assertEqual
+ # This message contains a line after the Subject: header that has only
+ # whitespace, but it is not empty!
+ msg = email.message_from_string("""\
+From: aperson@dom.ain
+To: bperson@dom.ain
+Subject: the next line has a space on it
+\x20
+Date: Mon, 8 Apr 2002 15:09:19 -0400
+Message-ID: spam
+
+Here's the message body
+""")
+ eq(msg['subject'], 'the next line has a space on it\n ')
+ eq(msg['message-id'], 'spam')
+ eq(msg.get_payload(), "Here's the message body\n")
+
+ def test_whitespace_continuation_last_header(self):
+ eq = self.assertEqual
+ # Like the previous test, but the subject line is the last
+ # header.
+ msg = email.message_from_string("""\
+From: aperson@dom.ain
+To: bperson@dom.ain
+Date: Mon, 8 Apr 2002 15:09:19 -0400
+Message-ID: spam
+Subject: the next line has a space on it
+\x20
+
+Here's the message body
+""")
+ eq(msg['subject'], 'the next line has a space on it\n ')
+ eq(msg['message-id'], 'spam')
+ eq(msg.get_payload(), "Here's the message body\n")
+
+ def test_crlf_separation(self):
+ eq = self.assertEqual
+ with openfile('msg_26.txt', newline='\n') as fp:
+ msg = Parser().parse(fp)
+ eq(len(msg.get_payload()), 2)
+ part1 = msg.get_payload(0)
+ eq(part1.get_content_type(), 'text/plain')
+ eq(part1.get_payload(), 'Simple email with attachment.\r\n\r\n')
+ part2 = msg.get_payload(1)
+ eq(part2.get_content_type(), 'application/riscos')
+
+ def test_crlf_flatten(self):
+ # Using newline='\n' preserves the crlfs in this input file.
+ with openfile('msg_26.txt', newline='\n') as fp:
+ text = fp.read()
+ msg = email.message_from_string(text)
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(msg, linesep='\r\n')
+ self.assertEqual(s.getvalue(), text)
+
+ maxDiff = None
+
+ def test_multipart_digest_with_extra_mime_headers(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ with openfile('msg_28.txt') as fp:
+ msg = email.message_from_file(fp)
+ # Structure is:
+ # multipart/digest
+ # message/rfc822
+ # text/plain
+ # message/rfc822
+ # text/plain
+ eq(msg.is_multipart(), 1)
+ eq(len(msg.get_payload()), 2)
+ part1 = msg.get_payload(0)
+ eq(part1.get_content_type(), 'message/rfc822')
+ eq(part1.is_multipart(), 1)
+ eq(len(part1.get_payload()), 1)
+ part1a = part1.get_payload(0)
+ eq(part1a.is_multipart(), 0)
+ eq(part1a.get_content_type(), 'text/plain')
+ neq(part1a.get_payload(), 'message 1\n')
+ # next message/rfc822
+ part2 = msg.get_payload(1)
+ eq(part2.get_content_type(), 'message/rfc822')
+ eq(part2.is_multipart(), 1)
+ eq(len(part2.get_payload()), 1)
+ part2a = part2.get_payload(0)
+ eq(part2a.is_multipart(), 0)
+ eq(part2a.get_content_type(), 'text/plain')
+ neq(part2a.get_payload(), 'message 2\n')
+
+ def test_three_lines(self):
+ # A bug report by Andrew McNamara
+ lines = ['From: Andrew Person <aperson@dom.ain',
+ 'Subject: Test',
+ 'Date: Tue, 20 Aug 2002 16:43:45 +1000']
+ msg = email.message_from_string(NL.join(lines))
+ self.assertEqual(msg['date'], 'Tue, 20 Aug 2002 16:43:45 +1000')
+
+ def test_strip_line_feed_and_carriage_return_in_headers(self):
+ eq = self.assertEqual
+ # For [ 1002475 ] email message parser doesn't handle \r\n correctly
+ value1 = 'text'
+ value2 = 'more text'
+ m = 'Header: %s\r\nNext-Header: %s\r\n\r\nBody\r\n\r\n' % (
+ value1, value2)
+ msg = email.message_from_string(m)
+ eq(msg.get('Header'), value1)
+ eq(msg.get('Next-Header'), value2)
+
+ def test_rfc2822_header_syntax(self):
+ eq = self.assertEqual
+ m = '>From: foo\nFrom: bar\n!"#QUX;~: zoo\n\nbody'
+ msg = email.message_from_string(m)
+ eq(len(msg), 3)
+ eq(sorted(field for field in msg), ['!"#QUX;~', '>From', 'From'])
+ eq(msg.get_payload(), 'body')
+
+ def test_rfc2822_space_not_allowed_in_header(self):
+ eq = self.assertEqual
+ m = '>From foo@example.com 11:25:53\nFrom: bar\n!"#QUX;~: zoo\n\nbody'
+ msg = email.message_from_string(m)
+ eq(len(msg.keys()), 0)
+
+ def test_rfc2822_one_character_header(self):
+ eq = self.assertEqual
+ m = 'A: first header\nB: second header\nCC: third header\n\nbody'
+ msg = email.message_from_string(m)
+ headers = msg.keys()
+ headers.sort()
+ eq(headers, ['A', 'B', 'CC'])
+ eq(msg.get_payload(), 'body')
+
+ def test_CRLFLF_at_end_of_part(self):
+ # issue 5610: feedparser should not eat two chars from body part ending
+ # with "\r\n\n".
+ m = (
+ "From: foo@bar.com\n"
+ "To: baz\n"
+ "Mime-Version: 1.0\n"
+ "Content-Type: multipart/mixed; boundary=BOUNDARY\n"
+ "\n"
+ "--BOUNDARY\n"
+ "Content-Type: text/plain\n"
+ "\n"
+ "body ending with CRLF newline\r\n"
+ "\n"
+ "--BOUNDARY--\n"
+ )
+ msg = email.message_from_string(m)
+ self.assertTrue(msg.get_payload(0).get_payload().endswith('\r\n'))
+
+
+class Test8BitBytesHandling(unittest.TestCase):
+ # In Python3 all input is string, but that doesn't work if the actual input
+ # uses an 8bit transfer encoding. To hack around that, in email 5.1 we
+ # decode byte streams using the surrogateescape error handler, and
+ # reconvert to binary at appropriate places if we detect surrogates. This
+ # doesn't allow us to transform headers with 8bit bytes (they get munged),
+ # but it does allow us to parse and preserve them, and to decode body
+ # parts that use an 8bit CTE.
+
+ bodytest_msg = textwrap.dedent("""\
+ From: foo@bar.com
+ To: baz
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset={charset}
+ Content-Transfer-Encoding: {cte}
+
+ {bodyline}
+ """)
+
+ def test_known_8bit_CTE(self):
+ m = self.bodytest_msg.format(charset='utf-8',
+ cte='8bit',
+ bodyline='pöstal').encode('utf-8')
+ msg = email.message_from_bytes(m)
+ self.assertEqual(msg.get_payload(), "pöstal\n")
+ self.assertEqual(msg.get_payload(decode=True),
+ "pöstal\n".encode('utf-8'))
+
+ def test_unknown_8bit_CTE(self):
+ m = self.bodytest_msg.format(charset='notavalidcharset',
+ cte='8bit',
+ bodyline='pöstal').encode('utf-8')
+ msg = email.message_from_bytes(m)
+ self.assertEqual(msg.get_payload(), "p\uFFFD\uFFFDstal\n")
+ self.assertEqual(msg.get_payload(decode=True),
+ "pöstal\n".encode('utf-8'))
+
+ def test_8bit_in_quopri_body(self):
+ # This is non-RFC compliant data...without 'decode' the library code
+ # decodes the body using the charset from the headers, and because the
+ # source byte really is utf-8 this works. This is likely to fail
+ # against real dirty data (ie: produce mojibake), but the data is
+ # invalid anyway so it is as good a guess as any. But this means that
+ # this test just confirms the current behavior; that behavior is not
+ # necessarily the best possible behavior. With 'decode' it is
+ # returning the raw bytes, so that test should be of correct behavior,
+ # or at least produce the same result that email4 did.
+ m = self.bodytest_msg.format(charset='utf-8',
+ cte='quoted-printable',
+ bodyline='p=C3=B6stál').encode('utf-8')
+ msg = email.message_from_bytes(m)
+ self.assertEqual(msg.get_payload(), 'p=C3=B6stál\n')
+ self.assertEqual(msg.get_payload(decode=True),
+ 'pöstál\n'.encode('utf-8'))
+
+ def test_invalid_8bit_in_non_8bit_cte_uses_replace(self):
+ # This is similar to the previous test, but proves that if the 8bit
+ # byte is undecodeable in the specified charset, it gets replaced
+ # by the unicode 'unknown' character. Again, this may or may not
+ # be the ideal behavior. Note that if decode=False none of the
+ # decoders will get involved, so this is the only test we need
+ # for this behavior.
+ m = self.bodytest_msg.format(charset='ascii',
+ cte='quoted-printable',
+ bodyline='p=C3=B6stál').encode('utf-8')
+ msg = email.message_from_bytes(m)
+ self.assertEqual(msg.get_payload(), 'p=C3=B6st\uFFFD\uFFFDl\n')
+ self.assertEqual(msg.get_payload(decode=True),
+ 'pöstál\n'.encode('utf-8'))
+
+ # test_defect_handling:test_invalid_chars_in_base64_payload
+ def test_8bit_in_base64_body(self):
+ # If we get 8bit bytes in a base64 body, we can just ignore them
+ # as being outside the base64 alphabet and decode anyway. But
+ # we register a defect.
+ m = self.bodytest_msg.format(charset='utf-8',
+ cte='base64',
+ bodyline='cMO2c3RhbAá=').encode('utf-8')
+ msg = email.message_from_bytes(m)
+ self.assertEqual(msg.get_payload(decode=True),
+ 'pöstal'.encode('utf-8'))
+ self.assertIsInstance(msg.defects[0],
+ errors.InvalidBase64CharactersDefect)
+
+ def test_8bit_in_uuencode_body(self):
+ # Sticking an 8bit byte in a uuencode block makes it undecodable by
+ # normal means, so the block is returned undecoded, but as bytes.
+ m = self.bodytest_msg.format(charset='utf-8',
+ cte='uuencode',
+ bodyline='<,.V<W1A; á ').encode('utf-8')
+ msg = email.message_from_bytes(m)
+ self.assertEqual(msg.get_payload(decode=True),
+ '<,.V<W1A; á \n'.encode('utf-8'))
+
+
+ headertest_headers = (
+ ('From: foo@bar.com', ('From', 'foo@bar.com')),
+ ('To: báz', ('To', '=?unknown-8bit?q?b=C3=A1z?=')),
+ ('Subject: Maintenant je vous présente mon collègue, le pouf célèbre\n'
+ '\tJean de Baddie',
+ ('Subject', '=?unknown-8bit?q?Maintenant_je_vous_pr=C3=A9sente_mon_'
+ 'coll=C3=A8gue=2C_le_pouf_c=C3=A9l=C3=A8bre?=\n'
+ ' =?unknown-8bit?q?_Jean_de_Baddie?=')),
+ ('From: göst', ('From', '=?unknown-8bit?b?Z8O2c3Q=?=')),
+ )
+ headertest_msg = ('\n'.join([src for (src, _) in headertest_headers]) +
+ '\nYes, they are flying.\n').encode('utf-8')
+
+ def test_get_8bit_header(self):
+ msg = email.message_from_bytes(self.headertest_msg)
+ self.assertEqual(str(msg.get('to')), 'b\uFFFD\uFFFDz')
+ self.assertEqual(str(msg['to']), 'b\uFFFD\uFFFDz')
+
+ def test_print_8bit_headers(self):
+ msg = email.message_from_bytes(self.headertest_msg)
+ self.assertEqual(str(msg),
+ textwrap.dedent("""\
+ From: {}
+ To: {}
+ Subject: {}
+ From: {}
+
+ Yes, they are flying.
+ """).format(*[expected[1] for (_, expected) in
+ self.headertest_headers]))
+
+ def test_values_with_8bit_headers(self):
+ msg = email.message_from_bytes(self.headertest_msg)
+ self.assertListEqual([str(x) for x in msg.values()],
+ ['foo@bar.com',
+ 'b\uFFFD\uFFFDz',
+ 'Maintenant je vous pr\uFFFD\uFFFDsente mon '
+ 'coll\uFFFD\uFFFDgue, le pouf '
+ 'c\uFFFD\uFFFDl\uFFFD\uFFFDbre\n'
+ '\tJean de Baddie',
+ "g\uFFFD\uFFFDst"])
+
+ def test_items_with_8bit_headers(self):
+ msg = email.message_from_bytes(self.headertest_msg)
+ self.assertListEqual([(str(x), str(y)) for (x, y) in msg.items()],
+ [('From', 'foo@bar.com'),
+ ('To', 'b\uFFFD\uFFFDz'),
+ ('Subject', 'Maintenant je vous '
+ 'pr\uFFFD\uFFFDsente '
+ 'mon coll\uFFFD\uFFFDgue, le pouf '
+ 'c\uFFFD\uFFFDl\uFFFD\uFFFDbre\n'
+ '\tJean de Baddie'),
+ ('From', 'g\uFFFD\uFFFDst')])
+
+ def test_get_all_with_8bit_headers(self):
+ msg = email.message_from_bytes(self.headertest_msg)
+ self.assertListEqual([str(x) for x in msg.get_all('from')],
+ ['foo@bar.com',
+ 'g\uFFFD\uFFFDst'])
+
+ def test_get_content_type_with_8bit(self):
+ msg = email.message_from_bytes(textwrap.dedent("""\
+ Content-Type: text/pl\xA7in; charset=utf-8
+ """).encode('latin-1'))
+ self.assertEqual(msg.get_content_type(), "text/pl\uFFFDin")
+ self.assertEqual(msg.get_content_maintype(), "text")
+ self.assertEqual(msg.get_content_subtype(), "pl\uFFFDin")
+
+ # test_headerregistry.TestContentTypeHeader.non_ascii_in_params
+ def test_get_params_with_8bit(self):
+ msg = email.message_from_bytes(
+ 'X-Header: foo=\xa7ne; b\xa7r=two; baz=three\n'.encode('latin-1'))
+ self.assertEqual(msg.get_params(header='x-header'),
+ [('foo', '\uFFFDne'), ('b\uFFFDr', 'two'), ('baz', 'three')])
+ self.assertEqual(msg.get_param('Foo', header='x-header'), '\uFFFdne')
+ # XXX: someday you might be able to get 'b\xa7r', for now you can't.
+ self.assertEqual(msg.get_param('b\xa7r', header='x-header'), None)
+
+ # test_headerregistry.TestContentTypeHeader.non_ascii_in_rfc2231_value
+ def test_get_rfc2231_params_with_8bit(self):
+ msg = email.message_from_bytes(textwrap.dedent("""\
+ Content-Type: text/plain; charset=us-ascii;
+ title*=us-ascii'en'This%20is%20not%20f\xa7n"""
+ ).encode('latin-1'))
+ self.assertEqual(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is not f\uFFFDn'))
+
+ def test_set_rfc2231_params_with_8bit(self):
+ msg = email.message_from_bytes(textwrap.dedent("""\
+ Content-Type: text/plain; charset=us-ascii;
+ title*=us-ascii'en'This%20is%20not%20f\xa7n"""
+ ).encode('latin-1'))
+ msg.set_param('title', 'test')
+ self.assertEqual(msg.get_param('title'), 'test')
+
+ def test_del_rfc2231_params_with_8bit(self):
+ msg = email.message_from_bytes(textwrap.dedent("""\
+ Content-Type: text/plain; charset=us-ascii;
+ title*=us-ascii'en'This%20is%20not%20f\xa7n"""
+ ).encode('latin-1'))
+ msg.del_param('title')
+ self.assertEqual(msg.get_param('title'), None)
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_payload_with_8bit_cte_header(self):
+ msg = email.message_from_bytes(textwrap.dedent("""\
+ Content-Transfer-Encoding: b\xa7se64
+ Content-Type: text/plain; charset=latin-1
+
+ payload
+ """).encode('latin-1'))
+ self.assertEqual(msg.get_payload(), 'payload\n')
+ self.assertEqual(msg.get_payload(decode=True), b'payload\n')
+
+ non_latin_bin_msg = textwrap.dedent("""\
+ From: foo@bar.com
+ To: báz
+ Subject: Maintenant je vous présente mon collègue, le pouf célèbre
+ \tJean de Baddie
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+
+ Да, они летят.
+ """).encode('utf-8')
+
+ def test_bytes_generator(self):
+ msg = email.message_from_bytes(self.non_latin_bin_msg)
+ out = BytesIO()
+ email.generator.BytesGenerator(out).flatten(msg)
+ self.assertEqual(out.getvalue(), self.non_latin_bin_msg)
+
+ def test_bytes_generator_handles_None_body(self):
+ #Issue 11019
+ msg = email.message.Message()
+ out = BytesIO()
+ email.generator.BytesGenerator(out).flatten(msg)
+ self.assertEqual(out.getvalue(), b"\n")
+
+ non_latin_bin_msg_as7bit_wrapped = textwrap.dedent("""\
+ From: foo@bar.com
+ To: =?unknown-8bit?q?b=C3=A1z?=
+ Subject: =?unknown-8bit?q?Maintenant_je_vous_pr=C3=A9sente_mon_coll=C3=A8gue?=
+ =?unknown-8bit?q?=2C_le_pouf_c=C3=A9l=C3=A8bre?=
+ =?unknown-8bit?q?_Jean_de_Baddie?=
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: base64
+
+ 0JTQsCwg0L7QvdC4INC70LXRgtGP0YIuCg==
+ """)
+
+ def test_generator_handles_8bit(self):
+ msg = email.message_from_bytes(self.non_latin_bin_msg)
+ out = StringIO()
+ email.generator.Generator(out).flatten(msg)
+ self.assertEqual(out.getvalue(), self.non_latin_bin_msg_as7bit_wrapped)
+
+ def test_bytes_generator_with_unix_from(self):
+ # The unixfrom contains a current date, so we can't check it
+ # literally. Just make sure the first word is 'From' and the
+ # rest of the message matches the input.
+ msg = email.message_from_bytes(self.non_latin_bin_msg)
+ out = BytesIO()
+ email.generator.BytesGenerator(out).flatten(msg, unixfrom=True)
+ lines = out.getvalue().split(b'\n')
+ self.assertEqual(lines[0].split()[0], b'From')
+ self.assertEqual(b'\n'.join(lines[1:]), self.non_latin_bin_msg)
+
+ non_latin_bin_msg_as7bit = non_latin_bin_msg_as7bit_wrapped.split('\n')
+ non_latin_bin_msg_as7bit[2:4] = [
+ 'Subject: =?unknown-8bit?q?Maintenant_je_vous_pr=C3=A9sente_mon_'
+ 'coll=C3=A8gue=2C_le_pouf_c=C3=A9l=C3=A8bre?=']
+ non_latin_bin_msg_as7bit = '\n'.join(non_latin_bin_msg_as7bit)
+
+ def test_message_from_binary_file(self):
+ fn = 'test.msg'
+ self.addCleanup(unlink, fn)
+ with open(fn, 'wb') as testfile:
+ testfile.write(self.non_latin_bin_msg)
+ with open(fn, 'rb') as testfile:
+ m = email.parser.BytesParser().parse(testfile)
+ self.assertEqual(str(m), self.non_latin_bin_msg_as7bit)
+
+ latin_bin_msg = textwrap.dedent("""\
+ From: foo@bar.com
+ To: Dinsdale
+ Subject: Nudge nudge, wink, wink
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset="latin-1"
+ Content-Transfer-Encoding: 8bit
+
+ oh là là, know what I mean, know what I mean?
+ """).encode('latin-1')
+
+ latin_bin_msg_as7bit = textwrap.dedent("""\
+ From: foo@bar.com
+ To: Dinsdale
+ Subject: Nudge nudge, wink, wink
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset="iso-8859-1"
+ Content-Transfer-Encoding: quoted-printable
+
+ oh l=E0 l=E0, know what I mean, know what I mean?
+ """)
+
+ def test_string_generator_reencodes_to_quopri_when_appropriate(self):
+ m = email.message_from_bytes(self.latin_bin_msg)
+ self.assertEqual(str(m), self.latin_bin_msg_as7bit)
+
+ def test_decoded_generator_emits_unicode_body(self):
+ m = email.message_from_bytes(self.latin_bin_msg)
+ out = StringIO()
+ email.generator.DecodedGenerator(out).flatten(m)
+ #DecodedHeader output contains an extra blank line compared
+ #to the input message. RDM: not sure if this is a bug or not,
+ #but it is not specific to the 8bit->7bit conversion.
+ self.assertEqual(out.getvalue(),
+ self.latin_bin_msg.decode('latin-1')+'\n')
+
+ def test_bytes_feedparser(self):
+ bfp = email.feedparser.BytesFeedParser()
+ for i in range(0, len(self.latin_bin_msg), 10):
+ bfp.feed(self.latin_bin_msg[i:i+10])
+ m = bfp.close()
+ self.assertEqual(str(m), self.latin_bin_msg_as7bit)
+
+ def test_crlf_flatten(self):
+ with openfile('msg_26.txt', 'rb') as fp:
+ text = fp.read()
+ msg = email.message_from_bytes(text)
+ s = BytesIO()
+ g = email.generator.BytesGenerator(s)
+ g.flatten(msg, linesep='\r\n')
+ self.assertEqual(s.getvalue(), text)
+
+ def test_8bit_multipart(self):
+ # Issue 11605
+ source = textwrap.dedent("""\
+ Date: Fri, 18 Mar 2011 17:15:43 +0100
+ To: foo@example.com
+ From: foodwatch-Newsletter <bar@example.com>
+ Subject: Aktuelles zu Japan, Klonfleisch und Smiley-System
+ Message-ID: <76a486bee62b0d200f33dc2ca08220ad@localhost.localdomain>
+ MIME-Version: 1.0
+ Content-Type: multipart/alternative;
+ boundary="b1_76a486bee62b0d200f33dc2ca08220ad"
+
+ --b1_76a486bee62b0d200f33dc2ca08220ad
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+
+ Guten Tag, ,
+
+ mit großer Betroffenheit verfolgen auch wir im foodwatch-Team die
+ Nachrichten aus Japan.
+
+
+ --b1_76a486bee62b0d200f33dc2ca08220ad
+ Content-Type: text/html; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+
+ <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"
+ "http://www.w3.org/TR/html4/loose.dtd">
+ <html lang="de">
+ <head>
+ <title>foodwatch - Newsletter</title>
+ </head>
+ <body>
+ <p>mit gro&szlig;er Betroffenheit verfolgen auch wir im foodwatch-Team
+ die Nachrichten aus Japan.</p>
+ </body>
+ </html>
+ --b1_76a486bee62b0d200f33dc2ca08220ad--
+
+ """).encode('utf-8')
+ msg = email.message_from_bytes(source)
+ s = BytesIO()
+ g = email.generator.BytesGenerator(s)
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), source)
+
+ def test_bytes_generator_b_encoding_linesep(self):
+ # Issue 14062: b encoding was tacking on an extra \n.
+ m = Message()
+ # This has enough non-ascii that it should always end up b encoded.
+ m['Subject'] = Header('žluťoučký kůň')
+ s = BytesIO()
+ g = email.generator.BytesGenerator(s)
+ g.flatten(m, linesep='\r\n')
+ self.assertEqual(
+ s.getvalue(),
+ b'Subject: =?utf-8?b?xb5sdcWlb3XEjWvDvSBrxa/FiA==?=\r\n\r\n')
+
+ def test_generator_b_encoding_linesep(self):
+ # Since this broke in ByteGenerator, test Generator for completeness.
+ m = Message()
+ # This has enough non-ascii that it should always end up b encoded.
+ m['Subject'] = Header('žluťoučký kůň')
+ s = StringIO()
+ g = email.generator.Generator(s)
+ g.flatten(m, linesep='\r\n')
+ self.assertEqual(
+ s.getvalue(),
+ 'Subject: =?utf-8?b?xb5sdcWlb3XEjWvDvSBrxa/FiA==?=\r\n\r\n')
+
+ maxDiff = None
+
+
+class BaseTestBytesGeneratorIdempotent:
+
+ maxDiff = None
+
+ def _msgobj(self, filename):
+ with openfile(filename, 'rb') as fp:
+ data = fp.read()
+ data = self.normalize_linesep_regex.sub(self.blinesep, data)
+ msg = email.message_from_bytes(data)
+ return msg, data
+
+ def _idempotent(self, msg, data, unixfrom=False):
+ b = BytesIO()
+ g = email.generator.BytesGenerator(b, maxheaderlen=0)
+ g.flatten(msg, unixfrom=unixfrom, linesep=self.linesep)
+ self.assertEqual(data, b.getvalue())
+
+
+class TestBytesGeneratorIdempotentNL(BaseTestBytesGeneratorIdempotent,
+ TestIdempotent):
+ linesep = '\n'
+ blinesep = b'\n'
+ normalize_linesep_regex = re.compile(br'\r\n')
+
+
+class TestBytesGeneratorIdempotentCRLF(BaseTestBytesGeneratorIdempotent,
+ TestIdempotent):
+ linesep = '\r\n'
+ blinesep = b'\r\n'
+ normalize_linesep_regex = re.compile(br'(?<!\r)\n')
+
+
+class TestBase64(unittest.TestCase):
+ def test_len(self):
+ eq = self.assertEqual
+ eq(base64mime.header_length('hello'),
+ len(base64mime.body_encode(b'hello', eol='')))
+ for size in range(15):
+ if size == 0 : bsize = 0
+ elif size <= 3 : bsize = 4
+ elif size <= 6 : bsize = 8
+ elif size <= 9 : bsize = 12
+ elif size <= 12: bsize = 16
+ else : bsize = 20
+ eq(base64mime.header_length('x' * size), bsize)
+
+ def test_decode(self):
+ eq = self.assertEqual
+ eq(base64mime.decode(''), b'')
+ eq(base64mime.decode('aGVsbG8='), b'hello')
+
+ def test_encode(self):
+ eq = self.assertEqual
+ eq(base64mime.body_encode(b''), b'')
+ eq(base64mime.body_encode(b'hello'), 'aGVsbG8=\n')
+ # Test the binary flag
+ eq(base64mime.body_encode(b'hello\n'), 'aGVsbG8K\n')
+ # Test the maxlinelen arg
+ eq(base64mime.body_encode(b'xxxx ' * 20, maxlinelen=40), """\
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IA==
+""")
+ # Test the eol argument
+ eq(base64mime.body_encode(b'xxxx ' * 20, maxlinelen=40, eol='\r\n'),
+ """\
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IA==\r
+""")
+
+ def test_header_encode(self):
+ eq = self.assertEqual
+ he = base64mime.header_encode
+ eq(he('hello'), '=?iso-8859-1?b?aGVsbG8=?=')
+ eq(he('hello\r\nworld'), '=?iso-8859-1?b?aGVsbG8NCndvcmxk?=')
+ eq(he('hello\nworld'), '=?iso-8859-1?b?aGVsbG8Kd29ybGQ=?=')
+ # Test the charset option
+ eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?b?aGVsbG8=?=')
+ eq(he('hello\nworld'), '=?iso-8859-1?b?aGVsbG8Kd29ybGQ=?=')
+
+
+
+class TestQuopri(unittest.TestCase):
+ def setUp(self):
+ # Set of characters (as byte integers) that don't need to be encoded
+ # in headers.
+ self.hlit = list(chain(
+ range(ord('a'), ord('z') + 1),
+ range(ord('A'), ord('Z') + 1),
+ range(ord('0'), ord('9') + 1),
+ (c for c in b'!*+-/')))
+ # Set of characters (as byte integers) that do need to be encoded in
+ # headers.
+ self.hnon = [c for c in range(256) if c not in self.hlit]
+ assert len(self.hlit) + len(self.hnon) == 256
+ # Set of characters (as byte integers) that don't need to be encoded
+ # in bodies.
+ self.blit = list(range(ord(' '), ord('~') + 1))
+ self.blit.append(ord('\t'))
+ self.blit.remove(ord('='))
+ # Set of characters (as byte integers) that do need to be encoded in
+ # bodies.
+ self.bnon = [c for c in range(256) if c not in self.blit]
+ assert len(self.blit) + len(self.bnon) == 256
+
+ def test_quopri_header_check(self):
+ for c in self.hlit:
+ self.assertFalse(quoprimime.header_check(c),
+ 'Should not be header quopri encoded: %s' % chr(c))
+ for c in self.hnon:
+ self.assertTrue(quoprimime.header_check(c),
+ 'Should be header quopri encoded: %s' % chr(c))
+
+ def test_quopri_body_check(self):
+ for c in self.blit:
+ self.assertFalse(quoprimime.body_check(c),
+ 'Should not be body quopri encoded: %s' % chr(c))
+ for c in self.bnon:
+ self.assertTrue(quoprimime.body_check(c),
+ 'Should be body quopri encoded: %s' % chr(c))
+
+ def test_header_quopri_len(self):
+ eq = self.assertEqual
+ eq(quoprimime.header_length(b'hello'), 5)
+ # RFC 2047 chrome is not included in header_length().
+ eq(len(quoprimime.header_encode(b'hello', charset='xxx')),
+ quoprimime.header_length(b'hello') +
+ # =?xxx?q?...?= means 10 extra characters
+ 10)
+ eq(quoprimime.header_length(b'h@e@l@l@o@'), 20)
+ # RFC 2047 chrome is not included in header_length().
+ eq(len(quoprimime.header_encode(b'h@e@l@l@o@', charset='xxx')),
+ quoprimime.header_length(b'h@e@l@l@o@') +
+ # =?xxx?q?...?= means 10 extra characters
+ 10)
+ for c in self.hlit:
+ eq(quoprimime.header_length(bytes([c])), 1,
+ 'expected length 1 for %r' % chr(c))
+ for c in self.hnon:
+ # Space is special; it's encoded to _
+ if c == ord(' '):
+ continue
+ eq(quoprimime.header_length(bytes([c])), 3,
+ 'expected length 3 for %r' % chr(c))
+ eq(quoprimime.header_length(b' '), 1)
+
+ def test_body_quopri_len(self):
+ eq = self.assertEqual
+ for c in self.blit:
+ eq(quoprimime.body_length(bytes([c])), 1)
+ for c in self.bnon:
+ eq(quoprimime.body_length(bytes([c])), 3)
+
+ def test_quote_unquote_idempotent(self):
+ for x in range(256):
+ c = chr(x)
+ self.assertEqual(quoprimime.unquote(quoprimime.quote(c)), c)
+
+ def _test_header_encode(self, header, expected_encoded_header, charset=None):
+ if charset is None:
+ encoded_header = quoprimime.header_encode(header)
+ else:
+ encoded_header = quoprimime.header_encode(header, charset)
+ self.assertEqual(encoded_header, expected_encoded_header)
+
+ def test_header_encode_null(self):
+ self._test_header_encode(b'', '')
+
+ def test_header_encode_one_word(self):
+ self._test_header_encode(b'hello', '=?iso-8859-1?q?hello?=')
+
+ def test_header_encode_two_lines(self):
+ self._test_header_encode(b'hello\nworld',
+ '=?iso-8859-1?q?hello=0Aworld?=')
+
+ def test_header_encode_non_ascii(self):
+ self._test_header_encode(b'hello\xc7there',
+ '=?iso-8859-1?q?hello=C7there?=')
+
+ def test_header_encode_alt_charset(self):
+ self._test_header_encode(b'hello', '=?iso-8859-2?q?hello?=',
+ charset='iso-8859-2')
+
+ def _test_header_decode(self, encoded_header, expected_decoded_header):
+ decoded_header = quoprimime.header_decode(encoded_header)
+ self.assertEqual(decoded_header, expected_decoded_header)
+
+ def test_header_decode_null(self):
+ self._test_header_decode('', '')
+
+ def test_header_decode_one_word(self):
+ self._test_header_decode('hello', 'hello')
+
+ def test_header_decode_two_lines(self):
+ self._test_header_decode('hello=0Aworld', 'hello\nworld')
+
+ def test_header_decode_non_ascii(self):
+ self._test_header_decode('hello=C7there', 'hello\xc7there')
+
+ def _test_decode(self, encoded, expected_decoded, eol=None):
+ if eol is None:
+ decoded = quoprimime.decode(encoded)
+ else:
+ decoded = quoprimime.decode(encoded, eol=eol)
+ self.assertEqual(decoded, expected_decoded)
+
+ def test_decode_null_word(self):
+ self._test_decode('', '')
+
+ def test_decode_null_line_null_word(self):
+ self._test_decode('\r\n', '\n')
+
+ def test_decode_one_word(self):
+ self._test_decode('hello', 'hello')
+
+ def test_decode_one_word_eol(self):
+ self._test_decode('hello', 'hello', eol='X')
+
+ def test_decode_one_line(self):
+ self._test_decode('hello\r\n', 'hello\n')
+
+ def test_decode_one_line_lf(self):
+ self._test_decode('hello\n', 'hello\n')
+
+ def test_decode_one_line_cr(self):
+ self._test_decode('hello\r', 'hello\n')
+
+ def test_decode_one_line_nl(self):
+ self._test_decode('hello\n', 'helloX', eol='X')
+
+ def test_decode_one_line_crnl(self):
+ self._test_decode('hello\r\n', 'helloX', eol='X')
+
+ def test_decode_one_line_one_word(self):
+ self._test_decode('hello\r\nworld', 'hello\nworld')
+
+ def test_decode_one_line_one_word_eol(self):
+ self._test_decode('hello\r\nworld', 'helloXworld', eol='X')
+
+ def test_decode_two_lines(self):
+ self._test_decode('hello\r\nworld\r\n', 'hello\nworld\n')
+
+ def test_decode_two_lines_eol(self):
+ self._test_decode('hello\r\nworld\r\n', 'helloXworldX', eol='X')
+
+ def test_decode_one_long_line(self):
+ self._test_decode('Spam' * 250, 'Spam' * 250)
+
+ def test_decode_one_space(self):
+ self._test_decode(' ', '')
+
+ def test_decode_multiple_spaces(self):
+ self._test_decode(' ' * 5, '')
+
+ def test_decode_one_line_trailing_spaces(self):
+ self._test_decode('hello \r\n', 'hello\n')
+
+ def test_decode_two_lines_trailing_spaces(self):
+ self._test_decode('hello \r\nworld \r\n', 'hello\nworld\n')
+
+ def test_decode_quoted_word(self):
+ self._test_decode('=22quoted=20words=22', '"quoted words"')
+
+ def test_decode_uppercase_quoting(self):
+ self._test_decode('ab=CD=EF', 'ab\xcd\xef')
+
+ def test_decode_lowercase_quoting(self):
+ self._test_decode('ab=cd=ef', 'ab\xcd\xef')
+
+ def test_decode_soft_line_break(self):
+ self._test_decode('soft line=\r\nbreak', 'soft linebreak')
+
+ def test_decode_false_quoting(self):
+ self._test_decode('A=1,B=A ==> A+B==2', 'A=1,B=A ==> A+B==2')
+
+ def _test_encode(self, body, expected_encoded_body, maxlinelen=None, eol=None):
+ kwargs = {}
+ if maxlinelen is None:
+ # Use body_encode's default.
+ maxlinelen = 76
+ else:
+ kwargs['maxlinelen'] = maxlinelen
+ if eol is None:
+ # Use body_encode's default.
+ eol = '\n'
+ else:
+ kwargs['eol'] = eol
+ encoded_body = quoprimime.body_encode(body, **kwargs)
+ self.assertEqual(encoded_body, expected_encoded_body)
+ if eol == '\n' or eol == '\r\n':
+ # We know how to split the result back into lines, so maxlinelen
+ # can be checked.
+ for line in encoded_body.splitlines():
+ self.assertLessEqual(len(line), maxlinelen)
+
+ def test_encode_null(self):
+ self._test_encode('', '')
+
+ def test_encode_null_lines(self):
+ self._test_encode('\n\n', '\n\n')
+
+ def test_encode_one_line(self):
+ self._test_encode('hello\n', 'hello\n')
+
+ def test_encode_one_line_crlf(self):
+ self._test_encode('hello\r\n', 'hello\n')
+
+ def test_encode_one_line_eol(self):
+ self._test_encode('hello\n', 'hello\r\n', eol='\r\n')
+
+ def test_encode_one_space(self):
+ self._test_encode(' ', '=20')
+
+ def test_encode_one_line_one_space(self):
+ self._test_encode(' \n', '=20\n')
+
+# XXX: body_encode() expect strings, but uses ord(char) from these strings
+# to index into a 256-entry list. For code points above 255, this will fail.
+# Should there be a check for 8-bit only ord() values in body, or at least
+# a comment about the expected input?
+
+ def test_encode_two_lines_one_space(self):
+ self._test_encode(' \n \n', '=20\n=20\n')
+
+ def test_encode_one_word_trailing_spaces(self):
+ self._test_encode('hello ', 'hello =20')
+
+ def test_encode_one_line_trailing_spaces(self):
+ self._test_encode('hello \n', 'hello =20\n')
+
+ def test_encode_one_word_trailing_tab(self):
+ self._test_encode('hello \t', 'hello =09')
+
+ def test_encode_one_line_trailing_tab(self):
+ self._test_encode('hello \t\n', 'hello =09\n')
+
+ def test_encode_trailing_space_before_maxlinelen(self):
+ self._test_encode('abcd \n1234', 'abcd =\n\n1234', maxlinelen=6)
+
+ def test_encode_trailing_space_at_maxlinelen(self):
+ self._test_encode('abcd \n1234', 'abcd=\n=20\n1234', maxlinelen=5)
+
+ def test_encode_trailing_space_beyond_maxlinelen(self):
+ self._test_encode('abcd \n1234', 'abc=\nd=20\n1234', maxlinelen=4)
+
+ def test_encode_whitespace_lines(self):
+ self._test_encode(' \n' * 5, '=20\n' * 5)
+
+ def test_encode_quoted_equals(self):
+ self._test_encode('a = b', 'a =3D b')
+
+ def test_encode_one_long_string(self):
+ self._test_encode('x' * 100, 'x' * 75 + '=\n' + 'x' * 25)
+
+ def test_encode_one_long_line(self):
+ self._test_encode('x' * 100 + '\n', 'x' * 75 + '=\n' + 'x' * 25 + '\n')
+
+ def test_encode_one_very_long_line(self):
+ self._test_encode('x' * 200 + '\n',
+ 2 * ('x' * 75 + '=\n') + 'x' * 50 + '\n')
+
+ def test_encode_one_long_line(self):
+ self._test_encode('x' * 100 + '\n', 'x' * 75 + '=\n' + 'x' * 25 + '\n')
+
+ def test_encode_shortest_maxlinelen(self):
+ self._test_encode('=' * 5, '=3D=\n' * 4 + '=3D', maxlinelen=4)
+
+ def test_encode_maxlinelen_too_small(self):
+ self.assertRaises(ValueError, self._test_encode, '', '', maxlinelen=3)
+
+ def test_encode(self):
+ eq = self.assertEqual
+ eq(quoprimime.body_encode(''), '')
+ eq(quoprimime.body_encode('hello'), 'hello')
+ # Test the binary flag
+ eq(quoprimime.body_encode('hello\r\nworld'), 'hello\nworld')
+ # Test the maxlinelen arg
+ eq(quoprimime.body_encode('xxxx ' * 20, maxlinelen=40), """\
+xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=
+ xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=
+x xxxx xxxx xxxx xxxx=20""")
+ # Test the eol argument
+ eq(quoprimime.body_encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'),
+ """\
+xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=\r
+ xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=\r
+x xxxx xxxx xxxx xxxx=20""")
+ eq(quoprimime.body_encode("""\
+one line
+
+two line"""), """\
+one line
+
+two line""")
+
+
+
+# Test the Charset class
+class TestCharset(unittest.TestCase):
+ def tearDown(self):
+ from email import charset as CharsetModule
+ try:
+ del CharsetModule.CHARSETS['fake']
+ except KeyError:
+ pass
+
+ def test_codec_encodeable(self):
+ eq = self.assertEqual
+ # Make sure us-ascii = no Unicode conversion
+ c = Charset('us-ascii')
+ eq(c.header_encode('Hello World!'), 'Hello World!')
+ # Test 8-bit idempotency with us-ascii
+ s = '\xa4\xa2\xa4\xa4\xa4\xa6\xa4\xa8\xa4\xaa'
+ self.assertRaises(UnicodeError, c.header_encode, s)
+ c = Charset('utf-8')
+ eq(c.header_encode(s), '=?utf-8?b?wqTCosKkwqTCpMKmwqTCqMKkwqo=?=')
+
+ def test_body_encode(self):
+ eq = self.assertEqual
+ # Try a charset with QP body encoding
+ c = Charset('iso-8859-1')
+ eq('hello w=F6rld', c.body_encode('hello w\xf6rld'))
+ # Try a charset with Base64 body encoding
+ c = Charset('utf-8')
+ eq('aGVsbG8gd29ybGQ=\n', c.body_encode(b'hello world'))
+ # Try a charset with None body encoding
+ c = Charset('us-ascii')
+ eq('hello world', c.body_encode('hello world'))
+ # Try the convert argument, where input codec != output codec
+ c = Charset('euc-jp')
+ # With apologies to Tokio Kikuchi ;)
+ # XXX FIXME
+## try:
+## eq('\x1b$B5FCO;~IW\x1b(B',
+## c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7'))
+## eq('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7',
+## c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', False))
+## except LookupError:
+## # We probably don't have the Japanese codecs installed
+## pass
+ # Testing SF bug #625509, which we have to fake, since there are no
+ # built-in encodings where the header encoding is QP but the body
+ # encoding is not.
+ from email import charset as CharsetModule
+ CharsetModule.add_charset('fake', CharsetModule.QP, None, 'utf-8')
+ c = Charset('fake')
+ eq('hello world', c.body_encode('hello world'))
+
+ def test_unicode_charset_name(self):
+ charset = Charset('us-ascii')
+ self.assertEqual(str(charset), 'us-ascii')
+ self.assertRaises(errors.CharsetError, Charset, 'asc\xffii')
+
+
+
+# Test multilingual MIME headers.
+class TestHeader(TestEmailBase):
+ def test_simple(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Hello World!')
+ eq(h.encode(), 'Hello World!')
+ h.append(' Goodbye World!')
+ eq(h.encode(), 'Hello World! Goodbye World!')
+
+ def test_simple_surprise(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Hello World!')
+ eq(h.encode(), 'Hello World!')
+ h.append('Goodbye World!')
+ eq(h.encode(), 'Hello World! Goodbye World!')
+
+ def test_header_needs_no_decoding(self):
+ h = 'no decoding needed'
+ self.assertEqual(decode_header(h), [(h, None)])
+
+ def test_long(self):
+ h = Header("I am the very model of a modern Major-General; I've information vegetable, animal, and mineral; I know the kings of England, and I quote the fights historical from Marathon to Waterloo, in order categorical; I'm very well acquainted, too, with matters mathematical; I understand equations, both the simple and quadratical; about binomial theorem I'm teeming with a lot o' news, with many cheerful facts about the square of the hypotenuse.",
+ maxlinelen=76)
+ for l in h.encode(splitchars=' ').split('\n '):
+ self.assertTrue(len(l) <= 76)
+
+ def test_multilingual(self):
+ eq = self.ndiffAssertEqual
+ g = Charset("iso-8859-1")
+ cz = Charset("iso-8859-2")
+ utf8 = Charset("utf-8")
+ g_head = (b'Die Mieter treten hier ein werden mit einem '
+ b'Foerderband komfortabel den Korridor entlang, '
+ b'an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, '
+ b'gegen die rotierenden Klingen bef\xf6rdert. ')
+ cz_head = (b'Finan\xe8ni metropole se hroutily pod tlakem jejich '
+ b'd\xf9vtipu.. ')
+ utf8_head = ('\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f'
+ '\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00'
+ '\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c'
+ '\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067'
+ '\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das '
+ 'Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder '
+ 'die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066'
+ '\u3044\u307e\u3059\u3002')
+ h = Header(g_head, g)
+ h.append(cz_head, cz)
+ h.append(utf8_head, utf8)
+ enc = h.encode(maxlinelen=76)
+ eq(enc, """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerderband_kom?=
+ =?iso-8859-1?q?fortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndischen_Wand?=
+ =?iso-8859-1?q?gem=E4lden_vorbei=2C_gegen_die_rotierenden_Klingen_bef=F6r?=
+ =?iso-8859-1?q?dert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se_hroutily?=
+ =?iso-8859-2?q?_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= =?utf-8?b?5q2j56K6?=
+ =?utf-8?b?44Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb44KT44CC?=
+ =?utf-8?b?5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go44Gv44Gn?=
+ =?utf-8?b?44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBpc3QgZGFz?=
+ =?utf-8?b?IE51bnN0dWNrIGdpdCB1bmQgU2xvdGVybWV5ZXI/IEphISBCZWloZXJodW5k?=
+ =?utf-8?b?IGRhcyBPZGVyIGRpZSBGbGlwcGVyd2FsZHQgZ2Vyc3B1dC7jgI3jgajoqIA=?=
+ =?utf-8?b?44Gj44Gm44GE44G+44GZ44CC?=""")
+ decoded = decode_header(enc)
+ eq(len(decoded), 3)
+ eq(decoded[0], (g_head, 'iso-8859-1'))
+ eq(decoded[1], (cz_head, 'iso-8859-2'))
+ eq(decoded[2], (utf8_head.encode('utf-8'), 'utf-8'))
+ ustr = str(h)
+ eq(ustr,
+ (b'Die Mieter treten hier ein werden mit einem Foerderband '
+ b'komfortabel den Korridor entlang, an s\xc3\xbcdl\xc3\xbcndischen '
+ b'Wandgem\xc3\xa4lden vorbei, gegen die rotierenden Klingen '
+ b'bef\xc3\xb6rdert. Finan\xc4\x8dni metropole se hroutily pod '
+ b'tlakem jejich d\xc5\xafvtipu.. \xe6\xad\xa3\xe7\xa2\xba\xe3\x81'
+ b'\xab\xe8\xa8\x80\xe3\x81\x86\xe3\x81\xa8\xe7\xbf\xbb\xe8\xa8\xb3'
+ b'\xe3\x81\xaf\xe3\x81\x95\xe3\x82\x8c\xe3\x81\xa6\xe3\x81\x84\xe3'
+ b'\x81\xbe\xe3\x81\x9b\xe3\x82\x93\xe3\x80\x82\xe4\xb8\x80\xe9\x83'
+ b'\xa8\xe3\x81\xaf\xe3\x83\x89\xe3\x82\xa4\xe3\x83\x84\xe8\xaa\x9e'
+ b'\xe3\x81\xa7\xe3\x81\x99\xe3\x81\x8c\xe3\x80\x81\xe3\x81\x82\xe3'
+ b'\x81\xa8\xe3\x81\xaf\xe3\x81\xa7\xe3\x81\x9f\xe3\x82\x89\xe3\x82'
+ b'\x81\xe3\x81\xa7\xe3\x81\x99\xe3\x80\x82\xe5\xae\x9f\xe9\x9a\x9b'
+ b'\xe3\x81\xab\xe3\x81\xaf\xe3\x80\x8cWenn ist das Nunstuck git '
+ b'und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt '
+ b'gersput.\xe3\x80\x8d\xe3\x81\xa8\xe8\xa8\x80\xe3\x81\xa3\xe3\x81'
+ b'\xa6\xe3\x81\x84\xe3\x81\xbe\xe3\x81\x99\xe3\x80\x82'
+ ).decode('utf-8'))
+ # Test make_header()
+ newh = make_header(decode_header(enc))
+ eq(newh, h)
+
+ def test_empty_header_encode(self):
+ h = Header()
+ self.assertEqual(h.encode(), '')
+
+ def test_header_ctor_default_args(self):
+ eq = self.ndiffAssertEqual
+ h = Header()
+ eq(h, '')
+ h.append('foo', Charset('iso-8859-1'))
+ eq(h, 'foo')
+
+ def test_explicit_maxlinelen(self):
+ eq = self.ndiffAssertEqual
+ hstr = ('A very long line that must get split to something other '
+ 'than at the 76th character boundary to test the non-default '
+ 'behavior')
+ h = Header(hstr)
+ eq(h.encode(), '''\
+A very long line that must get split to something other than at the 76th
+ character boundary to test the non-default behavior''')
+ eq(str(h), hstr)
+ h = Header(hstr, header_name='Subject')
+ eq(h.encode(), '''\
+A very long line that must get split to something other than at the
+ 76th character boundary to test the non-default behavior''')
+ eq(str(h), hstr)
+ h = Header(hstr, maxlinelen=1024, header_name='Subject')
+ eq(h.encode(), hstr)
+ eq(str(h), hstr)
+
+ def test_quopri_splittable(self):
+ eq = self.ndiffAssertEqual
+ h = Header(charset='iso-8859-1', maxlinelen=20)
+ x = 'xxxx ' * 20
+ h.append(x)
+ s = h.encode()
+ eq(s, """\
+=?iso-8859-1?q?xxx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_x?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?x_?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?xx?=
+ =?iso-8859-1?q?_?=""")
+ eq(x, str(make_header(decode_header(s))))
+ h = Header(charset='iso-8859-1', maxlinelen=40)
+ h.append('xxxx ' * 20)
+ s = h.encode()
+ eq(s, """\
+=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xxx?=
+ =?iso-8859-1?q?x_xxxx_xxxx_xxxx_xxxx_?=
+ =?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=
+ =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=
+ =?iso-8859-1?q?_xxxx_xxxx_?=""")
+ eq(x, str(make_header(decode_header(s))))
+
+ def test_base64_splittable(self):
+ eq = self.ndiffAssertEqual
+ h = Header(charset='koi8-r', maxlinelen=20)
+ x = 'xxxx ' * 20
+ h.append(x)
+ s = h.encode()
+ eq(s, """\
+=?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IHh4?=
+ =?koi8-r?b?eHgg?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IHh4?=
+ =?koi8-r?b?eHgg?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IHh4?=
+ =?koi8-r?b?eHgg?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IHh4?=
+ =?koi8-r?b?eHgg?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IHh4?=
+ =?koi8-r?b?eHgg?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IHh4?=
+ =?koi8-r?b?eHgg?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?eCB4?=
+ =?koi8-r?b?eHh4?=
+ =?koi8-r?b?IA==?=""")
+ eq(x, str(make_header(decode_header(s))))
+ h = Header(charset='koi8-r', maxlinelen=40)
+ h.append(x)
+ s = h.encode()
+ eq(s, """\
+=?koi8-r?b?eHh4eCB4eHh4IHh4eHggeHh4?=
+ =?koi8-r?b?eCB4eHh4IHh4eHggeHh4eCB4?=
+ =?koi8-r?b?eHh4IHh4eHggeHh4eCB4eHh4?=
+ =?koi8-r?b?IHh4eHggeHh4eCB4eHh4IHh4?=
+ =?koi8-r?b?eHggeHh4eCB4eHh4IHh4eHgg?=
+ =?koi8-r?b?eHh4eCB4eHh4IA==?=""")
+ eq(x, str(make_header(decode_header(s))))
+
+ def test_us_ascii_header(self):
+ eq = self.assertEqual
+ s = 'hello'
+ x = decode_header(s)
+ eq(x, [('hello', None)])
+ h = make_header(x)
+ eq(s, h.encode())
+
+ def test_string_charset(self):
+ eq = self.assertEqual
+ h = Header()
+ h.append('hello', 'iso-8859-1')
+ eq(h, 'hello')
+
+## def test_unicode_error(self):
+## raises = self.assertRaises
+## raises(UnicodeError, Header, u'[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, Header, '[P\xf6stal]', 'us-ascii')
+## h = Header()
+## raises(UnicodeError, h.append, u'[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, h.append, '[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, Header, u'\u83ca\u5730\u6642\u592b', 'iso-8859-1')
+
+ def test_utf8_shortest(self):
+ eq = self.assertEqual
+ h = Header('p\xf6stal', 'utf-8')
+ eq(h.encode(), '=?utf-8?q?p=C3=B6stal?=')
+ h = Header('\u83ca\u5730\u6642\u592b', 'utf-8')
+ eq(h.encode(), '=?utf-8?b?6I+K5Zyw5pmC5aSr?=')
+
+ def test_bad_8bit_header(self):
+ raises = self.assertRaises
+ eq = self.assertEqual
+ x = b'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big'
+ raises(UnicodeError, Header, x)
+ h = Header()
+ raises(UnicodeError, h.append, x)
+ e = x.decode('utf-8', 'replace')
+ eq(str(Header(x, errors='replace')), e)
+ h.append(x, errors='replace')
+ eq(str(h), e)
+
+ def test_escaped_8bit_header(self):
+ x = b'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big'
+ e = x.decode('ascii', 'surrogateescape')
+ h = Header(e, charset=email.charset.UNKNOWN8BIT)
+ self.assertEqual(str(h),
+ 'Ynwp4dUEbay Auction Semiar- No Charge \uFFFD Earn Big')
+ self.assertEqual(email.header.decode_header(h), [(x, 'unknown-8bit')])
+
+ def test_header_handles_binary_unknown8bit(self):
+ x = b'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big'
+ h = Header(x, charset=email.charset.UNKNOWN8BIT)
+ self.assertEqual(str(h),
+ 'Ynwp4dUEbay Auction Semiar- No Charge \uFFFD Earn Big')
+ self.assertEqual(email.header.decode_header(h), [(x, 'unknown-8bit')])
+
+ def test_make_header_handles_binary_unknown8bit(self):
+ x = b'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big'
+ h = Header(x, charset=email.charset.UNKNOWN8BIT)
+ h2 = email.header.make_header(email.header.decode_header(h))
+ self.assertEqual(str(h2),
+ 'Ynwp4dUEbay Auction Semiar- No Charge \uFFFD Earn Big')
+ self.assertEqual(email.header.decode_header(h2), [(x, 'unknown-8bit')])
+
+ def test_modify_returned_list_does_not_change_header(self):
+ h = Header('test')
+ chunks = email.header.decode_header(h)
+ chunks.append(('ascii', 'test2'))
+ self.assertEqual(str(h), 'test')
+
+ def test_encoded_adjacent_nonencoded(self):
+ eq = self.assertEqual
+ h = Header()
+ h.append('hello', 'iso-8859-1')
+ h.append('world')
+ s = h.encode()
+ eq(s, '=?iso-8859-1?q?hello?= world')
+ h = make_header(decode_header(s))
+ eq(h.encode(), s)
+
+ def test_whitespace_keeper(self):
+ eq = self.assertEqual
+ s = 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztk=?= =?koi8-r?q?=CA?= zz.'
+ parts = decode_header(s)
+ eq(parts, [(b'Subject: ', None), (b'\xf0\xd2\xcf\xd7\xc5\xd2\xcb\xc1 \xce\xc1 \xc6\xc9\xce\xc1\xcc\xd8\xce\xd9\xca', 'koi8-r'), (b' zz.', None)])
+ hdr = make_header(parts)
+ eq(hdr.encode(),
+ 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztnK?= zz.')
+
+ def test_broken_base64_header(self):
+ raises = self.assertRaises
+ s = 'Subject: =?EUC-KR?B?CSixpLDtKSC/7Liuvsax4iC6uLmwMcijIKHaILzSwd/H0SC8+LCjwLsgv7W/+Mj3I ?='
+ raises(errors.HeaderParseError, decode_header, s)
+
+ def test_shift_jis_charset(self):
+ h = Header('文', charset='shift_jis')
+ self.assertEqual(h.encode(), '=?iso-2022-jp?b?GyRCSjgbKEI=?=')
+
+ def test_flatten_header_with_no_value(self):
+ # Issue 11401 (regression from email 4.x) Note that the space after
+ # the header doesn't reflect the input, but this is also the way
+ # email 4.x behaved. At some point it would be nice to fix that.
+ msg = email.message_from_string("EmptyHeader:")
+ self.assertEqual(str(msg), "EmptyHeader: \n\n")
+
+ def test_encode_preserves_leading_ws_on_value(self):
+ msg = Message()
+ msg['SomeHeader'] = ' value with leading ws'
+ self.assertEqual(str(msg), "SomeHeader: value with leading ws\n\n")
+
+
+
+# Test RFC 2231 header parameters (en/de)coding
+class TestRFC2231(TestEmailBase):
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_encoded_with_double_quotes
+ # test_headerregistry.TestContentTypeHeader.rfc2231_single_quote_inside_double_quotes
+ def test_get_param(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_29.txt')
+ eq(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!'))
+ eq(msg.get_param('title', unquote=False),
+ ('us-ascii', 'en', '"This is even more ***fun*** isn\'t it!"'))
+
+ def test_set_param(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii')
+ eq(msg.get_param('title'),
+ ('us-ascii', '', 'This is even more ***fun*** isn\'t it!'))
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ eq(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!'))
+ msg = self._msgobj('msg_01.txt')
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ eq(msg.as_string(maxheaderlen=78), """\
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+\tid 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+Content-Type: text/plain; charset=us-ascii;
+ title*=us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21
+
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_set_param_requote(self):
+ msg = Message()
+ msg.set_param('title', 'foo')
+ self.assertEqual(msg['content-type'], 'text/plain; title="foo"')
+ msg.set_param('title', 'bar', requote=False)
+ self.assertEqual(msg['content-type'], 'text/plain; title=bar')
+ # tspecial is still quoted.
+ msg.set_param('title', "(bar)bell", requote=False)
+ self.assertEqual(msg['content-type'], 'text/plain; title="(bar)bell"')
+
+ def test_del_param(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_01.txt')
+ msg.set_param('foo', 'bar', charset='us-ascii', language='en')
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ msg.del_param('foo', header='Content-Type')
+ eq(msg.as_string(maxheaderlen=78), """\
+Return-Path: <bbb@zzz.org>
+Delivered-To: bbb@zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+\tid 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684@aaa.zzz.org>
+From: bbb@ddd.com (John X. Doe)
+To: bbb@zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+Content-Type: text/plain; charset="us-ascii";
+ title*=us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21
+
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_encoded_charset
+ # I changed the charset name, though, because the one in the file isn't
+ # a legal charset name. Should add a test for an illegal charset.
+ def test_rfc2231_get_content_charset(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_32.txt')
+ eq(msg.get_content_charset(), 'us-ascii')
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_encoded_no_double_quotes
+ def test_rfc2231_parse_rfc_quoting(self):
+ m = textwrap.dedent('''\
+ Content-Disposition: inline;
+ \tfilename*0*=''This%20is%20even%20more%20;
+ \tfilename*1*=%2A%2A%2Afun%2A%2A%2A%20;
+ \tfilename*2="is it not.pdf"
+
+ ''')
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+ self.assertEqual(m, msg.as_string())
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_encoded_with_double_quotes
+ def test_rfc2231_parse_extra_quoting(self):
+ m = textwrap.dedent('''\
+ Content-Disposition: inline;
+ \tfilename*0*="''This%20is%20even%20more%20";
+ \tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+ \tfilename*2="is it not.pdf"
+
+ ''')
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+ self.assertEqual(m, msg.as_string())
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_no_language_or_charset
+ # but new test uses *0* because otherwise lang/charset is not valid.
+ # test_headerregistry.TestContentTypeHeader.rfc2231_segmented_normal_values
+ def test_rfc2231_no_language_or_charset(self):
+ m = '''\
+Content-Transfer-Encoding: 8bit
+Content-Disposition: inline; filename="file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm"
+Content-Type: text/html; NAME*0=file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEM; NAME*1=P_nsmail.htm
+
+'''
+ msg = email.message_from_string(m)
+ param = msg.get_param('NAME')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(
+ param,
+ 'file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm')
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_encoded_no_charset
+ def test_rfc2231_no_language_or_charset_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ # Duplicate of previous test?
+ def test_rfc2231_no_language_or_charset_in_filename_encoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_partly_encoded,
+ # but the test below is wrong (the first part should be decoded).
+ def test_rfc2231_partly_encoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(
+ msg.get_filename(),
+ 'This%20is%20even%20more%20***fun*** is it not.pdf')
+
+ def test_rfc2231_partly_nonencoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0="This%20is%20even%20more%20";
+\tfilename*1="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(
+ msg.get_filename(),
+ 'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_boundary(self):
+ m = '''\
+Content-Type: multipart/alternative;
+\tboundary*0*="''This%20is%20even%20more%20";
+\tboundary*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tboundary*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_boundary(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_charset(self):
+ # This is a nonsensical charset value, but tests the code anyway
+ m = '''\
+Content-Type: text/plain;
+\tcharset*0*="This%20is%20even%20more%20";
+\tcharset*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tcharset*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_content_charset(),
+ 'this is even more ***fun*** is it not.pdf')
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_unknown_charset_treated_as_ascii
+ def test_rfc2231_bad_encoding_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="bogus'xx'This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_bad_encoding_in_charset(self):
+ m = """\
+Content-Type: text/plain; charset*=bogus''utf-8%E2%80%9D
+
+"""
+ msg = email.message_from_string(m)
+ # This should return None because non-ascii characters in the charset
+ # are not allowed.
+ self.assertEqual(msg.get_content_charset(), None)
+
+ def test_rfc2231_bad_character_in_charset(self):
+ m = """\
+Content-Type: text/plain; charset*=ascii''utf-8%E2%80%9D
+
+"""
+ msg = email.message_from_string(m)
+ # This should return None because non-ascii characters in the charset
+ # are not allowed.
+ self.assertEqual(msg.get_content_charset(), None)
+
+ def test_rfc2231_bad_character_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="ascii'xx'This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2*="is it not.pdf%E2"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf\ufffd')
+
+ def test_rfc2231_unknown_encoding(self):
+ m = """\
+Content-Transfer-Encoding: 8bit
+Content-Disposition: inline; filename*=X-UNKNOWN''myfile.txt
+
+"""
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(), 'myfile.txt')
+
+ def test_rfc2231_single_tick_in_filename_extended(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"Frank's\"; name*1*=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, None)
+ eq(language, None)
+ eq(s, "Frank's Document")
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_single_quote_inside_double_quotes
+ def test_rfc2231_single_tick_in_filename(self):
+ m = """\
+Content-Type: application/x-foo; name*0=\"Frank's\"; name*1=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ param = msg.get_param('name')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(param, "Frank's Document")
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_single_quote_in_value_with_charset_and_lang
+ def test_rfc2231_tick_attack_extended(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"us-ascii'en-us'Frank's\"; name*1*=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, "Frank's Document")
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_single_quote_in_non_encoded_value
+ def test_rfc2231_tick_attack(self):
+ m = """\
+Content-Type: application/x-foo;
+\tname*0=\"us-ascii'en-us'Frank's\"; name*1=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ param = msg.get_param('name')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(param, "us-ascii'en-us'Frank's Document")
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_single_quotes_inside_quotes
+ def test_rfc2231_no_extended_values(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo; name=\"Frank's Document\"
+
+"""
+ msg = email.message_from_string(m)
+ eq(msg.get_param('name'), "Frank's Document")
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_encoded_then_unencoded_segments
+ def test_rfc2231_encoded_then_unencoded_segments(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"us-ascii'en-us'My\";
+\tname*1=\" Document\";
+\tname*2*=\" For You\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, 'My Document For You')
+
+ # test_headerregistry.TestContentTypeHeader.rfc2231_unencoded_then_encoded_segments
+ # test_headerregistry.TestContentTypeHeader.rfc2231_quoted_unencoded_then_encoded_segments
+ def test_rfc2231_unencoded_then_encoded_segments(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0=\"us-ascii'en-us'My\";
+\tname*1*=\" Document\";
+\tname*2*=\" For You\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, 'My Document For You')
+
+
+
+# Tests to ensure that signed parts of an email are completely preserved, as
+# required by RFC1847 section 2.1. Note that these are incomplete, because the
+# email package does not currently always preserve the body. See issue 1670765.
+class TestSigned(TestEmailBase):
+
+ def _msg_and_obj(self, filename):
+ with openfile(filename) as fp:
+ original = fp.read()
+ msg = email.message_from_string(original)
+ return original, msg
+
+ def _signed_parts_eq(self, original, result):
+ # Extract the first mime part of each message
+ import re
+ repart = re.compile(r'^--([^\n]+)\n(.*?)\n--\1$', re.S | re.M)
+ inpart = repart.search(original).group(2)
+ outpart = repart.search(result).group(2)
+ self.assertEqual(outpart, inpart)
+
+ def test_long_headers_as_string(self):
+ original, msg = self._msg_and_obj('msg_45.txt')
+ result = msg.as_string()
+ self._signed_parts_eq(original, result)
+
+ def test_long_headers_as_string_maxheaderlen(self):
+ original, msg = self._msg_and_obj('msg_45.txt')
+ result = msg.as_string(maxheaderlen=60)
+ self._signed_parts_eq(original, result)
+
+ def test_long_headers_flatten(self):
+ original, msg = self._msg_and_obj('msg_45.txt')
+ fp = StringIO()
+ Generator(fp).flatten(msg)
+ result = fp.getvalue()
+ self._signed_parts_eq(original, result)
+
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_generator.py b/Lib/test/test_email/test_generator.py
new file mode 100644
index 0000000..8917408
--- /dev/null
+++ b/Lib/test/test_email/test_generator.py
@@ -0,0 +1,199 @@
+import io
+import textwrap
+import unittest
+from email import message_from_string, message_from_bytes
+from email.generator import Generator, BytesGenerator
+from email import policy
+from test.test_email import TestEmailBase, parameterize
+
+
+@parameterize
+class TestGeneratorBase:
+
+ policy = policy.default
+
+ def msgmaker(self, msg, policy=None):
+ policy = self.policy if policy is None else policy
+ return self.msgfunc(msg, policy=policy)
+
+ refold_long_expected = {
+ 0: textwrap.dedent("""\
+ To: whom_it_may_concern@example.com
+ From: nobody_you_want_to_know@example.com
+ Subject: We the willing led by the unknowing are doing the
+ impossible for the ungrateful. We have done so much for so long with so little
+ we are now qualified to do anything with nothing.
+
+ None
+ """),
+ # From is wrapped because wrapped it fits in 40.
+ 40: textwrap.dedent("""\
+ To: whom_it_may_concern@example.com
+ From:
+ nobody_you_want_to_know@example.com
+ Subject: We the willing led by the
+ unknowing are doing the impossible for
+ the ungrateful. We have done so much
+ for so long with so little we are now
+ qualified to do anything with nothing.
+
+ None
+ """),
+ # Neither to nor from fit even if put on a new line,
+ # so we leave them sticking out on the first line.
+ 20: textwrap.dedent("""\
+ To: whom_it_may_concern@example.com
+ From: nobody_you_want_to_know@example.com
+ Subject: We the
+ willing led by the
+ unknowing are doing
+ the impossible for
+ the ungrateful. We
+ have done so much
+ for so long with so
+ little we are now
+ qualified to do
+ anything with
+ nothing.
+
+ None
+ """),
+ }
+ refold_long_expected[100] = refold_long_expected[0]
+
+ refold_all_expected = refold_long_expected.copy()
+ refold_all_expected[0] = (
+ "To: whom_it_may_concern@example.com\n"
+ "From: nobody_you_want_to_know@example.com\n"
+ "Subject: We the willing led by the unknowing are doing the "
+ "impossible for the ungrateful. We have done so much for "
+ "so long with so little we are now qualified to do anything "
+ "with nothing.\n"
+ "\n"
+ "None\n")
+ refold_all_expected[100] = (
+ "To: whom_it_may_concern@example.com\n"
+ "From: nobody_you_want_to_know@example.com\n"
+ "Subject: We the willing led by the unknowing are doing the "
+ "impossible for the ungrateful. We have\n"
+ " done so much for so long with so little we are now qualified "
+ "to do anything with nothing.\n"
+ "\n"
+ "None\n")
+
+ length_params = [n for n in refold_long_expected]
+
+ def length_as_maxheaderlen_parameter(self, n):
+ msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
+ s = self.ioclass()
+ g = self.genclass(s, maxheaderlen=n, policy=self.policy)
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
+
+ def length_as_max_line_length_policy(self, n):
+ msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
+ s = self.ioclass()
+ g = self.genclass(s, policy=self.policy.clone(max_line_length=n))
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
+
+ def length_as_maxheaderlen_parm_overrides_policy(self, n):
+ msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
+ s = self.ioclass()
+ g = self.genclass(s, maxheaderlen=n,
+ policy=self.policy.clone(max_line_length=10))
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
+
+ def length_as_max_line_length_with_refold_none_does_not_fold(self, n):
+ msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
+ s = self.ioclass()
+ g = self.genclass(s, policy=self.policy.clone(refold_source='none',
+ max_line_length=n))
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[0]))
+
+ def length_as_max_line_length_with_refold_all_folds(self, n):
+ msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
+ s = self.ioclass()
+ g = self.genclass(s, policy=self.policy.clone(refold_source='all',
+ max_line_length=n))
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), self.typ(self.refold_all_expected[n]))
+
+ def test_crlf_control_via_policy(self):
+ source = "Subject: test\r\n\r\ntest body\r\n"
+ expected = source
+ msg = self.msgmaker(self.typ(source))
+ s = self.ioclass()
+ g = self.genclass(s, policy=policy.SMTP)
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), self.typ(expected))
+
+ def test_flatten_linesep_overrides_policy(self):
+ source = "Subject: test\n\ntest body\n"
+ expected = source
+ msg = self.msgmaker(self.typ(source))
+ s = self.ioclass()
+ g = self.genclass(s, policy=policy.SMTP)
+ g.flatten(msg, linesep='\n')
+ self.assertEqual(s.getvalue(), self.typ(expected))
+
+
+class TestGenerator(TestGeneratorBase, TestEmailBase):
+
+ msgfunc = staticmethod(message_from_string)
+ genclass = Generator
+ ioclass = io.StringIO
+ typ = str
+
+
+class TestBytesGenerator(TestGeneratorBase, TestEmailBase):
+
+ msgfunc = staticmethod(message_from_bytes)
+ genclass = BytesGenerator
+ ioclass = io.BytesIO
+ typ = lambda self, x: x.encode('ascii')
+
+ def test_cte_type_7bit_handles_unknown_8bit(self):
+ source = ("Subject: Maintenant je vous présente mon "
+ "collègue\n\n").encode('utf-8')
+ expected = ('Subject: Maintenant je vous =?unknown-8bit?q?'
+ 'pr=C3=A9sente_mon_coll=C3=A8gue?=\n\n').encode('ascii')
+ msg = message_from_bytes(source)
+ s = io.BytesIO()
+ g = BytesGenerator(s, policy=self.policy.clone(cte_type='7bit'))
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), expected)
+
+ def test_cte_type_7bit_transforms_8bit_cte(self):
+ source = textwrap.dedent("""\
+ From: foo@bar.com
+ To: Dinsdale
+ Subject: Nudge nudge, wink, wink
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset="latin-1"
+ Content-Transfer-Encoding: 8bit
+
+ oh là là, know what I mean, know what I mean?
+ """).encode('latin1')
+ msg = message_from_bytes(source)
+ expected = textwrap.dedent("""\
+ From: foo@bar.com
+ To: Dinsdale
+ Subject: Nudge nudge, wink, wink
+ Mime-Version: 1.0
+ Content-Type: text/plain; charset="iso-8859-1"
+ Content-Transfer-Encoding: quoted-printable
+
+ oh l=E0 l=E0, know what I mean, know what I mean?
+ """).encode('ascii')
+ s = io.BytesIO()
+ g = BytesGenerator(s, policy=self.policy.clone(cte_type='7bit',
+ linesep='\n'))
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), expected)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py
new file mode 100644
index 0000000..c0c81c1
--- /dev/null
+++ b/Lib/test/test_email/test_headerregistry.py
@@ -0,0 +1,1515 @@
+import datetime
+import textwrap
+import unittest
+from email import errors
+from email import policy
+from email.message import Message
+from test.test_email import TestEmailBase, parameterize
+from email import headerregistry
+from email.headerregistry import Address, Group
+
+
+DITTO = object()
+
+
+class TestHeaderRegistry(TestEmailBase):
+
+ def test_arbitrary_name_unstructured(self):
+ factory = headerregistry.HeaderRegistry()
+ h = factory('foobar', 'test')
+ self.assertIsInstance(h, headerregistry.BaseHeader)
+ self.assertIsInstance(h, headerregistry.UnstructuredHeader)
+
+ def test_name_case_ignored(self):
+ factory = headerregistry.HeaderRegistry()
+ # Whitebox check that test is valid
+ self.assertNotIn('Subject', factory.registry)
+ h = factory('Subject', 'test')
+ self.assertIsInstance(h, headerregistry.BaseHeader)
+ self.assertIsInstance(h, headerregistry.UniqueUnstructuredHeader)
+
+ class FooBase:
+ def __init__(self, *args, **kw):
+ pass
+
+ def test_override_default_base_class(self):
+ factory = headerregistry.HeaderRegistry(base_class=self.FooBase)
+ h = factory('foobar', 'test')
+ self.assertIsInstance(h, self.FooBase)
+ self.assertIsInstance(h, headerregistry.UnstructuredHeader)
+
+ class FooDefault:
+ parse = headerregistry.UnstructuredHeader.parse
+
+ def test_override_default_class(self):
+ factory = headerregistry.HeaderRegistry(default_class=self.FooDefault)
+ h = factory('foobar', 'test')
+ self.assertIsInstance(h, headerregistry.BaseHeader)
+ self.assertIsInstance(h, self.FooDefault)
+
+ def test_override_default_class_only_overrides_default(self):
+ factory = headerregistry.HeaderRegistry(default_class=self.FooDefault)
+ h = factory('subject', 'test')
+ self.assertIsInstance(h, headerregistry.BaseHeader)
+ self.assertIsInstance(h, headerregistry.UniqueUnstructuredHeader)
+
+ def test_dont_use_default_map(self):
+ factory = headerregistry.HeaderRegistry(use_default_map=False)
+ h = factory('subject', 'test')
+ self.assertIsInstance(h, headerregistry.BaseHeader)
+ self.assertIsInstance(h, headerregistry.UnstructuredHeader)
+
+ def test_map_to_type(self):
+ factory = headerregistry.HeaderRegistry()
+ h1 = factory('foobar', 'test')
+ factory.map_to_type('foobar', headerregistry.UniqueUnstructuredHeader)
+ h2 = factory('foobar', 'test')
+ self.assertIsInstance(h1, headerregistry.BaseHeader)
+ self.assertIsInstance(h1, headerregistry.UnstructuredHeader)
+ self.assertIsInstance(h2, headerregistry.BaseHeader)
+ self.assertIsInstance(h2, headerregistry.UniqueUnstructuredHeader)
+
+
+class TestHeaderBase(TestEmailBase):
+
+ factory = headerregistry.HeaderRegistry()
+
+ def make_header(self, name, value):
+ return self.factory(name, value)
+
+
+class TestBaseHeaderFeatures(TestHeaderBase):
+
+ def test_str(self):
+ h = self.make_header('subject', 'this is a test')
+ self.assertIsInstance(h, str)
+ self.assertEqual(h, 'this is a test')
+ self.assertEqual(str(h), 'this is a test')
+
+ def test_substr(self):
+ h = self.make_header('subject', 'this is a test')
+ self.assertEqual(h[5:7], 'is')
+
+ def test_has_name(self):
+ h = self.make_header('subject', 'this is a test')
+ self.assertEqual(h.name, 'subject')
+
+ def _test_attr_ro(self, attr):
+ h = self.make_header('subject', 'this is a test')
+ with self.assertRaises(AttributeError):
+ setattr(h, attr, 'foo')
+
+ def test_name_read_only(self):
+ self._test_attr_ro('name')
+
+ def test_defects_read_only(self):
+ self._test_attr_ro('defects')
+
+ def test_defects_is_tuple(self):
+ h = self.make_header('subject', 'this is a test')
+ self.assertEqual(len(h.defects), 0)
+ self.assertIsInstance(h.defects, tuple)
+ # Make sure it is still true when there are defects.
+ h = self.make_header('date', '')
+ self.assertEqual(len(h.defects), 1)
+ self.assertIsInstance(h.defects, tuple)
+
+ # XXX: FIXME
+ #def test_CR_in_value(self):
+ # # XXX: this also re-raises the issue of embedded headers,
+ # # need test and solution for that.
+ # value = '\r'.join(['this is', ' a test'])
+ # h = self.make_header('subject', value)
+ # self.assertEqual(h, value)
+ # self.assertDefectsEqual(h.defects, [errors.ObsoleteHeaderDefect])
+
+ def test_RFC2047_value_decoded(self):
+ value = '=?utf-8?q?this_is_a_test?='
+ h = self.make_header('subject', value)
+ self.assertEqual(h, 'this is a test')
+
+
+class TestDateHeader(TestHeaderBase):
+
+ datestring = 'Sun, 23 Sep 2001 20:10:55 -0700'
+ utcoffset = datetime.timedelta(hours=-7)
+ tz = datetime.timezone(utcoffset)
+ dt = datetime.datetime(2001, 9, 23, 20, 10, 55, tzinfo=tz)
+
+ def test_parse_date(self):
+ h = self.make_header('date', self.datestring)
+ self.assertEqual(h, self.datestring)
+ self.assertEqual(h.datetime, self.dt)
+ self.assertEqual(h.datetime.utcoffset(), self.utcoffset)
+ self.assertEqual(h.defects, ())
+
+ def test_set_from_datetime(self):
+ h = self.make_header('date', self.dt)
+ self.assertEqual(h, self.datestring)
+ self.assertEqual(h.datetime, self.dt)
+ self.assertEqual(h.defects, ())
+
+ def test_date_header_properties(self):
+ h = self.make_header('date', self.datestring)
+ self.assertIsInstance(h, headerregistry.UniqueDateHeader)
+ self.assertEqual(h.max_count, 1)
+ self.assertEqual(h.defects, ())
+
+ def test_resent_date_header_properties(self):
+ h = self.make_header('resent-date', self.datestring)
+ self.assertIsInstance(h, headerregistry.DateHeader)
+ self.assertEqual(h.max_count, None)
+ self.assertEqual(h.defects, ())
+
+ def test_no_value_is_defect(self):
+ h = self.make_header('date', '')
+ self.assertEqual(len(h.defects), 1)
+ self.assertIsInstance(h.defects[0], errors.HeaderMissingRequiredValue)
+
+ def test_datetime_read_only(self):
+ h = self.make_header('date', self.datestring)
+ with self.assertRaises(AttributeError):
+ h.datetime = 'foo'
+
+ def test_set_date_header_from_datetime(self):
+ m = Message(policy=policy.default)
+ m['Date'] = self.dt
+ self.assertEqual(m['Date'], self.datestring)
+ self.assertEqual(m['Date'].datetime, self.dt)
+
+
+@parameterize
+class TestContentTypeHeader(TestHeaderBase):
+
+ def content_type_as_value(self,
+ source,
+ content_type,
+ maintype,
+ subtype,
+ *args):
+ l = len(args)
+ parmdict = args[0] if l>0 else {}
+ defects = args[1] if l>1 else []
+ decoded = args[2] if l>2 and args[2] is not DITTO else source
+ header = 'Content-Type:' + ' ' if source else ''
+ folded = args[3] if l>3 else header + source + '\n'
+ h = self.make_header('Content-Type', source)
+ self.assertEqual(h.content_type, content_type)
+ self.assertEqual(h.maintype, maintype)
+ self.assertEqual(h.subtype, subtype)
+ self.assertEqual(h.params, parmdict)
+ self.assertDefectsEqual(h.defects, defects)
+ self.assertEqual(h, decoded)
+ self.assertEqual(h.fold(policy=policy.default), folded)
+
+ content_type_params = {
+
+ # Examples from RFC 2045.
+
+ 'RFC_2045_1': (
+ 'text/plain; charset=us-ascii (Plain text)',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'us-ascii'},
+ [],
+ 'text/plain; charset="us-ascii"'),
+
+ 'RFC_2045_2': (
+ 'text/plain; charset=us-ascii',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'us-ascii'},
+ [],
+ 'text/plain; charset="us-ascii"'),
+
+ 'RFC_2045_3': (
+ 'text/plain; charset="us-ascii"',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'us-ascii'}),
+
+ # RFC 2045 5.2 says syntactically invalid values are to be treated as
+ # text/plain.
+
+ 'no_subtype_in_content_type': (
+ 'text/',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {},
+ [errors.InvalidHeaderDefect]),
+
+ 'no_slash_in_content_type': (
+ 'foo',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {},
+ [errors.InvalidHeaderDefect]),
+
+ 'junk_text_in_content_type': (
+ '<crazy "stuff">',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {},
+ [errors.InvalidHeaderDefect]),
+
+ 'too_many_slashes_in_content_type': (
+ 'image/jpeg/foo',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {},
+ [errors.InvalidHeaderDefect]),
+
+ # But unknown names are OK. We could make non-IANA names a defect, but
+ # by not doing so we make ourselves future proof. The fact that they
+ # are unknown will be detectable by the fact that they don't appear in
+ # the mime_registry...and the application is free to extend that list
+ # to handle them even if the core library doesn't.
+
+ 'unknown_content_type': (
+ 'bad/names',
+ 'bad/names',
+ 'bad',
+ 'names'),
+
+ # The content type is case insensitive, and CFWS is ignored.
+
+ 'mixed_case_content_type': (
+ 'ImAge/JPeg',
+ 'image/jpeg',
+ 'image',
+ 'jpeg'),
+
+ 'spaces_in_content_type': (
+ ' text / plain ',
+ 'text/plain',
+ 'text',
+ 'plain'),
+
+ 'cfws_in_content_type': (
+ '(foo) text (bar)/(baz)plain(stuff)',
+ 'text/plain',
+ 'text',
+ 'plain'),
+
+ # test some parameters (more tests could be added for parameters
+ # associated with other content types, but since parameter parsing is
+ # generic they would be redundant for the current implementation).
+
+ 'charset_param': (
+ 'text/plain; charset="utf-8"',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'utf-8'}),
+
+ 'capitalized_charset': (
+ 'text/plain; charset="US-ASCII"',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'US-ASCII'}),
+
+ 'unknown_charset': (
+ 'text/plain; charset="fOo"',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'fOo'}),
+
+ 'capitalized_charset_param_name_and_comment': (
+ 'text/plain; (interjection) Charset="utf-8"',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'utf-8'},
+ [],
+ # Should the parameter name be lowercased here?
+ 'text/plain; Charset="utf-8"'),
+
+ # Since this is pretty much the ur-mimeheader, we'll put all the tests
+ # that exercise the parameter parsing and formatting here.
+ #
+ # XXX: question: is minimal quoting preferred?
+
+ 'unquoted_param_value': (
+ 'text/plain; title=foo',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'title': 'foo'},
+ [],
+ 'text/plain; title="foo"'),
+
+ 'param_value_with_tspecials': (
+ 'text/plain; title="(bar)foo blue"',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'title': '(bar)foo blue'}),
+
+ 'param_with_extra_quoted_whitespace': (
+ 'text/plain; title=" a loong way \t home "',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'title': ' a loong way \t home '}),
+
+ 'bad_params': (
+ 'blarg; baz; boo',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'baz': '', 'boo': ''},
+ [errors.InvalidHeaderDefect]*3),
+
+ 'spaces_around_param_equals': (
+ 'Multipart/mixed; boundary = "CPIMSSMTPC06p5f3tG"',
+ 'multipart/mixed',
+ 'multipart',
+ 'mixed',
+ {'boundary': 'CPIMSSMTPC06p5f3tG'},
+ [],
+ 'Multipart/mixed; boundary="CPIMSSMTPC06p5f3tG"'),
+
+ 'spaces_around_semis': (
+ ('image/jpeg; name="wibble.JPG" ; x-mac-type="4A504547" ; '
+ 'x-mac-creator="474B4F4E"'),
+ 'image/jpeg',
+ 'image',
+ 'jpeg',
+ {'name': 'wibble.JPG',
+ 'x-mac-type': '4A504547',
+ 'x-mac-creator': '474B4F4E'},
+ [],
+ ('image/jpeg; name="wibble.JPG"; x-mac-type="4A504547"; '
+ 'x-mac-creator="474B4F4E"'),
+ # XXX: it could be that we will eventually prefer to fold starting
+ # from the decoded value, in which case these spaces and similar
+ # spaces in other tests will be wrong.
+ ('Content-Type: image/jpeg; name="wibble.JPG" ; '
+ 'x-mac-type="4A504547" ;\n'
+ ' x-mac-creator="474B4F4E"\n'),
+ ),
+
+ 'semis_inside_quotes': (
+ 'image/jpeg; name="Jim&amp;&amp;Jill"',
+ 'image/jpeg',
+ 'image',
+ 'jpeg',
+ {'name': 'Jim&amp;&amp;Jill'}),
+
+ 'single_quotes_inside_quotes': (
+ 'image/jpeg; name="Jim \'Bob\' Jill"',
+ 'image/jpeg',
+ 'image',
+ 'jpeg',
+ {'name': "Jim 'Bob' Jill"}),
+
+ 'double_quotes_inside_quotes': (
+ r'image/jpeg; name="Jim \"Bob\" Jill"',
+ 'image/jpeg',
+ 'image',
+ 'jpeg',
+ {'name': 'Jim "Bob" Jill'},
+ [],
+ r'image/jpeg; name="Jim \"Bob\" Jill"'),
+
+ # XXX: This test works except for the refolding of the header. I'll
+ # deal with that bug when I deal with the other folding bugs.
+ #'non_ascii_in_params': (
+ # ('foo\xa7/bar; b\xa7r=two; '
+ # 'baz=thr\xa7e'.encode('latin-1').decode('us-ascii',
+ # 'surrogateescape')),
+ # 'foo\uFFFD/bar',
+ # 'foo\uFFFD',
+ # 'bar',
+ # {'b\uFFFDr': 'two', 'baz': 'thr\uFFFDe'},
+ # [errors.UndecodableBytesDefect]*3,
+ # 'foo�/bar; b�r="two"; baz="thr�e"',
+ # ),
+
+ # RFC 2231 parameter tests.
+
+ 'rfc2231_segmented_normal_values': (
+ 'image/jpeg; name*0="abc"; name*1=".html"',
+ 'image/jpeg',
+ 'image',
+ 'jpeg',
+ {'name': "abc.html"},
+ [],
+ 'image/jpeg; name="abc.html"'),
+
+ 'quotes_inside_rfc2231_value': (
+ r'image/jpeg; bar*0="baz\"foobar"; bar*1="\"baz"',
+ 'image/jpeg',
+ 'image',
+ 'jpeg',
+ {'bar': 'baz"foobar"baz'},
+ [],
+ r'image/jpeg; bar="baz\"foobar\"baz"'),
+
+ # XXX: This test works except for the refolding of the header. I'll
+ # deal with that bug when I deal with the other folding bugs.
+ #'non_ascii_rfc2231_value': (
+ # ('text/plain; charset=us-ascii; '
+ # "title*=us-ascii'en'This%20is%20"
+ # 'not%20f\xa7n').encode('latin-1').decode('us-ascii',
+ # 'surrogateescape'),
+ # 'text/plain',
+ # 'text',
+ # 'plain',
+ # {'charset': 'us-ascii', 'title': 'This is not f\uFFFDn'},
+ # [errors.UndecodableBytesDefect],
+ # 'text/plain; charset="us-ascii"; title="This is not f�n"'),
+
+ 'rfc2231_encoded_charset': (
+ 'text/plain; charset*=ansi-x3.4-1968\'\'us-ascii',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'us-ascii'},
+ [],
+ 'text/plain; charset="us-ascii"'),
+
+ # This follows the RFC: no double quotes around encoded values.
+ 'rfc2231_encoded_no_double_quotes': (
+ ("text/plain;"
+ "\tname*0*=''This%20is%20;"
+ "\tname*1*=%2A%2A%2Afun%2A%2A%2A%20;"
+ '\tname*2="is it not.pdf"'),
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'This is ***fun*** is it not.pdf'},
+ [],
+ 'text/plain; name="This is ***fun*** is it not.pdf"',
+ ('Content-Type: text/plain;\tname*0*=\'\'This%20is%20;\n'
+ '\tname*1*=%2A%2A%2Afun%2A%2A%2A%20;\tname*2="is it not.pdf"\n'),
+ ),
+
+ # Make sure we also handle it if there are spurrious double qoutes.
+ 'rfc2231_encoded_with_double_quotes': (
+ ("text/plain;"
+ '\tname*0*="us-ascii\'\'This%20is%20even%20more%20";'
+ '\tname*1*="%2A%2A%2Afun%2A%2A%2A%20";'
+ '\tname*2="is it not.pdf"'),
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'This is even more ***fun*** is it not.pdf'},
+ [errors.InvalidHeaderDefect]*2,
+ 'text/plain; name="This is even more ***fun*** is it not.pdf"',
+ ('Content-Type: text/plain;\t'
+ 'name*0*="us-ascii\'\'This%20is%20even%20more%20";\n'
+ '\tname*1*="%2A%2A%2Afun%2A%2A%2A%20";\tname*2="is it not.pdf"\n'),
+ ),
+
+ 'rfc2231_single_quote_inside_double_quotes': (
+ ('text/plain; charset=us-ascii;'
+ '\ttitle*0*="us-ascii\'en\'This%20is%20really%20";'
+ '\ttitle*1*="%2A%2A%2Afun%2A%2A%2A%20";'
+ '\ttitle*2="isn\'t it!"'),
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'us-ascii', 'title': "This is really ***fun*** isn't it!"},
+ [errors.InvalidHeaderDefect]*2,
+ ('text/plain; charset="us-ascii"; '
+ 'title="This is really ***fun*** isn\'t it!"'),
+ ('Content-Type: text/plain; charset=us-ascii;\n'
+ '\ttitle*0*="us-ascii\'en\'This%20is%20really%20";\n'
+ '\ttitle*1*="%2A%2A%2Afun%2A%2A%2A%20";\ttitle*2="isn\'t it!"\n'),
+ ),
+
+ 'rfc2231_single_quote_in_value_with_charset_and_lang': (
+ ('application/x-foo;'
+ "\tname*0*=\"us-ascii'en-us'Frank's\"; name*1*=\" Document\""),
+ 'application/x-foo',
+ 'application',
+ 'x-foo',
+ {'name': "Frank's Document"},
+ [errors.InvalidHeaderDefect]*2,
+ 'application/x-foo; name="Frank\'s Document"',
+ ('Content-Type: application/x-foo;\t'
+ 'name*0*="us-ascii\'en-us\'Frank\'s";\n'
+ ' name*1*=" Document"\n'),
+ ),
+
+ 'rfc2231_single_quote_in_non_encoded_value': (
+ ('application/x-foo;'
+ "\tname*0=\"us-ascii'en-us'Frank's\"; name*1=\" Document\""),
+ 'application/x-foo',
+ 'application',
+ 'x-foo',
+ {'name': "us-ascii'en-us'Frank's Document"},
+ [],
+ 'application/x-foo; name="us-ascii\'en-us\'Frank\'s Document"',
+ ('Content-Type: application/x-foo;\t'
+ 'name*0="us-ascii\'en-us\'Frank\'s";\n'
+ ' name*1=" Document"\n'),
+ ),
+
+ 'rfc2231_no_language_or_charset': (
+ 'text/plain; NAME*0*=english_is_the_default.html',
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'english_is_the_default.html'},
+ [errors.InvalidHeaderDefect],
+ 'text/plain; NAME="english_is_the_default.html"'),
+
+ 'rfc2231_encoded_no_charset': (
+ ("text/plain;"
+ '\tname*0*="\'\'This%20is%20even%20more%20";'
+ '\tname*1*="%2A%2A%2Afun%2A%2A%2A%20";'
+ '\tname*2="is it.pdf"'),
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'This is even more ***fun*** is it.pdf'},
+ [errors.InvalidHeaderDefect]*2,
+ 'text/plain; name="This is even more ***fun*** is it.pdf"',
+ ('Content-Type: text/plain;\t'
+ 'name*0*="\'\'This%20is%20even%20more%20";\n'
+ '\tname*1*="%2A%2A%2Afun%2A%2A%2A%20";\tname*2="is it.pdf"\n'),
+ ),
+
+ # XXX: see below...the first name line here should be *0 not *0*.
+ 'rfc2231_partly_encoded': (
+ ("text/plain;"
+ '\tname*0*="\'\'This%20is%20even%20more%20";'
+ '\tname*1*="%2A%2A%2Afun%2A%2A%2A%20";'
+ '\tname*2="is it.pdf"'),
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'This is even more ***fun*** is it.pdf'},
+ [errors.InvalidHeaderDefect]*2,
+ 'text/plain; name="This is even more ***fun*** is it.pdf"',
+ ('Content-Type: text/plain;\t'
+ 'name*0*="\'\'This%20is%20even%20more%20";\n'
+ '\tname*1*="%2A%2A%2Afun%2A%2A%2A%20";\tname*2="is it.pdf"\n'),
+ ),
+
+ 'rfc2231_partly_encoded_2': (
+ ("text/plain;"
+ '\tname*0*="\'\'This%20is%20even%20more%20";'
+ '\tname*1="%2A%2A%2Afun%2A%2A%2A%20";'
+ '\tname*2="is it.pdf"'),
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'This is even more %2A%2A%2Afun%2A%2A%2A%20is it.pdf'},
+ [errors.InvalidHeaderDefect],
+ 'text/plain; name="This is even more %2A%2A%2Afun%2A%2A%2A%20is it.pdf"',
+ ('Content-Type: text/plain;\t'
+ 'name*0*="\'\'This%20is%20even%20more%20";\n'
+ '\tname*1="%2A%2A%2Afun%2A%2A%2A%20";\tname*2="is it.pdf"\n'),
+ ),
+
+ 'rfc2231_unknown_charset_treated_as_ascii': (
+ "text/plain; name*0*=bogus'xx'ascii_is_the_default",
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'name': 'ascii_is_the_default'},
+ [],
+ 'text/plain; name="ascii_is_the_default"'),
+
+ 'rfc2231_bad_character_in_charset_parameter_value': (
+ "text/plain; charset*=ascii''utf-8%E2%80%9D",
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'utf-8\uFFFD\uFFFD\uFFFD'},
+ [errors.UndecodableBytesDefect],
+ 'text/plain; charset="utf-8\uFFFD\uFFFD\uFFFD"'),
+
+ 'rfc2231_encoded_then_unencoded_segments': (
+ ('application/x-foo;'
+ '\tname*0*="us-ascii\'en-us\'My";'
+ '\tname*1=" Document";'
+ '\tname*2=" For You"'),
+ 'application/x-foo',
+ 'application',
+ 'x-foo',
+ {'name': 'My Document For You'},
+ [errors.InvalidHeaderDefect],
+ 'application/x-foo; name="My Document For You"',
+ ('Content-Type: application/x-foo;\t'
+ 'name*0*="us-ascii\'en-us\'My";\n'
+ '\tname*1=" Document";\tname*2=" For You"\n'),
+ ),
+
+ # My reading of the RFC is that this is an invalid header. The RFC
+ # says that if charset and language information is given, the first
+ # segment *must* be encoded.
+ 'rfc2231_unencoded_then_encoded_segments': (
+ ('application/x-foo;'
+ '\tname*0=us-ascii\'en-us\'My;'
+ '\tname*1*=" Document";'
+ '\tname*2*=" For You"'),
+ 'application/x-foo',
+ 'application',
+ 'x-foo',
+ {'name': 'My Document For You'},
+ [errors.InvalidHeaderDefect]*3,
+ 'application/x-foo; name="My Document For You"',
+ ("Content-Type: application/x-foo;\tname*0=us-ascii'en-us'My;\t"
+ # XXX: the newline is in the wrong place, come back and fix
+ # this when the rest of tests pass.
+ 'name*1*=" Document"\n;'
+ '\tname*2*=" For You"\n'),
+ ),
+
+ # XXX: I would say this one should default to ascii/en for the
+ # "encoded" segment, since the first segment is not encoded and is
+ # in double quotes, making the value a valid non-encoded string. The
+ # old parser decodes this just like the previous case, which may be the
+ # better Postel rule, but could equally result in borking headers that
+ # intentially have quoted quotes in them. We could get this 98% right
+ # if we treat it as a quoted string *unless* it matches the
+ # charset'lang'value pattern exactly *and* there is at least one
+ # encoded segment. Implementing that algorithm will require some
+ # refactoring, so I haven't done it (yet).
+
+ 'rfc2231_qouted_unencoded_then_encoded_segments': (
+ ('application/x-foo;'
+ '\tname*0="us-ascii\'en-us\'My";'
+ '\tname*1*=" Document";'
+ '\tname*2*=" For You"'),
+ 'application/x-foo',
+ 'application',
+ 'x-foo',
+ {'name': "us-ascii'en-us'My Document For You"},
+ [errors.InvalidHeaderDefect]*2,
+ 'application/x-foo; name="us-ascii\'en-us\'My Document For You"',
+ ('Content-Type: application/x-foo;\t'
+ 'name*0="us-ascii\'en-us\'My";\n'
+ '\tname*1*=" Document";\tname*2*=" For You"\n'),
+ ),
+
+ }
+
+
+@parameterize
+class TestContentTransferEncoding(TestHeaderBase):
+
+ def cte_as_value(self,
+ source,
+ cte,
+ *args):
+ l = len(args)
+ defects = args[0] if l>0 else []
+ decoded = args[1] if l>1 and args[1] is not DITTO else source
+ header = 'Content-Transfer-Encoding:' + ' ' if source else ''
+ folded = args[2] if l>2 else header + source + '\n'
+ h = self.make_header('Content-Transfer-Encoding', source)
+ self.assertEqual(h.cte, cte)
+ self.assertDefectsEqual(h.defects, defects)
+ self.assertEqual(h, decoded)
+ self.assertEqual(h.fold(policy=policy.default), folded)
+
+ cte_params = {
+
+ 'RFC_2183_1': (
+ 'base64',
+ 'base64',),
+
+ 'no_value': (
+ '',
+ '7bit',
+ [errors.HeaderMissingRequiredValue],
+ '',
+ 'Content-Transfer-Encoding:\n',
+ ),
+
+ 'junk_after_cte': (
+ '7bit and a bunch more',
+ '7bit',
+ [errors.InvalidHeaderDefect]),
+
+ }
+
+
+@parameterize
+class TestContentDisposition(TestHeaderBase):
+
+ def content_disp_as_value(self,
+ source,
+ content_disposition,
+ *args):
+ l = len(args)
+ parmdict = args[0] if l>0 else {}
+ defects = args[1] if l>1 else []
+ decoded = args[2] if l>2 and args[2] is not DITTO else source
+ header = 'Content-Disposition:' + ' ' if source else ''
+ folded = args[3] if l>3 else header + source + '\n'
+ h = self.make_header('Content-Disposition', source)
+ self.assertEqual(h.content_disposition, content_disposition)
+ self.assertEqual(h.params, parmdict)
+ self.assertDefectsEqual(h.defects, defects)
+ self.assertEqual(h, decoded)
+ self.assertEqual(h.fold(policy=policy.default), folded)
+
+ content_disp_params = {
+
+ # Examples from RFC 2183.
+
+ 'RFC_2183_1': (
+ 'inline',
+ 'inline',),
+
+ 'RFC_2183_2': (
+ ('attachment; filename=genome.jpeg;'
+ ' modification-date="Wed, 12 Feb 1997 16:29:51 -0500";'),
+ 'attachment',
+ {'filename': 'genome.jpeg',
+ 'modification-date': 'Wed, 12 Feb 1997 16:29:51 -0500'},
+ [],
+ ('attachment; filename="genome.jpeg"; '
+ 'modification-date="Wed, 12 Feb 1997 16:29:51 -0500"'),
+ ('Content-Disposition: attachment; filename=genome.jpeg;\n'
+ ' modification-date="Wed, 12 Feb 1997 16:29:51 -0500";\n'),
+ ),
+
+ 'no_value': (
+ '',
+ None,
+ {},
+ [errors.HeaderMissingRequiredValue],
+ '',
+ 'Content-Disposition:\n'),
+
+ 'invalid_value': (
+ 'ab./k',
+ 'ab.',
+ {},
+ [errors.InvalidHeaderDefect]),
+
+ 'invalid_value_with_params': (
+ 'ab./k; filename="foo"',
+ 'ab.',
+ {'filename': 'foo'},
+ [errors.InvalidHeaderDefect]),
+
+ }
+
+
+@parameterize
+class TestMIMEVersionHeader(TestHeaderBase):
+
+ def version_string_as_MIME_Version(self,
+ source,
+ decoded,
+ version,
+ major,
+ minor,
+ defects):
+ h = self.make_header('MIME-Version', source)
+ self.assertEqual(h, decoded)
+ self.assertEqual(h.version, version)
+ self.assertEqual(h.major, major)
+ self.assertEqual(h.minor, minor)
+ self.assertDefectsEqual(h.defects, defects)
+ if source:
+ source = ' ' + source
+ self.assertEqual(h.fold(policy=policy.default),
+ 'MIME-Version:' + source + '\n')
+
+ version_string_params = {
+
+ # Examples from the RFC.
+
+ 'RFC_2045_1': (
+ '1.0',
+ '1.0',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ 'RFC_2045_2': (
+ '1.0 (produced by MetaSend Vx.x)',
+ '1.0 (produced by MetaSend Vx.x)',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ 'RFC_2045_3': (
+ '(produced by MetaSend Vx.x) 1.0',
+ '(produced by MetaSend Vx.x) 1.0',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ 'RFC_2045_4': (
+ '1.(produced by MetaSend Vx.x)0',
+ '1.(produced by MetaSend Vx.x)0',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ # Other valid values.
+
+ '1_1': (
+ '1.1',
+ '1.1',
+ '1.1',
+ 1,
+ 1,
+ []),
+
+ '2_1': (
+ '2.1',
+ '2.1',
+ '2.1',
+ 2,
+ 1,
+ []),
+
+ 'whitespace': (
+ '1 .0',
+ '1 .0',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ 'leading_trailing_whitespace_ignored': (
+ ' 1.0 ',
+ ' 1.0 ',
+ '1.0',
+ 1,
+ 0,
+ []),
+
+ # Recoverable invalid values. We can recover here only because we
+ # already have a valid value by the time we encounter the garbage.
+ # Anywhere else, and we don't know where the garbage ends.
+
+ 'non_comment_garbage_after': (
+ '1.0 <abc>',
+ '1.0 <abc>',
+ '1.0',
+ 1,
+ 0,
+ [errors.InvalidHeaderDefect]),
+
+ # Unrecoverable invalid values. We *could* apply more heuristics to
+ # get someing out of the first two, but doing so is not worth the
+ # effort.
+
+ 'non_comment_garbage_before': (
+ '<abc> 1.0',
+ '<abc> 1.0',
+ None,
+ None,
+ None,
+ [errors.InvalidHeaderDefect]),
+
+ 'non_comment_garbage_inside': (
+ '1.<abc>0',
+ '1.<abc>0',
+ None,
+ None,
+ None,
+ [errors.InvalidHeaderDefect]),
+
+ 'two_periods': (
+ '1..0',
+ '1..0',
+ None,
+ None,
+ None,
+ [errors.InvalidHeaderDefect]),
+
+ '2_x': (
+ '2.x',
+ '2.x',
+ None, # This could be 2, but it seems safer to make it None.
+ None,
+ None,
+ [errors.InvalidHeaderDefect]),
+
+ 'foo': (
+ 'foo',
+ 'foo',
+ None,
+ None,
+ None,
+ [errors.InvalidHeaderDefect]),
+
+ 'missing': (
+ '',
+ '',
+ None,
+ None,
+ None,
+ [errors.HeaderMissingRequiredValue]),
+
+ }
+
+
+@parameterize
+class TestAddressHeader(TestHeaderBase):
+
+ example_params = {
+
+ 'empty':
+ ('<>',
+ [errors.InvalidHeaderDefect],
+ '<>',
+ '',
+ '<>',
+ '',
+ '',
+ None),
+
+ 'address_only':
+ ('zippy@pinhead.com',
+ [],
+ 'zippy@pinhead.com',
+ '',
+ 'zippy@pinhead.com',
+ 'zippy',
+ 'pinhead.com',
+ None),
+
+ 'name_and_address':
+ ('Zaphrod Beblebrux <zippy@pinhead.com>',
+ [],
+ 'Zaphrod Beblebrux <zippy@pinhead.com>',
+ 'Zaphrod Beblebrux',
+ 'zippy@pinhead.com',
+ 'zippy',
+ 'pinhead.com',
+ None),
+
+ 'quoted_local_part':
+ ('Zaphrod Beblebrux <"foo bar"@pinhead.com>',
+ [],
+ 'Zaphrod Beblebrux <"foo bar"@pinhead.com>',
+ 'Zaphrod Beblebrux',
+ '"foo bar"@pinhead.com',
+ 'foo bar',
+ 'pinhead.com',
+ None),
+
+ 'quoted_parens_in_name':
+ (r'"A \(Special\) Person" <person@dom.ain>',
+ [],
+ '"A (Special) Person" <person@dom.ain>',
+ 'A (Special) Person',
+ 'person@dom.ain',
+ 'person',
+ 'dom.ain',
+ None),
+
+ 'quoted_backslashes_in_name':
+ (r'"Arthur \\Backslash\\ Foobar" <person@dom.ain>',
+ [],
+ r'"Arthur \\Backslash\\ Foobar" <person@dom.ain>',
+ r'Arthur \Backslash\ Foobar',
+ 'person@dom.ain',
+ 'person',
+ 'dom.ain',
+ None),
+
+ 'name_with_dot':
+ ('John X. Doe <jxd@example.com>',
+ [errors.ObsoleteHeaderDefect],
+ '"John X. Doe" <jxd@example.com>',
+ 'John X. Doe',
+ 'jxd@example.com',
+ 'jxd',
+ 'example.com',
+ None),
+
+ 'quoted_strings_in_local_part':
+ ('""example" example"@example.com',
+ [errors.InvalidHeaderDefect]*3,
+ '"example example"@example.com',
+ '',
+ '"example example"@example.com',
+ 'example example',
+ 'example.com',
+ None),
+
+ 'escaped_quoted_strings_in_local_part':
+ (r'"\"example\" example"@example.com',
+ [],
+ r'"\"example\" example"@example.com',
+ '',
+ r'"\"example\" example"@example.com',
+ r'"example" example',
+ 'example.com',
+ None),
+
+ 'escaped_escapes_in_local_part':
+ (r'"\\"example\\" example"@example.com',
+ [errors.InvalidHeaderDefect]*5,
+ r'"\\example\\\\ example"@example.com',
+ '',
+ r'"\\example\\\\ example"@example.com',
+ r'\example\\ example',
+ 'example.com',
+ None),
+
+ 'spaces_in_unquoted_local_part_collapsed':
+ ('merwok wok @example.com',
+ [errors.InvalidHeaderDefect]*2,
+ '"merwok wok"@example.com',
+ '',
+ '"merwok wok"@example.com',
+ 'merwok wok',
+ 'example.com',
+ None),
+
+ 'spaces_around_dots_in_local_part_removed':
+ ('merwok. wok . wok@example.com',
+ [errors.ObsoleteHeaderDefect],
+ 'merwok.wok.wok@example.com',
+ '',
+ 'merwok.wok.wok@example.com',
+ 'merwok.wok.wok',
+ 'example.com',
+ None),
+
+ }
+
+ # XXX: Need many more examples, and in particular some with names in
+ # trailing comments, which aren't currently handled. comments in
+ # general are not handled yet.
+
+ def example_as_address(self, source, defects, decoded, display_name,
+ addr_spec, username, domain, comment):
+ h = self.make_header('sender', source)
+ self.assertEqual(h, decoded)
+ self.assertDefectsEqual(h.defects, defects)
+ a = h.address
+ self.assertEqual(str(a), decoded)
+ self.assertEqual(len(h.groups), 1)
+ self.assertEqual([a], list(h.groups[0].addresses))
+ self.assertEqual([a], list(h.addresses))
+ self.assertEqual(a.display_name, display_name)
+ self.assertEqual(a.addr_spec, addr_spec)
+ self.assertEqual(a.username, username)
+ self.assertEqual(a.domain, domain)
+ # XXX: we have no comment support yet.
+ #self.assertEqual(a.comment, comment)
+
+ def example_as_group(self, source, defects, decoded, display_name,
+ addr_spec, username, domain, comment):
+ source = 'foo: {};'.format(source)
+ gdecoded = 'foo: {};'.format(decoded) if decoded else 'foo:;'
+ h = self.make_header('to', source)
+ self.assertEqual(h, gdecoded)
+ self.assertDefectsEqual(h.defects, defects)
+ self.assertEqual(h.groups[0].addresses, h.addresses)
+ self.assertEqual(len(h.groups), 1)
+ self.assertEqual(len(h.addresses), 1)
+ a = h.addresses[0]
+ self.assertEqual(str(a), decoded)
+ self.assertEqual(a.display_name, display_name)
+ self.assertEqual(a.addr_spec, addr_spec)
+ self.assertEqual(a.username, username)
+ self.assertEqual(a.domain, domain)
+
+ def test_simple_address_list(self):
+ value = ('Fred <dinsdale@python.org>, foo@example.com, '
+ '"Harry W. Hastings" <hasty@example.com>')
+ h = self.make_header('to', value)
+ self.assertEqual(h, value)
+ self.assertEqual(len(h.groups), 3)
+ self.assertEqual(len(h.addresses), 3)
+ for i in range(3):
+ self.assertEqual(h.groups[i].addresses[0], h.addresses[i])
+ self.assertEqual(str(h.addresses[0]), 'Fred <dinsdale@python.org>')
+ self.assertEqual(str(h.addresses[1]), 'foo@example.com')
+ self.assertEqual(str(h.addresses[2]),
+ '"Harry W. Hastings" <hasty@example.com>')
+ self.assertEqual(h.addresses[2].display_name,
+ 'Harry W. Hastings')
+
+ def test_complex_address_list(self):
+ examples = list(self.example_params.values())
+ source = ('dummy list:;, another: (empty);,' +
+ ', '.join([x[0] for x in examples[:4]]) + ', ' +
+ r'"A \"list\"": ' +
+ ', '.join([x[0] for x in examples[4:6]]) + ';,' +
+ ', '.join([x[0] for x in examples[6:]])
+ )
+ # XXX: the fact that (empty) disappears here is a potential API design
+ # bug. We don't currently have a way to preserve comments.
+ expected = ('dummy list:;, another:;, ' +
+ ', '.join([x[2] for x in examples[:4]]) + ', ' +
+ r'"A \"list\"": ' +
+ ', '.join([x[2] for x in examples[4:6]]) + ';, ' +
+ ', '.join([x[2] for x in examples[6:]])
+ )
+
+ h = self.make_header('to', source)
+ self.assertEqual(h.split(','), expected.split(','))
+ self.assertEqual(h, expected)
+ self.assertEqual(len(h.groups), 7 + len(examples) - 6)
+ self.assertEqual(h.groups[0].display_name, 'dummy list')
+ self.assertEqual(h.groups[1].display_name, 'another')
+ self.assertEqual(h.groups[6].display_name, 'A "list"')
+ self.assertEqual(len(h.addresses), len(examples))
+ for i in range(4):
+ self.assertIsNone(h.groups[i+2].display_name)
+ self.assertEqual(str(h.groups[i+2].addresses[0]), examples[i][2])
+ for i in range(7, 7 + len(examples) - 6):
+ self.assertIsNone(h.groups[i].display_name)
+ self.assertEqual(str(h.groups[i].addresses[0]), examples[i-1][2])
+ for i in range(len(examples)):
+ self.assertEqual(str(h.addresses[i]), examples[i][2])
+ self.assertEqual(h.addresses[i].addr_spec, examples[i][4])
+
+ def test_address_read_only(self):
+ h = self.make_header('sender', 'abc@xyz.com')
+ with self.assertRaises(AttributeError):
+ h.address = 'foo'
+
+ def test_addresses_read_only(self):
+ h = self.make_header('sender', 'abc@xyz.com')
+ with self.assertRaises(AttributeError):
+ h.addresses = 'foo'
+
+ def test_groups_read_only(self):
+ h = self.make_header('sender', 'abc@xyz.com')
+ with self.assertRaises(AttributeError):
+ h.groups = 'foo'
+
+ def test_addresses_types(self):
+ source = 'me <who@example.com>'
+ h = self.make_header('to', source)
+ self.assertIsInstance(h.addresses, tuple)
+ self.assertIsInstance(h.addresses[0], Address)
+
+ def test_groups_types(self):
+ source = 'me <who@example.com>'
+ h = self.make_header('to', source)
+ self.assertIsInstance(h.groups, tuple)
+ self.assertIsInstance(h.groups[0], Group)
+
+ def test_set_from_Address(self):
+ h = self.make_header('to', Address('me', 'foo', 'example.com'))
+ self.assertEqual(h, 'me <foo@example.com>')
+
+ def test_set_from_Address_list(self):
+ h = self.make_header('to', [Address('me', 'foo', 'example.com'),
+ Address('you', 'bar', 'example.com')])
+ self.assertEqual(h, 'me <foo@example.com>, you <bar@example.com>')
+
+ def test_set_from_Address_and_Group_list(self):
+ h = self.make_header('to', [Address('me', 'foo', 'example.com'),
+ Group('bing', [Address('fiz', 'z', 'b.com'),
+ Address('zif', 'f', 'c.com')]),
+ Address('you', 'bar', 'example.com')])
+ self.assertEqual(h, 'me <foo@example.com>, bing: fiz <z@b.com>, '
+ 'zif <f@c.com>;, you <bar@example.com>')
+ self.assertEqual(h.fold(policy=policy.default.clone(max_line_length=40)),
+ 'to: me <foo@example.com>,\n'
+ ' bing: fiz <z@b.com>, zif <f@c.com>;,\n'
+ ' you <bar@example.com>\n')
+
+ def test_set_from_Group_list(self):
+ h = self.make_header('to', [Group('bing', [Address('fiz', 'z', 'b.com'),
+ Address('zif', 'f', 'c.com')])])
+ self.assertEqual(h, 'bing: fiz <z@b.com>, zif <f@c.com>;')
+
+
+class TestAddressAndGroup(TestEmailBase):
+
+ def _test_attr_ro(self, obj, attr):
+ with self.assertRaises(AttributeError):
+ setattr(obj, attr, 'foo')
+
+ def test_address_display_name_ro(self):
+ self._test_attr_ro(Address('foo', 'bar', 'baz'), 'display_name')
+
+ def test_address_username_ro(self):
+ self._test_attr_ro(Address('foo', 'bar', 'baz'), 'username')
+
+ def test_address_domain_ro(self):
+ self._test_attr_ro(Address('foo', 'bar', 'baz'), 'domain')
+
+ def test_group_display_name_ro(self):
+ self._test_attr_ro(Group('foo'), 'display_name')
+
+ def test_group_addresses_ro(self):
+ self._test_attr_ro(Group('foo'), 'addresses')
+
+ def test_address_from_username_domain(self):
+ a = Address('foo', 'bar', 'baz')
+ self.assertEqual(a.display_name, 'foo')
+ self.assertEqual(a.username, 'bar')
+ self.assertEqual(a.domain, 'baz')
+ self.assertEqual(a.addr_spec, 'bar@baz')
+ self.assertEqual(str(a), 'foo <bar@baz>')
+
+ def test_address_from_addr_spec(self):
+ a = Address('foo', addr_spec='bar@baz')
+ self.assertEqual(a.display_name, 'foo')
+ self.assertEqual(a.username, 'bar')
+ self.assertEqual(a.domain, 'baz')
+ self.assertEqual(a.addr_spec, 'bar@baz')
+ self.assertEqual(str(a), 'foo <bar@baz>')
+
+ def test_address_with_no_display_name(self):
+ a = Address(addr_spec='bar@baz')
+ self.assertEqual(a.display_name, '')
+ self.assertEqual(a.username, 'bar')
+ self.assertEqual(a.domain, 'baz')
+ self.assertEqual(a.addr_spec, 'bar@baz')
+ self.assertEqual(str(a), 'bar@baz')
+
+ def test_null_address(self):
+ a = Address()
+ self.assertEqual(a.display_name, '')
+ self.assertEqual(a.username, '')
+ self.assertEqual(a.domain, '')
+ self.assertEqual(a.addr_spec, '<>')
+ self.assertEqual(str(a), '<>')
+
+ def test_domain_only(self):
+ # This isn't really a valid address.
+ a = Address(domain='buzz')
+ self.assertEqual(a.display_name, '')
+ self.assertEqual(a.username, '')
+ self.assertEqual(a.domain, 'buzz')
+ self.assertEqual(a.addr_spec, '@buzz')
+ self.assertEqual(str(a), '@buzz')
+
+ def test_username_only(self):
+ # This isn't really a valid address.
+ a = Address(username='buzz')
+ self.assertEqual(a.display_name, '')
+ self.assertEqual(a.username, 'buzz')
+ self.assertEqual(a.domain, '')
+ self.assertEqual(a.addr_spec, 'buzz')
+ self.assertEqual(str(a), 'buzz')
+
+ def test_display_name_only(self):
+ a = Address('buzz')
+ self.assertEqual(a.display_name, 'buzz')
+ self.assertEqual(a.username, '')
+ self.assertEqual(a.domain, '')
+ self.assertEqual(a.addr_spec, '<>')
+ self.assertEqual(str(a), 'buzz <>')
+
+ def test_quoting(self):
+ # Ideally we'd check every special individually, but I'm not up for
+ # writing that many tests.
+ a = Address('Sara J.', 'bad name', 'example.com')
+ self.assertEqual(a.display_name, 'Sara J.')
+ self.assertEqual(a.username, 'bad name')
+ self.assertEqual(a.domain, 'example.com')
+ self.assertEqual(a.addr_spec, '"bad name"@example.com')
+ self.assertEqual(str(a), '"Sara J." <"bad name"@example.com>')
+
+ def test_il8n(self):
+ a = Address('Éric', 'wok', 'exàmple.com')
+ self.assertEqual(a.display_name, 'Éric')
+ self.assertEqual(a.username, 'wok')
+ self.assertEqual(a.domain, 'exàmple.com')
+ self.assertEqual(a.addr_spec, 'wok@exàmple.com')
+ self.assertEqual(str(a), 'Éric <wok@exàmple.com>')
+
+ # XXX: there is an API design issue that needs to be solved here.
+ #def test_non_ascii_username_raises(self):
+ # with self.assertRaises(ValueError):
+ # Address('foo', 'wők', 'example.com')
+
+ def test_non_ascii_username_in_addr_spec_raises(self):
+ with self.assertRaises(ValueError):
+ Address('foo', addr_spec='wők@example.com')
+
+ def test_address_addr_spec_and_username_raises(self):
+ with self.assertRaises(TypeError):
+ Address('foo', username='bing', addr_spec='bar@baz')
+
+ def test_address_addr_spec_and_domain_raises(self):
+ with self.assertRaises(TypeError):
+ Address('foo', domain='bing', addr_spec='bar@baz')
+
+ def test_address_addr_spec_and_username_and_domain_raises(self):
+ with self.assertRaises(TypeError):
+ Address('foo', username='bong', domain='bing', addr_spec='bar@baz')
+
+ def test_space_in_addr_spec_username_raises(self):
+ with self.assertRaises(ValueError):
+ Address('foo', addr_spec="bad name@example.com")
+
+ def test_bad_addr_sepc_raises(self):
+ with self.assertRaises(ValueError):
+ Address('foo', addr_spec="name@ex[]ample.com")
+
+ def test_empty_group(self):
+ g = Group('foo')
+ self.assertEqual(g.display_name, 'foo')
+ self.assertEqual(g.addresses, tuple())
+ self.assertEqual(str(g), 'foo:;')
+
+ def test_empty_group_list(self):
+ g = Group('foo', addresses=[])
+ self.assertEqual(g.display_name, 'foo')
+ self.assertEqual(g.addresses, tuple())
+ self.assertEqual(str(g), 'foo:;')
+
+ def test_null_group(self):
+ g = Group()
+ self.assertIsNone(g.display_name)
+ self.assertEqual(g.addresses, tuple())
+ self.assertEqual(str(g), 'None:;')
+
+ def test_group_with_addresses(self):
+ addrs = [Address('b', 'b', 'c'), Address('a', 'b','c')]
+ g = Group('foo', addrs)
+ self.assertEqual(g.display_name, 'foo')
+ self.assertEqual(g.addresses, tuple(addrs))
+ self.assertEqual(str(g), 'foo: b <b@c>, a <b@c>;')
+
+ def test_group_with_addresses_no_display_name(self):
+ addrs = [Address('b', 'b', 'c'), Address('a', 'b','c')]
+ g = Group(addresses=addrs)
+ self.assertIsNone(g.display_name)
+ self.assertEqual(g.addresses, tuple(addrs))
+ self.assertEqual(str(g), 'None: b <b@c>, a <b@c>;')
+
+ def test_group_with_one_address_no_display_name(self):
+ addrs = [Address('b', 'b', 'c')]
+ g = Group(addresses=addrs)
+ self.assertIsNone(g.display_name)
+ self.assertEqual(g.addresses, tuple(addrs))
+ self.assertEqual(str(g), 'b <b@c>')
+
+ def test_display_name_quoting(self):
+ g = Group('foo.bar')
+ self.assertEqual(g.display_name, 'foo.bar')
+ self.assertEqual(g.addresses, tuple())
+ self.assertEqual(str(g), '"foo.bar":;')
+
+ def test_display_name_blanks_not_quoted(self):
+ g = Group('foo bar')
+ self.assertEqual(g.display_name, 'foo bar')
+ self.assertEqual(g.addresses, tuple())
+ self.assertEqual(str(g), 'foo bar:;')
+
+ def test_set_message_header_from_address(self):
+ a = Address('foo', 'bar', 'example.com')
+ m = Message(policy=policy.default)
+ m['To'] = a
+ self.assertEqual(m['to'], 'foo <bar@example.com>')
+ self.assertEqual(m['to'].addresses, (a,))
+
+ def test_set_message_header_from_group(self):
+ g = Group('foo bar')
+ m = Message(policy=policy.default)
+ m['To'] = g
+ self.assertEqual(m['to'], 'foo bar:;')
+ self.assertEqual(m['to'].addresses, g.addresses)
+
+
+class TestFolding(TestHeaderBase):
+
+ def test_short_unstructured(self):
+ h = self.make_header('subject', 'this is a test')
+ self.assertEqual(h.fold(policy=policy.default),
+ 'subject: this is a test\n')
+
+ def test_long_unstructured(self):
+ h = self.make_header('Subject', 'This is a long header '
+ 'line that will need to be folded into two lines '
+ 'and will demonstrate basic folding')
+ self.assertEqual(h.fold(policy=policy.default),
+ 'Subject: This is a long header line that will '
+ 'need to be folded into two lines\n'
+ ' and will demonstrate basic folding\n')
+
+ def test_unstructured_short_max_line_length(self):
+ h = self.make_header('Subject', 'this is a short header '
+ 'that will be folded anyway')
+ self.assertEqual(
+ h.fold(policy=policy.default.clone(max_line_length=20)),
+ textwrap.dedent("""\
+ Subject: this is a
+ short header that
+ will be folded
+ anyway
+ """))
+
+ def test_fold_unstructured_single_word(self):
+ h = self.make_header('Subject', 'test')
+ self.assertEqual(h.fold(policy=policy.default), 'Subject: test\n')
+
+ def test_fold_unstructured_short(self):
+ h = self.make_header('Subject', 'test test test')
+ self.assertEqual(h.fold(policy=policy.default),
+ 'Subject: test test test\n')
+
+ def test_fold_unstructured_with_overlong_word(self):
+ h = self.make_header('Subject', 'thisisaverylonglineconsistingofa'
+ 'singlewordthatwontfit')
+ self.assertEqual(
+ h.fold(policy=policy.default.clone(max_line_length=20)),
+ 'Subject: thisisaverylonglineconsistingofasinglewordthatwontfit\n')
+
+ def test_fold_unstructured_with_two_overlong_words(self):
+ h = self.make_header('Subject', 'thisisaverylonglineconsistingofa'
+ 'singlewordthatwontfit plusanotherverylongwordthatwontfit')
+ self.assertEqual(
+ h.fold(policy=policy.default.clone(max_line_length=20)),
+ 'Subject: thisisaverylonglineconsistingofasinglewordthatwontfit\n'
+ ' plusanotherverylongwordthatwontfit\n')
+
+ def test_fold_unstructured_with_slightly_long_word(self):
+ h = self.make_header('Subject', 'thislongwordislessthanmaxlinelen')
+ self.assertEqual(
+ h.fold(policy=policy.default.clone(max_line_length=35)),
+ 'Subject:\n thislongwordislessthanmaxlinelen\n')
+
+ def test_fold_unstructured_with_commas(self):
+ # The old wrapper would fold this at the commas.
+ h = self.make_header('Subject', "This header is intended to "
+ "demonstrate, in a fairly susinct way, that we now do "
+ "not give a , special treatment in unstructured headers.")
+ self.assertEqual(
+ h.fold(policy=policy.default.clone(max_line_length=60)),
+ textwrap.dedent("""\
+ Subject: This header is intended to demonstrate, in a fairly
+ susinct way, that we now do not give a , special treatment
+ in unstructured headers.
+ """))
+
+ def test_fold_address_list(self):
+ h = self.make_header('To', '"Theodore H. Perfect" <yes@man.com>, '
+ '"My address is very long because my name is long" <foo@bar.com>, '
+ '"Only A. Friend" <no@yes.com>')
+ self.assertEqual(h.fold(policy=policy.default), textwrap.dedent("""\
+ To: "Theodore H. Perfect" <yes@man.com>,
+ "My address is very long because my name is long" <foo@bar.com>,
+ "Only A. Friend" <no@yes.com>
+ """))
+
+ def test_fold_date_header(self):
+ h = self.make_header('Date', 'Sat, 2 Feb 2002 17:00:06 -0800')
+ self.assertEqual(h.fold(policy=policy.default),
+ 'Date: Sat, 02 Feb 2002 17:00:06 -0800\n')
+
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_inversion.py b/Lib/test/test_email/test_inversion.py
new file mode 100644
index 0000000..6f5c534
--- /dev/null
+++ b/Lib/test/test_email/test_inversion.py
@@ -0,0 +1,45 @@
+"""Test the parser and generator are inverses.
+
+Note that this is only strictly true if we are parsing RFC valid messages and
+producing RFC valid messages.
+"""
+
+import io
+import unittest
+from email import policy, message_from_bytes
+from email.generator import BytesGenerator
+from test.test_email import TestEmailBase, parameterize
+
+# This is like textwrap.dedent for bytes, except that it uses \r\n for the line
+# separators on the rebuilt string.
+def dedent(bstr):
+ lines = bstr.splitlines()
+ if not lines[0].strip():
+ raise ValueError("First line must contain text")
+ stripamt = len(lines[0]) - len(lines[0].lstrip())
+ return b'\r\n'.join(
+ [x[stripamt:] if len(x)>=stripamt else b''
+ for x in lines])
+
+
+@parameterize
+class TestInversion(TestEmailBase, unittest.TestCase):
+
+ def msg_as_input(self, msg):
+ m = message_from_bytes(msg, policy=policy.SMTP)
+ b = io.BytesIO()
+ g = BytesGenerator(b)
+ g.flatten(m)
+ self.assertEqual(b.getvalue(), msg)
+
+ # XXX: spaces are not preserved correctly here yet in the general case.
+ msg_params = {
+ 'header_with_one_space_body': (dedent(b"""\
+ From: abc@xyz.com
+ X-Status:\x20
+ Subject: test
+
+ foo
+ """),),
+
+ }
diff --git a/Lib/test/test_email/test_message.py b/Lib/test/test_email/test_message.py
new file mode 100644
index 0000000..8cc3f80
--- /dev/null
+++ b/Lib/test/test_email/test_message.py
@@ -0,0 +1,18 @@
+import unittest
+from email import policy
+from test.test_email import TestEmailBase
+
+
+class Test(TestEmailBase):
+
+ policy = policy.default
+
+ def test_error_on_setitem_if_max_count_exceeded(self):
+ m = self._str_msg("")
+ m['To'] = 'abc@xyz'
+ with self.assertRaises(ValueError):
+ m['To'] = 'xyz@abc'
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_parser.py b/Lib/test/test_email/test_parser.py
new file mode 100644
index 0000000..3abd11a
--- /dev/null
+++ b/Lib/test/test_email/test_parser.py
@@ -0,0 +1,36 @@
+import io
+import email
+import unittest
+from email.message import Message
+from test.test_email import TestEmailBase
+
+
+class TestCustomMessage(TestEmailBase):
+
+ class MyMessage(Message):
+ def __init__(self, policy):
+ self.check_policy = policy
+ super().__init__()
+
+ MyPolicy = TestEmailBase.policy.clone(linesep='boo')
+
+ def test_custom_message_gets_policy_if_possible_from_string(self):
+ msg = email.message_from_string("Subject: bogus\n\nmsg\n",
+ self.MyMessage,
+ policy=self.MyPolicy)
+ self.assertTrue(isinstance(msg, self.MyMessage))
+ self.assertIs(msg.check_policy, self.MyPolicy)
+
+ def test_custom_message_gets_policy_if_possible_from_file(self):
+ source_file = io.StringIO("Subject: bogus\n\nmsg\n")
+ msg = email.message_from_file(source_file,
+ self.MyMessage,
+ policy=self.MyPolicy)
+ self.assertTrue(isinstance(msg, self.MyMessage))
+ self.assertIs(msg.check_policy, self.MyPolicy)
+
+ # XXX add tests for other functions that take Message arg.
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_pickleable.py b/Lib/test/test_email/test_pickleable.py
new file mode 100644
index 0000000..daa8d25
--- /dev/null
+++ b/Lib/test/test_email/test_pickleable.py
@@ -0,0 +1,74 @@
+import unittest
+import textwrap
+import copy
+import pickle
+import email
+import email.message
+from email import policy
+from email.headerregistry import HeaderRegistry
+from test.test_email import TestEmailBase, parameterize
+
+
+@parameterize
+class TestPickleCopyHeader(TestEmailBase):
+
+ header_factory = HeaderRegistry()
+
+ unstructured = header_factory('subject', 'this is a test')
+
+ header_params = {
+ 'subject': ('subject', 'this is a test'),
+ 'from': ('from', 'frodo@mordor.net'),
+ 'to': ('to', 'a: k@b.com, y@z.com;, j@f.com'),
+ 'date': ('date', 'Tue, 29 May 2012 09:24:26 +1000'),
+ }
+
+ def header_as_deepcopy(self, name, value):
+ header = self.header_factory(name, value)
+ h = copy.deepcopy(header)
+ self.assertEqual(str(h), str(header))
+
+ def header_as_pickle(self, name, value):
+ header = self.header_factory(name, value)
+ p = pickle.dumps(header)
+ h = pickle.loads(p)
+ self.assertEqual(str(h), str(header))
+
+
+@parameterize
+class TestPickleCopyMessage(TestEmailBase):
+
+ # Message objects are a sequence, so we have to make them a one-tuple in
+ # msg_params so they get passed to the parameterized test method as a
+ # single argument instead of as a list of headers.
+ msg_params = {}
+
+ # Note: there will be no custom header objects in the parsed message.
+ msg_params['parsed'] = (email.message_from_string(textwrap.dedent("""\
+ Date: Tue, 29 May 2012 09:24:26 +1000
+ From: frodo@mordor.net
+ To: bilbo@underhill.org
+ Subject: help
+
+ I think I forgot the ring.
+ """), policy=policy.default),)
+
+ msg_params['created'] = (email.message.Message(policy=policy.default),)
+ msg_params['created'][0]['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
+ msg_params['created'][0]['From'] = 'frodo@mordor.net'
+ msg_params['created'][0]['To'] = 'bilbo@underhill.org'
+ msg_params['created'][0]['Subject'] = 'help'
+ msg_params['created'][0].set_payload('I think I forgot the ring.')
+
+ def msg_as_deepcopy(self, msg):
+ msg2 = copy.deepcopy(msg)
+ self.assertEqual(msg2.as_string(), msg.as_string())
+
+ def msg_as_pickle(self, msg):
+ p = pickle.dumps(msg)
+ msg2 = pickle.loads(p)
+ self.assertEqual(msg2.as_string(), msg.as_string())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_policy.py b/Lib/test/test_email/test_policy.py
new file mode 100644
index 0000000..983bd49
--- /dev/null
+++ b/Lib/test/test_email/test_policy.py
@@ -0,0 +1,322 @@
+import io
+import types
+import textwrap
+import unittest
+import email.policy
+import email.parser
+import email.generator
+from email import headerregistry
+
+def make_defaults(base_defaults, differences):
+ defaults = base_defaults.copy()
+ defaults.update(differences)
+ return defaults
+
+class PolicyAPITests(unittest.TestCase):
+
+ longMessage = True
+
+ # Base default values.
+ compat32_defaults = {
+ 'max_line_length': 78,
+ 'linesep': '\n',
+ 'cte_type': '8bit',
+ 'raise_on_defect': False,
+ }
+ # These default values are the ones set on email.policy.default.
+ # If any of these defaults change, the docs must be updated.
+ policy_defaults = compat32_defaults.copy()
+ policy_defaults.update({
+ 'raise_on_defect': False,
+ 'header_factory': email.policy.EmailPolicy.header_factory,
+ 'refold_source': 'long',
+ })
+
+ # For each policy under test, we give here what we expect the defaults to
+ # be for that policy. The second argument to make defaults is the
+ # difference between the base defaults and that for the particular policy.
+ new_policy = email.policy.EmailPolicy()
+ policies = {
+ email.policy.compat32: make_defaults(compat32_defaults, {}),
+ email.policy.default: make_defaults(policy_defaults, {}),
+ email.policy.SMTP: make_defaults(policy_defaults,
+ {'linesep': '\r\n'}),
+ email.policy.HTTP: make_defaults(policy_defaults,
+ {'linesep': '\r\n',
+ 'max_line_length': None}),
+ email.policy.strict: make_defaults(policy_defaults,
+ {'raise_on_defect': True}),
+ new_policy: make_defaults(policy_defaults, {}),
+ }
+ # Creating a new policy creates a new header factory. There is a test
+ # later that proves this.
+ policies[new_policy]['header_factory'] = new_policy.header_factory
+
+ def test_defaults(self):
+ for policy, expected in self.policies.items():
+ for attr, value in expected.items():
+ self.assertEqual(getattr(policy, attr), value,
+ ("change {} docs/docstrings if defaults have "
+ "changed").format(policy))
+
+ def test_all_attributes_covered(self):
+ for policy, expected in self.policies.items():
+ for attr in dir(policy):
+ if (attr.startswith('_') or
+ isinstance(getattr(email.policy.EmailPolicy, attr),
+ types.FunctionType)):
+ continue
+ else:
+ self.assertIn(attr, expected,
+ "{} is not fully tested".format(attr))
+
+ def test_abc(self):
+ with self.assertRaises(TypeError) as cm:
+ email.policy.Policy()
+ msg = str(cm.exception)
+ abstract_methods = ('fold',
+ 'fold_binary',
+ 'header_fetch_parse',
+ 'header_source_parse',
+ 'header_store_parse')
+ for method in abstract_methods:
+ self.assertIn(method, msg)
+
+ def test_policy_is_immutable(self):
+ for policy, defaults in self.policies.items():
+ for attr in defaults:
+ with self.assertRaisesRegex(AttributeError, attr+".*read-only"):
+ setattr(policy, attr, None)
+ with self.assertRaisesRegex(AttributeError, 'no attribute.*foo'):
+ policy.foo = None
+
+ def test_set_policy_attrs_when_cloned(self):
+ # None of the attributes has a default value of None, so we set them
+ # all to None in the clone call and check that it worked.
+ for policyclass, defaults in self.policies.items():
+ testattrdict = {attr: None for attr in defaults}
+ policy = policyclass.clone(**testattrdict)
+ for attr in defaults:
+ self.assertIsNone(getattr(policy, attr))
+
+ def test_reject_non_policy_keyword_when_called(self):
+ for policyclass in self.policies:
+ with self.assertRaises(TypeError):
+ policyclass(this_keyword_should_not_be_valid=None)
+ with self.assertRaises(TypeError):
+ policyclass(newtline=None)
+
+ def test_policy_addition(self):
+ expected = self.policy_defaults.copy()
+ p1 = email.policy.default.clone(max_line_length=100)
+ p2 = email.policy.default.clone(max_line_length=50)
+ added = p1 + p2
+ expected.update(max_line_length=50)
+ for attr, value in expected.items():
+ self.assertEqual(getattr(added, attr), value)
+ added = p2 + p1
+ expected.update(max_line_length=100)
+ for attr, value in expected.items():
+ self.assertEqual(getattr(added, attr), value)
+ added = added + email.policy.default
+ for attr, value in expected.items():
+ self.assertEqual(getattr(added, attr), value)
+
+ def test_register_defect(self):
+ class Dummy:
+ def __init__(self):
+ self.defects = []
+ obj = Dummy()
+ defect = object()
+ policy = email.policy.EmailPolicy()
+ policy.register_defect(obj, defect)
+ self.assertEqual(obj.defects, [defect])
+ defect2 = object()
+ policy.register_defect(obj, defect2)
+ self.assertEqual(obj.defects, [defect, defect2])
+
+ class MyObj:
+ def __init__(self):
+ self.defects = []
+
+ class MyDefect(Exception):
+ pass
+
+ def test_handle_defect_raises_on_strict(self):
+ foo = self.MyObj()
+ defect = self.MyDefect("the telly is broken")
+ with self.assertRaisesRegex(self.MyDefect, "the telly is broken"):
+ email.policy.strict.handle_defect(foo, defect)
+
+ def test_handle_defect_registers_defect(self):
+ foo = self.MyObj()
+ defect1 = self.MyDefect("one")
+ email.policy.default.handle_defect(foo, defect1)
+ self.assertEqual(foo.defects, [defect1])
+ defect2 = self.MyDefect("two")
+ email.policy.default.handle_defect(foo, defect2)
+ self.assertEqual(foo.defects, [defect1, defect2])
+
+ class MyPolicy(email.policy.EmailPolicy):
+ defects = None
+ def __init__(self, *args, **kw):
+ super().__init__(*args, defects=[], **kw)
+ def register_defect(self, obj, defect):
+ self.defects.append(defect)
+
+ def test_overridden_register_defect_still_raises(self):
+ foo = self.MyObj()
+ defect = self.MyDefect("the telly is broken")
+ with self.assertRaisesRegex(self.MyDefect, "the telly is broken"):
+ self.MyPolicy(raise_on_defect=True).handle_defect(foo, defect)
+
+ def test_overriden_register_defect_works(self):
+ foo = self.MyObj()
+ defect1 = self.MyDefect("one")
+ my_policy = self.MyPolicy()
+ my_policy.handle_defect(foo, defect1)
+ self.assertEqual(my_policy.defects, [defect1])
+ self.assertEqual(foo.defects, [])
+ defect2 = self.MyDefect("two")
+ my_policy.handle_defect(foo, defect2)
+ self.assertEqual(my_policy.defects, [defect1, defect2])
+ self.assertEqual(foo.defects, [])
+
+ def test_default_header_factory(self):
+ h = email.policy.default.header_factory('Test', 'test')
+ self.assertEqual(h.name, 'Test')
+ self.assertIsInstance(h, headerregistry.UnstructuredHeader)
+ self.assertIsInstance(h, headerregistry.BaseHeader)
+
+ class Foo:
+ parse = headerregistry.UnstructuredHeader.parse
+
+ def test_each_Policy_gets_unique_factory(self):
+ policy1 = email.policy.EmailPolicy()
+ policy2 = email.policy.EmailPolicy()
+ policy1.header_factory.map_to_type('foo', self.Foo)
+ h = policy1.header_factory('foo', 'test')
+ self.assertIsInstance(h, self.Foo)
+ self.assertNotIsInstance(h, headerregistry.UnstructuredHeader)
+ h = policy2.header_factory('foo', 'test')
+ self.assertNotIsInstance(h, self.Foo)
+ self.assertIsInstance(h, headerregistry.UnstructuredHeader)
+
+ def test_clone_copies_factory(self):
+ policy1 = email.policy.EmailPolicy()
+ policy2 = policy1.clone()
+ policy1.header_factory.map_to_type('foo', self.Foo)
+ h = policy1.header_factory('foo', 'test')
+ self.assertIsInstance(h, self.Foo)
+ h = policy2.header_factory('foo', 'test')
+ self.assertIsInstance(h, self.Foo)
+
+ def test_new_factory_overrides_default(self):
+ mypolicy = email.policy.EmailPolicy()
+ myfactory = mypolicy.header_factory
+ newpolicy = mypolicy + email.policy.strict
+ self.assertEqual(newpolicy.header_factory, myfactory)
+ newpolicy = email.policy.strict + mypolicy
+ self.assertEqual(newpolicy.header_factory, myfactory)
+
+ def test_adding_default_policies_preserves_default_factory(self):
+ newpolicy = email.policy.default + email.policy.strict
+ self.assertEqual(newpolicy.header_factory,
+ email.policy.EmailPolicy.header_factory)
+ self.assertEqual(newpolicy.__dict__, {'raise_on_defect': True})
+
+ # XXX: Need subclassing tests.
+ # For adding subclassed objects, make sure the usual rules apply (subclass
+ # wins), but that the order still works (right overrides left).
+
+
+class TestPolicyPropagation(unittest.TestCase):
+
+ # The abstract methods are used by the parser but not by the wrapper
+ # functions that call it, so if the exception gets raised we know that the
+ # policy was actually propagated all the way to feedparser.
+ class MyPolicy(email.policy.Policy):
+ def badmethod(self, *args, **kw):
+ raise Exception("test")
+ fold = fold_binary = header_fetch_parser = badmethod
+ header_source_parse = header_store_parse = badmethod
+
+ def test_message_from_string(self):
+ with self.assertRaisesRegex(Exception, "^test$"):
+ email.message_from_string("Subject: test\n\n",
+ policy=self.MyPolicy)
+
+ def test_message_from_bytes(self):
+ with self.assertRaisesRegex(Exception, "^test$"):
+ email.message_from_bytes(b"Subject: test\n\n",
+ policy=self.MyPolicy)
+
+ def test_message_from_file(self):
+ f = io.StringIO('Subject: test\n\n')
+ with self.assertRaisesRegex(Exception, "^test$"):
+ email.message_from_file(f, policy=self.MyPolicy)
+
+ def test_message_from_binary_file(self):
+ f = io.BytesIO(b'Subject: test\n\n')
+ with self.assertRaisesRegex(Exception, "^test$"):
+ email.message_from_binary_file(f, policy=self.MyPolicy)
+
+ # These are redundant, but we need them for black-box completeness.
+
+ def test_parser(self):
+ p = email.parser.Parser(policy=self.MyPolicy)
+ with self.assertRaisesRegex(Exception, "^test$"):
+ p.parsestr('Subject: test\n\n')
+
+ def test_bytes_parser(self):
+ p = email.parser.BytesParser(policy=self.MyPolicy)
+ with self.assertRaisesRegex(Exception, "^test$"):
+ p.parsebytes(b'Subject: test\n\n')
+
+ # Now that we've established that all the parse methods get the
+ # policy in to feedparser, we can use message_from_string for
+ # the rest of the propagation tests.
+
+ def _make_msg(self, source='Subject: test\n\n', policy=None):
+ self.policy = email.policy.default.clone() if policy is None else policy
+ return email.message_from_string(source, policy=self.policy)
+
+ def test_parser_propagates_policy_to_message(self):
+ msg = self._make_msg()
+ self.assertIs(msg.policy, self.policy)
+
+ def test_parser_propagates_policy_to_sub_messages(self):
+ msg = self._make_msg(textwrap.dedent("""\
+ Subject: mime test
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed, boundary="XXX"
+
+ --XXX
+ Content-Type: text/plain
+
+ test
+ --XXX
+ Content-Type: text/plain
+
+ test2
+ --XXX--
+ """))
+ for part in msg.walk():
+ self.assertIs(part.policy, self.policy)
+
+ def test_message_policy_propagates_to_generator(self):
+ msg = self._make_msg("Subject: test\nTo: foo\n\n",
+ policy=email.policy.default.clone(linesep='X'))
+ s = io.StringIO()
+ g = email.generator.Generator(s)
+ g.flatten(msg)
+ self.assertEqual(s.getvalue(), "Subject: testXTo: fooXX")
+
+ def test_message_policy_used_by_as_string(self):
+ msg = self._make_msg("Subject: test\nTo: foo\n\n",
+ policy=email.policy.default.clone(linesep='X'))
+ self.assertEqual(msg.as_string(), "Subject: testXTo: fooXX")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_utils.py b/Lib/test/test_email/test_utils.py
new file mode 100644
index 0000000..622ef3a
--- /dev/null
+++ b/Lib/test/test_email/test_utils.py
@@ -0,0 +1,136 @@
+import datetime
+from email import utils
+import test.support
+import time
+import unittest
+import sys
+
+class DateTimeTests(unittest.TestCase):
+
+ datestring = 'Sun, 23 Sep 2001 20:10:55'
+ dateargs = (2001, 9, 23, 20, 10, 55)
+ offsetstring = ' -0700'
+ utcoffset = datetime.timedelta(hours=-7)
+ tz = datetime.timezone(utcoffset)
+ naive_dt = datetime.datetime(*dateargs)
+ aware_dt = datetime.datetime(*dateargs, tzinfo=tz)
+
+ def test_naive_datetime(self):
+ self.assertEqual(utils.format_datetime(self.naive_dt),
+ self.datestring + ' -0000')
+
+ def test_aware_datetime(self):
+ self.assertEqual(utils.format_datetime(self.aware_dt),
+ self.datestring + self.offsetstring)
+
+ def test_usegmt(self):
+ utc_dt = datetime.datetime(*self.dateargs,
+ tzinfo=datetime.timezone.utc)
+ self.assertEqual(utils.format_datetime(utc_dt, usegmt=True),
+ self.datestring + ' GMT')
+
+ def test_usegmt_with_naive_datetime_raises(self):
+ with self.assertRaises(ValueError):
+ utils.format_datetime(self.naive_dt, usegmt=True)
+
+ def test_usegmt_with_non_utc_datetime_raises(self):
+ with self.assertRaises(ValueError):
+ utils.format_datetime(self.aware_dt, usegmt=True)
+
+ def test_parsedate_to_datetime(self):
+ self.assertEqual(
+ utils.parsedate_to_datetime(self.datestring + self.offsetstring),
+ self.aware_dt)
+
+ def test_parsedate_to_datetime_naive(self):
+ self.assertEqual(
+ utils.parsedate_to_datetime(self.datestring + ' -0000'),
+ self.naive_dt)
+
+
+class LocaltimeTests(unittest.TestCase):
+
+ def test_localtime_is_tz_aware_daylight_true(self):
+ test.support.patch(self, time, 'daylight', True)
+ t = utils.localtime()
+ self.assertIsNot(t.tzinfo, None)
+
+ def test_localtime_is_tz_aware_daylight_false(self):
+ test.support.patch(self, time, 'daylight', False)
+ t = utils.localtime()
+ self.assertIsNot(t.tzinfo, None)
+
+ def test_localtime_daylight_true_dst_false(self):
+ test.support.patch(self, time, 'daylight', True)
+ t0 = datetime.datetime(2012, 3, 12, 1, 1)
+ t1 = utils.localtime(t0, isdst=-1)
+ t2 = utils.localtime(t1)
+ self.assertEqual(t1, t2)
+
+ def test_localtime_daylight_false_dst_false(self):
+ test.support.patch(self, time, 'daylight', False)
+ t0 = datetime.datetime(2012, 3, 12, 1, 1)
+ t1 = utils.localtime(t0, isdst=-1)
+ t2 = utils.localtime(t1)
+ self.assertEqual(t1, t2)
+
+ def test_localtime_daylight_true_dst_true(self):
+ test.support.patch(self, time, 'daylight', True)
+ t0 = datetime.datetime(2012, 3, 12, 1, 1)
+ t1 = utils.localtime(t0, isdst=1)
+ t2 = utils.localtime(t1)
+ self.assertEqual(t1, t2)
+
+ def test_localtime_daylight_false_dst_true(self):
+ test.support.patch(self, time, 'daylight', False)
+ t0 = datetime.datetime(2012, 3, 12, 1, 1)
+ t1 = utils.localtime(t0, isdst=1)
+ t2 = utils.localtime(t1)
+ self.assertEqual(t1, t2)
+
+ @test.support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0')
+ def test_localtime_epoch_utc_daylight_true(self):
+ test.support.patch(self, time, 'daylight', True)
+ t0 = datetime.datetime(1990, 1, 1, tzinfo = datetime.timezone.utc)
+ t1 = utils.localtime(t0)
+ t2 = t0 - datetime.timedelta(hours=5)
+ t2 = t2.replace(tzinfo = datetime.timezone(datetime.timedelta(hours=-5)))
+ self.assertEqual(t1, t2)
+
+ @test.support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0')
+ def test_localtime_epoch_utc_daylight_false(self):
+ test.support.patch(self, time, 'daylight', False)
+ t0 = datetime.datetime(1990, 1, 1, tzinfo = datetime.timezone.utc)
+ t1 = utils.localtime(t0)
+ t2 = t0 - datetime.timedelta(hours=5)
+ t2 = t2.replace(tzinfo = datetime.timezone(datetime.timedelta(hours=-5)))
+ self.assertEqual(t1, t2)
+
+ def test_localtime_epoch_notz_daylight_true(self):
+ test.support.patch(self, time, 'daylight', True)
+ t0 = datetime.datetime(1990, 1, 1)
+ t1 = utils.localtime(t0)
+ t2 = utils.localtime(t0.replace(tzinfo=None))
+ self.assertEqual(t1, t2)
+
+ def test_localtime_epoch_notz_daylight_false(self):
+ test.support.patch(self, time, 'daylight', False)
+ t0 = datetime.datetime(1990, 1, 1)
+ t1 = utils.localtime(t0)
+ t2 = utils.localtime(t0.replace(tzinfo=None))
+ self.assertEqual(t1, t2)
+
+ # XXX: Need a more robust test for Olson's tzdata
+ @unittest.skipIf(sys.platform.startswith('win'),
+ "Windows does not use Olson's TZ database")
+ @test.support.run_with_tz('Europe/Kiev')
+ def test_variable_tzname(self):
+ t0 = datetime.datetime(1984, 1, 1, tzinfo=datetime.timezone.utc)
+ t1 = utils.localtime(t0)
+ self.assertEqual(t1.tzname(), 'MSK')
+ t0 = datetime.datetime(1994, 1, 1, tzinfo=datetime.timezone.utc)
+ t1 = utils.localtime(t0)
+ self.assertEqual(t1.tzname(), 'EET')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/torture_test.py b/Lib/test/test_email/torture_test.py
new file mode 100644
index 0000000..544b1bb
--- /dev/null
+++ b/Lib/test/test_email/torture_test.py
@@ -0,0 +1,136 @@
+# Copyright (C) 2002-2004 Python Software Foundation
+#
+# A torture test of the email package. This should not be run as part of the
+# standard Python test suite since it requires several meg of email messages
+# collected in the wild. These source messages are not checked into the
+# Python distro, but are available as part of the standalone email package at
+# http://sf.net/projects/mimelib
+
+import sys
+import os
+import unittest
+from io import StringIO
+from types import ListType
+
+from email.test.test_email import TestEmailBase
+from test.support import TestSkipped, run_unittest
+
+import email
+from email import __file__ as testfile
+from email.iterators import _structure
+
+def openfile(filename):
+ from os.path import join, dirname, abspath
+ path = abspath(join(dirname(testfile), os.pardir, 'moredata', filename))
+ return open(path, 'r')
+
+# Prevent this test from running in the Python distro
+try:
+ openfile('crispin-torture.txt')
+except IOError:
+ raise TestSkipped
+
+
+
+class TortureBase(TestEmailBase):
+ def _msgobj(self, filename):
+ fp = openfile(filename)
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ return msg
+
+
+
+class TestCrispinTorture(TortureBase):
+ # Mark Crispin's torture test from the SquirrelMail project
+ def test_mondo_message(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ msg = self._msgobj('crispin-torture.txt')
+ payload = msg.get_payload()
+ eq(type(payload), ListType)
+ eq(len(payload), 12)
+ eq(msg.preamble, None)
+ eq(msg.epilogue, '\n')
+ # Probably the best way to verify the message is parsed correctly is to
+ # dump its structure and compare it against the known structure.
+ fp = StringIO()
+ _structure(msg, fp=fp)
+ neq(fp.getvalue(), """\
+multipart/mixed
+ text/plain
+ message/rfc822
+ multipart/alternative
+ text/plain
+ multipart/mixed
+ text/richtext
+ application/andrew-inset
+ message/rfc822
+ audio/basic
+ audio/basic
+ image/pbm
+ message/rfc822
+ multipart/mixed
+ multipart/mixed
+ text/plain
+ audio/x-sun
+ multipart/mixed
+ image/gif
+ image/gif
+ application/x-be2
+ application/atomicmail
+ audio/x-sun
+ message/rfc822
+ multipart/mixed
+ text/plain
+ image/pgm
+ text/plain
+ message/rfc822
+ multipart/mixed
+ text/plain
+ image/pbm
+ message/rfc822
+ application/postscript
+ image/gif
+ message/rfc822
+ multipart/mixed
+ audio/basic
+ audio/basic
+ message/rfc822
+ multipart/mixed
+ application/postscript
+ text/plain
+ message/rfc822
+ multipart/mixed
+ text/plain
+ multipart/parallel
+ image/gif
+ audio/basic
+ application/atomicmail
+ message/rfc822
+ audio/x-sun
+""")
+
+
+def _testclasses():
+ mod = sys.modules[__name__]
+ return [getattr(mod, name) for name in dir(mod) if name.startswith('Test')]
+
+
+def suite():
+ suite = unittest.TestSuite()
+ for testclass in _testclasses():
+ suite.addTest(unittest.makeSuite(testclass))
+ return suite
+
+
+def test_main():
+ for testclass in _testclasses():
+ run_unittest(testclass)
+
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/Lib/test/test_enumerate.py b/Lib/test/test_enumerate.py
index 095820b..2e904cf 100644
--- a/Lib/test/test_enumerate.py
+++ b/Lib/test/test_enumerate.py
@@ -1,5 +1,6 @@
import unittest
import sys
+import pickle
from test import support
@@ -61,7 +62,25 @@ class N:
def __iter__(self):
return self
-class EnumerateTestCase(unittest.TestCase):
+class PickleTest:
+ # Helper to check picklability
+ def check_pickle(self, itorg, seq):
+ d = pickle.dumps(itorg)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(list(it), seq)
+
+ it = pickle.loads(d)
+ try:
+ next(it)
+ except StopIteration:
+ self.assertFalse(seq[1:])
+ return
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(list(it), seq[1:])
+
+class EnumerateTestCase(unittest.TestCase, PickleTest):
enum = enumerate
seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
@@ -73,6 +92,9 @@ class EnumerateTestCase(unittest.TestCase):
self.assertEqual(list(self.enum(self.seq)), self.res)
self.enum.__doc__
+ def test_pickle(self):
+ self.check_pickle(self.enum(self.seq), self.res)
+
def test_getitemseqn(self):
self.assertEqual(list(self.enum(G(self.seq))), self.res)
e = self.enum(G(''))
@@ -126,7 +148,7 @@ class TestBig(EnumerateTestCase):
seq = range(10,20000,2)
res = list(zip(range(20000), seq))
-class TestReversed(unittest.TestCase):
+class TestReversed(unittest.TestCase, PickleTest):
def test_simple(self):
class A:
@@ -212,6 +234,10 @@ class TestReversed(unittest.TestCase):
ngi = NoGetItem()
self.assertRaises(TypeError, reversed, ngi)
+ def test_pickle(self):
+ for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
+ self.check_pickle(reversed(data), list(data)[::-1])
+
class EnumerateStartTestCase(EnumerateTestCase):
diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py
index 083fd7f..7f9547f 100644
--- a/Lib/test/test_epoll.py
+++ b/Lib/test/test_epoll.py
@@ -75,6 +75,9 @@ class TestEPoll(unittest.TestCase):
ep.close()
self.assertTrue(ep.closed)
self.assertRaises(ValueError, ep.fileno)
+ if hasattr(select, "EPOLL_CLOEXEC"):
+ select.epoll(select.EPOLL_CLOEXEC).close()
+ self.assertRaises(OSError, select.epoll, flags=12356)
def test_badcreate(self):
self.assertRaises(TypeError, select.epoll, 1, 2, 3)
diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py
index 79bd7ff..1ad7f97 100644
--- a/Lib/test/test_exceptions.py
+++ b/Lib/test/test_exceptions.py
@@ -8,7 +8,17 @@ import weakref
import errno
from test.support import (TESTFN, captured_output, check_impl_detail,
- cpython_only, gc_collect, run_unittest, unlink)
+ cpython_only, gc_collect, run_unittest, no_tracing,
+ unlink)
+
+class NaiveException(Exception):
+ def __init__(self, x):
+ self.x = x
+
+class SlottedNaiveException(Exception):
+ __slots__ = ('x',)
+ def __init__(self, x):
+ self.x = x
# XXX This is not really enough, each *operation* should be tested!
@@ -46,8 +56,8 @@ class ExceptionTests(unittest.TestCase):
fp.close()
unlink(TESTFN)
- self.raise_catch(IOError, "IOError")
- self.assertRaises(IOError, open, 'this file does not exist', 'r')
+ self.raise_catch(OSError, "OSError")
+ self.assertRaises(OSError, open, 'this file does not exist', 'r')
self.raise_catch(ImportError, "ImportError")
self.assertRaises(ImportError, __import__, "undefined_module")
@@ -192,11 +202,35 @@ class ExceptionTests(unittest.TestCase):
except NameError:
pass
else:
- self.assertEqual(str(WindowsError(1001)), "1001")
- self.assertEqual(str(WindowsError(1001, "message")),
- "[Error 1001] message")
- self.assertEqual(WindowsError(1001, "message").errno, 22)
- self.assertEqual(WindowsError(1001, "message").winerror, 1001)
+ self.assertIs(WindowsError, OSError)
+ self.assertEqual(str(OSError(1001)), "1001")
+ self.assertEqual(str(OSError(1001, "message")),
+ "[Errno 1001] message")
+ # POSIX errno (9 aka EBADF) is untranslated
+ w = OSError(9, 'foo', 'bar')
+ self.assertEqual(w.errno, 9)
+ self.assertEqual(w.winerror, None)
+ self.assertEqual(str(w), "[Errno 9] foo: 'bar'")
+ # ERROR_PATH_NOT_FOUND (win error 3) becomes ENOENT (2)
+ w = OSError(0, 'foo', 'bar', 3)
+ self.assertEqual(w.errno, 2)
+ self.assertEqual(w.winerror, 3)
+ self.assertEqual(w.strerror, 'foo')
+ self.assertEqual(w.filename, 'bar')
+ self.assertEqual(str(w), "[WinError 3] foo: 'bar'")
+ # Unknown win error becomes EINVAL (22)
+ w = OSError(0, 'foo', None, 1001)
+ self.assertEqual(w.errno, 22)
+ self.assertEqual(w.winerror, 1001)
+ self.assertEqual(w.strerror, 'foo')
+ self.assertEqual(w.filename, None)
+ self.assertEqual(str(w), "[WinError 1001] foo")
+ # Non-numeric "errno"
+ w = OSError('bar', 'foo')
+ self.assertEqual(w.errno, 'bar')
+ self.assertEqual(w.winerror, None)
+ self.assertEqual(w.strerror, 'foo')
+ self.assertEqual(w.filename, None)
def testAttributes(self):
# test that exception attributes are happy
@@ -272,13 +306,18 @@ class ExceptionTests(unittest.TestCase):
{'args' : ('\u3042', 0, 1, 'ouch'),
'object' : '\u3042', 'reason' : 'ouch',
'start' : 0, 'end' : 1}),
+ (NaiveException, ('foo',),
+ {'args': ('foo',), 'x': 'foo'}),
+ (SlottedNaiveException, ('foo',),
+ {'args': ('foo',), 'x': 'foo'}),
]
try:
+ # More tests are in test_WindowsError
exceptionList.append(
(WindowsError, (1, 'strErrorStr', 'filenameStr'),
{'args' : (1, 'strErrorStr'),
- 'strerror' : 'strErrorStr', 'winerror' : 1,
- 'errno' : 22, 'filename' : 'filenameStr'})
+ 'strerror' : 'strErrorStr', 'winerror' : None,
+ 'errno' : 1, 'filename' : 'filenameStr'})
)
except NameError:
pass
@@ -291,7 +330,8 @@ class ExceptionTests(unittest.TestCase):
raise
else:
# Verify module name
- self.assertEqual(type(e).__module__, 'builtins')
+ if not type(e).__name__.endswith('NaiveException'):
+ self.assertEqual(type(e).__module__, 'builtins')
# Verify no ref leaks in Exc_str()
s = str(e)
for checkArgName in expected:
@@ -362,19 +402,37 @@ class ExceptionTests(unittest.TestCase):
def testChainingAttrs(self):
e = Exception()
- self.assertEqual(e.__context__, None)
- self.assertEqual(e.__cause__, None)
+ self.assertIsNone(e.__context__)
+ self.assertIsNone(e.__cause__)
e = TypeError()
- self.assertEqual(e.__context__, None)
- self.assertEqual(e.__cause__, None)
+ self.assertIsNone(e.__context__)
+ self.assertIsNone(e.__cause__)
class MyException(EnvironmentError):
pass
e = MyException()
- self.assertEqual(e.__context__, None)
- self.assertEqual(e.__cause__, None)
+ self.assertIsNone(e.__context__)
+ self.assertIsNone(e.__cause__)
+
+ def testChainingDescriptors(self):
+ try:
+ raise Exception()
+ except Exception as exc:
+ e = exc
+
+ self.assertIsNone(e.__context__)
+ self.assertIsNone(e.__cause__)
+ self.assertFalse(e.__suppress_context__)
+
+ e.__context__ = NameError()
+ e.__cause__ = None
+ self.assertIsInstance(e.__context__, NameError)
+ self.assertIsNone(e.__cause__)
+ self.assertTrue(e.__suppress_context__)
+ e.__suppress_context__ = False
+ self.assertFalse(e.__suppress_context__)
def testKeywordArgs(self):
# test that builtin exception don't take keyword args,
@@ -389,6 +447,7 @@ class ExceptionTests(unittest.TestCase):
x = DerivedException(fancy_arg=42)
self.assertEqual(x.fancy_arg, 42)
+ @no_tracing
def testInfiniteRecursion(self):
def f():
return f()
@@ -728,6 +787,7 @@ class ExceptionTests(unittest.TestCase):
u.start = 1000
self.assertEqual(str(u), "can't translate characters in position 1000-4: 965230951443685724997")
+ @no_tracing
def test_badisinstance(self):
# Bug #2542: if issubclass(e, MyException) raises an exception,
# it should be ignored
@@ -838,6 +898,7 @@ class ExceptionTests(unittest.TestCase):
self.fail("MemoryError not raised")
self.assertEqual(wr(), None)
+ @no_tracing
def test_recursion_error_cleanup(self):
# Same test as above, but with "recursion exceeded" errors
class C:
@@ -864,8 +925,35 @@ class ExceptionTests(unittest.TestCase):
self.assertEqual(cm.exception.errno, errno.ENOTDIR, cm.exception)
+class ImportErrorTests(unittest.TestCase):
+
+ def test_attributes(self):
+ # Setting 'name' and 'path' should not be a problem.
+ exc = ImportError('test')
+ self.assertIsNone(exc.name)
+ self.assertIsNone(exc.path)
+
+ exc = ImportError('test', name='somemodule')
+ self.assertEqual(exc.name, 'somemodule')
+ self.assertIsNone(exc.path)
+
+ exc = ImportError('test', path='somepath')
+ self.assertEqual(exc.path, 'somepath')
+ self.assertIsNone(exc.name)
+
+ exc = ImportError('test', path='somepath', name='somename')
+ self.assertEqual(exc.name, 'somename')
+ self.assertEqual(exc.path, 'somepath')
+
+ def test_non_str_argument(self):
+ # Issue #15778
+ arg = b'abc'
+ exc = ImportError(arg)
+ self.assertEqual(str(arg), str(exc))
+
+
def test_main():
- run_unittest(ExceptionTests)
+ run_unittest(ExceptionTests, ImportErrorTests)
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_extcall.py b/Lib/test/test_extcall.py
index 1f7f630..6b6c12d 100644
--- a/Lib/test/test_extcall.py
+++ b/Lib/test/test_extcall.py
@@ -66,17 +66,17 @@ Verify clearing of SF bug #733667
>>> g()
Traceback (most recent call last):
...
- TypeError: g() takes at least 1 argument (0 given)
+ TypeError: g() missing 1 required positional argument: 'x'
>>> g(*())
Traceback (most recent call last):
...
- TypeError: g() takes at least 1 argument (0 given)
+ TypeError: g() missing 1 required positional argument: 'x'
>>> g(*(), **{})
Traceback (most recent call last):
...
- TypeError: g() takes at least 1 argument (0 given)
+ TypeError: g() missing 1 required positional argument: 'x'
>>> g(1)
1 () {}
@@ -151,7 +151,7 @@ What about willful misconduct?
>>> g(1, 2, 3, **{'x': 4, 'y': 5})
Traceback (most recent call last):
...
- TypeError: g() got multiple values for keyword argument 'x'
+ TypeError: g() got multiple values for argument 'x'
>>> f(**{1:2})
Traceback (most recent call last):
@@ -263,29 +263,80 @@ the function call setup. See <http://bugs.python.org/issue2016>.
>>> f(**x)
1 2
-A obscure message:
+Too many arguments:
- >>> def f(a, b):
- ... pass
- >>> f(b=1)
+ >>> def f(): pass
+ >>> f(1)
+ Traceback (most recent call last):
+ ...
+ TypeError: f() takes 0 positional arguments but 1 was given
+ >>> def f(a): pass
+ >>> f(1, 2)
Traceback (most recent call last):
...
- TypeError: f() takes exactly 2 arguments (1 given)
+ TypeError: f() takes 1 positional argument but 2 were given
+ >>> def f(a, b=1): pass
+ >>> f(1, 2, 3)
+ Traceback (most recent call last):
+ ...
+ TypeError: f() takes from 1 to 2 positional arguments but 3 were given
+ >>> def f(*, kw): pass
+ >>> f(1, kw=3)
+ Traceback (most recent call last):
+ ...
+ TypeError: f() takes 0 positional arguments but 1 positional argument (and 1 keyword-only argument) were given
+ >>> def f(*, kw, b): pass
+ >>> f(1, 2, 3, b=3, kw=3)
+ Traceback (most recent call last):
+ ...
+ TypeError: f() takes 0 positional arguments but 3 positional arguments (and 2 keyword-only arguments) were given
+ >>> def f(a, b=2, *, kw): pass
+ >>> f(2, 3, 4, kw=4)
+ Traceback (most recent call last):
+ ...
+ TypeError: f() takes from 1 to 2 positional arguments but 3 positional arguments (and 1 keyword-only argument) were given
-The number of arguments passed in includes keywords:
+Too few and missing arguments:
- >>> def f(a):
- ... pass
- >>> f(6, a=4, *(1, 2, 3))
+ >>> def f(a): pass
+ >>> f()
Traceback (most recent call last):
...
- TypeError: f() takes exactly 1 positional argument (5 given)
- >>> def f(a, *, kw):
- ... pass
- >>> f(6, 4, kw=4)
+ TypeError: f() missing 1 required positional argument: 'a'
+ >>> def f(a, b): pass
+ >>> f()
Traceback (most recent call last):
...
- TypeError: f() takes exactly 1 positional argument (3 given)
+ TypeError: f() missing 2 required positional arguments: 'a' and 'b'
+ >>> def f(a, b, c): pass
+ >>> f()
+ Traceback (most recent call last):
+ ...
+ TypeError: f() missing 3 required positional arguments: 'a', 'b', and 'c'
+ >>> def f(a, b, c, d, e): pass
+ >>> f()
+ Traceback (most recent call last):
+ ...
+ TypeError: f() missing 5 required positional arguments: 'a', 'b', 'c', 'd', and 'e'
+ >>> def f(a, b=4, c=5, d=5): pass
+ >>> f(c=12, b=9)
+ Traceback (most recent call last):
+ ...
+ TypeError: f() missing 1 required positional argument: 'a'
+
+Same with keyword only args:
+
+ >>> def f(*, w): pass
+ >>> f()
+ Traceback (most recent call last):
+ ...
+ TypeError: f() missing 1 required keyword-only argument: 'w'
+ >>> def f(*, a, b, c, d, e): pass
+ >>> f()
+ Traceback (most recent call last):
+ ...
+ TypeError: f() missing 5 required keyword-only arguments: 'a', 'b', 'c', 'd', and 'e'
+
"""
import sys
diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py
new file mode 100644
index 0000000..b81b34d
--- /dev/null
+++ b/Lib/test/test_faulthandler.py
@@ -0,0 +1,593 @@
+from contextlib import contextmanager
+import datetime
+import faulthandler
+import os
+import re
+import signal
+import subprocess
+import sys
+from test import support, script_helper
+from test.script_helper import assert_python_ok
+import tempfile
+import unittest
+
+try:
+ import threading
+ HAVE_THREADS = True
+except ImportError:
+ HAVE_THREADS = False
+
+TIMEOUT = 0.5
+
+try:
+ from resource import setrlimit, RLIMIT_CORE, error as resource_error
+except ImportError:
+ prepare_subprocess = None
+else:
+ def prepare_subprocess():
+ # don't create core file
+ try:
+ setrlimit(RLIMIT_CORE, (0, 0))
+ except (ValueError, resource_error):
+ pass
+
+def expected_traceback(lineno1, lineno2, header, min_count=1):
+ regex = header
+ regex += ' File "<string>", line %s in func\n' % lineno1
+ regex += ' File "<string>", line %s in <module>' % lineno2
+ if 1 < min_count:
+ return '^' + (regex + '\n') * (min_count - 1) + regex
+ else:
+ return '^' + regex + '$'
+
+@contextmanager
+def temporary_filename():
+ filename = tempfile.mktemp()
+ try:
+ yield filename
+ finally:
+ support.unlink(filename)
+
+class FaultHandlerTests(unittest.TestCase):
+ def get_output(self, code, filename=None):
+ """
+ Run the specified code in Python (in a new child process) and read the
+ output from the standard error or from a file (if filename is set).
+ Return the output lines as a list.
+
+ Strip the reference count from the standard error for Python debug
+ build, and replace "Current thread 0x00007f8d8fbd9700" by "Current
+ thread XXX".
+ """
+ options = {}
+ if prepare_subprocess:
+ options['preexec_fn'] = prepare_subprocess
+ process = script_helper.spawn_python('-c', code, **options)
+ stdout, stderr = process.communicate()
+ exitcode = process.wait()
+ output = support.strip_python_stderr(stdout)
+ output = output.decode('ascii', 'backslashreplace')
+ if filename:
+ self.assertEqual(output, '')
+ with open(filename, "rb") as fp:
+ output = fp.read()
+ output = output.decode('ascii', 'backslashreplace')
+ output = re.sub('Current thread 0x[0-9a-f]+',
+ 'Current thread XXX',
+ output)
+ return output.splitlines(), exitcode
+
+ def check_fatal_error(self, code, line_number, name_regex,
+ filename=None, all_threads=True, other_regex=None):
+ """
+ Check that the fault handler for fatal errors is enabled and check the
+ traceback from the child process output.
+
+ Raise an error if the output doesn't match the expected format.
+ """
+ if all_threads:
+ header = 'Current thread XXX'
+ else:
+ header = 'Traceback (most recent call first)'
+ regex = """
+^Fatal Python error: {name}
+
+{header}:
+ File "<string>", line {lineno} in <module>
+""".strip()
+ regex = regex.format(
+ lineno=line_number,
+ name=name_regex,
+ header=re.escape(header))
+ if other_regex:
+ regex += '|' + other_regex
+ output, exitcode = self.get_output(code, filename)
+ output = '\n'.join(output)
+ self.assertRegex(output, regex)
+ self.assertNotEqual(exitcode, 0)
+
+ def test_read_null(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._read_null()
+""".strip(),
+ 3,
+ # Issue #12700: Read NULL raises SIGILL on Mac OS X Lion
+ '(?:Segmentation fault|Bus error|Illegal instruction)')
+
+ def test_sigsegv(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._sigsegv()
+""".strip(),
+ 3,
+ 'Segmentation fault')
+
+ def test_sigabrt(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._sigabrt()
+""".strip(),
+ 3,
+ 'Aborted')
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "SIGFPE cannot be caught on Windows")
+ def test_sigfpe(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._sigfpe()
+""".strip(),
+ 3,
+ 'Floating point exception')
+
+ @unittest.skipIf(not hasattr(faulthandler, '_sigbus'),
+ "need faulthandler._sigbus()")
+ def test_sigbus(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._sigbus()
+""".strip(),
+ 3,
+ 'Bus error')
+
+ @unittest.skipIf(not hasattr(faulthandler, '_sigill'),
+ "need faulthandler._sigill()")
+ def test_sigill(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._sigill()
+""".strip(),
+ 3,
+ 'Illegal instruction')
+
+ def test_fatal_error(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler._fatal_error(b'xyz')
+""".strip(),
+ 2,
+ 'xyz')
+
+ @unittest.skipIf(sys.platform.startswith('openbsd') and HAVE_THREADS,
+ "Issue #12868: sigaltstack() doesn't work on "
+ "OpenBSD if Python is compiled with pthread")
+ @unittest.skipIf(not hasattr(faulthandler, '_stack_overflow'),
+ 'need faulthandler._stack_overflow()')
+ def test_stack_overflow(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._stack_overflow()
+""".strip(),
+ 3,
+ '(?:Segmentation fault|Bus error)',
+ other_regex='unable to raise a stack overflow')
+
+ def test_gil_released(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable()
+faulthandler._read_null(True)
+""".strip(),
+ 3,
+ '(?:Segmentation fault|Bus error|Illegal instruction)')
+
+ def test_enable_file(self):
+ with temporary_filename() as filename:
+ self.check_fatal_error("""
+import faulthandler
+output = open({filename}, 'wb')
+faulthandler.enable(output)
+faulthandler._read_null()
+""".strip().format(filename=repr(filename)),
+ 4,
+ '(?:Segmentation fault|Bus error|Illegal instruction)',
+ filename=filename)
+
+ def test_enable_single_thread(self):
+ self.check_fatal_error("""
+import faulthandler
+faulthandler.enable(all_threads=False)
+faulthandler._read_null()
+""".strip(),
+ 3,
+ '(?:Segmentation fault|Bus error|Illegal instruction)',
+ all_threads=False)
+
+ def test_disable(self):
+ code = """
+import faulthandler
+faulthandler.enable()
+faulthandler.disable()
+faulthandler._read_null()
+""".strip()
+ not_expected = 'Fatal Python error'
+ stderr, exitcode = self.get_output(code)
+ stder = '\n'.join(stderr)
+ self.assertTrue(not_expected not in stderr,
+ "%r is present in %r" % (not_expected, stderr))
+ self.assertNotEqual(exitcode, 0)
+
+ def test_is_enabled(self):
+ orig_stderr = sys.stderr
+ try:
+ # regrtest may replace sys.stderr by io.StringIO object, but
+ # faulthandler.enable() requires that sys.stderr has a fileno()
+ # method
+ sys.stderr = sys.__stderr__
+
+ was_enabled = faulthandler.is_enabled()
+ try:
+ faulthandler.enable()
+ self.assertTrue(faulthandler.is_enabled())
+ faulthandler.disable()
+ self.assertFalse(faulthandler.is_enabled())
+ finally:
+ if was_enabled:
+ faulthandler.enable()
+ else:
+ faulthandler.disable()
+ finally:
+ sys.stderr = orig_stderr
+
+ def test_disabled_by_default(self):
+ # By default, the module should be disabled
+ code = "import faulthandler; print(faulthandler.is_enabled())"
+ rc, stdout, stderr = assert_python_ok("-c", code)
+ stdout = (stdout + stderr).strip()
+ self.assertEqual(stdout, b"False")
+
+ def test_sys_xoptions(self):
+ # Test python -X faulthandler
+ code = "import faulthandler; print(faulthandler.is_enabled())"
+ rc, stdout, stderr = assert_python_ok("-X", "faulthandler", "-c", code)
+ stdout = (stdout + stderr).strip()
+ self.assertEqual(stdout, b"True")
+
+ def check_dump_traceback(self, filename):
+ """
+ Explicitly call dump_traceback() function and check its output.
+ Raise an error if the output doesn't match the expected format.
+ """
+ code = """
+import faulthandler
+
+def funcB():
+ if {has_filename}:
+ with open({filename}, "wb") as fp:
+ faulthandler.dump_traceback(fp, all_threads=False)
+ else:
+ faulthandler.dump_traceback(all_threads=False)
+
+def funcA():
+ funcB()
+
+funcA()
+""".strip()
+ code = code.format(
+ filename=repr(filename),
+ has_filename=bool(filename),
+ )
+ if filename:
+ lineno = 6
+ else:
+ lineno = 8
+ expected = [
+ 'Traceback (most recent call first):',
+ ' File "<string>", line %s in funcB' % lineno,
+ ' File "<string>", line 11 in funcA',
+ ' File "<string>", line 13 in <module>'
+ ]
+ trace, exitcode = self.get_output(code, filename)
+ self.assertEqual(trace, expected)
+ self.assertEqual(exitcode, 0)
+
+ def test_dump_traceback(self):
+ self.check_dump_traceback(None)
+
+ def test_dump_traceback_file(self):
+ with temporary_filename() as filename:
+ self.check_dump_traceback(filename)
+
+ def test_truncate(self):
+ maxlen = 500
+ func_name = 'x' * (maxlen + 50)
+ truncated = 'x' * maxlen + '...'
+ code = """
+import faulthandler
+
+def {func_name}():
+ faulthandler.dump_traceback(all_threads=False)
+
+{func_name}()
+""".strip()
+ code = code.format(
+ func_name=func_name,
+ )
+ expected = [
+ 'Traceback (most recent call first):',
+ ' File "<string>", line 4 in %s' % truncated,
+ ' File "<string>", line 6 in <module>'
+ ]
+ trace, exitcode = self.get_output(code)
+ self.assertEqual(trace, expected)
+ self.assertEqual(exitcode, 0)
+
+ @unittest.skipIf(not HAVE_THREADS, 'need threads')
+ def check_dump_traceback_threads(self, filename):
+ """
+ Call explicitly dump_traceback(all_threads=True) and check the output.
+ Raise an error if the output doesn't match the expected format.
+ """
+ code = """
+import faulthandler
+from threading import Thread, Event
+import time
+
+def dump():
+ if {filename}:
+ with open({filename}, "wb") as fp:
+ faulthandler.dump_traceback(fp, all_threads=True)
+ else:
+ faulthandler.dump_traceback(all_threads=True)
+
+class Waiter(Thread):
+ # avoid blocking if the main thread raises an exception.
+ daemon = True
+
+ def __init__(self):
+ Thread.__init__(self)
+ self.running = Event()
+ self.stop = Event()
+
+ def run(self):
+ self.running.set()
+ self.stop.wait()
+
+waiter = Waiter()
+waiter.start()
+waiter.running.wait()
+dump()
+waiter.stop.set()
+waiter.join()
+""".strip()
+ code = code.format(filename=repr(filename))
+ output, exitcode = self.get_output(code, filename)
+ output = '\n'.join(output)
+ if filename:
+ lineno = 8
+ else:
+ lineno = 10
+ regex = """
+^Thread 0x[0-9a-f]+:
+(?: File ".*threading.py", line [0-9]+ in [_a-z]+
+){{1,3}} File "<string>", line 23 in run
+ File ".*threading.py", line [0-9]+ in _bootstrap_inner
+ File ".*threading.py", line [0-9]+ in _bootstrap
+
+Current thread XXX:
+ File "<string>", line {lineno} in dump
+ File "<string>", line 28 in <module>$
+""".strip()
+ regex = regex.format(lineno=lineno)
+ self.assertRegex(output, regex)
+ self.assertEqual(exitcode, 0)
+
+ def test_dump_traceback_threads(self):
+ self.check_dump_traceback_threads(None)
+
+ def test_dump_traceback_threads_file(self):
+ with temporary_filename() as filename:
+ self.check_dump_traceback_threads(filename)
+
+ def _check_dump_traceback_later(self, repeat, cancel, filename, loops):
+ """
+ Check how many times the traceback is written in timeout x 2.5 seconds,
+ or timeout x 3.5 seconds if cancel is True: 1, 2 or 3 times depending
+ on repeat and cancel options.
+
+ Raise an error if the output doesn't match the expect format.
+ """
+ timeout_str = str(datetime.timedelta(seconds=TIMEOUT))
+ code = """
+import faulthandler
+import time
+
+def func(timeout, repeat, cancel, file, loops):
+ for loop in range(loops):
+ faulthandler.dump_traceback_later(timeout, repeat=repeat, file=file)
+ if cancel:
+ faulthandler.cancel_dump_traceback_later()
+ time.sleep(timeout * 5)
+ faulthandler.cancel_dump_traceback_later()
+
+timeout = {timeout}
+repeat = {repeat}
+cancel = {cancel}
+loops = {loops}
+if {has_filename}:
+ file = open({filename}, "wb")
+else:
+ file = None
+func(timeout, repeat, cancel, file, loops)
+if file is not None:
+ file.close()
+""".strip()
+ code = code.format(
+ timeout=TIMEOUT,
+ repeat=repeat,
+ cancel=cancel,
+ loops=loops,
+ has_filename=bool(filename),
+ filename=repr(filename),
+ )
+ trace, exitcode = self.get_output(code, filename)
+ trace = '\n'.join(trace)
+
+ if not cancel:
+ count = loops
+ if repeat:
+ count *= 2
+ header = r'Timeout \(%s\)!\nThread 0x[0-9a-f]+:\n' % timeout_str
+ regex = expected_traceback(9, 20, header, min_count=count)
+ self.assertRegex(trace, regex)
+ else:
+ self.assertEqual(trace, '')
+ self.assertEqual(exitcode, 0)
+
+ @unittest.skipIf(not hasattr(faulthandler, 'dump_traceback_later'),
+ 'need faulthandler.dump_traceback_later()')
+ def check_dump_traceback_later(self, repeat=False, cancel=False,
+ file=False, twice=False):
+ if twice:
+ loops = 2
+ else:
+ loops = 1
+ if file:
+ with temporary_filename() as filename:
+ self._check_dump_traceback_later(repeat, cancel,
+ filename, loops)
+ else:
+ self._check_dump_traceback_later(repeat, cancel, None, loops)
+
+ def test_dump_traceback_later(self):
+ self.check_dump_traceback_later()
+
+ def test_dump_traceback_later_repeat(self):
+ self.check_dump_traceback_later(repeat=True)
+
+ def test_dump_traceback_later_cancel(self):
+ self.check_dump_traceback_later(cancel=True)
+
+ def test_dump_traceback_later_file(self):
+ self.check_dump_traceback_later(file=True)
+
+ def test_dump_traceback_later_twice(self):
+ self.check_dump_traceback_later(twice=True)
+
+ @unittest.skipIf(not hasattr(faulthandler, "register"),
+ "need faulthandler.register")
+ def check_register(self, filename=False, all_threads=False,
+ unregister=False, chain=False):
+ """
+ Register a handler displaying the traceback on a user signal. Raise the
+ signal and check the written traceback.
+
+ If chain is True, check that the previous signal handler is called.
+
+ Raise an error if the output doesn't match the expected format.
+ """
+ signum = signal.SIGUSR1
+ code = """
+import faulthandler
+import os
+import signal
+import sys
+
+def func(signum):
+ os.kill(os.getpid(), signum)
+
+def handler(signum, frame):
+ handler.called = True
+handler.called = False
+
+exitcode = 0
+signum = {signum}
+unregister = {unregister}
+chain = {chain}
+
+if {has_filename}:
+ file = open({filename}, "wb")
+else:
+ file = None
+if chain:
+ signal.signal(signum, handler)
+faulthandler.register(signum, file=file,
+ all_threads={all_threads}, chain={chain})
+if unregister:
+ faulthandler.unregister(signum)
+func(signum)
+if chain and not handler.called:
+ if file is not None:
+ output = file
+ else:
+ output = sys.stderr
+ print("Error: signal handler not called!", file=output)
+ exitcode = 1
+if file is not None:
+ file.close()
+sys.exit(exitcode)
+""".strip()
+ code = code.format(
+ filename=repr(filename),
+ has_filename=bool(filename),
+ all_threads=all_threads,
+ signum=signum,
+ unregister=unregister,
+ chain=chain,
+ )
+ trace, exitcode = self.get_output(code, filename)
+ trace = '\n'.join(trace)
+ if not unregister:
+ if all_threads:
+ regex = 'Current thread XXX:\n'
+ else:
+ regex = 'Traceback \(most recent call first\):\n'
+ regex = expected_traceback(7, 28, regex)
+ self.assertRegex(trace, regex)
+ else:
+ self.assertEqual(trace, '')
+ if unregister:
+ self.assertNotEqual(exitcode, 0)
+ else:
+ self.assertEqual(exitcode, 0)
+
+ def test_register(self):
+ self.check_register()
+
+ def test_unregister(self):
+ self.check_register(unregister=True)
+
+ def test_register_file(self):
+ with temporary_filename() as filename:
+ self.check_register(filename=filename)
+
+ def test_register_threads(self):
+ self.check_register(all_threads=True)
+
+ def test_register_chain(self):
+ self.check_register(chain=True)
+
+
+def test_main():
+ support.run_unittest(FaultHandlerTests)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_file.py b/Lib/test/test_file.py
index bb0da79..a78ddf3 100644
--- a/Lib/test/test_file.py
+++ b/Lib/test/test_file.py
@@ -10,7 +10,7 @@ import _pyio as pyio
from test.support import TESTFN, run_unittest
from collections import UserList
-class AutoFileTests(unittest.TestCase):
+class AutoFileTests:
# file tests for which a test file is automatically set up
def setUp(self):
@@ -128,14 +128,14 @@ class AutoFileTests(unittest.TestCase):
def testReadWhenWriting(self):
self.assertRaises(IOError, self.f.read)
-class CAutoFileTests(AutoFileTests):
+class CAutoFileTests(AutoFileTests, unittest.TestCase):
open = io.open
-class PyAutoFileTests(AutoFileTests):
+class PyAutoFileTests(AutoFileTests, unittest.TestCase):
open = staticmethod(pyio.open)
-class OtherFileTests(unittest.TestCase):
+class OtherFileTests:
def testModeStrings(self):
# check invalid mode strings
@@ -322,22 +322,18 @@ class OtherFileTests(unittest.TestCase):
finally:
os.unlink(TESTFN)
-class COtherFileTests(OtherFileTests):
+class COtherFileTests(OtherFileTests, unittest.TestCase):
open = io.open
-class PyOtherFileTests(OtherFileTests):
+class PyOtherFileTests(OtherFileTests, unittest.TestCase):
open = staticmethod(pyio.open)
-def test_main():
+def tearDownModule():
# Historically, these tests have been sloppy about removing TESTFN.
# So get rid of it no matter what.
- try:
- run_unittest(CAutoFileTests, PyAutoFileTests,
- COtherFileTests, PyOtherFileTests)
- finally:
- if os.path.exists(TESTFN):
- os.unlink(TESTFN)
+ if os.path.exists(TESTFN):
+ os.unlink(TESTFN)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py
index f312882..1e70641 100644
--- a/Lib/test/test_fileinput.py
+++ b/Lib/test/test_fileinput.py
@@ -2,18 +2,33 @@
Tests for fileinput module.
Nick Mathewson
'''
-
+import os
+import sys
+import re
+import fileinput
+import collections
+import builtins
import unittest
-from test.support import verbose, TESTFN, run_unittest
-from test.support import unlink as safe_unlink
-import sys, re
+
+try:
+ import bz2
+except ImportError:
+ bz2 = None
+try:
+ import gzip
+except ImportError:
+ gzip = None
+
from io import StringIO
from fileinput import FileInput, hook_encoded
+from test.support import verbose, TESTFN, run_unittest
+from test.support import unlink as safe_unlink
+
+
# The fileinput module has 2 interfaces: the FileInput class which does
# all the work, and a few functions (input, etc.) that use a global _state
-# variable. We only test the FileInput class, since the other functions
-# only provide a thin facade over FileInput.
+# variable.
# Write lines (a list of lines) to temp file number i, and return the
# temp file's name.
@@ -121,7 +136,16 @@ class BufferSizesTests(unittest.TestCase):
self.assertEqual(int(m.group(1)), fi.filelineno())
fi.close()
+class UnconditionallyRaise:
+ def __init__(self, exception_type):
+ self.exception_type = exception_type
+ self.invoked = False
+ def __call__(self, *args, **kwargs):
+ self.invoked = True
+ raise self.exception_type()
+
class FileInputTests(unittest.TestCase):
+
def test_zero_byte_files(self):
t1 = t2 = t3 = t4 = None
try:
@@ -219,17 +243,20 @@ class FileInputTests(unittest.TestCase):
self.fail("FileInput should check openhook for being callable")
except ValueError:
pass
- # XXX The rot13 codec was removed.
- # So this test needs to be changed to use something else.
- # (Or perhaps the API needs to change so we can just pass
- # an encoding rather than using a hook?)
-## try:
-## t1 = writeTmp(1, ["A\nB"], mode="wb")
-## fi = FileInput(files=t1, openhook=hook_encoded("rot13"))
-## lines = list(fi)
-## self.assertEqual(lines, ["N\n", "O"])
-## finally:
-## remove_tempfiles(t1)
+
+ class CustomOpenHook:
+ def __init__(self):
+ self.invoked = False
+ def __call__(self, *args):
+ self.invoked = True
+ return open(*args)
+
+ t = writeTmp(1, ["\n"])
+ self.addCleanup(remove_tempfiles, t)
+ custom_open_hook = CustomOpenHook()
+ with FileInput([t], openhook=custom_open_hook) as fi:
+ fi.readline()
+ self.assertTrue(custom_open_hook.invoked, "openhook not invoked")
def test_context_manager(self):
try:
@@ -254,9 +281,576 @@ class FileInputTests(unittest.TestCase):
finally:
remove_tempfiles(t1)
+ def test_empty_files_list_specified_to_constructor(self):
+ with FileInput(files=[]) as fi:
+ self.assertEqual(fi._files, ('-',))
+
+ def test__getitem__(self):
+ """Tests invoking FileInput.__getitem__() with the current
+ line number"""
+ t = writeTmp(1, ["line1\n", "line2\n"])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t]) as fi:
+ retval1 = fi[0]
+ self.assertEqual(retval1, "line1\n")
+ retval2 = fi[1]
+ self.assertEqual(retval2, "line2\n")
+
+ def test__getitem__invalid_key(self):
+ """Tests invoking FileInput.__getitem__() with an index unequal to
+ the line number"""
+ t = writeTmp(1, ["line1\n", "line2\n"])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t]) as fi:
+ with self.assertRaises(RuntimeError) as cm:
+ fi[1]
+ self.assertEqual(cm.exception.args, ("accessing lines out of order",))
+
+ def test__getitem__eof(self):
+ """Tests invoking FileInput.__getitem__() with the line number but at
+ end-of-input"""
+ t = writeTmp(1, [])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t]) as fi:
+ with self.assertRaises(IndexError) as cm:
+ fi[0]
+ self.assertEqual(cm.exception.args, ("end of input reached",))
+
+ def test_nextfile_oserror_deleting_backup(self):
+ """Tests invoking FileInput.nextfile() when the attempt to delete
+ the backup file would raise OSError. This error is expected to be
+ silently ignored"""
+
+ os_unlink_orig = os.unlink
+ os_unlink_replacement = UnconditionallyRaise(OSError)
+ try:
+ t = writeTmp(1, ["\n"])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t], inplace=True) as fi:
+ next(fi) # make sure the file is opened
+ os.unlink = os_unlink_replacement
+ fi.nextfile()
+ finally:
+ os.unlink = os_unlink_orig
+
+ # sanity check to make sure that our test scenario was actually hit
+ self.assertTrue(os_unlink_replacement.invoked,
+ "os.unlink() was not invoked")
+
+ def test_readline_os_fstat_raises_OSError(self):
+ """Tests invoking FileInput.readline() when os.fstat() raises OSError.
+ This exception should be silently discarded."""
+
+ os_fstat_orig = os.fstat
+ os_fstat_replacement = UnconditionallyRaise(OSError)
+ try:
+ t = writeTmp(1, ["\n"])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t], inplace=True) as fi:
+ os.fstat = os_fstat_replacement
+ fi.readline()
+ finally:
+ os.fstat = os_fstat_orig
+
+ # sanity check to make sure that our test scenario was actually hit
+ self.assertTrue(os_fstat_replacement.invoked,
+ "os.fstat() was not invoked")
+
+ @unittest.skipIf(not hasattr(os, "chmod"), "os.chmod does not exist")
+ def test_readline_os_chmod_raises_OSError(self):
+ """Tests invoking FileInput.readline() when os.chmod() raises OSError.
+ This exception should be silently discarded."""
+
+ os_chmod_orig = os.chmod
+ os_chmod_replacement = UnconditionallyRaise(OSError)
+ try:
+ t = writeTmp(1, ["\n"])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t], inplace=True) as fi:
+ os.chmod = os_chmod_replacement
+ fi.readline()
+ finally:
+ os.chmod = os_chmod_orig
+
+ # sanity check to make sure that our test scenario was actually hit
+ self.assertTrue(os_chmod_replacement.invoked,
+ "os.fstat() was not invoked")
+
+ def test_fileno_when_ValueError_raised(self):
+ class FilenoRaisesValueError(UnconditionallyRaise):
+ def __init__(self):
+ UnconditionallyRaise.__init__(self, ValueError)
+ def fileno(self):
+ self.__call__()
+
+ unconditionally_raise_ValueError = FilenoRaisesValueError()
+ t = writeTmp(1, ["\n"])
+ self.addCleanup(remove_tempfiles, t)
+ with FileInput(files=[t]) as fi:
+ file_backup = fi._file
+ try:
+ fi._file = unconditionally_raise_ValueError
+ result = fi.fileno()
+ finally:
+ fi._file = file_backup # make sure the file gets cleaned up
+
+ # sanity check to make sure that our test scenario was actually hit
+ self.assertTrue(unconditionally_raise_ValueError.invoked,
+ "_file.fileno() was not invoked")
+
+ self.assertEqual(result, -1, "fileno() should return -1")
+
+class MockFileInput:
+ """A class that mocks out fileinput.FileInput for use during unit tests"""
+
+ def __init__(self, files=None, inplace=False, backup="", bufsize=0,
+ mode="r", openhook=None):
+ self.files = files
+ self.inplace = inplace
+ self.backup = backup
+ self.bufsize = bufsize
+ self.mode = mode
+ self.openhook = openhook
+ self._file = None
+ self.invocation_counts = collections.defaultdict(lambda: 0)
+ self.return_values = {}
+
+ def close(self):
+ self.invocation_counts["close"] += 1
+
+ def nextfile(self):
+ self.invocation_counts["nextfile"] += 1
+ return self.return_values["nextfile"]
+
+ def filename(self):
+ self.invocation_counts["filename"] += 1
+ return self.return_values["filename"]
+
+ def lineno(self):
+ self.invocation_counts["lineno"] += 1
+ return self.return_values["lineno"]
+
+ def filelineno(self):
+ self.invocation_counts["filelineno"] += 1
+ return self.return_values["filelineno"]
+
+ def fileno(self):
+ self.invocation_counts["fileno"] += 1
+ return self.return_values["fileno"]
+
+ def isfirstline(self):
+ self.invocation_counts["isfirstline"] += 1
+ return self.return_values["isfirstline"]
+
+ def isstdin(self):
+ self.invocation_counts["isstdin"] += 1
+ return self.return_values["isstdin"]
+
+class BaseFileInputGlobalMethodsTest(unittest.TestCase):
+ """Base class for unit tests for the global function of
+ the fileinput module."""
+
+ def setUp(self):
+ self._orig_state = fileinput._state
+ self._orig_FileInput = fileinput.FileInput
+ fileinput.FileInput = MockFileInput
+
+ def tearDown(self):
+ fileinput.FileInput = self._orig_FileInput
+ fileinput._state = self._orig_state
+
+ def assertExactlyOneInvocation(self, mock_file_input, method_name):
+ # assert that the method with the given name was invoked once
+ actual_count = mock_file_input.invocation_counts[method_name]
+ self.assertEqual(actual_count, 1, method_name)
+ # assert that no other unexpected methods were invoked
+ actual_total_count = len(mock_file_input.invocation_counts)
+ self.assertEqual(actual_total_count, 1)
+
+class Test_fileinput_input(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.input()"""
+
+ def test_state_is_not_None_and_state_file_is_not_None(self):
+ """Tests invoking fileinput.input() when fileinput._state is not None
+ and its _file attribute is also not None. Expect RuntimeError to
+ be raised with a meaningful error message and for fileinput._state
+ to *not* be modified."""
+ instance = MockFileInput()
+ instance._file = object()
+ fileinput._state = instance
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.input()
+ self.assertEqual(("input() already active",), cm.exception.args)
+ self.assertIs(instance, fileinput._state, "fileinput._state")
+
+ def test_state_is_not_None_and_state_file_is_None(self):
+ """Tests invoking fileinput.input() when fileinput._state is not None
+ but its _file attribute *is* None. Expect it to create and return
+ a new fileinput.FileInput object with all method parameters passed
+ explicitly to the __init__() method; also ensure that
+ fileinput._state is set to the returned instance."""
+ instance = MockFileInput()
+ instance._file = None
+ fileinput._state = instance
+ self.do_test_call_input()
+
+ def test_state_is_None(self):
+ """Tests invoking fileinput.input() when fileinput._state is None
+ Expect it to create and return a new fileinput.FileInput object
+ with all method parameters passed explicitly to the __init__()
+ method; also ensure that fileinput._state is set to the returned
+ instance."""
+ fileinput._state = None
+ self.do_test_call_input()
+
+ def do_test_call_input(self):
+ """Tests that fileinput.input() creates a new fileinput.FileInput
+ object, passing the given parameters unmodified to
+ fileinput.FileInput.__init__(). Note that this test depends on the
+ monkey patching of fileinput.FileInput done by setUp()."""
+ files = object()
+ inplace = object()
+ backup = object()
+ bufsize = object()
+ mode = object()
+ openhook = object()
+
+ # call fileinput.input() with different values for each argument
+ result = fileinput.input(files=files, inplace=inplace, backup=backup,
+ bufsize=bufsize,
+ mode=mode, openhook=openhook)
+
+ # ensure fileinput._state was set to the returned object
+ self.assertIs(result, fileinput._state, "fileinput._state")
+
+ # ensure the parameters to fileinput.input() were passed directly
+ # to FileInput.__init__()
+ self.assertIs(files, result.files, "files")
+ self.assertIs(inplace, result.inplace, "inplace")
+ self.assertIs(backup, result.backup, "backup")
+ self.assertIs(bufsize, result.bufsize, "bufsize")
+ self.assertIs(mode, result.mode, "mode")
+ self.assertIs(openhook, result.openhook, "openhook")
+
+class Test_fileinput_close(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.close()"""
+
+ def test_state_is_None(self):
+ """Tests that fileinput.close() does nothing if fileinput._state
+ is None"""
+ fileinput._state = None
+ fileinput.close()
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests that fileinput.close() invokes close() on fileinput._state
+ and sets _state=None"""
+ instance = MockFileInput()
+ fileinput._state = instance
+ fileinput.close()
+ self.assertExactlyOneInvocation(instance, "close")
+ self.assertIsNone(fileinput._state)
+
+class Test_fileinput_nextfile(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.nextfile()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.nextfile() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.nextfile()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.nextfile() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.nextfile() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ nextfile_retval = object()
+ instance = MockFileInput()
+ instance.return_values["nextfile"] = nextfile_retval
+ fileinput._state = instance
+ retval = fileinput.nextfile()
+ self.assertExactlyOneInvocation(instance, "nextfile")
+ self.assertIs(retval, nextfile_retval)
+ self.assertIs(fileinput._state, instance)
+
+class Test_fileinput_filename(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.filename()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.filename() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.filename()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.filename() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.filename() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ filename_retval = object()
+ instance = MockFileInput()
+ instance.return_values["filename"] = filename_retval
+ fileinput._state = instance
+ retval = fileinput.filename()
+ self.assertExactlyOneInvocation(instance, "filename")
+ self.assertIs(retval, filename_retval)
+ self.assertIs(fileinput._state, instance)
+
+class Test_fileinput_lineno(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.lineno()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.lineno() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.lineno()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.lineno() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.lineno() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ lineno_retval = object()
+ instance = MockFileInput()
+ instance.return_values["lineno"] = lineno_retval
+ fileinput._state = instance
+ retval = fileinput.lineno()
+ self.assertExactlyOneInvocation(instance, "lineno")
+ self.assertIs(retval, lineno_retval)
+ self.assertIs(fileinput._state, instance)
+
+class Test_fileinput_filelineno(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.filelineno()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.filelineno() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.filelineno()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.filelineno() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.filelineno() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ filelineno_retval = object()
+ instance = MockFileInput()
+ instance.return_values["filelineno"] = filelineno_retval
+ fileinput._state = instance
+ retval = fileinput.filelineno()
+ self.assertExactlyOneInvocation(instance, "filelineno")
+ self.assertIs(retval, filelineno_retval)
+ self.assertIs(fileinput._state, instance)
+
+class Test_fileinput_fileno(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.fileno()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.fileno() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.fileno()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.fileno() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.fileno() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ fileno_retval = object()
+ instance = MockFileInput()
+ instance.return_values["fileno"] = fileno_retval
+ instance.fileno_retval = fileno_retval
+ fileinput._state = instance
+ retval = fileinput.fileno()
+ self.assertExactlyOneInvocation(instance, "fileno")
+ self.assertIs(retval, fileno_retval)
+ self.assertIs(fileinput._state, instance)
+
+class Test_fileinput_isfirstline(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.isfirstline()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.isfirstline() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.isfirstline()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.isfirstline() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.isfirstline() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ isfirstline_retval = object()
+ instance = MockFileInput()
+ instance.return_values["isfirstline"] = isfirstline_retval
+ fileinput._state = instance
+ retval = fileinput.isfirstline()
+ self.assertExactlyOneInvocation(instance, "isfirstline")
+ self.assertIs(retval, isfirstline_retval)
+ self.assertIs(fileinput._state, instance)
+
+class Test_fileinput_isstdin(BaseFileInputGlobalMethodsTest):
+ """Unit tests for fileinput.isstdin()"""
+
+ def test_state_is_None(self):
+ """Tests fileinput.isstdin() when fileinput._state is None.
+ Ensure that it raises RuntimeError with a meaningful error message
+ and does not modify fileinput._state"""
+ fileinput._state = None
+ with self.assertRaises(RuntimeError) as cm:
+ fileinput.isstdin()
+ self.assertEqual(("no active input()",), cm.exception.args)
+ self.assertIsNone(fileinput._state)
+
+ def test_state_is_not_None(self):
+ """Tests fileinput.isstdin() when fileinput._state is not None.
+ Ensure that it invokes fileinput._state.isstdin() exactly once,
+ returns whatever it returns, and does not modify fileinput._state
+ to point to a different object."""
+ isstdin_retval = object()
+ instance = MockFileInput()
+ instance.return_values["isstdin"] = isstdin_retval
+ fileinput._state = instance
+ retval = fileinput.isstdin()
+ self.assertExactlyOneInvocation(instance, "isstdin")
+ self.assertIs(retval, isstdin_retval)
+ self.assertIs(fileinput._state, instance)
+
+class InvocationRecorder:
+ def __init__(self):
+ self.invocation_count = 0
+ def __call__(self, *args, **kwargs):
+ self.invocation_count += 1
+ self.last_invocation = (args, kwargs)
+
+class Test_hook_compressed(unittest.TestCase):
+ """Unit tests for fileinput.hook_compressed()"""
+
+ def setUp(self):
+ self.fake_open = InvocationRecorder()
+
+ def test_empty_string(self):
+ self.do_test_use_builtin_open("", 1)
+
+ def test_no_ext(self):
+ self.do_test_use_builtin_open("abcd", 2)
+
+ @unittest.skipUnless(gzip, "Requires gzip and zlib")
+ def test_gz_ext_fake(self):
+ original_open = gzip.open
+ gzip.open = self.fake_open
+ try:
+ result = fileinput.hook_compressed("test.gz", 3)
+ finally:
+ gzip.open = original_open
+
+ self.assertEqual(self.fake_open.invocation_count, 1)
+ self.assertEqual(self.fake_open.last_invocation, (("test.gz", 3), {}))
+
+ @unittest.skipUnless(bz2, "Requires bz2")
+ def test_bz2_ext_fake(self):
+ original_open = bz2.BZ2File
+ bz2.BZ2File = self.fake_open
+ try:
+ result = fileinput.hook_compressed("test.bz2", 4)
+ finally:
+ bz2.BZ2File = original_open
+
+ self.assertEqual(self.fake_open.invocation_count, 1)
+ self.assertEqual(self.fake_open.last_invocation, (("test.bz2", 4), {}))
+
+ def test_blah_ext(self):
+ self.do_test_use_builtin_open("abcd.blah", 5)
+
+ def test_gz_ext_builtin(self):
+ self.do_test_use_builtin_open("abcd.Gz", 6)
+
+ def test_bz2_ext_builtin(self):
+ self.do_test_use_builtin_open("abcd.Bz2", 7)
+
+ def do_test_use_builtin_open(self, filename, mode):
+ original_open = self.replace_builtin_open(self.fake_open)
+ try:
+ result = fileinput.hook_compressed(filename, mode)
+ finally:
+ self.replace_builtin_open(original_open)
+
+ self.assertEqual(self.fake_open.invocation_count, 1)
+ self.assertEqual(self.fake_open.last_invocation,
+ ((filename, mode), {}))
+
+ @staticmethod
+ def replace_builtin_open(new_open_func):
+ original_open = builtins.open
+ builtins.open = new_open_func
+ return original_open
+
+class Test_hook_encoded(unittest.TestCase):
+ """Unit tests for fileinput.hook_encoded()"""
+
+ def test(self):
+ encoding = object()
+ result = fileinput.hook_encoded(encoding)
+
+ fake_open = InvocationRecorder()
+ original_open = builtins.open
+ builtins.open = fake_open
+ try:
+ filename = object()
+ mode = object()
+ open_result = result(filename, mode)
+ finally:
+ builtins.open = original_open
+
+ self.assertEqual(fake_open.invocation_count, 1)
+
+ args, kwargs = fake_open.last_invocation
+ self.assertIs(args[0], filename)
+ self.assertIs(args[1], mode)
+ self.assertIs(kwargs.pop('encoding'), encoding)
+ self.assertFalse(kwargs)
def test_main():
- run_unittest(BufferSizesTests, FileInputTests)
+ run_unittest(
+ BufferSizesTests,
+ FileInputTests,
+ Test_fileinput_input,
+ Test_fileinput_close,
+ Test_fileinput_nextfile,
+ Test_fileinput_filename,
+ Test_fileinput_lineno,
+ Test_fileinput_filelineno,
+ Test_fileinput_fileno,
+ Test_fileinput_isfirstline,
+ Test_fileinput_isstdin,
+ Test_hook_compressed,
+ Test_hook_encoded,
+ )
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py
index 8267d9e..19737d9 100644
--- a/Lib/test/test_fileio.py
+++ b/Lib/test/test_fileio.py
@@ -2,6 +2,7 @@
import sys
import os
+import io
import errno
import unittest
from array import array
@@ -373,10 +374,10 @@ class OtherFileTests(unittest.TestCase):
self.assertEqual(f.tell(), 10)
f.truncate(5)
self.assertEqual(f.tell(), 10)
- self.assertEqual(f.seek(0, os.SEEK_END), 5)
+ self.assertEqual(f.seek(0, io.SEEK_END), 5)
f.truncate(15)
self.assertEqual(f.tell(), 5)
- self.assertEqual(f.seek(0, os.SEEK_END), 15)
+ self.assertEqual(f.seek(0, io.SEEK_END), 15)
f.close()
def testTruncateOnWindows(self):
diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py
index 4d7bbba..502292f 100644
--- a/Lib/test/test_float.py
+++ b/Lib/test/test_float.py
@@ -88,7 +88,7 @@ class GeneralFloatCases(unittest.TestCase):
self.assertRaises(ValueError, float, " -0x3.p-1 ")
self.assertRaises(ValueError, float, " +0x3.p-1 ")
self.assertEqual(float(" 25.e-1 "), 2.5)
- self.assertEqual(support.fcmp(float(" .25e-1 "), .025), 0)
+ self.assertAlmostEqual(float(" .25e-1 "), .025)
def test_floatconversion(self):
# Make sure that calls to __float__() work properly
@@ -860,15 +860,18 @@ class InfNanTest(unittest.TestCase):
self.assertEqual(str(1e300 * 1e300 * 0), "nan")
self.assertEqual(str(-1e300 * 1e300 * 0), "nan")
- def notest_float_nan(self):
- self.assertTrue(NAN.is_nan())
- self.assertFalse(INF.is_nan())
- self.assertFalse((0.).is_nan())
+ def test_inf_signs(self):
+ self.assertEqual(copysign(1.0, float('inf')), 1.0)
+ self.assertEqual(copysign(1.0, float('-inf')), -1.0)
+
+ @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short',
+ "applies only when using short float repr style")
+ def test_nan_signs(self):
+ # When using the dtoa.c code, the sign of float('nan') should
+ # be predictable.
+ self.assertEqual(copysign(1.0, float('nan')), 1.0)
+ self.assertEqual(copysign(1.0, float('-nan')), -1.0)
- def notest_float_inf(self):
- self.assertTrue(INF.is_inf())
- self.assertFalse(NAN.is_inf())
- self.assertFalse((0.).is_inf())
fromHex = float.fromhex
toHex = float.hex
diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py
index 7fa950d..b6e2540 100644
--- a/Lib/test/test_format.py
+++ b/Lib/test/test_format.py
@@ -1,4 +1,5 @@
from test.support import verbose, TestFailed
+import locale
import sys
import test.support as support
import unittest
@@ -263,6 +264,51 @@ class FormatTest(unittest.TestCase):
else:
raise TestFailed('"%*d"%(maxsize, -127) should fail')
+ def test_non_ascii(self):
+ testformat("\u20ac=%f", (1.0,), "\u20ac=1.000000")
+
+ self.assertEqual(format("abc", "\u2007<5"), "abc\u2007\u2007")
+ self.assertEqual(format(123, "\u2007<5"), "123\u2007\u2007")
+ self.assertEqual(format(12.3, "\u2007<6"), "12.3\u2007\u2007")
+ self.assertEqual(format(0j, "\u2007<4"), "0j\u2007\u2007")
+ self.assertEqual(format(1+2j, "\u2007<8"), "(1+2j)\u2007\u2007")
+
+ self.assertEqual(format("abc", "\u2007>5"), "\u2007\u2007abc")
+ self.assertEqual(format(123, "\u2007>5"), "\u2007\u2007123")
+ self.assertEqual(format(12.3, "\u2007>6"), "\u2007\u200712.3")
+ self.assertEqual(format(1+2j, "\u2007>8"), "\u2007\u2007(1+2j)")
+ self.assertEqual(format(0j, "\u2007>4"), "\u2007\u20070j")
+
+ self.assertEqual(format("abc", "\u2007^5"), "\u2007abc\u2007")
+ self.assertEqual(format(123, "\u2007^5"), "\u2007123\u2007")
+ self.assertEqual(format(12.3, "\u2007^6"), "\u200712.3\u2007")
+ self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007")
+ self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007")
+
+ def test_locale(self):
+ try:
+ oldloc = locale.setlocale(locale.LC_ALL)
+ locale.setlocale(locale.LC_ALL, '')
+ except locale.Error as err:
+ self.skipTest("Cannot set locale: {}".format(err))
+ try:
+ localeconv = locale.localeconv()
+ sep = localeconv['thousands_sep']
+ point = localeconv['decimal_point']
+
+ text = format(123456789, "n")
+ self.assertIn(sep, text)
+ self.assertEqual(text.replace(sep, ''), '123456789')
+
+ text = format(1234.5, "n")
+ self.assertIn(sep, text)
+ self.assertIn(point, text)
+ self.assertEqual(text.replace(sep, ''), '1234' + point + '5')
+ finally:
+ locale.setlocale(locale.LC_ALL, oldloc)
+
+
+
def test_main():
support.run_unittest(FormatTest)
diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py
index 084ae0c..1fad921 100644
--- a/Lib/test/test_fractions.py
+++ b/Lib/test/test_fractions.py
@@ -401,14 +401,10 @@ class FractionTest(unittest.TestCase):
def testMixingWithDecimal(self):
# Decimal refuses mixed arithmetic (but not mixed comparisons)
- self.assertRaisesMessage(
- TypeError,
- "unsupported operand type(s) for +: 'Fraction' and 'Decimal'",
- operator.add, F(3,11), Decimal('3.1415926'))
- self.assertRaisesMessage(
- TypeError,
- "unsupported operand type(s) for +: 'Decimal' and 'Fraction'",
- operator.add, Decimal('3.1415926'), F(3,11))
+ self.assertRaises(TypeError, operator.add,
+ F(3,11), Decimal('3.1415926'))
+ self.assertRaises(TypeError, operator.add,
+ Decimal('3.1415926'), F(3,11))
def testComparisons(self):
self.assertTrue(F(1, 2) < F(2, 3))
diff --git a/Lib/test/test_frozen.py b/Lib/test/test_frozen.py
index 5243ebb..fd6761c 100644
--- a/Lib/test/test_frozen.py
+++ b/Lib/test/test_frozen.py
@@ -5,6 +5,12 @@ import unittest
import sys
class FrozenTests(unittest.TestCase):
+
+ module_attrs = frozenset(['__builtins__', '__cached__', '__doc__',
+ '__loader__', '__name__',
+ '__package__'])
+ package_attrs = frozenset(list(module_attrs) + ['__path__'])
+
def test_frozen(self):
with captured_stdout() as stdout:
try:
@@ -12,7 +18,9 @@ class FrozenTests(unittest.TestCase):
except ImportError as x:
self.fail("import __hello__ failed:" + str(x))
self.assertEqual(__hello__.initialized, True)
- self.assertEqual(len(dir(__hello__)), 7, dir(__hello__))
+ expect = set(self.module_attrs)
+ expect.add('initialized')
+ self.assertEqual(set(dir(__hello__)), expect)
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
with captured_stdout() as stdout:
@@ -21,10 +29,13 @@ class FrozenTests(unittest.TestCase):
except ImportError as x:
self.fail("import __phello__ failed:" + str(x))
self.assertEqual(__phello__.initialized, True)
+ expect = set(self.package_attrs)
+ expect.add('initialized')
if not "__phello__.spam" in sys.modules:
- self.assertEqual(len(dir(__phello__)), 8, dir(__phello__))
+ self.assertEqual(set(dir(__phello__)), expect)
else:
- self.assertEqual(len(dir(__phello__)), 9, dir(__phello__))
+ expect.add('spam')
+ self.assertEqual(set(dir(__phello__)), expect)
self.assertEqual(__phello__.__path__, [__phello__.__name__])
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
@@ -34,8 +45,13 @@ class FrozenTests(unittest.TestCase):
except ImportError as x:
self.fail("import __phello__.spam failed:" + str(x))
self.assertEqual(__phello__.spam.initialized, True)
- self.assertEqual(len(dir(__phello__.spam)), 7)
- self.assertEqual(len(dir(__phello__)), 9)
+ spam_expect = set(self.module_attrs)
+ spam_expect.add('initialized')
+ self.assertEqual(set(dir(__phello__.spam)), spam_expect)
+ phello_expect = set(self.package_attrs)
+ phello_expect.add('initialized')
+ phello_expect.add('spam')
+ self.assertEqual(set(dir(__phello__)), phello_expect)
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
try:
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
index 71bc23e..824b7c1 100644
--- a/Lib/test/test_ftplib.py
+++ b/Lib/test/test_ftplib.py
@@ -22,10 +22,25 @@ from test.support import HOST
threading = support.import_module('threading')
# the dummy data returned by server over the data channel when
-# RETR, LIST and NLST commands are issued
+# RETR, LIST, NLST, MLSD commands are issued
RETR_DATA = 'abcde12345\r\n' * 1000
LIST_DATA = 'foo\r\nbar\r\n'
NLST_DATA = 'foo\r\nbar\r\n'
+MLSD_DATA = ("type=cdir;perm=el;unique==keVO1+ZF4; test\r\n"
+ "type=pdir;perm=e;unique==keVO1+d?3; ..\r\n"
+ "type=OS.unix=slink:/foobar;perm=;unique==keVO1+4G4; foobar\r\n"
+ "type=OS.unix=chr-13/29;perm=;unique==keVO1+5G4; device\r\n"
+ "type=OS.unix=blk-11/108;perm=;unique==keVO1+6G4; block\r\n"
+ "type=file;perm=awr;unique==keVO1+8G4; writable\r\n"
+ "type=dir;perm=cpmel;unique==keVO1+7G4; promiscuous\r\n"
+ "type=dir;perm=;unique==keVO1+1t2; no-exec\r\n"
+ "type=file;perm=r;unique==keVO1+EG4; two words\r\n"
+ "type=file;perm=r;unique==keVO1+IH4; leading space\r\n"
+ "type=file;perm=r;unique==keVO1+1G4; file1\r\n"
+ "type=dir;perm=cpmel;unique==keVO1+7G4; incoming\r\n"
+ "type=file;perm=r;unique==keVO1+1G4; file2\r\n"
+ "type=file;perm=r;unique==keVO1+1G4; file3\r\n"
+ "type=file;perm=r;unique==keVO1+1G4; file4\r\n")
class DummyDTPHandler(asynchat.async_chat):
@@ -49,6 +64,11 @@ class DummyDTPHandler(asynchat.async_chat):
self.dtp_conn_closed = True
def push(self, what):
+ if self.baseclass.next_data is not None:
+ what = self.baseclass.next_data
+ self.baseclass.next_data = None
+ if not what:
+ return self.close_when_done()
super(DummyDTPHandler, self).push(what.encode('ascii'))
def handle_error(self):
@@ -69,6 +89,7 @@ class DummyFTPHandler(asynchat.async_chat):
self.last_received_cmd = None
self.last_received_data = ''
self.next_response = ''
+ self.next_data = None
self.rest = None
self.push('220 welcome')
@@ -104,7 +125,7 @@ class DummyFTPHandler(asynchat.async_chat):
addr = list(map(int, arg.split(',')))
ip = '%d.%d.%d.%d' %tuple(addr[:4])
port = (addr[4] * 256) + addr[5]
- s = socket.create_connection((ip, port), timeout=10)
+ s = socket.create_connection((ip, port), timeout=2)
self.dtp = self.dtp_handler(s, baseclass=self)
self.push('200 active data connection established')
@@ -122,7 +143,7 @@ class DummyFTPHandler(asynchat.async_chat):
def cmd_eprt(self, arg):
af, ip, port = arg.split(arg[0])[1:-1]
port = int(port)
- s = socket.create_connection((ip, port), timeout=10)
+ s = socket.create_connection((ip, port), timeout=2)
self.dtp = self.dtp_handler(s, baseclass=self)
self.push('200 active data connection established')
@@ -213,6 +234,14 @@ class DummyFTPHandler(asynchat.async_chat):
self.dtp.push(NLST_DATA)
self.dtp.close_when_done()
+ def cmd_opts(self, arg):
+ self.push('200 opts ok')
+
+ def cmd_mlsd(self, arg):
+ self.push('125 mlsd ok')
+ self.dtp.push(MLSD_DATA)
+ self.dtp.close_when_done()
+
class DummyFTPServer(asyncore.dispatcher, threading.Thread):
@@ -274,11 +303,11 @@ if ssl is not None:
_ssl_closing = False
def secure_connection(self):
- self.del_channel()
socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
certfile=CERTFILE, server_side=True,
do_handshake_on_connect=False,
ssl_version=ssl.PROTOCOL_SSLv23)
+ self.del_channel()
self.set_socket(socket)
self._ssl_accepting = True
@@ -313,7 +342,10 @@ if ssl is not None:
# http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
pass
self._ssl_closing = False
- super(SSLConnection, self).close()
+ if getattr(self, '_ccc', False) is False:
+ super(SSLConnection, self).close()
+ else:
+ pass
def handle_read_event(self):
if self._ssl_accepting:
@@ -381,12 +413,18 @@ if ssl is not None:
def __init__(self, conn):
DummyFTPHandler.__init__(self, conn)
self.secure_data_channel = False
+ self._ccc = False
def cmd_auth(self, line):
"""Set up secure control channel."""
self.push('234 AUTH TLS successful')
self.secure_connection()
+ def cmd_ccc(self, line):
+ self.push('220 Reverting back to clear-text')
+ self._ccc = True
+ self._do_ssl_shutdown()
+
def cmd_pbsz(self, line):
"""Negotiate size of buffer for secure data transfer.
For TLS/SSL the only valid value for the parameter is '0'.
@@ -416,13 +454,17 @@ class TestFTPClass(TestCase):
def setUp(self):
self.server = DummyFTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP(timeout=10)
+ self.client = ftplib.FTP(timeout=2)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
self.client.close()
self.server.stop()
+ def check_data(self, received, expected):
+ self.assertEqual(len(received), len(expected))
+ self.assertEqual(received, expected)
+
def test_getwelcome(self):
self.assertEqual(self.client.getwelcome(), '220 welcome')
@@ -504,7 +546,7 @@ class TestFTPClass(TestCase):
received.append(data.decode('ascii'))
received = []
self.client.retrbinary('retr', callback)
- self.assertEqual(''.join(received), RETR_DATA)
+ self.check_data(''.join(received), RETR_DATA)
def test_retrbinary_rest(self):
def callback(data):
@@ -512,20 +554,17 @@ class TestFTPClass(TestCase):
for rest in (0, 10, 20):
received = []
self.client.retrbinary('retr', callback, rest=rest)
- self.assertEqual(''.join(received), RETR_DATA[rest:],
- msg='rest test case %d %d %d' % (rest,
- len(''.join(received)),
- len(RETR_DATA[rest:])))
+ self.check_data(''.join(received), RETR_DATA[rest:])
def test_retrlines(self):
received = []
self.client.retrlines('retr', received.append)
- self.assertEqual(''.join(received), RETR_DATA.replace('\r\n', ''))
+ self.check_data(''.join(received), RETR_DATA.replace('\r\n', ''))
def test_storbinary(self):
f = io.BytesIO(RETR_DATA.encode('ascii'))
self.client.storbinary('stor', f)
- self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
+ self.check_data(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
@@ -542,7 +581,7 @@ class TestFTPClass(TestCase):
def test_storlines(self):
f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii'))
self.client.storlines('stor', f)
- self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
+ self.check_data(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
@@ -558,6 +597,64 @@ class TestFTPClass(TestCase):
self.client.dir(lambda x: l.append(x))
self.assertEqual(''.join(l), LIST_DATA.replace('\r\n', ''))
+ def test_mlsd(self):
+ list(self.client.mlsd())
+ list(self.client.mlsd(path='/'))
+ list(self.client.mlsd(path='/', facts=['size', 'type']))
+
+ ls = list(self.client.mlsd())
+ for name, facts in ls:
+ self.assertIsInstance(name, str)
+ self.assertIsInstance(facts, dict)
+ self.assertTrue(name)
+ self.assertIn('type', facts)
+ self.assertIn('perm', facts)
+ self.assertIn('unique', facts)
+
+ def set_data(data):
+ self.server.handler_instance.next_data = data
+
+ def test_entry(line, type=None, perm=None, unique=None, name=None):
+ type = 'type' if type is None else type
+ perm = 'perm' if perm is None else perm
+ unique = 'unique' if unique is None else unique
+ name = 'name' if name is None else name
+ set_data(line)
+ _name, facts = next(self.client.mlsd())
+ self.assertEqual(_name, name)
+ self.assertEqual(facts['type'], type)
+ self.assertEqual(facts['perm'], perm)
+ self.assertEqual(facts['unique'], unique)
+
+ # plain
+ test_entry('type=type;perm=perm;unique=unique; name\r\n')
+ # "=" in fact value
+ test_entry('type=ty=pe;perm=perm;unique=unique; name\r\n', type="ty=pe")
+ test_entry('type==type;perm=perm;unique=unique; name\r\n', type="=type")
+ test_entry('type=t=y=pe;perm=perm;unique=unique; name\r\n', type="t=y=pe")
+ test_entry('type=====;perm=perm;unique=unique; name\r\n', type="====")
+ # spaces in name
+ test_entry('type=type;perm=perm;unique=unique; na me\r\n', name="na me")
+ test_entry('type=type;perm=perm;unique=unique; name \r\n', name="name ")
+ test_entry('type=type;perm=perm;unique=unique; name\r\n', name=" name")
+ test_entry('type=type;perm=perm;unique=unique; n am e\r\n', name="n am e")
+ # ";" in name
+ test_entry('type=type;perm=perm;unique=unique; na;me\r\n', name="na;me")
+ test_entry('type=type;perm=perm;unique=unique; ;name\r\n', name=";name")
+ test_entry('type=type;perm=perm;unique=unique; ;name;\r\n', name=";name;")
+ test_entry('type=type;perm=perm;unique=unique; ;;;;\r\n', name=";;;;")
+ # case sensitiveness
+ set_data('Type=type;TyPe=perm;UNIQUE=unique; name\r\n')
+ _name, facts = next(self.client.mlsd())
+ for x in facts:
+ self.assertTrue(x.islower())
+ # no data (directory empty)
+ set_data('')
+ self.assertRaises(StopIteration, next, self.client.mlsd())
+ set_data('')
+ for x in self.client.mlsd():
+ self.fail("unexpected data %s" % data)
+
def test_makeport(self):
with self.client.makeport():
# IPv4 is in use, just make sure send_eprt has not been used
@@ -584,7 +681,7 @@ class TestFTPClass(TestCase):
return True
# base test
- with ftplib.FTP(timeout=10) as self.client:
+ with ftplib.FTP(timeout=2) as self.client:
self.client.connect(self.server.host, self.server.port)
self.client.sendcmd('noop')
self.assertTrue(is_client_connected())
@@ -592,7 +689,7 @@ class TestFTPClass(TestCase):
self.assertFalse(is_client_connected())
# QUIT sent inside the with block
- with ftplib.FTP(timeout=10) as self.client:
+ with ftplib.FTP(timeout=2) as self.client:
self.client.connect(self.server.host, self.server.port)
self.client.sendcmd('noop')
self.client.quit()
@@ -602,7 +699,7 @@ class TestFTPClass(TestCase):
# force a wrong response code to be sent on QUIT: error_perm
# is expected and the connection is supposed to be closed
try:
- with ftplib.FTP(timeout=10) as self.client:
+ with ftplib.FTP(timeout=2) as self.client:
self.client.connect(self.server.host, self.server.port)
self.client.sendcmd('noop')
self.server.handler_instance.next_response = '550 error on quit'
@@ -616,6 +713,30 @@ class TestFTPClass(TestCase):
self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit')
self.assertFalse(is_client_connected())
+ def test_source_address(self):
+ self.client.quit()
+ port = support.find_unused_port()
+ try:
+ self.client.connect(self.server.host, self.server.port,
+ source_address=(HOST, port))
+ self.assertEqual(self.client.sock.getsockname()[1], port)
+ self.client.quit()
+ except IOError as e:
+ if e.errno == errno.EADDRINUSE:
+ self.skipTest("couldn't bind to port %d" % port)
+ raise
+
+ def test_source_address_passive_connection(self):
+ port = support.find_unused_port()
+ self.client.source_address = (HOST, port)
+ try:
+ with self.client.transfercmd('list') as sock:
+ self.assertEqual(sock.getsockname()[1], port)
+ except IOError as e:
+ if e.errno == errno.EADDRINUSE:
+ self.skipTest("couldn't bind to port %d" % port)
+ raise
+
def test_parse257(self):
self.assertEqual(ftplib.parse257('257 "/foo/bar"'), '/foo/bar')
self.assertEqual(ftplib.parse257('257 "/foo/bar" created'), '/foo/bar')
@@ -632,7 +753,7 @@ class TestFTPClass(TestCase):
class TestIPv6Environment(TestCase):
def setUp(self):
- self.server = DummyFTPServer((HOST, 0), af=socket.AF_INET6)
+ self.server = DummyFTPServer(('::1', 0), af=socket.AF_INET6)
self.server.start()
self.client = ftplib.FTP()
self.client.connect(self.server.host, self.server.port)
@@ -661,6 +782,7 @@ class TestIPv6Environment(TestCase):
received.append(data.decode('ascii'))
received = []
self.client.retrbinary('retr', callback)
+ self.assertEqual(len(''.join(received)), len(RETR_DATA))
self.assertEqual(''.join(received), RETR_DATA)
self.client.set_pasv(True)
retr()
@@ -676,7 +798,7 @@ class TestTLS_FTPClassMixin(TestFTPClass):
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP_TLS(timeout=10)
+ self.client = ftplib.FTP_TLS(timeout=2)
self.client.connect(self.server.host, self.server.port)
# enable TLS
self.client.auth()
@@ -689,7 +811,7 @@ class TestTLS_FTPClass(TestCase):
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP_TLS(timeout=10)
+ self.client = ftplib.FTP_TLS(timeout=2)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
@@ -749,7 +871,7 @@ class TestTLS_FTPClass(TestCase):
self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
keyfile=CERTFILE, context=ctx)
- self.client = ftplib.FTP_TLS(context=ctx, timeout=10)
+ self.client = ftplib.FTP_TLS(context=ctx, timeout=2)
self.client.connect(self.server.host, self.server.port)
self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
self.client.auth()
@@ -761,45 +883,53 @@ class TestTLS_FTPClass(TestCase):
self.assertIs(sock.context, ctx)
self.assertIsInstance(sock, ssl.SSLSocket)
+ def test_ccc(self):
+ self.assertRaises(ValueError, self.client.ccc)
+ self.client.login(secure=True)
+ self.assertIsInstance(self.client.sock, ssl.SSLSocket)
+ self.client.ccc()
+ self.assertRaises(ValueError, self.client.sock.unwrap)
+
class TestTimeouts(TestCase):
def setUp(self):
self.evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.sock.settimeout(10)
+ self.sock.settimeout(20)
self.port = support.bind_port(self.sock)
- threading.Thread(target=self.server, args=(self.evt,self.sock)).start()
+ self.server_thread = threading.Thread(target=self.server)
+ self.server_thread.start()
# Wait for the server to be ready.
self.evt.wait()
self.evt.clear()
+ self.old_port = ftplib.FTP.port
ftplib.FTP.port = self.port
def tearDown(self):
- self.evt.wait()
- self.sock.close()
+ ftplib.FTP.port = self.old_port
+ self.server_thread.join()
- def server(self, evt, serv):
+ def server(self):
# This method sets the evt 3 times:
# 1) when the connection is ready to be accepted.
# 2) when it is safe for the caller to close the connection
# 3) when we have closed the socket
- serv.listen(5)
+ self.sock.listen(5)
# (1) Signal the caller that we are ready to accept the connection.
- evt.set()
+ self.evt.set()
try:
- conn, addr = serv.accept()
+ conn, addr = self.sock.accept()
except socket.timeout:
pass
else:
- conn.send(b"1 Hola mundo\n")
+ conn.sendall(b"1 Hola mundo\n")
+ conn.shutdown(socket.SHUT_WR)
# (2) Signal the caller that it is safe to close the socket.
- evt.set()
+ self.evt.set()
conn.close()
finally:
- serv.close()
- # (3) Signal the caller that we are done.
- evt.set()
+ self.sock.close()
def testTimeoutDefault(self):
# default -- use global socket timeout
@@ -857,13 +987,8 @@ class TestTimeouts(TestCase):
def test_main():
tests = [TestFTPClass, TestTimeouts]
- if socket.has_ipv6:
- try:
- DummyFTPServer((HOST, 0), af=socket.AF_INET6)
- except socket.error:
- pass
- else:
- tests.append(TestIPv6Environment)
+ if support.IPV6_ENABLED:
+ tests.append(TestIPv6Environment)
if ssl is not None:
tests.extend([TestTLS_FTPClassMixin, TestTLS_FTPClass])
diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py
index 4d19368..c8ed830 100644
--- a/Lib/test/test_funcattrs.py
+++ b/Lib/test/test_funcattrs.py
@@ -2,6 +2,15 @@ from test import support
import types
import unittest
+
+def global_function():
+ def inner_function():
+ class LocalClass:
+ pass
+ return LocalClass
+ return lambda: inner_function
+
+
class FuncAttrsTest(unittest.TestCase):
def setUp(self):
class F:
@@ -96,6 +105,24 @@ class FunctionPropertiesTest(FuncAttrsTest):
self.assertEqual(self.fi.a.__name__, 'a')
self.cannot_set_attr(self.fi.a, "__name__", 'a', AttributeError)
+ def test___qualname__(self):
+ # PEP 3155
+ self.assertEqual(self.b.__qualname__, 'FuncAttrsTest.setUp.<locals>.b')
+ self.assertEqual(FuncAttrsTest.setUp.__qualname__, 'FuncAttrsTest.setUp')
+ self.assertEqual(global_function.__qualname__, 'global_function')
+ self.assertEqual(global_function().__qualname__,
+ 'global_function.<locals>.<lambda>')
+ self.assertEqual(global_function()().__qualname__,
+ 'global_function.<locals>.inner_function')
+ self.assertEqual(global_function()()().__qualname__,
+ 'global_function.<locals>.inner_function.<locals>.LocalClass')
+ self.b.__qualname__ = 'c'
+ self.assertEqual(self.b.__qualname__, 'c')
+ self.b.__qualname__ = 'd'
+ self.assertEqual(self.b.__qualname__, 'd')
+ # __qualname__ must be a string
+ self.cannot_set_attr(self.b, '__qualname__', 7, TypeError)
+
def test___code__(self):
num_one, num_two = 7, 8
def a(): pass
@@ -315,11 +342,37 @@ class StaticMethodAttrsTest(unittest.TestCase):
self.assertTrue(s.__func__ is f)
+class BuiltinFunctionPropertiesTest(unittest.TestCase):
+ # XXX Not sure where this should really go since I can't find a
+ # test module specifically for builtin_function_or_method.
+
+ def test_builtin__qualname__(self):
+ import time
+
+ # builtin function:
+ self.assertEqual(len.__qualname__, 'len')
+ self.assertEqual(time.time.__qualname__, 'time')
+
+ # builtin classmethod:
+ self.assertEqual(dict.fromkeys.__qualname__, 'dict.fromkeys')
+ self.assertEqual(float.__getformat__.__qualname__,
+ 'float.__getformat__')
+
+ # builtin staticmethod:
+ self.assertEqual(str.maketrans.__qualname__, 'str.maketrans')
+ self.assertEqual(bytes.maketrans.__qualname__, 'bytes.maketrans')
+
+ # builtin bound instance method:
+ self.assertEqual([1, 2, 3].append.__qualname__, 'list.append')
+ self.assertEqual({'foo': 'bar'}.pop.__qualname__, 'dict.pop')
+
+
def test_main():
support.run_unittest(FunctionPropertiesTest, InstancemethodAttrTest,
ArbitraryFunctionAttrTest, FunctionDictsTest,
FunctionDocstringTest, CellTest,
- StaticMethodAttrsTest)
+ StaticMethodAttrsTest,
+ BuiltinFunctionPropertiesTest)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 11e6e84..b3803da 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -246,6 +246,7 @@ class TestUpdateWrapper(unittest.TestCase):
self.check_wrapper(wrapper, f)
self.assertIs(wrapper.__wrapped__, f)
self.assertEqual(wrapper.__name__, 'f')
+ self.assertEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.attr, 'This is also a test')
self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
self.assertNotIn('b', wrapper.__annotations__)
@@ -266,6 +267,7 @@ class TestUpdateWrapper(unittest.TestCase):
functools.update_wrapper(wrapper, f, (), ())
self.check_wrapper(wrapper, f, (), ())
self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.__doc__, None)
self.assertEqual(wrapper.__annotations__, {})
self.assertFalse(hasattr(wrapper, 'attr'))
@@ -283,6 +285,7 @@ class TestUpdateWrapper(unittest.TestCase):
functools.update_wrapper(wrapper, f, assign, update)
self.check_wrapper(wrapper, f, assign, update)
self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.__doc__, None)
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
@@ -330,17 +333,18 @@ class TestWraps(TestUpdateWrapper):
def wrapper():
pass
self.check_wrapper(wrapper, f)
- return wrapper
+ return wrapper, f
def test_default_update(self):
- wrapper = self._default_update()
+ wrapper, f = self._default_update()
self.assertEqual(wrapper.__name__, 'f')
+ self.assertEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.attr, 'This is also a test')
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
def test_default_update_doc(self):
- wrapper = self._default_update()
+ wrapper, _ = self._default_update()
self.assertEqual(wrapper.__doc__, 'This is a test')
def test_no_update(self):
@@ -353,6 +357,7 @@ class TestWraps(TestUpdateWrapper):
pass
self.check_wrapper(wrapper, f, (), ())
self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.__doc__, None)
self.assertFalse(hasattr(wrapper, 'attr'))
@@ -372,6 +377,7 @@ class TestWraps(TestUpdateWrapper):
pass
self.check_wrapper(wrapper, f, assign, update)
self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.__doc__, None)
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
@@ -457,19 +463,82 @@ class TestReduce(unittest.TestCase):
self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey(unittest.TestCase):
+
def test_cmp_to_key(self):
+ def cmp1(x, y):
+ return (x > y) - (x < y)
+ key = functools.cmp_to_key(cmp1)
+ self.assertEqual(key(3), key(3))
+ self.assertGreater(key(3), key(1))
+ def cmp2(x, y):
+ return int(x) - int(y)
+ key = functools.cmp_to_key(cmp2)
+ self.assertEqual(key(4.0), key('4'))
+ self.assertLess(key(2), key('35'))
+
+ def test_cmp_to_key_arguments(self):
+ def cmp1(x, y):
+ return (x > y) - (x < y)
+ key = functools.cmp_to_key(mycmp=cmp1)
+ self.assertEqual(key(obj=3), key(obj=3))
+ self.assertGreater(key(obj=3), key(obj=1))
+ with self.assertRaises((TypeError, AttributeError)):
+ key(3) > 1 # rhs is not a K object
+ with self.assertRaises((TypeError, AttributeError)):
+ 1 < key(3) # lhs is not a K object
+ with self.assertRaises(TypeError):
+ key = functools.cmp_to_key() # too few args
+ with self.assertRaises(TypeError):
+ key = functools.cmp_to_key(cmp1, None) # too many args
+ key = functools.cmp_to_key(cmp1)
+ with self.assertRaises(TypeError):
+ key() # too few args
+ with self.assertRaises(TypeError):
+ key(None, None) # too many args
+
+ def test_bad_cmp(self):
+ def cmp1(x, y):
+ raise ZeroDivisionError
+ key = functools.cmp_to_key(cmp1)
+ with self.assertRaises(ZeroDivisionError):
+ key(3) > key(1)
+
+ class BadCmp:
+ def __lt__(self, other):
+ raise ZeroDivisionError
+ def cmp1(x, y):
+ return BadCmp()
+ with self.assertRaises(ZeroDivisionError):
+ key(3) > key(1)
+
+ def test_obj_field(self):
+ def cmp1(x, y):
+ return (x > y) - (x < y)
+ key = functools.cmp_to_key(mycmp=cmp1)
+ self.assertEqual(key(50).obj, 50)
+
+ def test_sort_int(self):
def mycmp(x, y):
return y - x
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
+ def test_sort_int_str(self):
+ def mycmp(x, y):
+ x, y = int(x), int(y)
+ return (x > y) - (x < y)
+ values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
+ values = sorted(values, key=functools.cmp_to_key(mycmp))
+ self.assertEqual([int(value) for value in values],
+ [0, 1, 1, 2, 3, 4, 5, 7, 10])
+
def test_hash(self):
def mycmp(x, y):
return y - x
key = functools.cmp_to_key(mycmp)
k = key(10)
self.assertRaises(TypeError, hash, k)
- self.assertFalse(isinstance(k, collections.Hashable))
+ self.assertNotIsInstance(k, collections.Hashable)
class TestTotalOrdering(unittest.TestCase):
@@ -692,6 +761,22 @@ class TestLRU(unittest.TestCase):
with self.assertRaises(IndexError):
func(15)
+ def test_lru_with_types(self):
+ for maxsize in (None, 100):
+ @functools.lru_cache(maxsize=maxsize, typed=True)
+ def square(x):
+ return x * x
+ self.assertEqual(square(3), 9)
+ self.assertEqual(type(square(3)), type(9))
+ self.assertEqual(square(3.0), 9.0)
+ self.assertEqual(type(square(3.0)), type(9.0))
+ self.assertEqual(square(x=3), 9)
+ self.assertEqual(type(square(x=3)), type(9))
+ self.assertEqual(square(x=3.0), 9.0)
+ self.assertEqual(type(square(x=3.0)), type(9.0))
+ self.assertEqual(square.cache_info().hits, 4)
+ self.assertEqual(square.cache_info().misses, 4)
+
def test_main(verbose=None):
test_classes = (
TestPartial,
diff --git a/Lib/test/test_future.py b/Lib/test/test_future.py
index c6689a1..3a25eb1 100644
--- a/Lib/test/test_future.py
+++ b/Lib/test/test_future.py
@@ -13,14 +13,14 @@ def get_error_location(msg):
class FutureTest(unittest.TestCase):
def test_future1(self):
- support.unload('test_future1')
- from test import test_future1
- self.assertEqual(test_future1.result, 6)
+ support.unload('future_test1')
+ from test import future_test1
+ self.assertEqual(future_test1.result, 6)
def test_future2(self):
- support.unload('test_future2')
- from test import test_future2
- self.assertEqual(test_future2.result, 6)
+ support.unload('future_test2')
+ from test import future_test2
+ self.assertEqual(future_test2.result, 6)
def test_future3(self):
support.unload('test_future3')
diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py
index e1c124d..c59b72e 100644
--- a/Lib/test/test_gc.py
+++ b/Lib/test/test_gc.py
@@ -1,5 +1,6 @@
import unittest
-from test.support import verbose, run_unittest, strip_python_stderr
+from test.support import (verbose, refcount_test, run_unittest,
+ strip_python_stderr)
import sys
import time
import gc
@@ -37,6 +38,20 @@ class GC_Detector(object):
# gc collects it.
self.wr = weakref.ref(C1055820(666), it_happened)
+class Uncollectable(object):
+ """Create a reference cycle with multiple __del__ methods.
+
+ An object in a reference cycle will never have zero references,
+ and so must be garbage collected. If one or more objects in the
+ cycle have __del__ methods, the gc refuses to guess an order,
+ and leaves the cycle uncollected."""
+ def __init__(self, partner=None):
+ if partner is None:
+ self.partner = Uncollectable(partner=self)
+ else:
+ self.partner = partner
+ def __del__(self):
+ pass
### Tests
###############################################################################
@@ -181,6 +196,7 @@ class GCTests(unittest.TestCase):
del d
self.assertEqual(gc.collect(), 2)
+ @refcount_test
def test_frame(self):
def f():
frame = sys._getframe()
@@ -248,6 +264,7 @@ class GCTests(unittest.TestCase):
# For example, disposed tuples are not freed, but reused.
# To minimize variations, though, we first store the get_count() results
# and check them at the end.
+ @refcount_test
def test_get_count(self):
gc.collect()
a, b, c = gc.get_count()
@@ -261,6 +278,7 @@ class GCTests(unittest.TestCase):
# created (the list).
self.assertGreater(d, a)
+ @refcount_test
def test_collect_generations(self):
gc.collect()
# This object will "trickle" into generation N + 1 after
@@ -593,6 +611,127 @@ class GCTests(unittest.TestCase):
self.assertNotIn(b"uncollectable objects at shutdown", stderr)
+class GCCallbackTests(unittest.TestCase):
+ def setUp(self):
+ # Save gc state and disable it.
+ self.enabled = gc.isenabled()
+ gc.disable()
+ self.debug = gc.get_debug()
+ gc.set_debug(0)
+ gc.callbacks.append(self.cb1)
+ gc.callbacks.append(self.cb2)
+ self.othergarbage = []
+
+ def tearDown(self):
+ # Restore gc state
+ del self.visit
+ gc.callbacks.remove(self.cb1)
+ gc.callbacks.remove(self.cb2)
+ gc.set_debug(self.debug)
+ if self.enabled:
+ gc.enable()
+ # destroy any uncollectables
+ gc.collect()
+ for obj in gc.garbage:
+ if isinstance(obj, Uncollectable):
+ obj.partner = None
+ del gc.garbage[:]
+ del self.othergarbage
+ gc.collect()
+
+ def preclean(self):
+ # Remove all fluff from the system. Invoke this function
+ # manually rather than through self.setUp() for maximum
+ # safety.
+ self.visit = []
+ gc.collect()
+ garbage, gc.garbage[:] = gc.garbage[:], []
+ self.othergarbage.append(garbage)
+ self.visit = []
+
+ def cb1(self, phase, info):
+ self.visit.append((1, phase, dict(info)))
+
+ def cb2(self, phase, info):
+ self.visit.append((2, phase, dict(info)))
+ if phase == "stop" and hasattr(self, "cleanup"):
+ # Clean Uncollectable from garbage
+ uc = [e for e in gc.garbage if isinstance(e, Uncollectable)]
+ gc.garbage[:] = [e for e in gc.garbage
+ if not isinstance(e, Uncollectable)]
+ for e in uc:
+ e.partner = None
+
+ def test_collect(self):
+ self.preclean()
+ gc.collect()
+ # Algorithmically verify the contents of self.visit
+ # because it is long and tortuous.
+
+ # Count the number of visits to each callback
+ n = [v[0] for v in self.visit]
+ n1 = [i for i in n if i == 1]
+ n2 = [i for i in n if i == 2]
+ self.assertEqual(n1, [1]*2)
+ self.assertEqual(n2, [2]*2)
+
+ # Count that we got the right number of start and stop callbacks.
+ n = [v[1] for v in self.visit]
+ n1 = [i for i in n if i == "start"]
+ n2 = [i for i in n if i == "stop"]
+ self.assertEqual(n1, ["start"]*2)
+ self.assertEqual(n2, ["stop"]*2)
+
+ # Check that we got the right info dict for all callbacks
+ for v in self.visit:
+ info = v[2]
+ self.assertTrue("generation" in info)
+ self.assertTrue("collected" in info)
+ self.assertTrue("uncollectable" in info)
+
+ def test_collect_generation(self):
+ self.preclean()
+ gc.collect(2)
+ for v in self.visit:
+ info = v[2]
+ self.assertEqual(info["generation"], 2)
+
+ def test_collect_garbage(self):
+ self.preclean()
+ # Each of these cause four objects to be garbage: Two
+ # Uncolectables and their instance dicts.
+ Uncollectable()
+ Uncollectable()
+ C1055820(666)
+ gc.collect()
+ for v in self.visit:
+ if v[1] != "stop":
+ continue
+ info = v[2]
+ self.assertEqual(info["collected"], 2)
+ self.assertEqual(info["uncollectable"], 8)
+
+ # We should now have the Uncollectables in gc.garbage
+ self.assertEqual(len(gc.garbage), 4)
+ for e in gc.garbage:
+ self.assertIsInstance(e, Uncollectable)
+
+ # Now, let our callback handle the Uncollectable instances
+ self.cleanup=True
+ self.visit = []
+ gc.garbage[:] = []
+ gc.collect()
+ for v in self.visit:
+ if v[1] != "stop":
+ continue
+ info = v[2]
+ self.assertEqual(info["collected"], 0)
+ self.assertEqual(info["uncollectable"], 4)
+
+ # Uncollectables should be gone
+ self.assertEqual(len(gc.garbage), 0)
+
+
class GCTogglingTests(unittest.TestCase):
def setUp(self):
gc.enable()
@@ -746,7 +885,7 @@ def test_main():
try:
gc.collect() # Delete 2nd generation garbage
- run_unittest(GCTests, GCTogglingTests)
+ run_unittest(GCTests, GCTogglingTests, GCCallbackTests)
finally:
gc.set_debug(debug)
# test gc.enable() even if GC is disabled by default
diff --git a/Lib/test/test_gdb.py b/Lib/test/test_gdb.py
index 6d96550..9713dc9 100644
--- a/Lib/test/test_gdb.py
+++ b/Lib/test/test_gdb.py
@@ -7,9 +7,16 @@ import os
import re
import subprocess
import sys
+import sysconfig
import unittest
import locale
+# Is this Python configured to support threads?
+try:
+ import _thread
+except ImportError:
+ _thread = None
+
from test.support import run_unittest, findfile, python_is_optimized
try:
@@ -26,6 +33,9 @@ if gdb_major_version < 7:
raise unittest.SkipTest("gdb versions before 7.0 didn't support python embedding"
" Saw:\n" + gdb_version.decode('ascii', 'replace'))
+if not sysconfig.is_python_build():
+ raise unittest.SkipTest("test_gdb only works on source builds at the moment.")
+
# Location of custom hooks file in a repository checkout.
checkout_hook_path = os.path.join(os.path.dirname(sys.executable),
'python-gdb.py')
@@ -53,9 +63,7 @@ gdbpy_version, _ = run_gdb("--eval-command=python import sys; print sys.version_
if not gdbpy_version:
raise unittest.SkipTest("gdb not built with embedded python support")
-# Verify that "gdb" can load our custom hooks. In theory this should never
-# fail, but we don't handle the case of the hooks file not existing if the
-# tests are run from an installed Python (we'll produce failures in that case).
+# Verify that "gdb" can load our custom hooks. In theory this should never fail.
cmd = ['--args', sys.executable]
_, gdbpy_errors = run_gdb('--args', sys.executable)
if "auto-loading has been declined" in gdbpy_errors:
@@ -160,7 +168,6 @@ class DebuggerTests(unittest.TestCase):
# Ensure no unexpected error messages:
self.assertEqual(err, '')
-
return out
def get_gdb_repr(self, source,
@@ -181,7 +188,7 @@ class DebuggerTests(unittest.TestCase):
# gdb can insert additional '\n' and space characters in various places
# in its output, depending on the width of the terminal it's connected
# to (using its "wrap_here" function)
- m = re.match('.*#0\s+builtin_id\s+\(self\=.*,\s+v=\s*(.*?)\)\s+at\s+Python/bltinmodule.c.*',
+ m = re.match('.*#0\s+builtin_id\s+\(self\=.*,\s+v=\s*(.*?)\)\s+at\s+\S*Python/bltinmodule.c.*',
gdb_output, re.DOTALL)
if not m:
self.fail('Unexpected gdb output: %r\n%s' % (gdb_output, gdb_output))
@@ -343,7 +350,7 @@ class Foo:
foo = Foo()
foo.an_int = 42
id(foo)''')
- m = re.match(r'<Foo\(an_int=42\) at remote 0x[0-9a-f]+>', gdb_repr)
+ m = re.match(r'<Foo\(an_int=42\) at remote 0x-?[0-9a-f]+>', gdb_repr)
self.assertTrue(m,
msg='Unexpected new-style class rendering %r' % gdb_repr)
@@ -356,7 +363,7 @@ foo = Foo()
foo += [1, 2, 3]
foo.an_int = 42
id(foo)''')
- m = re.match(r'<Foo\(an_int=42\) at remote 0x[0-9a-f]+>', gdb_repr)
+ m = re.match(r'<Foo\(an_int=42\) at remote 0x-?[0-9a-f]+>', gdb_repr)
self.assertTrue(m,
msg='Unexpected new-style class rendering %r' % gdb_repr)
@@ -371,7 +378,7 @@ class Foo(tuple):
foo = Foo((1, 2, 3))
foo.an_int = 42
id(foo)''')
- m = re.match(r'<Foo\(an_int=42\) at remote 0x[0-9a-f]+>', gdb_repr)
+ m = re.match(r'<Foo\(an_int=42\) at remote 0x-?[0-9a-f]+>', gdb_repr)
self.assertTrue(m,
msg='Unexpected new-style class rendering %r' % gdb_repr)
@@ -398,7 +405,7 @@ id(foo)''')
# Match anything for the type name; 0xDEADBEEF could point to
# something arbitrary (see http://bugs.python.org/issue8330)
- pattern = '<.* at remote 0x[0-9a-f]+>'
+ pattern = '<.* at remote 0x-?[0-9a-f]+>'
m = re.match(pattern, gdb_repr)
if not m:
@@ -444,7 +451,7 @@ id(foo)''')
# http://bugs.python.org/issue8032#msg100537 )
gdb_repr, gdb_output = self.get_gdb_repr('id(__builtins__.help)', import_site=True)
- m = re.match(r'<_Helper at remote 0x[0-9a-f]+>', gdb_repr)
+ m = re.match(r'<_Helper at remote 0x-?[0-9a-f]+>', gdb_repr)
self.assertTrue(m,
msg='Unexpected rendering %r' % gdb_repr)
@@ -475,7 +482,7 @@ class Foo:
foo = Foo()
foo.an_attr = foo
id(foo)''')
- self.assertTrue(re.match('<Foo\(an_attr=<\.\.\.>\) at remote 0x[0-9a-f]+>',
+ self.assertTrue(re.match('<Foo\(an_attr=<\.\.\.>\) at remote 0x-?[0-9a-f]+>',
gdb_repr),
'Unexpected gdb representation: %r\n%s' % \
(gdb_repr, gdb_output))
@@ -488,7 +495,7 @@ class Foo(object):
foo = Foo()
foo.an_attr = foo
id(foo)''')
- self.assertTrue(re.match('<Foo\(an_attr=<\.\.\.>\) at remote 0x[0-9a-f]+>',
+ self.assertTrue(re.match('<Foo\(an_attr=<\.\.\.>\) at remote 0x-?[0-9a-f]+>',
gdb_repr),
'Unexpected gdb representation: %r\n%s' % \
(gdb_repr, gdb_output))
@@ -502,7 +509,7 @@ b = Foo()
a.an_attr = b
b.an_attr = a
id(a)''')
- self.assertTrue(re.match('<Foo\(an_attr=<Foo\(an_attr=<\.\.\.>\) at remote 0x[0-9a-f]+>\) at remote 0x[0-9a-f]+>',
+ self.assertTrue(re.match('<Foo\(an_attr=<Foo\(an_attr=<\.\.\.>\) at remote 0x-?[0-9a-f]+>\) at remote 0x-?[0-9a-f]+>',
gdb_repr),
'Unexpected gdb representation: %r\n%s' % \
(gdb_repr, gdb_output))
@@ -537,7 +544,7 @@ id(a)''')
def test_builtin_method(self):
gdb_repr, gdb_output = self.get_gdb_repr('import sys; id(sys.stdout.readlines)')
- self.assertTrue(re.match('<built-in method readlines of _io.TextIOWrapper object at remote 0x[0-9a-f]+>',
+ self.assertTrue(re.match('<built-in method readlines of _io.TextIOWrapper object at remote 0x-?[0-9a-f]+>',
gdb_repr),
'Unexpected gdb representation: %r\n%s' % \
(gdb_repr, gdb_output))
@@ -552,7 +559,7 @@ id(foo.__code__)''',
breakpoint='builtin_id',
cmds_after_breakpoint=['print (PyFrameObject*)(((PyCodeObject*)v)->co_zombieframe)']
)
- self.assertTrue(re.match('.*\s+\$1 =\s+Frame 0x[0-9a-f]+, for file <string>, line 3, in foo \(\)\s+.*',
+ self.assertTrue(re.match('.*\s+\$1 =\s+Frame 0x-?[0-9a-f]+, for file <string>, line 3, in foo \(\)\s+.*',
gdb_output,
re.DOTALL),
'Unexpected gdb representation: %r\n%s' % (gdb_output, gdb_output))
@@ -609,7 +616,7 @@ class StackNavigationTests(DebuggerTests):
cmds_after_breakpoint=['py-up'])
self.assertMultilineMatches(bt,
r'''^.*
-#[0-9]+ Frame 0x[0-9a-f]+, for file .*gdb_sample.py, line 7, in bar \(a=1, b=2, c=3\)
+#[0-9]+ Frame 0x-?[0-9a-f]+, for file .*gdb_sample.py, line 7, in bar \(a=1, b=2, c=3\)
baz\(a, b, c\)
$''')
@@ -638,9 +645,9 @@ $''')
cmds_after_breakpoint=['py-up', 'py-down'])
self.assertMultilineMatches(bt,
r'''^.*
-#[0-9]+ Frame 0x[0-9a-f]+, for file .*gdb_sample.py, line 7, in bar \(a=1, b=2, c=3\)
+#[0-9]+ Frame 0x-?[0-9a-f]+, for file .*gdb_sample.py, line 7, in bar \(a=1, b=2, c=3\)
baz\(a, b, c\)
-#[0-9]+ Frame 0x[0-9a-f]+, for file .*gdb_sample.py, line 10, in baz \(args=\(1, 2, 3\)\)
+#[0-9]+ Frame 0x-?[0-9a-f]+, for file .*gdb_sample.py, line 10, in baz \(args=\(1, 2, 3\)\)
id\(42\)
$''')
@@ -672,14 +679,106 @@ Traceback \(most recent call first\):
cmds_after_breakpoint=['py-bt-full'])
self.assertMultilineMatches(bt,
r'''^.*
-#[0-9]+ Frame 0x[0-9a-f]+, for file .*gdb_sample.py, line 7, in bar \(a=1, b=2, c=3\)
+#[0-9]+ Frame 0x-?[0-9a-f]+, for file .*gdb_sample.py, line 7, in bar \(a=1, b=2, c=3\)
baz\(a, b, c\)
-#[0-9]+ Frame 0x[0-9a-f]+, for file .*gdb_sample.py, line 4, in foo \(a=1, b=2, c=3\)
+#[0-9]+ Frame 0x-?[0-9a-f]+, for file .*gdb_sample.py, line 4, in foo \(a=1, b=2, c=3\)
bar\(a, b, c\)
-#[0-9]+ Frame 0x[0-9a-f]+, for file .*gdb_sample.py, line 12, in <module> \(\)
+#[0-9]+ Frame 0x-?[0-9a-f]+, for file .*gdb_sample.py, line 12, in <module> \(\)
foo\(1, 2, 3\)
''')
+ @unittest.skipUnless(_thread,
+ "Python was compiled without thread support")
+ def test_threads(self):
+ 'Verify that "py-bt" indicates threads that are waiting for the GIL'
+ cmd = '''
+from threading import Thread
+
+class TestThread(Thread):
+ # These threads would run forever, but we'll interrupt things with the
+ # debugger
+ def run(self):
+ i = 0
+ while 1:
+ i += 1
+
+t = {}
+for i in range(4):
+ t[i] = TestThread()
+ t[i].start()
+
+# Trigger a breakpoint on the main thread
+id(42)
+
+'''
+ # Verify with "py-bt":
+ gdb_output = self.get_stack_trace(cmd,
+ cmds_after_breakpoint=['thread apply all py-bt'])
+ self.assertIn('Waiting for the GIL', gdb_output)
+
+ # Verify with "py-bt-full":
+ gdb_output = self.get_stack_trace(cmd,
+ cmds_after_breakpoint=['thread apply all py-bt-full'])
+ self.assertIn('Waiting for the GIL', gdb_output)
+
+ @unittest.skipIf(python_is_optimized(),
+ "Python was compiled with optimizations")
+ # Some older versions of gdb will fail with
+ # "Cannot find new threads: generic error"
+ # unless we add LD_PRELOAD=PATH-TO-libpthread.so.1 as a workaround
+ @unittest.skipUnless(_thread,
+ "Python was compiled without thread support")
+ def test_gc(self):
+ 'Verify that "py-bt" indicates if a thread is garbage-collecting'
+ cmd = ('from gc import collect\n'
+ 'id(42)\n'
+ 'def foo():\n'
+ ' collect()\n'
+ 'def bar():\n'
+ ' foo()\n'
+ 'bar()\n')
+ # Verify with "py-bt":
+ gdb_output = self.get_stack_trace(cmd,
+ cmds_after_breakpoint=['break update_refs', 'continue', 'py-bt'],
+ )
+ self.assertIn('Garbage-collecting', gdb_output)
+
+ # Verify with "py-bt-full":
+ gdb_output = self.get_stack_trace(cmd,
+ cmds_after_breakpoint=['break update_refs', 'continue', 'py-bt-full'],
+ )
+ self.assertIn('Garbage-collecting', gdb_output)
+
+ @unittest.skipIf(python_is_optimized(),
+ "Python was compiled with optimizations")
+ # Some older versions of gdb will fail with
+ # "Cannot find new threads: generic error"
+ # unless we add LD_PRELOAD=PATH-TO-libpthread.so.1 as a workaround
+ @unittest.skipUnless(_thread,
+ "Python was compiled without thread support")
+ def test_pycfunction(self):
+ 'Verify that "py-bt" displays invocations of PyCFunction instances'
+ cmd = ('from time import sleep\n'
+ 'def foo():\n'
+ ' sleep(1)\n'
+ 'def bar():\n'
+ ' foo()\n'
+ 'bar()\n')
+ # Verify with "py-bt":
+ gdb_output = self.get_stack_trace(cmd,
+ breakpoint='time_sleep',
+ cmds_after_breakpoint=['bt', 'py-bt'],
+ )
+ self.assertIn('<built-in method sleep', gdb_output)
+
+ # Verify with "py-bt-full":
+ gdb_output = self.get_stack_trace(cmd,
+ breakpoint='time_sleep',
+ cmds_after_breakpoint=['py-bt-full'],
+ )
+ self.assertIn('#0 <built-in method sleep', gdb_output)
+
+
class PyPrintTests(DebuggerTests):
@unittest.skipIf(python_is_optimized(),
"Python was compiled with optimizations")
@@ -713,7 +812,7 @@ class PyPrintTests(DebuggerTests):
bt = self.get_stack_trace(script=self.get_sample_script(),
cmds_after_breakpoint=['py-print len'])
self.assertMultilineMatches(bt,
- r".*\nbuiltin 'len' = <built-in method len of module object at remote 0x[0-9a-f]+>\n.*")
+ r".*\nbuiltin 'len' = <built-in method len of module object at remote 0x-?[0-9a-f]+>\n.*")
class PyLocalsTests(DebuggerTests):
@unittest.skipIf(python_is_optimized(),
diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py
index 2c88373..958054a 100644
--- a/Lib/test/test_generators.py
+++ b/Lib/test/test_generators.py
@@ -729,29 +729,6 @@ Ye olde Fibonacci generator, tee style.
syntax_tests = """
->>> def f():
-... return 22
-... yield 1
-Traceback (most recent call last):
- ..
-SyntaxError: 'return' with argument inside generator
-
->>> def f():
-... yield 1
-... return 22
-Traceback (most recent call last):
- ..
-SyntaxError: 'return' with argument inside generator
-
-"return None" is not the same as "return" in a generator:
-
->>> def f():
-... yield 1
-... return None
-Traceback (most recent call last):
- ..
-SyntaxError: 'return' with argument inside generator
-
These are fine:
>>> def f():
@@ -867,20 +844,6 @@ These are fine:
>>> type(f())
<class 'generator'>
-
->>> def f():
-... if 0:
-... lambda x: x # shouldn't trigger here
-... return # or here
-... def f(i):
-... return 2*i # or here
-... if 0:
-... return 3 # but *this* sucks (line 8)
-... if 0:
-... yield 2 # because it's a generator (line 10)
-Traceback (most recent call last):
-SyntaxError: 'return' with argument inside generator
-
This one caused a crash (see SF bug 567538):
>>> def f():
@@ -1567,11 +1530,6 @@ Traceback (most recent call last):
...
SyntaxError: 'yield' outside function
->>> def f(): return lambda x=(yield): 1
-Traceback (most recent call last):
- ...
-SyntaxError: 'return' with argument inside generator
-
>>> def f(): x = yield = y
Traceback (most recent call last):
...
diff --git a/Lib/test/test_genericpath.py b/Lib/test/test_genericpath.py
index 6020923..fd8bc57 100644
--- a/Lib/test/test_genericpath.py
+++ b/Lib/test/test_genericpath.py
@@ -2,11 +2,12 @@
Tests common to genericpath, macpath, ntpath and posixpath
"""
-import unittest
-from test import support
-import os
import genericpath
+import os
import sys
+import unittest
+import warnings
+from test import support
def safe_rmdir(dirname):
@@ -16,9 +17,7 @@ def safe_rmdir(dirname):
pass
-class GenericTest(unittest.TestCase):
- # The path module to be tested
- pathmodule = genericpath
+class GenericTest:
common_attributes = ['commonprefix', 'getsize', 'getatime', 'getctime',
'getmtime', 'exists', 'isdir', 'isfile']
attributes = []
@@ -145,6 +144,16 @@ class GenericTest(unittest.TestCase):
f.close()
support.unlink(support.TESTFN)
+ @unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
+ def test_exists_fd(self):
+ r, w = os.pipe()
+ try:
+ self.assertTrue(self.pathmodule.exists(r))
+ finally:
+ os.close(r)
+ os.close(w)
+ self.assertFalse(self.pathmodule.exists(r))
+
def test_isdir(self):
self.assertIs(self.pathmodule.isdir(support.TESTFN), False)
f = open(support.TESTFN, "wb")
@@ -179,13 +188,16 @@ class GenericTest(unittest.TestCase):
support.unlink(support.TESTFN)
safe_rmdir(support.TESTFN)
+class TestGenericTest(GenericTest, unittest.TestCase):
+ # Issue 16852: GenericTest can't inherit from unittest.TestCase
+ # for test discovery purposes; CommonTest inherits from GenericTest
+ # and is only meant to be inherited by others.
+ pathmodule = genericpath
# Following TestCase is not supposed to be run from test_genericpath.
# It is inherited by other test modules (macpath, ntpath, posixpath).
class CommonTest(GenericTest):
- # The path module to be tested
- pathmodule = None
common_attributes = GenericTest.common_attributes + [
# Properties
'curdir', 'pardir', 'extsep', 'sep',
@@ -258,15 +270,21 @@ class CommonTest(GenericTest):
def test_abspath(self):
self.assertIn("foo", self.pathmodule.abspath("foo"))
- self.assertIn(b"foo", self.pathmodule.abspath(b"foo"))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ self.assertIn(b"foo", self.pathmodule.abspath(b"foo"))
# Abspath returns bytes when the arg is bytes
- for path in (b'', b'foo', b'f\xf2\xf2', b'/foo', b'C:\\'):
- self.assertIsInstance(self.pathmodule.abspath(path), bytes)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ for path in (b'', b'foo', b'f\xf2\xf2', b'/foo', b'C:\\'):
+ self.assertIsInstance(self.pathmodule.abspath(path), bytes)
def test_realpath(self):
self.assertIn("foo", self.pathmodule.realpath("foo"))
- self.assertIn(b"foo", self.pathmodule.realpath(b"foo"))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ self.assertIn(b"foo", self.pathmodule.realpath(b"foo"))
def test_normpath_issue5827(self):
# Make sure normpath preserves unicode
@@ -282,8 +300,7 @@ class CommonTest(GenericTest):
unicwd = '\xe7w\xf0'
try:
- fsencoding = support.TESTFN_ENCODING or "ascii"
- unicwd.encode(fsencoding)
+ os.fsencode(unicwd)
except (AttributeError, UnicodeEncodeError):
# FS encoding is probably ASCII
pass
@@ -305,13 +322,12 @@ class CommonTest(GenericTest):
else:
self.skipTest("need support.TESTFN_NONASCII")
- with support.temp_cwd(name):
- self.test_abspath()
-
-
-def test_main():
- support.run_unittest(GenericTest)
+ # Test non-ASCII, non-UTF8 bytes in the path.
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ with support.temp_cwd(name):
+ self.test_abspath()
if __name__=="__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_genexps.py b/Lib/test/test_genexps.py
index d8eb550..203b336 100644
--- a/Lib/test/test_genexps.py
+++ b/Lib/test/test_genexps.py
@@ -258,11 +258,15 @@ Verify that genexps are weakly referencable
"""
+import sys
-__test__ = {'doctests' : doctests}
+# Trace function can throw off the tuple reuse test.
+if hasattr(sys, 'gettrace') and sys.gettrace():
+ __test__ = {}
+else:
+ __test__ = {'doctests' : doctests}
def test_main(verbose=None):
- import sys
from test import support
from test import test_genexps
support.run_doctest(test_genexps, verbose)
diff --git a/Lib/test/test_getargs2.py b/Lib/test/test_getargs2.py
index 3d9c06a..48ca94e 100644
--- a/Lib/test/test_getargs2.py
+++ b/Lib/test/test_getargs2.py
@@ -1,6 +1,6 @@
import unittest
from test import support
-from _testcapi import getargs_keywords
+from _testcapi import getargs_keywords, getargs_keyword_only
"""
> How about the following counterproposal. This also changes some of
@@ -214,6 +214,36 @@ class LongLong_TestCase(unittest.TestCase):
self.assertEqual(VERY_LARGE & ULLONG_MAX, getargs_K(VERY_LARGE))
+class Paradox:
+ "This statement is false."
+ def __bool__(self):
+ raise NotImplementedError
+
+class Boolean_TestCase(unittest.TestCase):
+ def test_p(self):
+ from _testcapi import getargs_p
+ self.assertEqual(0, getargs_p(False))
+ self.assertEqual(0, getargs_p(None))
+ self.assertEqual(0, getargs_p(0))
+ self.assertEqual(0, getargs_p(0.0))
+ self.assertEqual(0, getargs_p(0j))
+ self.assertEqual(0, getargs_p(''))
+ self.assertEqual(0, getargs_p(()))
+ self.assertEqual(0, getargs_p([]))
+ self.assertEqual(0, getargs_p({}))
+
+ self.assertEqual(1, getargs_p(True))
+ self.assertEqual(1, getargs_p(1))
+ self.assertEqual(1, getargs_p(1.0))
+ self.assertEqual(1, getargs_p(1j))
+ self.assertEqual(1, getargs_p('x'))
+ self.assertEqual(1, getargs_p((1,)))
+ self.assertEqual(1, getargs_p([1]))
+ self.assertEqual(1, getargs_p({1:2}))
+ self.assertEqual(1, getargs_p(unittest.TestCase))
+
+ self.assertRaises(NotImplementedError, getargs_p, Paradox())
+
class Tuple_TestCase(unittest.TestCase):
def test_tuple(self):
@@ -293,7 +323,87 @@ class Keywords_TestCase(unittest.TestCase):
else:
self.fail('TypeError should have been raised')
+class KeywordOnly_TestCase(unittest.TestCase):
+ def test_positional_args(self):
+ # using all possible positional args
+ self.assertEqual(
+ getargs_keyword_only(1, 2),
+ (1, 2, -1)
+ )
+
+ def test_mixed_args(self):
+ # positional and keyword args
+ self.assertEqual(
+ getargs_keyword_only(1, 2, keyword_only=3),
+ (1, 2, 3)
+ )
+
+ def test_keyword_args(self):
+ # all keywords
+ self.assertEqual(
+ getargs_keyword_only(required=1, optional=2, keyword_only=3),
+ (1, 2, 3)
+ )
+
+ def test_optional_args(self):
+ # missing optional keyword args, skipping tuples
+ self.assertEqual(
+ getargs_keyword_only(required=1, optional=2),
+ (1, 2, -1)
+ )
+ self.assertEqual(
+ getargs_keyword_only(required=1, keyword_only=3),
+ (1, -1, 3)
+ )
+
+ def test_required_args(self):
+ self.assertEqual(
+ getargs_keyword_only(1),
+ (1, -1, -1)
+ )
+ self.assertEqual(
+ getargs_keyword_only(required=1),
+ (1, -1, -1)
+ )
+ # required arg missing
+ with self.assertRaisesRegex(TypeError,
+ "Required argument 'required' \(pos 1\) not found"):
+ getargs_keyword_only(optional=2)
+
+ with self.assertRaisesRegex(TypeError,
+ "Required argument 'required' \(pos 1\) not found"):
+ getargs_keyword_only(keyword_only=3)
+
+ def test_too_many_args(self):
+ with self.assertRaisesRegex(TypeError,
+ "Function takes at most 2 positional arguments \(3 given\)"):
+ getargs_keyword_only(1, 2, 3)
+
+ with self.assertRaisesRegex(TypeError,
+ "function takes at most 3 arguments \(4 given\)"):
+ getargs_keyword_only(1, 2, 3, keyword_only=5)
+
+ def test_invalid_keyword(self):
+ # extraneous keyword arg
+ with self.assertRaisesRegex(TypeError,
+ "'monster' is an invalid keyword argument for this function"):
+ getargs_keyword_only(1, 2, monster=666)
+
+ def test_surrogate_keyword(self):
+ with self.assertRaisesRegex(TypeError,
+ "'\udc80' is an invalid keyword argument for this function"):
+ getargs_keyword_only(1, 2, **{'\uDC80': 10})
+
class Bytes_TestCase(unittest.TestCase):
+ def test_c(self):
+ from _testcapi import getargs_c
+ self.assertRaises(TypeError, getargs_c, b'abc') # len > 1
+ self.assertEqual(getargs_c(b'a'), b'a')
+ self.assertEqual(getargs_c(bytearray(b'a')), b'a')
+ self.assertRaises(TypeError, getargs_c, memoryview(b'a'))
+ self.assertRaises(TypeError, getargs_c, 's')
+ self.assertRaises(TypeError, getargs_c, None)
+
def test_s(self):
from _testcapi import getargs_s
self.assertEqual(getargs_s('abc\xe9'), b'abc\xc3\xa9')
@@ -430,8 +540,10 @@ def test_main():
tests = [
Signed_TestCase,
Unsigned_TestCase,
+ Boolean_TestCase,
Tuple_TestCase,
Keywords_TestCase,
+ KeywordOnly_TestCase,
Bytes_TestCase,
Unicode_TestCase,
]
diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py
index 1cdc24a..eb9aeb5 100644
--- a/Lib/test/test_glob.py
+++ b/Lib/test/test_glob.py
@@ -4,7 +4,8 @@ import shutil
import sys
import unittest
-from test.support import run_unittest, TESTFN, skip_unless_symlink, can_symlink
+from test.support import (run_unittest, TESTFN, skip_unless_symlink,
+ can_symlink, create_empty_file)
class GlobTests(unittest.TestCase):
@@ -17,8 +18,7 @@ class GlobTests(unittest.TestCase):
base, file = os.path.split(filename)
if not os.path.exists(base):
os.makedirs(base)
- f = open(filename, 'w')
- f.close()
+ create_empty_file(filename)
def setUp(self):
self.tempdir = TESTFN + "_dir"
diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py
index 268a633..6b326bd 100644
--- a/Lib/test/test_grammar.py
+++ b/Lib/test/test_grammar.py
@@ -10,7 +10,7 @@ from sys import *
class TokenTests(unittest.TestCase):
- def testBackslash(self):
+ def test_backslash(self):
# Backslash means line continuation:
x = 1 \
+ 1
@@ -20,7 +20,7 @@ class TokenTests(unittest.TestCase):
x = 0
self.assertEqual(x, 0, 'backslash ending comment')
- def testPlainIntegers(self):
+ def test_plain_integers(self):
self.assertEqual(type(000), type(0))
self.assertEqual(0xff, 255)
self.assertEqual(0o377, 255)
@@ -56,7 +56,7 @@ class TokenTests(unittest.TestCase):
else:
self.fail('Weird maxsize value %r' % maxsize)
- def testLongIntegers(self):
+ def test_long_integers(self):
x = 0
x = 0xffffffffffffffff
x = 0Xffffffffffffffff
@@ -66,7 +66,7 @@ class TokenTests(unittest.TestCase):
x = 0b100000000000000000000000000000000000000000000000000000000000000000000
x = 0B111111111111111111111111111111111111111111111111111111111111111111111
- def testFloats(self):
+ def test_floats(self):
x = 3.14
x = 314.
x = 0.314
@@ -80,7 +80,7 @@ class TokenTests(unittest.TestCase):
x = .3e14
x = 3.1e4
- def testStringLiterals(self):
+ def test_string_literals(self):
x = ''; y = ""; self.assertTrue(len(x) == 0 and x == y)
x = '\''; y = "'"; self.assertTrue(len(x) == 1 and x == y and ord(x) == 39)
x = '"'; y = "\""; self.assertTrue(len(x) == 1 and x == y and ord(x) == 34)
@@ -120,11 +120,18 @@ the \'lazy\' dog.\n\
'
self.assertEqual(x, y)
- def testEllipsis(self):
+ def test_ellipsis(self):
x = ...
self.assertTrue(x is Ellipsis)
self.assertRaises(SyntaxError, eval, ".. .")
+ def test_eof_error(self):
+ samples = ("def foo(", "\ndef foo(", "def foo(\n")
+ for s in samples:
+ with self.assertRaises(SyntaxError) as cm:
+ compile(s, "<test>", "exec")
+ self.assertIn("unexpected EOF", str(cm.exception))
+
class GrammarTests(unittest.TestCase):
# single_input: NEWLINE | simple_stmt | compound_stmt NEWLINE
@@ -136,11 +143,11 @@ class GrammarTests(unittest.TestCase):
# expr_input: testlist NEWLINE
# XXX Hard to test -- used only in calls to input()
- def testEvalInput(self):
+ def test_eval_input(self):
# testlist ENDMARKER
x = eval('1, 0 or 1')
- def testFuncdef(self):
+ def test_funcdef(self):
### [decorators] 'def' NAME parameters ['->' test] ':' suite
### decorator: '@' dotted_name [ '(' [arglist] ')' ] NEWLINE
### decorators: decorator+
@@ -324,7 +331,7 @@ class GrammarTests(unittest.TestCase):
check_syntax_error(self, "f(*g(1=2))")
check_syntax_error(self, "f(**g(1=2))")
- def testLambdef(self):
+ def test_lambdef(self):
### lambdef: 'lambda' [varargslist] ':' test
l1 = lambda : 0
self.assertEqual(l1(), 0)
@@ -346,7 +353,7 @@ class GrammarTests(unittest.TestCase):
### stmt: simple_stmt | compound_stmt
# Tested below
- def testSimpleStmt(self):
+ def test_simple_stmt(self):
### simple_stmt: small_stmt (';' small_stmt)* [';']
x = 1; pass; del x
def foo():
@@ -357,7 +364,7 @@ class GrammarTests(unittest.TestCase):
### small_stmt: expr_stmt | pass_stmt | del_stmt | flow_stmt | import_stmt | global_stmt | access_stmt
# Tested below
- def testExprStmt(self):
+ def test_expr_stmt(self):
# (exprlist '=')* exprlist
1
1, 2, 3
@@ -370,7 +377,7 @@ class GrammarTests(unittest.TestCase):
check_syntax_error(self, "x + 1 = 1")
check_syntax_error(self, "a + 1 = b + 2")
- def testDelStmt(self):
+ def test_del_stmt(self):
# 'del' exprlist
abc = [1,2,3]
x, y, z = abc
@@ -379,18 +386,18 @@ class GrammarTests(unittest.TestCase):
del abc
del x, y, (z, xyz)
- def testPassStmt(self):
+ def test_pass_stmt(self):
# 'pass'
pass
# flow_stmt: break_stmt | continue_stmt | return_stmt | raise_stmt
# Tested below
- def testBreakStmt(self):
+ def test_break_stmt(self):
# 'break'
while 1: break
- def testContinueStmt(self):
+ def test_continue_stmt(self):
# 'continue'
i = 1
while i: i = 0; continue
@@ -442,7 +449,7 @@ class GrammarTests(unittest.TestCase):
self.fail("continue then break in try/except in loop broken!")
test_inner()
- def testReturn(self):
+ def test_return(self):
# 'return' [testlist]
def g1(): return
def g2(): return 1
@@ -450,17 +457,49 @@ class GrammarTests(unittest.TestCase):
x = g2()
check_syntax_error(self, "class foo:return 1")
- def testYield(self):
+ def test_yield(self):
+ # Allowed as standalone statement
+ def g(): yield 1
+ def g(): yield from ()
+ # Allowed as RHS of assignment
+ def g(): x = yield 1
+ def g(): x = yield from ()
+ # Ordinary yield accepts implicit tuples
+ def g(): yield 1, 1
+ def g(): x = yield 1, 1
+ # 'yield from' does not
+ check_syntax_error(self, "def g(): yield from (), 1")
+ check_syntax_error(self, "def g(): x = yield from (), 1")
+ # Requires parentheses as subexpression
+ def g(): 1, (yield 1)
+ def g(): 1, (yield from ())
+ check_syntax_error(self, "def g(): 1, yield 1")
+ check_syntax_error(self, "def g(): 1, yield from ()")
+ # Requires parentheses as call argument
+ def g(): f((yield 1))
+ def g(): f((yield 1), 1)
+ def g(): f((yield from ()))
+ def g(): f((yield from ()), 1)
+ check_syntax_error(self, "def g(): f(yield 1)")
+ check_syntax_error(self, "def g(): f(yield 1, 1)")
+ check_syntax_error(self, "def g(): f(yield from ())")
+ check_syntax_error(self, "def g(): f(yield from (), 1)")
+ # Not allowed at top level
+ check_syntax_error(self, "yield")
+ check_syntax_error(self, "yield from")
+ # Not allowed at class scope
check_syntax_error(self, "class foo:yield 1")
+ check_syntax_error(self, "class foo:yield from ()")
+
- def testRaise(self):
+ def test_raise(self):
# 'raise' test [',' test]
try: raise RuntimeError('just testing')
except RuntimeError: pass
try: raise KeyboardInterrupt
except KeyboardInterrupt: pass
- def testImport(self):
+ def test_import(self):
# 'import' dotted_as_names
import sys
import time, sys
@@ -473,13 +512,13 @@ class GrammarTests(unittest.TestCase):
from sys import (path, argv)
from sys import (path, argv,)
- def testGlobal(self):
+ def test_global(self):
# 'global' NAME (',' NAME)*
global a
global a, b
global one, two, three, four, five, six, seven, eight, nine, ten
- def testNonlocal(self):
+ def test_nonlocal(self):
# 'nonlocal' NAME (',' NAME)*
x = 0
y = 0
@@ -487,7 +526,7 @@ class GrammarTests(unittest.TestCase):
nonlocal x
nonlocal x, y
- def testAssert(self):
+ def test_assert(self):
# assertTruestmt: 'assert' test [',' test]
assert 1
assert 1, 1
@@ -526,7 +565,7 @@ class GrammarTests(unittest.TestCase):
### compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | funcdef | classdef
# Tested below
- def testIf(self):
+ def test_if(self):
# 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
if 1: pass
if 1: pass
@@ -539,7 +578,7 @@ class GrammarTests(unittest.TestCase):
elif 0: pass
else: pass
- def testWhile(self):
+ def test_while(self):
# 'while' test ':' suite ['else' ':' suite]
while 0: pass
while 0: pass
@@ -554,7 +593,7 @@ class GrammarTests(unittest.TestCase):
x = 2
self.assertEqual(x, 2)
- def testFor(self):
+ def test_for(self):
# 'for' exprlist 'in' exprlist ':' suite ['else' ':' suite]
for i in 1, 2, 3: pass
for i, j, k in (): pass
@@ -581,7 +620,7 @@ class GrammarTests(unittest.TestCase):
result.append(x)
self.assertEqual(result, [1, 2, 3])
- def testTry(self):
+ def test_try(self):
### try_stmt: 'try' ':' suite (except_clause ':' suite)+ ['else' ':' suite]
### | 'try' ':' suite 'finally' ':' suite
### except_clause: 'except' [expr ['as' expr]]
@@ -604,7 +643,7 @@ class GrammarTests(unittest.TestCase):
try: pass
finally: pass
- def testSuite(self):
+ def test_suite(self):
# simple_stmt | NEWLINE INDENT NEWLINE* (stmt NEWLINE*)+ DEDENT
if 1: pass
if 1:
@@ -619,7 +658,7 @@ class GrammarTests(unittest.TestCase):
pass
#
- def testTest(self):
+ def test_test(self):
### and_test ('or' and_test)*
### and_test: not_test ('and' not_test)*
### not_test: 'not' not_test | comparison
@@ -630,7 +669,7 @@ class GrammarTests(unittest.TestCase):
if not 1 and 1 and 1: pass
if 1 and 1 or 1 and 1 and 1 or not 1 and 1: pass
- def testComparison(self):
+ def test_comparison(self):
### comparison: expr (comp_op expr)*
### comp_op: '<'|'>'|'=='|'>='|'<='|'!='|'in'|'not' 'in'|'is'|'is' 'not'
if 1: pass
@@ -647,36 +686,36 @@ class GrammarTests(unittest.TestCase):
if 1 not in (): pass
if 1 < 1 > 1 == 1 >= 1 <= 1 != 1 in 1 not in 1 is 1 is not 1: pass
- def testBinaryMaskOps(self):
+ def test_binary_mask_ops(self):
x = 1 & 1
x = 1 ^ 1
x = 1 | 1
- def testShiftOps(self):
+ def test_shift_ops(self):
x = 1 << 1
x = 1 >> 1
x = 1 << 1 >> 1
- def testAdditiveOps(self):
+ def test_additive_ops(self):
x = 1
x = 1 + 1
x = 1 - 1 - 1
x = 1 - 1 + 1 - 1 + 1
- def testMultiplicativeOps(self):
+ def test_multiplicative_ops(self):
x = 1 * 1
x = 1 / 1
x = 1 % 1
x = 1 / 1 * 1 % 1
- def testUnaryOps(self):
+ def test_unary_ops(self):
x = +1
x = -1
x = ~1
x = ~1 ^ 1 & 1 | 1 & 1 ^ -1
x = -1*1/1 + 1*1 - ---1*1
- def testSelectors(self):
+ def test_selectors(self):
### trailer: '(' [testlist] ')' | '[' subscript ']' | '.' NAME
### subscript: expr | [expr] ':' [expr]
@@ -706,7 +745,7 @@ class GrammarTests(unittest.TestCase):
L.sort(key=lambda x: x if isinstance(x, tuple) else ())
self.assertEqual(str(L), '[1, (1,), (1, 2), (1, 2, 3)]')
- def testAtoms(self):
+ def test_atoms(self):
### atom: '(' [testlist] ')' | '[' [testlist] ']' | '{' [dictsetmaker] '}' | NAME | NUMBER | STRING
### dictsetmaker: (test ':' test (',' test ':' test)* [',']) | (test (',' test)* [','])
@@ -741,7 +780,7 @@ class GrammarTests(unittest.TestCase):
### testlist: test (',' test)* [',']
# These have been exercised enough above
- def testClassdef(self):
+ def test_classdef(self):
# 'class' NAME ['(' [testlist] ')'] ':' suite
class B: pass
class B2(): pass
@@ -760,14 +799,14 @@ class GrammarTests(unittest.TestCase):
@class_decorator
class G: pass
- def testDictcomps(self):
+ def test_dictcomps(self):
# dictorsetmaker: ( (test ':' test (comp_for |
# (',' test ':' test)* [','])) |
# (test (comp_for | (',' test)* [','])) )
nums = [1, 2, 3]
self.assertEqual({i:i+1 for i in nums}, {1: 2, 2: 3, 3: 4})
- def testListcomps(self):
+ def test_listcomps(self):
# list comprehension tests
nums = [1, 2, 3, 4, 5]
strs = ["Apple", "Banana", "Coconut"]
@@ -830,7 +869,7 @@ class GrammarTests(unittest.TestCase):
self.assertEqual(x, [('Boeing', 'Airliner'), ('Boeing', 'Engine'), ('Ford', 'Engine'),
('Macdonalds', 'Cheeseburger')])
- def testGenexps(self):
+ def test_genexps(self):
# generator expression tests
g = ([x for x in range(10)] for x in range(1))
self.assertEqual(next(g), [x for x in range(10)])
@@ -865,7 +904,7 @@ class GrammarTests(unittest.TestCase):
check_syntax_error(self, "foo(x for x in range(10), 100)")
check_syntax_error(self, "foo(100, x for x in range(10))")
- def testComprehensionSpecials(self):
+ def test_comprehension_specials(self):
# test for outmost iterable precomputation
x = 10; g = (i for i in range(x)); x = 5
self.assertEqual(len(list(g)), 10)
@@ -904,7 +943,7 @@ class GrammarTests(unittest.TestCase):
with manager() as x, manager():
pass
- def testIfElseExpr(self):
+ def test_if_else_expr(self):
# Test ifelse expressions in various cases
def _checkeval(msg, ret):
"helper to check that evaluation of expressions is done correctly"
diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py
index ba9d7da..ebd4c43 100644
--- a/Lib/test/test_gzip.py
+++ b/Lib/test/test_gzip.py
@@ -33,7 +33,7 @@ class UnseekableIO(io.BytesIO):
raise io.UnsupportedOperation
-class TestGzip(unittest.TestCase):
+class BaseTest(unittest.TestCase):
filename = support.TESTFN
def setUp(self):
@@ -43,6 +43,7 @@ class TestGzip(unittest.TestCase):
support.unlink(self.filename)
+class TestGzip(BaseTest):
def test_write(self):
with gzip.GzipFile(self.filename, 'wb') as f:
f.write(data1 * 50)
@@ -64,6 +65,21 @@ class TestGzip(unittest.TestCase):
d = f.read()
self.assertEqual(d, data1*50)
+ def test_read1(self):
+ self.test_write()
+ blocks = []
+ nread = 0
+ with gzip.GzipFile(self.filename, 'r') as f:
+ while True:
+ d = f.read1()
+ if not d:
+ break
+ blocks.append(d)
+ nread += len(d)
+ # Check that position was updated correctly (see issue10791).
+ self.assertEqual(f.tell(), nread)
+ self.assertEqual(b''.join(blocks), data1 * 50)
+
def test_io_on_closed_object(self):
# Test that I/O operations on closed GzipFile objects raise a
# ValueError, just like the corresponding functions on file objects.
@@ -100,14 +116,14 @@ class TestGzip(unittest.TestCase):
# Bug #1074261 was triggered when reading a file that contained
# many, many members. Create such a file and verify that reading it
# works.
- with gzip.open(self.filename, 'wb', 9) as f:
+ with gzip.GzipFile(self.filename, 'wb', 9) as f:
f.write(b'a')
for i in range(0, 200):
- with gzip.open(self.filename, "ab", 9) as f: # append
+ with gzip.GzipFile(self.filename, "ab", 9) as f: # append
f.write(b'a')
# Try reading the file
- with gzip.open(self.filename, "rb") as zgfile:
+ with gzip.GzipFile(self.filename, "rb") as zgfile:
contents = b""
while 1:
ztxt = zgfile.read(8192)
@@ -124,7 +140,7 @@ class TestGzip(unittest.TestCase):
with io.BufferedReader(f) as r:
lines = [line for line in r]
- self.assertEqual(lines, 50 * data1.splitlines(True))
+ self.assertEqual(lines, 50 * data1.splitlines(keepends=True))
def test_readline(self):
self.test_write()
@@ -323,6 +339,14 @@ class TestGzip(unittest.TestCase):
self.assertEqual(f.read(100), b'')
self.assertEqual(nread, len(uncompressed))
+ def test_textio_readlines(self):
+ # Issue #10791: TextIOWrapper.readlines() fails when wrapping GzipFile.
+ lines = (data1 * 50).decode("ascii").splitlines(keepends=True)
+ self.test_write()
+ with gzip.GzipFile(self.filename, 'r') as f:
+ with io.TextIOWrapper(f, encoding="ascii") as t:
+ self.assertEqual(t.readlines(), lines)
+
def test_fileobj_from_fdopen(self):
# Issue #13781: Opening a GzipFile for writing fails when using a
# fileobj created with os.fdopen().
@@ -380,8 +404,108 @@ class TestGzip(unittest.TestCase):
self.assertRaises(EOFError, f.read, 1)
+class TestOpen(BaseTest):
+ def test_binary_modes(self):
+ uncompressed = data1 * 50
+ with gzip.open(self.filename, "wb") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read())
+ self.assertEqual(file_data, uncompressed)
+ with gzip.open(self.filename, "rb") as f:
+ self.assertEqual(f.read(), uncompressed)
+ with gzip.open(self.filename, "ab") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read())
+ self.assertEqual(file_data, uncompressed * 2)
+
+ def test_implicit_binary_modes(self):
+ # Test implicit binary modes (no "b" or "t" in mode string).
+ uncompressed = data1 * 50
+ with gzip.open(self.filename, "w") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read())
+ self.assertEqual(file_data, uncompressed)
+ with gzip.open(self.filename, "r") as f:
+ self.assertEqual(f.read(), uncompressed)
+ with gzip.open(self.filename, "a") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read())
+ self.assertEqual(file_data, uncompressed * 2)
+
+ def test_text_modes(self):
+ uncompressed = data1.decode("ascii") * 50
+ uncompressed_raw = uncompressed.replace("\n", os.linesep)
+ with gzip.open(self.filename, "wt") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read()).decode("ascii")
+ self.assertEqual(file_data, uncompressed_raw)
+ with gzip.open(self.filename, "rt") as f:
+ self.assertEqual(f.read(), uncompressed)
+ with gzip.open(self.filename, "at") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read()).decode("ascii")
+ self.assertEqual(file_data, uncompressed_raw * 2)
+
+ def test_fileobj(self):
+ uncompressed_bytes = data1 * 50
+ uncompressed_str = uncompressed_bytes.decode("ascii")
+ compressed = gzip.compress(uncompressed_bytes)
+ with gzip.open(io.BytesIO(compressed), "r") as f:
+ self.assertEqual(f.read(), uncompressed_bytes)
+ with gzip.open(io.BytesIO(compressed), "rb") as f:
+ self.assertEqual(f.read(), uncompressed_bytes)
+ with gzip.open(io.BytesIO(compressed), "rt") as f:
+ self.assertEqual(f.read(), uncompressed_str)
+
+ def test_bad_params(self):
+ # Test invalid parameter combinations.
+ with self.assertRaises(TypeError):
+ gzip.open(123.456)
+ with self.assertRaises(ValueError):
+ gzip.open(self.filename, "wbt")
+ with self.assertRaises(ValueError):
+ gzip.open(self.filename, "rb", encoding="utf-8")
+ with self.assertRaises(ValueError):
+ gzip.open(self.filename, "rb", errors="ignore")
+ with self.assertRaises(ValueError):
+ gzip.open(self.filename, "rb", newline="\n")
+
+ def test_encoding(self):
+ # Test non-default encoding.
+ uncompressed = data1.decode("ascii") * 50
+ uncompressed_raw = uncompressed.replace("\n", os.linesep)
+ with gzip.open(self.filename, "wt", encoding="utf-16") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read()).decode("utf-16")
+ self.assertEqual(file_data, uncompressed_raw)
+ with gzip.open(self.filename, "rt", encoding="utf-16") as f:
+ self.assertEqual(f.read(), uncompressed)
+
+ def test_encoding_error_handler(self):
+ # Test with non-default encoding error handler.
+ with gzip.open(self.filename, "wb") as f:
+ f.write(b"foo\xffbar")
+ with gzip.open(self.filename, "rt", encoding="ascii", errors="ignore") \
+ as f:
+ self.assertEqual(f.read(), "foobar")
+
+ def test_newline(self):
+ # Test with explicit newline (universal newline mode disabled).
+ uncompressed = data1.decode("ascii") * 50
+ with gzip.open(self.filename, "wt", newline="\n") as f:
+ f.write(uncompressed)
+ with gzip.open(self.filename, "rt", newline="\r") as f:
+ self.assertEqual(f.readlines(), [uncompressed])
+
def test_main(verbose=None):
- support.run_unittest(TestGzip)
+ support.run_unittest(TestGzip, TestOpen)
if __name__ == "__main__":
test_main(verbose=True)
diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py
index c0dd34d..d1ddc76 100644
--- a/Lib/test/test_hash.py
+++ b/Lib/test/test_hash.py
@@ -45,6 +45,16 @@ class HashEqualityTestCase(unittest.TestCase):
self.same_hash(int(1.23e300), float(1.23e300))
self.same_hash(float(0.5), complex(0.5, 0.0))
+ def test_unaligned_buffers(self):
+ # The hash function for bytes-like objects shouldn't have
+ # alignment-dependent results (example in issue #16427).
+ b = b"123456789abcdefghijklmnopqrstuvwxyz" * 128
+ for i in range(16):
+ for j in range(16):
+ aligned = b[i:128+j]
+ unaligned = memoryview(b)[i:128+j]
+ self.assertEqual(hash(aligned), hash(unaligned))
+
_default_hash = object.__hash__
class DefaultHash(object): pass
@@ -113,8 +123,7 @@ class DefaultIterSeq(object):
return self.seq[index]
class HashBuiltinsTestCase(unittest.TestCase):
- hashes_to_check = [range(10),
- enumerate(range(10)),
+ hashes_to_check = [enumerate(range(10)),
iter(DefaultIterSeq()),
iter(lambda: 0, 0),
]
@@ -160,8 +169,8 @@ class StringlikeHashRandomizationTests(HashRandomizationTests):
else:
known_hash_of_obj = -1600925533
- # Randomization is disabled by default:
- self.assertEqual(self.get_hash(self.repr_), known_hash_of_obj)
+ # Randomization is enabled by default:
+ self.assertNotEqual(self.get_hash(self.repr_), known_hash_of_obj)
# It can also be disabled by setting the seed to 0:
self.assertEqual(self.get_hash(self.repr_, seed=0), known_hash_of_obj)
@@ -193,6 +202,12 @@ class BytesHashRandomizationTests(StringlikeHashRandomizationTests):
def test_empty_string(self):
self.assertEqual(hash(b""), 0)
+class MemoryviewHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = "memoryview(b'abc')"
+
+ def test_empty_string(self):
+ self.assertEqual(hash(memoryview(b"")), 0)
+
class DatetimeTests(HashRandomizationTests):
def get_hash_command(self, repr_):
return 'import datetime; print(hash(%s))' % repr_
@@ -213,6 +228,7 @@ def test_main():
HashBuiltinsTestCase,
StrHashRandomizationTests,
BytesHashRandomizationTests,
+ MemoryviewHashRandomizationTests,
DatetimeDateTests,
DatetimeDatetimeTests,
DatetimeTimeTests)
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
index 29d3a1c..32f85e9 100644
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -9,6 +9,7 @@
import array
import hashlib
import itertools
+import os
import sys
try:
import threading
@@ -37,7 +38,8 @@ class HashLibTestCase(unittest.TestCase):
'sha224', 'SHA224', 'sha256', 'SHA256',
'sha384', 'SHA384', 'sha512', 'SHA512' )
- _warn_on_extension_import = COMPILED_WITH_PYDEBUG
+ # Issue #14693: fallback modules are always compiled under POSIX
+ _warn_on_extension_import = os.name == 'posix' or COMPILED_WITH_PYDEBUG
def _conditional_import_module(self, module_name):
"""Import a module and return a reference to it or None on failure."""
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index e0c49c1..54395c0 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -2,6 +2,7 @@
import sys
import random
+import unittest
from test import support
from unittest import TestCase, skipUnless
@@ -25,8 +26,7 @@ class TestModules(TestCase):
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
-class TestHeap(TestCase):
- module = None
+class TestHeap:
def test_push_pop(self):
# 1) Push 256 random numbers and pop them off, verifying all's OK.
@@ -214,12 +214,12 @@ class TestHeap(TestCase):
self.assertRaises(TypeError, data, LE)
-class TestHeapPython(TestHeap):
+class TestHeapPython(TestHeap, TestCase):
module = py_heapq
@skipUnless(c_heapq, 'requires _heapq')
-class TestHeapC(TestHeap):
+class TestHeapC(TestHeap, TestCase):
module = c_heapq
@@ -319,8 +319,7 @@ def L(seqn):
return chain(map(lambda x:x, R(Ig(G(seqn)))))
-class TestErrorHandling(TestCase):
- module = None
+class TestErrorHandling:
def test_non_sequence(self):
for f in (self.module.heapify, self.module.heappop):
@@ -371,31 +370,13 @@ class TestErrorHandling(TestCase):
self.assertRaises(ZeroDivisionError, f, 2, E(s))
-class TestErrorHandlingPython(TestErrorHandling):
+class TestErrorHandlingPython(TestErrorHandling, TestCase):
module = py_heapq
@skipUnless(c_heapq, 'requires _heapq')
-class TestErrorHandlingC(TestErrorHandling):
+class TestErrorHandlingC(TestErrorHandling, TestCase):
module = c_heapq
-#==============================================================================
-
-
-def test_main(verbose=None):
- test_classes = [TestModules, TestHeapPython, TestHeapC,
- TestErrorHandlingPython, TestErrorHandlingC]
- support.run_unittest(*test_classes)
-
- # verify reference counting
- if verbose and hasattr(sys, "gettotalrefcount"):
- import gc
- counts = [None] * 5
- for i in range(len(counts)):
- support.run_unittest(*test_classes)
- gc.collect()
- counts[i] = sys.gettotalrefcount()
- print(counts)
-
if __name__ == "__main__":
- test_main(verbose=True)
+ unittest.main()
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
index 4de0620..4ca7cec 100644
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -234,6 +234,18 @@ class ConstructorTestCase(unittest.TestCase):
except:
self.fail("Standard constructor call raised exception.")
+ def test_with_str_key(self):
+ # Pass a key of type str, which is an error, because it expects a key
+ # of type bytes
+ with self.assertRaises(TypeError):
+ h = hmac.HMAC("key")
+
+ def test_dot_new_with_str_key(self):
+ # Pass a key of type str, which is an error, because it expects a key
+ # of type bytes
+ with self.assertRaises(TypeError):
+ h = hmac.new("key")
+
def test_withtext(self):
# Constructor call with text.
try:
@@ -302,12 +314,121 @@ class CopyTestCase(unittest.TestCase):
self.assertEqual(h1.hexdigest(), h2.hexdigest(),
"Hexdigest of copy doesn't match original hexdigest.")
+class CompareDigestTestCase(unittest.TestCase):
+
+ def test_compare_digest(self):
+ # Testing input type exception handling
+ a, b = 100, 200
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = 100, b"foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = b"foobar", 200
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = "foobar", b"foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = b"foobar", "foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+
+ # Testing bytes of different lengths
+ a, b = b"foobar", b"foo"
+ self.assertFalse(hmac.compare_digest(a, b))
+ a, b = b"\xde\xad\xbe\xef", b"\xde\xad"
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ # Testing bytes of same lengths, different values
+ a, b = b"foobar", b"foobaz"
+ self.assertFalse(hmac.compare_digest(a, b))
+ a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea"
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ # Testing bytes of same lengths, same values
+ a, b = b"foobar", b"foobar"
+ self.assertTrue(hmac.compare_digest(a, b))
+ a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef"
+ self.assertTrue(hmac.compare_digest(a, b))
+
+ # Testing bytearrays of same lengths, same values
+ a, b = bytearray(b"foobar"), bytearray(b"foobar")
+ self.assertTrue(hmac.compare_digest(a, b))
+
+ # Testing bytearrays of diffeent lengths
+ a, b = bytearray(b"foobar"), bytearray(b"foo")
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ # Testing bytearrays of same lengths, different values
+ a, b = bytearray(b"foobar"), bytearray(b"foobaz")
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ # Testing byte and bytearray of same lengths, same values
+ a, b = bytearray(b"foobar"), b"foobar"
+ self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(hmac.compare_digest(b, a))
+
+ # Testing byte bytearray of diffeent lengths
+ a, b = bytearray(b"foobar"), b"foo"
+ self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(hmac.compare_digest(b, a))
+
+ # Testing byte and bytearray of same lengths, different values
+ a, b = bytearray(b"foobar"), b"foobaz"
+ self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(hmac.compare_digest(b, a))
+
+ # Testing str of same lengths
+ a, b = "foobar", "foobar"
+ self.assertTrue(hmac.compare_digest(a, b))
+
+ # Testing str of diffeent lengths
+ a, b = "foo", "foobar"
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ # Testing bytes of same lengths, different values
+ a, b = "foobar", "foobaz"
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ # Testing error cases
+ a, b = "foobar", b"foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = b"foobar", "foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = b"foobar", 1
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = 100, 200
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = "fooä", "fooä"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+
+ # subclasses are supported by ignore __eq__
+ class mystr(str):
+ def __eq__(self, other):
+ return False
+
+ a, b = mystr("foobar"), mystr("foobar")
+ self.assertTrue(hmac.compare_digest(a, b))
+ a, b = mystr("foobar"), "foobar"
+ self.assertTrue(hmac.compare_digest(a, b))
+ a, b = mystr("foobar"), mystr("foobaz")
+ self.assertFalse(hmac.compare_digest(a, b))
+
+ class mybytes(bytes):
+ def __eq__(self, other):
+ return False
+
+ a, b = mybytes(b"foobar"), mybytes(b"foobar")
+ self.assertTrue(hmac.compare_digest(a, b))
+ a, b = mybytes(b"foobar"), b"foobar"
+ self.assertTrue(hmac.compare_digest(a, b))
+ a, b = mybytes(b"foobar"), mybytes(b"foobaz")
+ self.assertFalse(hmac.compare_digest(a, b))
+
+
def test_main():
support.run_unittest(
TestVectorsTestCase,
ConstructorTestCase,
SanityTestCase,
- CopyTestCase
+ CopyTestCase,
+ CompareDigestTestCase
)
if __name__ == "__main__":
diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py
index c4f80cc..c5d878d 100644
--- a/Lib/test/test_htmlparser.py
+++ b/Lib/test/test_htmlparser.py
@@ -102,7 +102,8 @@ class TestCaseBase(unittest.TestCase):
class HTMLParserStrictTestCase(TestCaseBase):
def get_collector(self):
- return EventCollector(strict=True)
+ with support.check_warnings(("", DeprecationWarning), quite=False):
+ return EventCollector(strict=True)
def test_processing_instruction_only(self):
self._run_check("<?processing instruction>", [
@@ -455,7 +456,7 @@ class HTMLParserTolerantTestCase(HTMLParserStrictTestCase):
self._run_check('<form action="/xxx.php?a=1&amp;b=2&amp", '
'method="post">', [
('starttag', 'form',
- [('action', '/xxx.php?a=1&b=2&amp'),
+ [('action', '/xxx.php?a=1&b=2&'),
(',', None), ('method', 'post')])])
def test_weird_chars_in_unquoted_attribute_values(self):
@@ -540,6 +541,11 @@ class HTMLParserTolerantTestCase(HTMLParserStrictTestCase):
self.assertEqual(p.unescape('&#0038;'),'&')
# see #12888
self.assertEqual(p.unescape('&#123; ' * 1050), '{ ' * 1050)
+ # see #15156
+ self.assertEqual(p.unescape('&Eacuteric&Eacute;ric'
+ '&alphacentauri&alpha;centauri'),
+ 'ÉricÉric&alphacentauriαcentauri')
+ self.assertEqual(p.unescape('&co;'), '&co;')
def test_broken_comments(self):
html = ('<! not really a comment >'
@@ -594,7 +600,8 @@ class HTMLParserTolerantTestCase(HTMLParserStrictTestCase):
class AttributesStrictTestCase(TestCaseBase):
def get_collector(self):
- return EventCollector(strict=True)
+ with support.check_warnings(("", DeprecationWarning), quite=False):
+ return EventCollector(strict=True)
def test_attr_syntax(self):
output = [
diff --git a/Lib/test/test_http_cookiejar.py b/Lib/test/test_http_cookiejar.py
index 41e0dfd..a35ec95 100644
--- a/Lib/test/test_http_cookiejar.py
+++ b/Lib/test/test_http_cookiejar.py
@@ -248,18 +248,19 @@ class FileCookieJarTests(unittest.TestCase):
self.assertEqual(c._cookies["www.acme.com"]["/"]["boo"].value, None)
def test_bad_magic(self):
- # IOErrors (eg. file doesn't exist) are allowed to propagate
+ # OSErrors (eg. file doesn't exist) are allowed to propagate
filename = test.support.TESTFN
for cookiejar_class in LWPCookieJar, MozillaCookieJar:
c = cookiejar_class()
try:
c.load(filename="for this test to work, a file with this "
"filename should not exist")
- except IOError as exc:
- # exactly IOError, not LoadError
- self.assertIs(exc.__class__, IOError)
+ except OSError as exc:
+ # an OSError subclass (likely FileNotFoundError), but not
+ # LoadError
+ self.assertIsNot(exc.__class__, LoadError)
else:
- self.fail("expected IOError for invalid filename")
+ self.fail("expected OSError for invalid filename")
# Invalid contents of cookies file (eg. bad magic string)
# causes a LoadError.
try:
diff --git a/Lib/test/test_http_cookies.py b/Lib/test/test_http_cookies.py
index 1f1ca58..e8327e5 100644
--- a/Lib/test/test_http_cookies.py
+++ b/Lib/test/test_http_cookies.py
@@ -34,6 +34,15 @@ class CookieTests(unittest.TestCase):
'dict': {'keebler' : 'E=mc2'},
'repr': "<SimpleCookie: keebler='E=mc2'>",
'output': 'Set-Cookie: keebler=E=mc2'},
+
+ # Cookies with ':' character in their name. Though not mentioned in
+ # RFC, servers / browsers allow it.
+
+ {'data': 'key:term=value:term',
+ 'dict': {'key:term' : 'value:term'},
+ 'repr': "<SimpleCookie: key:term='value:term'>",
+ 'output': 'Set-Cookie: key:term=value:term'},
+
]
for case in cases:
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
index 420302c..db123dc 100644
--- a/Lib/test/test_httplib.py
+++ b/Lib/test/test_httplib.py
@@ -192,6 +192,26 @@ class BasicTest(TestCase):
resp.close()
self.assertTrue(resp.closed)
+ def test_partial_readintos(self):
+ # if we have a length, the system knows when to close itself
+ # same behaviour than when we read the whole thing with read()
+ body = "HTTP/1.1 200 Ok\r\nContent-Length: 4\r\n\r\nText"
+ sock = FakeSocket(body)
+ resp = client.HTTPResponse(sock)
+ resp.begin()
+ b = bytearray(2)
+ n = resp.readinto(b)
+ self.assertEqual(n, 2)
+ self.assertEqual(bytes(b), b'Te')
+ self.assertFalse(resp.isclosed())
+ n = resp.readinto(b)
+ self.assertEqual(n, 2)
+ self.assertEqual(bytes(b), b'xt')
+ self.assertTrue(resp.isclosed())
+ self.assertFalse(resp.closed)
+ resp.close()
+ self.assertTrue(resp.closed)
+
def test_partial_reads_no_content_length(self):
# when no length is present, the socket should be gracefully closed when
# all data was read
@@ -208,6 +228,25 @@ class BasicTest(TestCase):
resp.close()
self.assertTrue(resp.closed)
+ def test_partial_readintos_no_content_length(self):
+ # when no length is present, the socket should be gracefully closed when
+ # all data was read
+ body = "HTTP/1.1 200 Ok\r\n\r\nText"
+ sock = FakeSocket(body)
+ resp = client.HTTPResponse(sock)
+ resp.begin()
+ b = bytearray(2)
+ n = resp.readinto(b)
+ self.assertEqual(n, 2)
+ self.assertEqual(bytes(b), b'Te')
+ self.assertFalse(resp.isclosed())
+ n = resp.readinto(b)
+ self.assertEqual(n, 2)
+ self.assertEqual(bytes(b), b'xt')
+ n = resp.readinto(b)
+ self.assertEqual(n, 0)
+ self.assertTrue(resp.isclosed())
+
def test_partial_reads_incomplete_body(self):
# if the server shuts down the connection before the whole
# content-length is delivered, the socket is gracefully closed
@@ -220,6 +259,25 @@ class BasicTest(TestCase):
self.assertEqual(resp.read(2), b'xt')
self.assertEqual(resp.read(1), b'')
self.assertTrue(resp.isclosed())
+
+ def test_partial_readintos_incomplete_body(self):
+ # if the server shuts down the connection before the whole
+ # content-length is delivered, the socket is gracefully closed
+ body = "HTTP/1.1 200 Ok\r\nContent-Length: 10\r\n\r\nText"
+ sock = FakeSocket(body)
+ resp = client.HTTPResponse(sock)
+ resp.begin()
+ b = bytearray(2)
+ n = resp.readinto(b)
+ self.assertEqual(n, 2)
+ self.assertEqual(bytes(b), b'Te')
+ self.assertFalse(resp.isclosed())
+ n = resp.readinto(b)
+ self.assertEqual(n, 2)
+ self.assertEqual(bytes(b), b'xt')
+ n = resp.readinto(b)
+ self.assertEqual(n, 0)
+ self.assertTrue(resp.isclosed())
self.assertFalse(resp.closed)
resp.close()
self.assertTrue(resp.closed)
@@ -272,6 +330,21 @@ class BasicTest(TestCase):
if resp.read():
self.fail("Did not expect response from HEAD request")
+ def test_readinto_head(self):
+ # Test that the library doesn't attempt to read any data
+ # from a HEAD request. (Tickles SF bug #622042.)
+ sock = FakeSocket(
+ 'HTTP/1.1 200 OK\r\n'
+ 'Content-Length: 14432\r\n'
+ '\r\n',
+ NoEOFStringIO)
+ resp = client.HTTPResponse(sock, method="HEAD")
+ resp.begin()
+ b = bytearray(5)
+ if resp.readinto(b) != 0:
+ self.fail("Did not expect response from HEAD request")
+ self.assertEqual(bytes(b), b'\x00'*5)
+
def test_send_file(self):
expected = (b'GET /foo HTTP/1.1\r\nHost: example.com\r\n'
b'Accept-Encoding: identity\r\nContent-Length:')
@@ -327,15 +400,28 @@ class BasicTest(TestCase):
'Transfer-Encoding: chunked\r\n\r\n'
'a\r\n'
'hello worl\r\n'
- '1\r\n'
- 'd\r\n'
+ '3\r\n'
+ 'd! \r\n'
+ '8\r\n'
+ 'and now \r\n'
+ '22\r\n'
+ 'for something completely different\r\n'
)
+ expected = b'hello world! and now for something completely different'
sock = FakeSocket(chunked_start + '0\r\n')
resp = client.HTTPResponse(sock, method="GET")
resp.begin()
- self.assertEqual(resp.read(), b'hello world')
+ self.assertEqual(resp.read(), expected)
resp.close()
+ # Various read sizes
+ for n in range(1, 12):
+ sock = FakeSocket(chunked_start + '0\r\n')
+ resp = client.HTTPResponse(sock, method="GET")
+ resp.begin()
+ self.assertEqual(resp.read(n) + resp.read(n) + resp.read(), expected)
+ resp.close()
+
for x in ('', 'foo\r\n'):
sock = FakeSocket(chunked_start + x)
resp = client.HTTPResponse(sock, method="GET")
@@ -343,9 +429,64 @@ class BasicTest(TestCase):
try:
resp.read()
except client.IncompleteRead as i:
- self.assertEqual(i.partial, b'hello world')
- self.assertEqual(repr(i),'IncompleteRead(11 bytes read)')
- self.assertEqual(str(i),'IncompleteRead(11 bytes read)')
+ self.assertEqual(i.partial, expected)
+ expected_message = 'IncompleteRead(%d bytes read)' % len(expected)
+ self.assertEqual(repr(i), expected_message)
+ self.assertEqual(str(i), expected_message)
+ else:
+ self.fail('IncompleteRead expected')
+ finally:
+ resp.close()
+
+ def test_readinto_chunked(self):
+ chunked_start = (
+ 'HTTP/1.1 200 OK\r\n'
+ 'Transfer-Encoding: chunked\r\n\r\n'
+ 'a\r\n'
+ 'hello worl\r\n'
+ '3\r\n'
+ 'd! \r\n'
+ '8\r\n'
+ 'and now \r\n'
+ '22\r\n'
+ 'for something completely different\r\n'
+ )
+ expected = b'hello world! and now for something completely different'
+ nexpected = len(expected)
+ b = bytearray(128)
+
+ sock = FakeSocket(chunked_start + '0\r\n')
+ resp = client.HTTPResponse(sock, method="GET")
+ resp.begin()
+ n = resp.readinto(b)
+ self.assertEqual(b[:nexpected], expected)
+ self.assertEqual(n, nexpected)
+ resp.close()
+
+ # Various read sizes
+ for n in range(1, 12):
+ sock = FakeSocket(chunked_start + '0\r\n')
+ resp = client.HTTPResponse(sock, method="GET")
+ resp.begin()
+ m = memoryview(b)
+ i = resp.readinto(m[0:n])
+ i += resp.readinto(m[i:n + i])
+ i += resp.readinto(m[i:])
+ self.assertEqual(b[:nexpected], expected)
+ self.assertEqual(i, nexpected)
+ resp.close()
+
+ for x in ('', 'foo\r\n'):
+ sock = FakeSocket(chunked_start + x)
+ resp = client.HTTPResponse(sock, method="GET")
+ resp.begin()
+ try:
+ n = resp.readinto(b)
+ except client.IncompleteRead as i:
+ self.assertEqual(i.partial, expected)
+ expected_message = 'IncompleteRead(%d bytes read)' % len(expected)
+ self.assertEqual(repr(i), expected_message)
+ self.assertEqual(str(i), expected_message)
else:
self.fail('IncompleteRead expected')
finally:
@@ -371,6 +512,29 @@ class BasicTest(TestCase):
resp.close()
self.assertTrue(resp.closed)
+ def test_readinto_chunked_head(self):
+ chunked_start = (
+ 'HTTP/1.1 200 OK\r\n'
+ 'Transfer-Encoding: chunked\r\n\r\n'
+ 'a\r\n'
+ 'hello world\r\n'
+ '1\r\n'
+ 'd\r\n'
+ )
+ sock = FakeSocket(chunked_start + '0\r\n')
+ resp = client.HTTPResponse(sock, method="HEAD")
+ resp.begin()
+ b = bytearray(5)
+ n = resp.readinto(b)
+ self.assertEqual(n, 0)
+ self.assertEqual(bytes(b), b'\x00'*5)
+ self.assertEqual(resp.status, 200)
+ self.assertEqual(resp.reason, 'OK')
+ self.assertTrue(resp.isclosed())
+ self.assertFalse(resp.closed)
+ resp.close()
+ self.assertTrue(resp.closed)
+
def test_negative_content_length(self):
sock = FakeSocket(
'HTTP/1.1 200 OK\r\nContent-Length: -1\r\n\r\nHello\r\n')
@@ -588,8 +752,7 @@ class HTTPSTest(TestCase):
def test_local_good_hostname(self):
# The (valid) cert validates the HTTP hostname
import ssl
- from test.ssl_servers import make_https_server
- server = make_https_server(self, CERT_localhost)
+ server = self.make_server(CERT_localhost)
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
context.verify_mode = ssl.CERT_REQUIRED
context.load_verify_locations(CERT_localhost)
@@ -597,12 +760,12 @@ class HTTPSTest(TestCase):
h.request('GET', '/nonexistent')
resp = h.getresponse()
self.assertEqual(resp.status, 404)
+ del server
def test_local_bad_hostname(self):
# The (valid) cert doesn't validate the HTTP hostname
import ssl
- from test.ssl_servers import make_https_server
- server = make_https_server(self, CERT_fakehostname)
+ server = self.make_server(CERT_fakehostname)
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
context.verify_mode = ssl.CERT_REQUIRED
context.load_verify_locations(CERT_fakehostname)
@@ -620,6 +783,7 @@ class HTTPSTest(TestCase):
h.request('GET', '/nonexistent')
resp = h.getresponse()
self.assertEqual(resp.status, 404)
+ del server
@unittest.skipIf(not hasattr(client, 'HTTPSConnection'),
'http.client.HTTPSConnection not available')
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index d71da1a..03c0776 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -470,6 +470,23 @@ class RejectingSocketlessRequestHandler(SocketlessRequestHandler):
self.send_error(417)
return False
+
+class AuditableBytesIO:
+
+ def __init__(self):
+ self.datas = []
+
+ def write(self, data):
+ self.datas.append(data)
+
+ def getData(self):
+ return b''.join(self.datas)
+
+ @property
+ def numWrites(self):
+ return len(self.datas)
+
+
class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
"""Test the functionality of the BaseHTTPServer.
@@ -536,27 +553,49 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
self.verify_get_called()
self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n')
- def test_header_buffering(self):
+ def test_header_buffering_of_send_error(self):
- def _readAndReseek(f):
- pos = f.tell()
- f.seek(0)
- data = f.read()
- f.seek(pos)
- return data
+ input = BytesIO(b'GET / HTTP/1.1\r\n\r\n')
+ output = AuditableBytesIO()
+ handler = SocketlessRequestHandler()
+ handler.rfile = input
+ handler.wfile = output
+ handler.request_version = 'HTTP/1.1'
+ handler.requestline = ''
+ handler.command = None
+
+ handler.send_error(418)
+ self.assertEqual(output.numWrites, 2)
+
+ def test_header_buffering_of_send_response_only(self):
input = BytesIO(b'GET / HTTP/1.1\r\n\r\n')
- output = BytesIO()
- self.handler.rfile = input
- self.handler.wfile = output
- self.handler.request_version = 'HTTP/1.1'
+ output = AuditableBytesIO()
+ handler = SocketlessRequestHandler()
+ handler.rfile = input
+ handler.wfile = output
+ handler.request_version = 'HTTP/1.1'
- self.handler.send_header('Foo', 'foo')
- self.handler.send_header('bar', 'bar')
- self.assertEqual(_readAndReseek(output), b'')
- self.handler.end_headers()
- self.assertEqual(_readAndReseek(output),
- b'Foo: foo\r\nbar: bar\r\n\r\n')
+ handler.send_response_only(418)
+ self.assertEqual(output.numWrites, 0)
+ handler.end_headers()
+ self.assertEqual(output.numWrites, 1)
+
+ def test_header_buffering_of_send_header(self):
+
+ input = BytesIO(b'GET / HTTP/1.1\r\n\r\n')
+ output = AuditableBytesIO()
+ handler = SocketlessRequestHandler()
+ handler.rfile = input
+ handler.wfile = output
+ handler.request_version = 'HTTP/1.1'
+
+ handler.send_header('Foo', 'foo')
+ handler.send_header('bar', 'bar')
+ self.assertEqual(output.numWrites, 0)
+ handler.end_headers()
+ self.assertEqual(output.getData(), b'Foo: foo\r\nbar: bar\r\n\r\n')
+ self.assertEqual(output.numWrites, 1)
def test_header_unbuffered_when_continue(self):
diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py
index 6b29943..4b12d23 100644
--- a/Lib/test/test_imaplib.py
+++ b/Lib/test/test_imaplib.py
@@ -11,9 +11,9 @@ import socketserver
import time
import calendar
-from test.support import reap_threads, verbose, transient_internet, run_with_tz
+from test.support import reap_threads, verbose, transient_internet, run_with_tz, run_with_locale
import unittest
-
+from datetime import datetime, timezone, timedelta
try:
import ssl
except ImportError:
@@ -43,14 +43,30 @@ class TestImaplib(unittest.TestCase):
imaplib.Internaldate2tuple(
b'25 (INTERNALDATE "02-Apr-2000 03:30:00 +0000")'))
- def test_that_Time2Internaldate_returns_a_result(self):
- # We can check only that it successfully produces a result,
- # not the correctness of the result itself, since the result
- # depends on the timezone the machine is in.
- timevalues = [2000000000, 2000000000.0, time.localtime(2000000000),
- '"18-May-2033 05:33:20 +0200"']
- for t in timevalues:
+
+ def timevalues(self):
+ return [2000000000, 2000000000.0, time.localtime(2000000000),
+ (2033, 5, 18, 5, 33, 20, -1, -1, -1),
+ (2033, 5, 18, 5, 33, 20, -1, -1, 1),
+ datetime.fromtimestamp(2000000000,
+ timezone(timedelta(0, 2*60*60))),
+ '"18-May-2033 05:33:20 +0200"']
+
+ @run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
+ @run_with_tz('STD-1DST')
+ def test_Time2Internaldate(self):
+ expected = '"18-May-2033 05:33:20 +0200"'
+
+ for t in self.timevalues():
+ internal = imaplib.Time2Internaldate(t)
+ self.assertEqual(internal, expected)
+
+ def test_that_Time2Internaldate_returns_a_result(self):
+ # Without tzset, we can check only that it successfully
+ # produces a result, not the correctness of the result itself,
+ # since the result depends on the timezone the machine is in.
+ for t in self.timevalues():
imaplib.Time2Internaldate(t)
@@ -374,11 +390,58 @@ class RemoteIMAP_SSLTest(RemoteIMAPTest):
port = 993
imap_class = IMAP4_SSL
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def create_ssl_context(self):
+ ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ssl_context.load_cert_chain(CERTFILE)
+ return ssl_context
+
+ def check_logincapa(self, server):
+ try:
+ for cap in server.capabilities:
+ self.assertIsInstance(cap, str)
+ self.assertNotIn('LOGINDISABLED', server.capabilities)
+ self.assertIn('AUTH=PLAIN', server.capabilities)
+ rs = server.login(self.username, self.password)
+ self.assertEqual(rs[0], 'OK')
+ finally:
+ server.logout()
+
def test_logincapa(self):
- for cap in self.server.capabilities:
- self.assertIsInstance(cap, str)
- self.assertNotIn('LOGINDISABLED', self.server.capabilities)
- self.assertIn('AUTH=PLAIN', self.server.capabilities)
+ with transient_internet(self.host):
+ _server = self.imap_class(self.host, self.port)
+ self.check_logincapa(_server)
+
+ def test_logincapa_with_client_certfile(self):
+ with transient_internet(self.host):
+ _server = self.imap_class(self.host, self.port, certfile=CERTFILE)
+ self.check_logincapa(_server)
+
+ def test_logincapa_with_client_ssl_context(self):
+ with transient_internet(self.host):
+ _server = self.imap_class(self.host, self.port, ssl_context=self.create_ssl_context())
+ self.check_logincapa(_server)
+
+ def test_logout(self):
+ with transient_internet(self.host):
+ _server = self.imap_class(self.host, self.port)
+ rs = _server.logout()
+ self.assertEqual(rs[0], 'BYE')
+
+ def test_ssl_context_certfile_exclusive(self):
+ with transient_internet(self.host):
+ self.assertRaises(ValueError, self.imap_class, self.host, self.port,
+ certfile=CERTFILE, ssl_context=self.create_ssl_context())
+
+ def test_ssl_context_keyfile_exclusive(self):
+ with transient_internet(self.host):
+ self.assertRaises(ValueError, self.imap_class, self.host, self.port,
+ keyfile=CERTFILE, ssl_context=self.create_ssl_context())
def test_main():
diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py
index 551ad1b..65c9f25 100644
--- a/Lib/test/test_imp.py
+++ b/Lib/test/test_imp.py
@@ -1,11 +1,12 @@
import imp
+import importlib
import os
import os.path
import shutil
import sys
-import unittest
from test import support
-import importlib
+import unittest
+import warnings
class LockTests(unittest.TestCase):
@@ -58,6 +59,10 @@ class ImportTests(unittest.TestCase):
with imp.find_module('module_' + mod, self.test_path)[0] as fd:
self.assertEqual(fd.encoding, encoding)
+ path = [os.path.dirname(__file__)]
+ with self.assertRaises(SyntaxError):
+ imp.find_module('badsyntax_pep3120', path)
+
def test_issue1267(self):
for mod, encoding, _ in self.test_strings:
fp, filename, info = imp.find_module('module_' + mod,
@@ -150,18 +155,24 @@ class ImportTests(unittest.TestCase):
mod = imp.load_module(temp_mod_name, file, filename, info)
self.assertEqual(mod.a, 1)
- mod = imp.load_source(temp_mod_name, temp_mod_name + '.py')
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ mod = imp.load_source(temp_mod_name, temp_mod_name + '.py')
self.assertEqual(mod.a, 1)
- mod = imp.load_compiled(
- temp_mod_name, imp.cache_from_source(temp_mod_name + '.py'))
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ mod = imp.load_compiled(
+ temp_mod_name, imp.cache_from_source(temp_mod_name + '.py'))
self.assertEqual(mod.a, 1)
if not os.path.exists(test_package_name):
os.mkdir(test_package_name)
with open(init_file_name, 'w') as file:
file.write('b = 2\n')
- package = imp.load_package(test_package_name, test_package_name)
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ package = imp.load_package(test_package_name, test_package_name)
self.assertEqual(package.b, 2)
finally:
del sys.path[0]
@@ -175,6 +186,48 @@ class ImportTests(unittest.TestCase):
self.assertRaises(SyntaxError,
imp.find_module, "badsyntax_pep3120", [path])
+ def test_load_from_source(self):
+ # Verify that the imp module can correctly load and find .py files
+ # XXX (ncoghlan): It would be nice to use support.CleanImport
+ # here, but that breaks because the os module registers some
+ # handlers in copy_reg on import. Since CleanImport doesn't
+ # revert that registration, the module is left in a broken
+ # state after reversion. Reinitialising the module contents
+ # and just reverting os.environ to its previous state is an OK
+ # workaround
+ orig_path = os.path
+ orig_getenv = os.getenv
+ with support.EnvironmentVarGuard():
+ x = imp.find_module("os")
+ self.addCleanup(x[0].close)
+ new_os = imp.load_module("os", *x)
+ self.assertIs(os, new_os)
+ self.assertIs(orig_path, new_os.path)
+ self.assertIsNot(orig_getenv, new_os.getenv)
+
+ @support.cpython_only
+ def test_issue15828_load_extensions(self):
+ # Issue 15828 picked up that the adapter between the old imp API
+ # and importlib couldn't handle C extensions
+ example = "_heapq"
+ x = imp.find_module(example)
+ file_ = x[0]
+ if file_ is not None:
+ self.addCleanup(file_.close)
+ mod = imp.load_module(example, *x)
+ self.assertEqual(mod.__name__, example)
+
+ def test_load_dynamic_ImportError_path(self):
+ # Issue #1559549 added `name` and `path` attributes to ImportError
+ # in order to provide better detail. Issue #10854 implemented those
+ # attributes on import failures of extensions on Windows.
+ path = 'bogus file path'
+ name = 'extension'
+ with self.assertRaises(ImportError) as err:
+ imp.load_dynamic(name, path)
+ self.assertIn(path, err.exception.path)
+ self.assertEqual(name, err.exception.name)
+
class ReloadTests(unittest.TestCase):
@@ -209,72 +262,83 @@ class PEP3147Tests(unittest.TestCase):
tag = imp.get_tag()
+ @unittest.skipUnless(sys.implementation.cache_tag is not None,
+ 'requires sys.implementation.cache_tag not be None')
def test_cache_from_source(self):
# Given the path to a .py file, return the path to its PEP 3147
# defined .pyc file (i.e. under __pycache__).
- self.assertEqual(
- imp.cache_from_source('/foo/bar/baz/qux.py', True),
- '/foo/bar/baz/__pycache__/qux.{}.pyc'.format(self.tag))
+ path = os.path.join('foo', 'bar', 'baz', 'qux.py')
+ expect = os.path.join('foo', 'bar', 'baz', '__pycache__',
+ 'qux.{}.pyc'.format(self.tag))
+ self.assertEqual(imp.cache_from_source(path, True), expect)
+
+ def test_cache_from_source_no_cache_tag(self):
+ # Non cache tag means NotImplementedError.
+ with support.swap_attr(sys.implementation, 'cache_tag', None):
+ with self.assertRaises(NotImplementedError):
+ imp.cache_from_source('whatever.py')
+
+ def test_cache_from_source_no_dot(self):
+ # Directory with a dot, filename without dot.
+ path = os.path.join('foo.bar', 'file')
+ expect = os.path.join('foo.bar', '__pycache__',
+ 'file{}.pyc'.format(self.tag))
+ self.assertEqual(imp.cache_from_source(path, True), expect)
def test_cache_from_source_optimized(self):
# Given the path to a .py file, return the path to its PEP 3147
# defined .pyo file (i.e. under __pycache__).
- self.assertEqual(
- imp.cache_from_source('/foo/bar/baz/qux.py', False),
- '/foo/bar/baz/__pycache__/qux.{}.pyo'.format(self.tag))
+ path = os.path.join('foo', 'bar', 'baz', 'qux.py')
+ expect = os.path.join('foo', 'bar', 'baz', '__pycache__',
+ 'qux.{}.pyo'.format(self.tag))
+ self.assertEqual(imp.cache_from_source(path, False), expect)
def test_cache_from_source_cwd(self):
- self.assertEqual(imp.cache_from_source('foo.py', True),
- os.sep.join(('__pycache__',
- 'foo.{}.pyc'.format(self.tag))))
+ path = 'foo.py'
+ expect = os.path.join('__pycache__', 'foo.{}.pyc'.format(self.tag))
+ self.assertEqual(imp.cache_from_source(path, True), expect)
def test_cache_from_source_override(self):
# When debug_override is not None, it can be any true-ish or false-ish
# value.
- self.assertEqual(
- imp.cache_from_source('/foo/bar/baz.py', []),
- '/foo/bar/__pycache__/baz.{}.pyo'.format(self.tag))
- self.assertEqual(
- imp.cache_from_source('/foo/bar/baz.py', [17]),
- '/foo/bar/__pycache__/baz.{}.pyc'.format(self.tag))
+ path = os.path.join('foo', 'bar', 'baz.py')
+ partial_expect = os.path.join('foo', 'bar', '__pycache__',
+ 'baz.{}.py'.format(self.tag))
+ self.assertEqual(imp.cache_from_source(path, []), partial_expect + 'o')
+ self.assertEqual(imp.cache_from_source(path, [17]),
+ partial_expect + 'c')
# However if the bool-ishness can't be determined, the exception
# propagates.
class Bearish:
def __bool__(self): raise RuntimeError
- self.assertRaises(
- RuntimeError,
- imp.cache_from_source, '/foo/bar/baz.py', Bearish())
-
- @unittest.skipIf(os.altsep is None,
- 'test meaningful only where os.altsep is defined')
- def test_altsep_cache_from_source(self):
- # Windows path and PEP 3147.
- self.assertEqual(
- imp.cache_from_source('\\foo\\bar\\baz\\qux.py', True),
- '\\foo\\bar\\baz\\__pycache__\\qux.{}.pyc'.format(self.tag))
+ with self.assertRaises(RuntimeError):
+ imp.cache_from_source('/foo/bar/baz.py', Bearish())
- @unittest.skipIf(os.altsep is None,
- 'test meaningful only where os.altsep is defined')
- def test_altsep_and_sep_cache_from_source(self):
- # Windows path and PEP 3147 where altsep is right of sep.
- self.assertEqual(
- imp.cache_from_source('\\foo\\bar/baz\\qux.py', True),
- '\\foo\\bar/baz\\__pycache__\\qux.{}.pyc'.format(self.tag))
-
- @unittest.skipIf(os.altsep is None,
+ @unittest.skipUnless(os.sep == '\\' and os.altsep == '/',
'test meaningful only where os.altsep is defined')
def test_sep_altsep_and_sep_cache_from_source(self):
# Windows path and PEP 3147 where sep is right of altsep.
self.assertEqual(
imp.cache_from_source('\\foo\\bar\\baz/qux.py', True),
- '\\foo\\bar\\baz/__pycache__/qux.{}.pyc'.format(self.tag))
+ '\\foo\\bar\\baz\\__pycache__\\qux.{}.pyc'.format(self.tag))
+ @unittest.skipUnless(sys.implementation.cache_tag is not None,
+ 'requires sys.implementation.cache_tag to not be '
+ 'None')
def test_source_from_cache(self):
# Given the path to a PEP 3147 defined .pyc file, return the path to
# its source. This tests the good path.
- self.assertEqual(imp.source_from_cache(
- '/foo/bar/baz/__pycache__/qux.{}.pyc'.format(self.tag)),
- '/foo/bar/baz/qux.py')
+ path = os.path.join('foo', 'bar', 'baz', '__pycache__',
+ 'qux.{}.pyc'.format(self.tag))
+ expect = os.path.join('foo', 'bar', 'baz', 'qux.py')
+ self.assertEqual(imp.source_from_cache(path), expect)
+
+ def test_source_from_cache_no_cache_tag(self):
+ # If sys.implementation.cache_tag is None, raise NotImplementedError.
+ path = os.path.join('blah', '__pycache__', 'whatever.pyc')
+ with support.swap_attr(sys.implementation, 'cache_tag', None):
+ with self.assertRaises(NotImplementedError):
+ imp.source_from_cache(path)
def test_source_from_cache_bad_path(self):
# When the path to a pyc file is not in PEP 3147 format, a ValueError
@@ -305,6 +369,12 @@ class PEP3147Tests(unittest.TestCase):
'/foo/bar/foo.cpython-32.foo.pyc')
def test_package___file__(self):
+ try:
+ m = __import__('pep3147')
+ except ImportError:
+ pass
+ else:
+ self.fail("pep3147 module already exists: %r" % (m,))
# Test that a package's __file__ points to the right source directory.
os.mkdir('pep3147')
sys.path.insert(0, os.curdir)
@@ -314,14 +384,16 @@ class PEP3147Tests(unittest.TestCase):
shutil.rmtree('pep3147')
self.addCleanup(cleanup)
# Touch the __init__.py file.
- with open('pep3147/__init__.py', 'w'):
- pass
+ support.create_empty_file('pep3147/__init__.py')
+ importlib.invalidate_caches()
+ expected___file__ = os.sep.join(('.', 'pep3147', '__init__.py'))
m = __import__('pep3147')
+ self.assertEqual(m.__file__, expected___file__, (m.__file__, m.__path__, sys.path, sys.path_importer_cache))
# Ensure we load the pyc file.
- support.forget('pep3147')
+ support.unload('pep3147')
m = __import__('pep3147')
- self.assertEqual(m.__file__,
- os.sep.join(('.', 'pep3147', '__init__.py')))
+ support.unload('pep3147')
+ self.assertEqual(m.__file__, expected___file__, (m.__file__, m.__path__, sys.path, sys.path_importer_cache))
class NullImporterTests(unittest.TestCase):
diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py
index b10f350..b0e5a95 100644
--- a/Lib/test/test_import.py
+++ b/Lib/test/test_import.py
@@ -1,7 +1,8 @@
+# We import importlib *ASAP* in order to test #15386
+import importlib
import builtins
import imp
-from importlib.test.import_ import test_relative_imports
-from importlib.test.import_ import util as importlib_util
+from test.test_importlib.import_ import util as importlib_util
import marshal
import os
import platform
@@ -12,39 +13,50 @@ import sys
import unittest
import textwrap
import errno
+import shutil
+import contextlib
+import test.support
from test.support import (
EnvironmentVarGuard, TESTFN, check_warnings, forget, is_jython,
make_legacy_pyc, rmtree, run_unittest, swap_attr, swap_item, temp_umask,
- unlink, unload)
+ unlink, unload, create_empty_file, cpython_only)
from test import script_helper
-def _files(name):
- return (name + os.extsep + "py",
- name + os.extsep + "pyc",
- name + os.extsep + "pyo",
- name + os.extsep + "pyw",
- name + "$py.class")
-
-def chmod_files(name):
- for f in _files(name):
- try:
- os.chmod(f, 0o600)
- except OSError as exc:
- if exc.errno != errno.ENOENT:
- raise
-
def remove_files(name):
- for f in _files(name):
+ for f in (name + ".py",
+ name + ".pyc",
+ name + ".pyo",
+ name + ".pyw",
+ name + "$py.class"):
unlink(f)
rmtree('__pycache__')
+@contextlib.contextmanager
+def _ready_to_import(name=None, source=""):
+ # sets up a temporary directory and removes it
+ # creates the module file
+ # temporarily clears the module from sys.modules (if any)
+ name = name or "spam"
+ with script_helper.temp_dir() as tempdir:
+ path = script_helper.make_script(tempdir, name, source)
+ old_module = sys.modules.pop(name, None)
+ try:
+ sys.path.insert(0, tempdir)
+ yield name, path
+ sys.path.remove(tempdir)
+ finally:
+ if old_module is not None:
+ sys.modules[name] = old_module
+
+
class ImportTests(unittest.TestCase):
def setUp(self):
remove_files(TESTFN)
+ importlib.invalidate_caches()
def tearDown(self):
unload(TESTFN)
@@ -82,6 +94,7 @@ class ImportTests(unittest.TestCase):
if TESTFN in sys.modules:
del sys.modules[TESTFN]
+ importlib.invalidate_caches()
try:
try:
mod = __import__(TESTFN)
@@ -107,90 +120,6 @@ class ImportTests(unittest.TestCase):
finally:
del sys.path[0]
- @unittest.skipUnless(os.name == 'posix',
- "test meaningful only on posix systems")
- def test_execute_bit_not_copied(self):
- # Issue 6070: under posix .pyc files got their execute bit set if
- # the .py file had the execute bit set, but they aren't executable.
- with temp_umask(0o022):
- sys.path.insert(0, os.curdir)
- try:
- fname = TESTFN + os.extsep + "py"
- open(fname, 'w').close()
- os.chmod(fname, (stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH |
- stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH))
- fn = imp.cache_from_source(fname)
- unlink(fn)
- __import__(TESTFN)
- if not os.path.exists(fn):
- self.fail("__import__ did not result in creation of "
- "either a .pyc or .pyo file")
- s = os.stat(fn)
- self.assertEqual(stat.S_IMODE(s.st_mode),
- stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
- finally:
- del sys.path[0]
- remove_files(TESTFN)
- unload(TESTFN)
-
- def test_rewrite_pyc_with_read_only_source(self):
- # Issue 6074: a long time ago on posix, and more recently on Windows,
- # a read only source file resulted in a read only pyc file, which
- # led to problems with updating it later
- sys.path.insert(0, os.curdir)
- fname = TESTFN + os.extsep + "py"
- try:
- # Write a Python file, make it read-only and import it
- with open(fname, 'w') as f:
- f.write("x = 'original'\n")
- # Tweak the mtime of the source to ensure pyc gets updated later
- s = os.stat(fname)
- os.utime(fname, (s.st_atime, s.st_mtime-100000000))
- os.chmod(fname, 0o400)
- m1 = __import__(TESTFN)
- self.assertEqual(m1.x, 'original')
- # Change the file and then reimport it
- os.chmod(fname, 0o600)
- with open(fname, 'w') as f:
- f.write("x = 'rewritten'\n")
- unload(TESTFN)
- m2 = __import__(TESTFN)
- self.assertEqual(m2.x, 'rewritten')
- # Now delete the source file and check the pyc was rewritten
- unlink(fname)
- unload(TESTFN)
- if __debug__:
- bytecode_name = fname + "c"
- else:
- bytecode_name = fname + "o"
- os.rename(imp.cache_from_source(fname), bytecode_name)
- m3 = __import__(TESTFN)
- self.assertEqual(m3.x, 'rewritten')
- finally:
- chmod_files(TESTFN)
- remove_files(TESTFN)
- unload(TESTFN)
- del sys.path[0]
-
- def test_imp_module(self):
- # Verify that the imp module can correctly load and find .py files
- # XXX (ncoghlan): It would be nice to use support.CleanImport
- # here, but that breaks because the os module registers some
- # handlers in copy_reg on import. Since CleanImport doesn't
- # revert that registration, the module is left in a broken
- # state after reversion. Reinitialising the module contents
- # and just reverting os.environ to its previous state is an OK
- # workaround
- orig_path = os.path
- orig_getenv = os.getenv
- with EnvironmentVarGuard():
- x = imp.find_module("os")
- self.addCleanup(x[0].close)
- new_os = imp.load_module("os", *x)
- self.assertIs(os, new_os)
- self.assertIs(orig_path, new_os.path)
- self.assertIsNot(orig_getenv, new_os.getenv)
-
def test_bug7732(self):
source = TESTFN + '.py'
os.mkdir(source)
@@ -220,6 +149,7 @@ class ImportTests(unittest.TestCase):
# Need to be able to load from current dir.
sys.path.append('')
+ importlib.invalidate_caches()
try:
make_legacy_pyc(filename)
@@ -239,6 +169,7 @@ class ImportTests(unittest.TestCase):
# New in 2.4, we shouldn't be able to import that no matter how often
# we try.
sys.path.insert(0, os.curdir)
+ importlib.invalidate_caches()
if TESTFN in sys.modules:
del sys.modules[TESTFN]
try:
@@ -312,6 +243,7 @@ class ImportTests(unittest.TestCase):
os.remove(source)
del sys.modules[TESTFN]
make_legacy_pyc(source)
+ importlib.invalidate_caches()
mod = __import__(TESTFN)
base, ext = os.path.splitext(mod.__file__)
self.assertIn(ext, ('.pyc', '.pyo'))
@@ -332,12 +264,6 @@ class ImportTests(unittest.TestCase):
import test.support as y
self.assertIs(y, test.support, y.__name__)
- def test_import_initless_directory_warning(self):
- with check_warnings(('', ImportWarning)):
- # Just a random non-package directory we always expect to be
- # somewhere in sys.path...
- self.assertRaises(ImportError, __import__, "site-packages")
-
def test_import_by_filename(self):
path = os.path.abspath(TESTFN)
encoding = sys.getfilesystemencoding()
@@ -347,8 +273,6 @@ class ImportTests(unittest.TestCase):
self.skipTest('path is not encodable to {}'.format(encoding))
with self.assertRaises(ImportError) as c:
__import__(path)
- self.assertEqual("Import by filename is not supported.",
- c.exception.args[0])
def test_import_in_del_does_not_crash(self):
# Issue 4236
@@ -385,6 +309,98 @@ class ImportTests(unittest.TestCase):
del sys.path[0]
remove_files(TESTFN)
+ def test_bogus_fromlist(self):
+ try:
+ __import__('http', fromlist=['blah'])
+ except ImportError:
+ self.fail("fromlist must allow bogus names")
+
+
+class FilePermissionTests(unittest.TestCase):
+ # tests for file mode on cached .pyc/.pyo files
+
+ @unittest.skipUnless(os.name == 'posix',
+ "test meaningful only on posix systems")
+ def test_creation_mode(self):
+ mask = 0o022
+ with temp_umask(mask), _ready_to_import() as (name, path):
+ cached_path = imp.cache_from_source(path)
+ module = __import__(name)
+ if not os.path.exists(cached_path):
+ self.fail("__import__ did not result in creation of "
+ "either a .pyc or .pyo file")
+ stat_info = os.stat(cached_path)
+
+ # Check that the umask is respected, and the executable bits
+ # aren't set.
+ self.assertEqual(oct(stat.S_IMODE(stat_info.st_mode)),
+ oct(0o666 & ~mask))
+
+ @unittest.skipUnless(os.name == 'posix',
+ "test meaningful only on posix systems")
+ def test_cached_mode_issue_2051(self):
+ # permissions of .pyc should match those of .py, regardless of mask
+ mode = 0o600
+ with temp_umask(0o022), _ready_to_import() as (name, path):
+ cached_path = imp.cache_from_source(path)
+ os.chmod(path, mode)
+ __import__(name)
+ if not os.path.exists(cached_path):
+ self.fail("__import__ did not result in creation of "
+ "either a .pyc or .pyo file")
+ stat_info = os.stat(cached_path)
+
+ self.assertEqual(oct(stat.S_IMODE(stat_info.st_mode)), oct(mode))
+
+ @unittest.skipUnless(os.name == 'posix',
+ "test meaningful only on posix systems")
+ def test_cached_readonly(self):
+ mode = 0o400
+ with temp_umask(0o022), _ready_to_import() as (name, path):
+ cached_path = imp.cache_from_source(path)
+ os.chmod(path, mode)
+ __import__(name)
+ if not os.path.exists(cached_path):
+ self.fail("__import__ did not result in creation of "
+ "either a .pyc or .pyo file")
+ stat_info = os.stat(cached_path)
+
+ expected = mode | 0o200 # Account for fix for issue #6074
+ self.assertEqual(oct(stat.S_IMODE(stat_info.st_mode)), oct(expected))
+
+ def test_pyc_always_writable(self):
+ # Initially read-only .pyc files on Windows used to cause problems
+ # with later updates, see issue #6074 for details
+ with _ready_to_import() as (name, path):
+ # Write a Python file, make it read-only and import it
+ with open(path, 'w') as f:
+ f.write("x = 'original'\n")
+ # Tweak the mtime of the source to ensure pyc gets updated later
+ s = os.stat(path)
+ os.utime(path, (s.st_atime, s.st_mtime-100000000))
+ os.chmod(path, 0o400)
+ m = __import__(name)
+ self.assertEqual(m.x, 'original')
+ # Change the file and then reimport it
+ os.chmod(path, 0o600)
+ with open(path, 'w') as f:
+ f.write("x = 'rewritten'\n")
+ unload(name)
+ importlib.invalidate_caches()
+ m = __import__(name)
+ self.assertEqual(m.x, 'rewritten')
+ # Now delete the source file and check the pyc was rewritten
+ unlink(path)
+ unload(name)
+ importlib.invalidate_caches()
+ if __debug__:
+ bytecode_only = path + "c"
+ else:
+ bytecode_only = path + "o"
+ os.rename(imp.cache_from_source(path), bytecode_only)
+ m = __import__(name)
+ self.assertEqual(m.x, 'rewritten')
+
class PycRewritingTests(unittest.TestCase):
# Test that the `co_filename` attribute on code objects always points
@@ -412,6 +428,7 @@ func_filename = func.__code__.co_filename
with open(self.file_name, "w") as f:
f.write(self.module_source)
sys.path.insert(0, self.dir_name)
+ importlib.invalidate_caches()
def tearDown(self):
sys.path[:] = self.sys_path
@@ -451,6 +468,7 @@ func_filename = func.__code__.co_filename
py_compile.compile(self.file_name, dfile=target)
os.remove(self.file_name)
pyc_file = make_legacy_pyc(self.file_name)
+ importlib.invalidate_caches()
mod = self.import_module()
self.assertEqual(mod.module_filename, pyc_file)
self.assertEqual(mod.code_filename, target)
@@ -459,7 +477,7 @@ func_filename = func.__code__.co_filename
def test_foreign_code(self):
py_compile.compile(self.file_name)
with open(self.compiled_name, "rb") as f:
- header = f.read(8)
+ header = f.read(12)
code = marshal.load(f)
constants = list(code.co_consts)
foreign_code = test_main.__code__
@@ -501,9 +519,11 @@ class PathsTests(unittest.TestCase):
unload("test_trailing_slash")
# Regression test for http://bugs.python.org/issue3677.
- def _test_UNC_path(self):
- with open(os.path.join(self.path, 'test_trailing_slash.py'), 'w') as f:
- f.write("testdata = 'test_trailing_slash'")
+ @unittest.skipUnless(sys.platform == 'win32', 'Windows-specific')
+ def test_UNC_path(self):
+ with open(os.path.join(self.path, 'test_unc_path.py'), 'w') as f:
+ f.write("testdata = 'test_unc_path'")
+ importlib.invalidate_caches()
# Create the UNC path, like \\myhost\c$\foo\bar.
path = os.path.abspath(self.path)
import socket
@@ -518,13 +538,15 @@ class PathsTests(unittest.TestCase):
# See issue #15338
self.skipTest("cannot access administrative share %r" % (unc,))
raise
- sys.path.append(path)
- mod = __import__("test_trailing_slash")
- self.assertEqual(mod.testdata, 'test_trailing_slash')
- unload("test_trailing_slash")
-
- if sys.platform == "win32":
- test_UNC_path = _test_UNC_path
+ sys.path.insert(0, unc)
+ try:
+ mod = __import__("test_unc_path")
+ except ImportError as e:
+ self.fail("could not import 'test_unc_path' from %r: %r"
+ % (unc, e))
+ self.assertEqual(mod.testdata, 'test_unc_path')
+ self.assertTrue(mod.__file__.startswith(unc), mod.__file__)
+ unload("test_unc_path")
class RelativeImportTests(unittest.TestCase):
@@ -565,7 +587,7 @@ class RelativeImportTests(unittest.TestCase):
# Check relative import fails with package set to a non-string
ns = dict(__package__=object())
- self.assertRaises(ValueError, check_relative)
+ self.assertRaises(TypeError, check_relative)
def test_absolute_import_without_future(self):
# If explicit relative import syntax is used, then do not try
@@ -613,6 +635,7 @@ class PycacheTests(unittest.TestCase):
with open(self.source, 'w') as fp:
print('# This is a test file written by test_import.py', file=fp)
sys.path.insert(0, os.curdir)
+ importlib.invalidate_caches()
def tearDown(self):
assert sys.path[0] == os.curdir, 'Unexpected sys.path[0]'
@@ -660,6 +683,7 @@ class PycacheTests(unittest.TestCase):
pyc_file = make_legacy_pyc(self.source)
os.remove(self.source)
unload(TESTFN)
+ importlib.invalidate_caches()
m = __import__(TESTFN)
self.assertEqual(m.__file__,
os.path.join(os.curdir, os.path.relpath(pyc_file)))
@@ -680,6 +704,7 @@ class PycacheTests(unittest.TestCase):
pyc_file = make_legacy_pyc(self.source)
os.remove(self.source)
unload(TESTFN)
+ importlib.invalidate_caches()
m = __import__(TESTFN)
self.assertEqual(m.__cached__,
os.path.join(os.curdir, os.path.relpath(pyc_file)))
@@ -688,6 +713,8 @@ class PycacheTests(unittest.TestCase):
# Like test___cached__ but for packages.
def cleanup():
rmtree('pep3147')
+ unload('pep3147.foo')
+ unload('pep3147')
os.mkdir('pep3147')
self.addCleanup(cleanup)
# Touch the __init__.py
@@ -695,8 +722,7 @@ class PycacheTests(unittest.TestCase):
pass
with open(os.path.join('pep3147', 'foo.py'), 'w'):
pass
- unload('pep3147.foo')
- unload('pep3147')
+ importlib.invalidate_caches()
m = __import__('pep3147.foo')
init_pyc = imp.cache_from_source(
os.path.join('pep3147', '__init__.py'))
@@ -710,18 +736,20 @@ class PycacheTests(unittest.TestCase):
# PEP 3147 pyc file.
def cleanup():
rmtree('pep3147')
+ unload('pep3147.foo')
+ unload('pep3147')
os.mkdir('pep3147')
self.addCleanup(cleanup)
- unload('pep3147.foo')
- unload('pep3147')
# Touch the __init__.py
with open(os.path.join('pep3147', '__init__.py'), 'w'):
pass
with open(os.path.join('pep3147', 'foo.py'), 'w'):
pass
+ importlib.invalidate_caches()
m = __import__('pep3147.foo')
unload('pep3147.foo')
unload('pep3147')
+ importlib.invalidate_caches()
m = __import__('pep3147.foo')
init_pyc = imp.cache_from_source(
os.path.join('pep3147', '__init__.py'))
@@ -730,22 +758,256 @@ class PycacheTests(unittest.TestCase):
self.assertEqual(sys.modules['pep3147.foo'].__cached__,
os.path.join(os.curdir, foo_pyc))
+ def test_recompute_pyc_same_second(self):
+ # Even when the source file doesn't change timestamp, a change in
+ # source size is enough to trigger recomputation of the pyc file.
+ __import__(TESTFN)
+ unload(TESTFN)
+ with open(self.source, 'a') as fp:
+ print("x = 5", file=fp)
+ m = __import__(TESTFN)
+ self.assertEqual(m.x, 5)
+
-class RelativeImportFromImportlibTests(test_relative_imports.RelativeImports):
+class TestSymbolicallyLinkedPackage(unittest.TestCase):
+ package_name = 'sample'
+ tagged = package_name + '-tagged'
def setUp(self):
- self._importlib_util_flag = importlib_util.using___import__
- importlib_util.using___import__ = True
+ test.support.rmtree(self.tagged)
+ test.support.rmtree(self.package_name)
+ self.orig_sys_path = sys.path[:]
+
+ # create a sample package; imagine you have a package with a tag and
+ # you want to symbolically link it from its untagged name.
+ os.mkdir(self.tagged)
+ self.addCleanup(test.support.rmtree, self.tagged)
+ init_file = os.path.join(self.tagged, '__init__.py')
+ test.support.create_empty_file(init_file)
+ assert os.path.exists(init_file)
+
+ # now create a symlink to the tagged package
+ # sample -> sample-tagged
+ os.symlink(self.tagged, self.package_name, target_is_directory=True)
+ self.addCleanup(test.support.unlink, self.package_name)
+ importlib.invalidate_caches()
+
+ self.assertEqual(os.path.isdir(self.package_name), True)
+
+ assert os.path.isfile(os.path.join(self.package_name, '__init__.py'))
def tearDown(self):
- importlib_util.using___import__ = self._importlib_util_flag
+ sys.path[:] = self.orig_sys_path
+
+ # regression test for issue6727
+ @unittest.skipUnless(
+ not hasattr(sys, 'getwindowsversion')
+ or sys.getwindowsversion() >= (6, 0),
+ "Windows Vista or later required")
+ @test.support.skip_unless_symlink
+ def test_symlinked_dir_importable(self):
+ # make sure sample can only be imported from the current directory.
+ sys.path[:] = ['.']
+ assert os.path.exists(self.package_name)
+ assert os.path.exists(os.path.join(self.package_name, '__init__.py'))
+
+ # Try to import the package
+ importlib.import_module(self.package_name)
+
+
+@cpython_only
+class ImportlibBootstrapTests(unittest.TestCase):
+ # These tests check that importlib is bootstrapped.
+
+ def test_frozen_importlib(self):
+ mod = sys.modules['_frozen_importlib']
+ self.assertTrue(mod)
+
+ def test_frozen_importlib_is_bootstrap(self):
+ from importlib import _bootstrap
+ mod = sys.modules['_frozen_importlib']
+ self.assertIs(mod, _bootstrap)
+ self.assertEqual(mod.__name__, 'importlib._bootstrap')
+ self.assertEqual(mod.__package__, 'importlib')
+ self.assertTrue(mod.__file__.endswith('_bootstrap.py'), mod.__file__)
+
+ def test_there_can_be_only_one(self):
+ # Issue #15386 revealed a tricky loophole in the bootstrapping
+ # This test is technically redundant, since the bug caused importing
+ # this test module to crash completely, but it helps prove the point
+ from importlib import machinery
+ mod = sys.modules['_frozen_importlib']
+ self.assertIs(machinery.FileFinder, mod.FileFinder)
+ self.assertIs(imp.new_module, mod.new_module)
+
+
+class ImportTracebackTests(unittest.TestCase):
+
+ def setUp(self):
+ os.mkdir(TESTFN)
+ self.old_path = sys.path[:]
+ sys.path.insert(0, TESTFN)
+
+ def tearDown(self):
+ sys.path[:] = self.old_path
+ rmtree(TESTFN)
+
+ def create_module(self, mod, contents, ext=".py"):
+ fname = os.path.join(TESTFN, mod + ext)
+ with open(fname, "w") as f:
+ f.write(contents)
+ self.addCleanup(unload, mod)
+ importlib.invalidate_caches()
+ return fname
+
+ def assert_traceback(self, tb, files):
+ deduped_files = []
+ while tb:
+ code = tb.tb_frame.f_code
+ fn = code.co_filename
+ if not deduped_files or fn != deduped_files[-1]:
+ deduped_files.append(fn)
+ tb = tb.tb_next
+ self.assertEqual(len(deduped_files), len(files), deduped_files)
+ for fn, pat in zip(deduped_files, files):
+ self.assertIn(pat, fn)
+
+ def test_nonexistent_module(self):
+ try:
+ # assertRaises() clears __traceback__
+ import nonexistent_xyzzy
+ except ImportError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ImportError should have been raised")
+ self.assert_traceback(tb, [__file__])
+
+ def test_nonexistent_module_nested(self):
+ self.create_module("foo", "import nonexistent_xyzzy")
+ try:
+ import foo
+ except ImportError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ImportError should have been raised")
+ self.assert_traceback(tb, [__file__, 'foo.py'])
+
+ def test_exec_failure(self):
+ self.create_module("foo", "1/0")
+ try:
+ import foo
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ZeroDivisionError should have been raised")
+ self.assert_traceback(tb, [__file__, 'foo.py'])
+
+ def test_exec_failure_nested(self):
+ self.create_module("foo", "import bar")
+ self.create_module("bar", "1/0")
+ try:
+ import foo
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ZeroDivisionError should have been raised")
+ self.assert_traceback(tb, [__file__, 'foo.py', 'bar.py'])
+
+ # A few more examples from issue #15425
+ def test_syntax_error(self):
+ self.create_module("foo", "invalid syntax is invalid")
+ try:
+ import foo
+ except SyntaxError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("SyntaxError should have been raised")
+ self.assert_traceback(tb, [__file__])
+
+ def _setup_broken_package(self, parent, child):
+ pkg_name = "_parent_foo"
+ self.addCleanup(unload, pkg_name)
+ pkg_path = os.path.join(TESTFN, pkg_name)
+ os.mkdir(pkg_path)
+ # Touch the __init__.py
+ init_path = os.path.join(pkg_path, '__init__.py')
+ with open(init_path, 'w') as f:
+ f.write(parent)
+ bar_path = os.path.join(pkg_path, 'bar.py')
+ with open(bar_path, 'w') as f:
+ f.write(child)
+ importlib.invalidate_caches()
+ return init_path, bar_path
+
+ def test_broken_submodule(self):
+ init_path, bar_path = self._setup_broken_package("", "1/0")
+ try:
+ import _parent_foo.bar
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ZeroDivisionError should have been raised")
+ self.assert_traceback(tb, [__file__, bar_path])
+
+ def test_broken_from(self):
+ init_path, bar_path = self._setup_broken_package("", "1/0")
+ try:
+ from _parent_foo import bar
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ImportError should have been raised")
+ self.assert_traceback(tb, [__file__, bar_path])
+
+ def test_broken_parent(self):
+ init_path, bar_path = self._setup_broken_package("1/0", "")
+ try:
+ import _parent_foo.bar
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ZeroDivisionError should have been raised")
+ self.assert_traceback(tb, [__file__, init_path])
+
+ def test_broken_parent_from(self):
+ init_path, bar_path = self._setup_broken_package("1/0", "")
+ try:
+ from _parent_foo import bar
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ZeroDivisionError should have been raised")
+ self.assert_traceback(tb, [__file__, init_path])
+
+ @cpython_only
+ def test_import_bug(self):
+ # We simulate a bug in importlib and check that it's not stripped
+ # away from the traceback.
+ self.create_module("foo", "")
+ importlib = sys.modules['_frozen_importlib']
+ old_load_module = importlib.SourceLoader.load_module
+ try:
+ def load_module(*args):
+ 1/0
+ importlib.SourceLoader.load_module = load_module
+ try:
+ import foo
+ except ZeroDivisionError as e:
+ tb = e.__traceback__
+ else:
+ self.fail("ZeroDivisionError should have been raised")
+ self.assert_traceback(tb, [__file__, '<frozen importlib', __file__])
+ finally:
+ importlib.SourceLoader.load_module = old_load_module
def test_main(verbose=None):
- run_unittest(ImportTests, PycacheTests,
+ run_unittest(ImportTests, PycacheTests, FilePermissionTests,
PycRewritingTests, PathsTests, RelativeImportTests,
OverridingImportBuiltinTests,
- RelativeImportFromImportlibTests)
+ ImportlibBootstrapTests,
+ TestSymbolicallyLinkedPackage,
+ ImportTracebackTests)
if __name__ == '__main__':
diff --git a/Lib/test/test_importhooks.py b/Lib/test/test_importhooks.py
index ec6730e..2a22d1a 100644
--- a/Lib/test/test_importhooks.py
+++ b/Lib/test/test_importhooks.py
@@ -51,7 +51,7 @@ class TestImporter:
def __init__(self, path=test_path):
if path != test_path:
- # if out class is on sys.path_hooks, we must raise
+ # if our class is on sys.path_hooks, we must raise
# ImportError for any path item that we can't handle.
raise ImportError
self.path = path
@@ -215,7 +215,7 @@ class ImportHooksTestCase(ImportHooksBaseTestCase):
self.doTestImports(i)
def testPathHook(self):
- sys.path_hooks.append(PathImporter)
+ sys.path_hooks.insert(0, PathImporter)
sys.path.append(test_path)
self.doTestImports()
@@ -228,8 +228,10 @@ class ImportHooksTestCase(ImportHooksBaseTestCase):
def testImpWrapper(self):
i = ImpWrapper()
sys.meta_path.append(i)
- sys.path_hooks.append(ImpWrapper)
- mnames = ("colorsys", "urllib.parse", "distutils.core")
+ sys.path_hooks.insert(0, ImpWrapper)
+ mnames = (
+ "colorsys", "urllib.parse", "distutils.core", "sys",
+ )
for mname in mnames:
parent = mname.split(".")[0]
for n in list(sys.modules):
@@ -237,7 +239,8 @@ class ImportHooksTestCase(ImportHooksBaseTestCase):
del sys.modules[n]
for mname in mnames:
m = __import__(mname, globals(), locals(), ["__dummy__"])
- m.__loader__ # to make sure we actually handled the import
+ # to make sure we actually handled the import
+ self.assertTrue(hasattr(m, "__loader__"))
def test_main():
diff --git a/Lib/test/test_importlib.py b/Lib/test/test_importlib.py
deleted file mode 100644
index 6ed0585..0000000
--- a/Lib/test/test_importlib.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from importlib.test.__main__ import test_main
-
-
-if __name__ == '__main__':
- test_main()
diff --git a/Lib/test/test_importlib/__init__.py b/Lib/test/test_importlib/__init__.py
new file mode 100644
index 0000000..0e345cd
--- /dev/null
+++ b/Lib/test/test_importlib/__init__.py
@@ -0,0 +1,33 @@
+import os
+import sys
+from test import support
+import unittest
+
+def test_suite(package=__package__, directory=os.path.dirname(__file__)):
+ suite = unittest.TestSuite()
+ for name in os.listdir(directory):
+ if name.startswith(('.', '__')):
+ continue
+ path = os.path.join(directory, name)
+ if (os.path.isfile(path) and name.startswith('test_') and
+ name.endswith('.py')):
+ submodule_name = os.path.splitext(name)[0]
+ module_name = "{0}.{1}".format(package, submodule_name)
+ __import__(module_name, level=0)
+ module_tests = unittest.findTestCases(sys.modules[module_name])
+ suite.addTest(module_tests)
+ elif os.path.isdir(path):
+ package_name = "{0}.{1}".format(package, name)
+ __import__(package_name, level=0)
+ package_tests = getattr(sys.modules[package_name], 'test_suite')()
+ suite.addTest(package_tests)
+ else:
+ continue
+ return suite
+
+
+def test_main():
+ start_dir = os.path.dirname(__file__)
+ top_dir = os.path.dirname(os.path.dirname(start_dir))
+ test_loader = unittest.TestLoader()
+ support.run_unittest(test_loader.discover(start_dir, top_level_dir=top_dir))
diff --git a/Lib/test/test_importlib/__main__.py b/Lib/test/test_importlib/__main__.py
new file mode 100644
index 0000000..c397128
--- /dev/null
+++ b/Lib/test/test_importlib/__main__.py
@@ -0,0 +1,20 @@
+"""Run importlib's test suite.
+
+Specifying the ``--builtin`` flag will run tests, where applicable, with
+builtins.__import__ instead of importlib.__import__.
+
+"""
+from . import test_main
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Execute the importlib test '
+ 'suite')
+ parser.add_argument('-b', '--builtin', action='store_true', default=False,
+ help='use builtins.__import__() instead of importlib')
+ args = parser.parse_args()
+ if args.builtin:
+ util.using___import__ = True
+ test_main()
diff --git a/Lib/test/test_importlib/abc.py b/Lib/test/test_importlib/abc.py
new file mode 100644
index 0000000..2c17ac3
--- /dev/null
+++ b/Lib/test/test_importlib/abc.py
@@ -0,0 +1,99 @@
+import abc
+import unittest
+
+
+class FinderTests(unittest.TestCase, metaclass=abc.ABCMeta):
+
+ """Basic tests for a finder to pass."""
+
+ @abc.abstractmethod
+ def test_module(self):
+ # Test importing a top-level module.
+ pass
+
+ @abc.abstractmethod
+ def test_package(self):
+ # Test importing a package.
+ pass
+
+ @abc.abstractmethod
+ def test_module_in_package(self):
+ # Test importing a module contained within a package.
+ # A value for 'path' should be used if for a meta_path finder.
+ pass
+
+ @abc.abstractmethod
+ def test_package_in_package(self):
+ # Test importing a subpackage.
+ # A value for 'path' should be used if for a meta_path finder.
+ pass
+
+ @abc.abstractmethod
+ def test_package_over_module(self):
+ # Test that packages are chosen over modules.
+ pass
+
+ @abc.abstractmethod
+ def test_failure(self):
+ # Test trying to find a module that cannot be handled.
+ pass
+
+
+class LoaderTests(unittest.TestCase, metaclass=abc.ABCMeta):
+
+ @abc.abstractmethod
+ def test_module(self):
+ """A module should load without issue.
+
+ After the loader returns the module should be in sys.modules.
+
+ Attributes to verify:
+
+ * __file__
+ * __loader__
+ * __name__
+ * No __path__
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def test_package(self):
+ """Loading a package should work.
+
+ After the loader returns the module should be in sys.modules.
+
+ Attributes to verify:
+
+ * __name__
+ * __file__
+ * __package__
+ * __path__
+ * __loader__
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def test_lacking_parent(self):
+ """A loader should not be dependent on it's parent package being
+ imported."""
+ pass
+
+ @abc.abstractmethod
+ def test_module_reuse(self):
+ """If a module is already in sys.modules, it should be reused."""
+ pass
+
+ @abc.abstractmethod
+ def test_state_after_failure(self):
+ """If a module is already in sys.modules and a reload fails
+ (e.g. a SyntaxError), the module should be in the state it was before
+ the reload began."""
+ pass
+
+ @abc.abstractmethod
+ def test_unloadable(self):
+ """Test ImportError is raised when the loader is asked to load a module
+ it can't."""
+ pass
diff --git a/Lib/test/test_importlib/builtin/__init__.py b/Lib/test/test_importlib/builtin/__init__.py
new file mode 100644
index 0000000..15c0ade
--- /dev/null
+++ b/Lib/test/test_importlib/builtin/__init__.py
@@ -0,0 +1,12 @@
+from .. import test_suite
+import os
+
+
+def test_suite():
+ directory = os.path.dirname(__file__)
+ return test_suite('importlib.test.builtin', directory)
+
+
+if __name__ == '__main__':
+ from test.support import run_unittest
+ run_unittest(test_suite())
diff --git a/Lib/test/test_importlib/builtin/test_finder.py b/Lib/test/test_importlib/builtin/test_finder.py
new file mode 100644
index 0000000..146538d
--- /dev/null
+++ b/Lib/test/test_importlib/builtin/test_finder.py
@@ -0,0 +1,55 @@
+from importlib import machinery
+from .. import abc
+from .. import util
+from . import util as builtin_util
+
+import sys
+import unittest
+
+class FinderTests(abc.FinderTests):
+
+ """Test find_module() for built-in modules."""
+
+ def test_module(self):
+ # Common case.
+ with util.uncache(builtin_util.NAME):
+ found = machinery.BuiltinImporter.find_module(builtin_util.NAME)
+ self.assertTrue(found)
+
+ def test_package(self):
+ # Built-in modules cannot be a package.
+ pass
+
+ def test_module_in_package(self):
+ # Built-in modules cannobt be in a package.
+ pass
+
+ def test_package_in_package(self):
+ # Built-in modules cannot be a package.
+ pass
+
+ def test_package_over_module(self):
+ # Built-in modules cannot be a package.
+ pass
+
+ def test_failure(self):
+ assert 'importlib' not in sys.builtin_module_names
+ loader = machinery.BuiltinImporter.find_module('importlib')
+ self.assertIsNone(loader)
+
+ def test_ignore_path(self):
+ # The value for 'path' should always trigger a failed import.
+ with util.uncache(builtin_util.NAME):
+ loader = machinery.BuiltinImporter.find_module(builtin_util.NAME,
+ ['pkg'])
+ self.assertIsNone(loader)
+
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(FinderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/builtin/test_loader.py b/Lib/test/test_importlib/builtin/test_loader.py
new file mode 100644
index 0000000..8e186e7
--- /dev/null
+++ b/Lib/test/test_importlib/builtin/test_loader.py
@@ -0,0 +1,105 @@
+import importlib
+from importlib import machinery
+from .. import abc
+from .. import util
+from . import util as builtin_util
+
+import sys
+import types
+import unittest
+
+
+class LoaderTests(abc.LoaderTests):
+
+ """Test load_module() for built-in modules."""
+
+ verification = {'__name__': 'errno', '__package__': '',
+ '__loader__': machinery.BuiltinImporter}
+
+ def verify(self, module):
+ """Verify that the module matches against what it should have."""
+ self.assertIsInstance(module, types.ModuleType)
+ for attr, value in self.verification.items():
+ self.assertEqual(getattr(module, attr), value)
+ self.assertIn(module.__name__, sys.modules)
+
+ load_module = staticmethod(lambda name:
+ machinery.BuiltinImporter.load_module(name))
+
+ def test_module(self):
+ # Common case.
+ with util.uncache(builtin_util.NAME):
+ module = self.load_module(builtin_util.NAME)
+ self.verify(module)
+
+ def test_package(self):
+ # Built-in modules cannot be a package.
+ pass
+
+ def test_lacking_parent(self):
+ # Built-in modules cannot be a package.
+ pass
+
+ def test_state_after_failure(self):
+ # Not way to force an imoprt failure.
+ pass
+
+ def test_module_reuse(self):
+ # Test that the same module is used in a reload.
+ with util.uncache(builtin_util.NAME):
+ module1 = self.load_module(builtin_util.NAME)
+ module2 = self.load_module(builtin_util.NAME)
+ self.assertIs(module1, module2)
+
+ def test_unloadable(self):
+ name = 'dssdsdfff'
+ assert name not in sys.builtin_module_names
+ with self.assertRaises(ImportError) as cm:
+ self.load_module(name)
+ self.assertEqual(cm.exception.name, name)
+
+ def test_already_imported(self):
+ # Using the name of a module already imported but not a built-in should
+ # still fail.
+ assert hasattr(importlib, '__file__') # Not a built-in.
+ with self.assertRaises(ImportError) as cm:
+ self.load_module('importlib')
+ self.assertEqual(cm.exception.name, 'importlib')
+
+
+class InspectLoaderTests(unittest.TestCase):
+
+ """Tests for InspectLoader methods for BuiltinImporter."""
+
+ def test_get_code(self):
+ # There is no code object.
+ result = machinery.BuiltinImporter.get_code(builtin_util.NAME)
+ self.assertIsNone(result)
+
+ def test_get_source(self):
+ # There is no source.
+ result = machinery.BuiltinImporter.get_source(builtin_util.NAME)
+ self.assertIsNone(result)
+
+ def test_is_package(self):
+ # Cannot be a package.
+ result = machinery.BuiltinImporter.is_package(builtin_util.NAME)
+ self.assertTrue(not result)
+
+ def test_not_builtin(self):
+ # Modules not built-in should raise ImportError.
+ for meth_name in ('get_code', 'get_source', 'is_package'):
+ method = getattr(machinery.BuiltinImporter, meth_name)
+ with self.assertRaises(ImportError) as cm:
+ method(builtin_util.BAD_NAME)
+ self.assertRaises(builtin_util.BAD_NAME)
+
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(LoaderTests, InspectLoaderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/builtin/util.py b/Lib/test/test_importlib/builtin/util.py
new file mode 100644
index 0000000..5704699
--- /dev/null
+++ b/Lib/test/test_importlib/builtin/util.py
@@ -0,0 +1,7 @@
+import sys
+
+assert 'errno' in sys.builtin_module_names
+NAME = 'errno'
+
+assert 'importlib' not in sys.builtin_module_names
+BAD_NAME = 'importlib'
diff --git a/Lib/test/test_importlib/extension/__init__.py b/Lib/test/test_importlib/extension/__init__.py
new file mode 100644
index 0000000..c033923
--- /dev/null
+++ b/Lib/test/test_importlib/extension/__init__.py
@@ -0,0 +1,13 @@
+from .. import test_suite
+import os.path
+import unittest
+
+
+def test_suite():
+ directory = os.path.dirname(__file__)
+ return test_suite('importlib.test.extension', directory)
+
+
+if __name__ == '__main__':
+ from test.support import run_unittest
+ run_unittest(test_suite())
diff --git a/Lib/test/test_importlib/extension/test_case_sensitivity.py b/Lib/test/test_importlib/extension/test_case_sensitivity.py
new file mode 100644
index 0000000..76c53e4
--- /dev/null
+++ b/Lib/test/test_importlib/extension/test_case_sensitivity.py
@@ -0,0 +1,50 @@
+import imp
+import sys
+from test import support
+import unittest
+from importlib import _bootstrap
+from .. import util
+from . import util as ext_util
+
+
+@util.case_insensitive_tests
+class ExtensionModuleCaseSensitivityTest(unittest.TestCase):
+
+ def find_module(self):
+ good_name = ext_util.NAME
+ bad_name = good_name.upper()
+ assert good_name != bad_name
+ finder = _bootstrap.FileFinder(ext_util.PATH,
+ (_bootstrap.ExtensionFileLoader,
+ _bootstrap.EXTENSION_SUFFIXES))
+ return finder.find_module(bad_name)
+
+ def test_case_sensitive(self):
+ with support.EnvironmentVarGuard() as env:
+ env.unset('PYTHONCASEOK')
+ if b'PYTHONCASEOK' in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
+ loader = self.find_module()
+ self.assertIsNone(loader)
+
+ def test_case_insensitivity(self):
+ with support.EnvironmentVarGuard() as env:
+ env.set('PYTHONCASEOK', '1')
+ if b'PYTHONCASEOK' not in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
+ loader = self.find_module()
+ self.assertTrue(hasattr(loader, 'load_module'))
+
+
+
+
+def test_main():
+ if ext_util.FILENAME is None:
+ return
+ support.run_unittest(ExtensionModuleCaseSensitivityTest)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/extension/test_finder.py b/Lib/test/test_importlib/extension/test_finder.py
new file mode 100644
index 0000000..a63cfdb
--- /dev/null
+++ b/Lib/test/test_importlib/extension/test_finder.py
@@ -0,0 +1,46 @@
+from importlib import machinery
+from .. import abc
+from . import util
+
+import unittest
+
+class FinderTests(abc.FinderTests):
+
+ """Test the finder for extension modules."""
+
+ def find_module(self, fullname):
+ importer = machinery.FileFinder(util.PATH,
+ (machinery.ExtensionFileLoader,
+ machinery.EXTENSION_SUFFIXES))
+ return importer.find_module(fullname)
+
+ def test_module(self):
+ self.assertTrue(self.find_module(util.NAME))
+
+ def test_package(self):
+ # No extension module as an __init__ available for testing.
+ pass
+
+ def test_module_in_package(self):
+ # No extension module in a package available for testing.
+ pass
+
+ def test_package_in_package(self):
+ # No extension module as an __init__ available for testing.
+ pass
+
+ def test_package_over_module(self):
+ # Extension modules cannot be an __init__ for a package.
+ pass
+
+ def test_failure(self):
+ self.assertIsNone(self.find_module('asdfjkl;'))
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(FinderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/extension/test_loader.py b/Lib/test/test_importlib/extension/test_loader.py
new file mode 100644
index 0000000..ca5af20
--- /dev/null
+++ b/Lib/test/test_importlib/extension/test_loader.py
@@ -0,0 +1,79 @@
+from importlib import machinery
+from . import util as ext_util
+from .. import abc
+from .. import util
+
+import os.path
+import sys
+import unittest
+
+
+class LoaderTests(abc.LoaderTests):
+
+ """Test load_module() for extension modules."""
+
+ def setUp(self):
+ self.loader = machinery.ExtensionFileLoader(ext_util.NAME,
+ ext_util.FILEPATH)
+
+ def load_module(self, fullname):
+ return self.loader.load_module(fullname)
+
+ def test_load_module_API(self):
+ # Test the default argument for load_module().
+ self.loader.load_module()
+ self.loader.load_module(None)
+ with self.assertRaises(ImportError):
+ self.load_module('XXX')
+
+
+ def test_module(self):
+ with util.uncache(ext_util.NAME):
+ module = self.load_module(ext_util.NAME)
+ for attr, value in [('__name__', ext_util.NAME),
+ ('__file__', ext_util.FILEPATH),
+ ('__package__', '')]:
+ self.assertEqual(getattr(module, attr), value)
+ self.assertIn(ext_util.NAME, sys.modules)
+ self.assertIsInstance(module.__loader__,
+ machinery.ExtensionFileLoader)
+
+ def test_package(self):
+ # No extension module as __init__ available for testing.
+ pass
+
+ def test_lacking_parent(self):
+ # No extension module in a package available for testing.
+ pass
+
+ def test_module_reuse(self):
+ with util.uncache(ext_util.NAME):
+ module1 = self.load_module(ext_util.NAME)
+ module2 = self.load_module(ext_util.NAME)
+ self.assertIs(module1, module2)
+
+ def test_state_after_failure(self):
+ # No easy way to trigger a failure after a successful import.
+ pass
+
+ def test_unloadable(self):
+ name = 'asdfjkl;'
+ with self.assertRaises(ImportError) as cm:
+ self.load_module(name)
+ self.assertEqual(cm.exception.name, name)
+
+ def test_is_package(self):
+ self.assertFalse(self.loader.is_package(ext_util.NAME))
+ for suffix in machinery.EXTENSION_SUFFIXES:
+ path = os.path.join('some', 'path', 'pkg', '__init__' + suffix)
+ loader = machinery.ExtensionFileLoader('pkg', path)
+ self.assertTrue(loader.is_package('pkg'))
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(LoaderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/extension/test_path_hook.py b/Lib/test/test_importlib/extension/test_path_hook.py
new file mode 100644
index 0000000..1d969a1
--- /dev/null
+++ b/Lib/test/test_importlib/extension/test_path_hook.py
@@ -0,0 +1,32 @@
+from importlib import machinery
+from . import util
+
+import collections
+import imp
+import sys
+import unittest
+
+
+class PathHookTests(unittest.TestCase):
+
+ """Test the path hook for extension modules."""
+ # XXX Should it only succeed for pre-existing directories?
+ # XXX Should it only work for directories containing an extension module?
+
+ def hook(self, entry):
+ return machinery.FileFinder.path_hook((machinery.ExtensionFileLoader,
+ machinery.EXTENSION_SUFFIXES))(entry)
+
+ def test_success(self):
+ # Path hook should handle a directory where a known extension module
+ # exists.
+ self.assertTrue(hasattr(self.hook(util.PATH), 'find_module'))
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(PathHookTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/extension/util.py b/Lib/test/test_importlib/extension/util.py
new file mode 100644
index 0000000..a266dd9
--- /dev/null
+++ b/Lib/test/test_importlib/extension/util.py
@@ -0,0 +1,20 @@
+import imp
+from importlib import machinery
+import os
+import sys
+
+PATH = None
+EXT = None
+FILENAME = None
+NAME = '_testcapi'
+try:
+ for PATH in sys.path:
+ for EXT in machinery.EXTENSION_SUFFIXES:
+ FILENAME = NAME + EXT
+ FILEPATH = os.path.join(PATH, FILENAME)
+ if os.path.exists(os.path.join(PATH, FILENAME)):
+ raise StopIteration
+ else:
+ PATH = EXT = FILENAME = FILEPATH = None
+except StopIteration:
+ pass
diff --git a/Lib/test/test_importlib/frozen/__init__.py b/Lib/test/test_importlib/frozen/__init__.py
new file mode 100644
index 0000000..9ef103b
--- /dev/null
+++ b/Lib/test/test_importlib/frozen/__init__.py
@@ -0,0 +1,13 @@
+from .. import test_suite
+import os.path
+import unittest
+
+
+def test_suite():
+ directory = os.path.dirname(__file__)
+ return test_suite('importlib.test.frozen', directory)
+
+
+if __name__ == '__main__':
+ from test.support import run_unittest
+ run_unittest(test_suite())
diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py
new file mode 100644
index 0000000..fa0c2a0
--- /dev/null
+++ b/Lib/test/test_importlib/frozen/test_finder.py
@@ -0,0 +1,47 @@
+from importlib import machinery
+from .. import abc
+
+import unittest
+
+
+class FinderTests(abc.FinderTests):
+
+ """Test finding frozen modules."""
+
+ def find(self, name, path=None):
+ finder = machinery.FrozenImporter
+ return finder.find_module(name, path)
+
+ def test_module(self):
+ name = '__hello__'
+ loader = self.find(name)
+ self.assertTrue(hasattr(loader, 'load_module'))
+
+ def test_package(self):
+ loader = self.find('__phello__')
+ self.assertTrue(hasattr(loader, 'load_module'))
+
+ def test_module_in_package(self):
+ loader = self.find('__phello__.spam', ['__phello__'])
+ self.assertTrue(hasattr(loader, 'load_module'))
+
+ def test_package_in_package(self):
+ # No frozen package within another package to test with.
+ pass
+
+ def test_package_over_module(self):
+ # No easy way to test.
+ pass
+
+ def test_failure(self):
+ loader = self.find('<not real>')
+ self.assertIsNone(loader)
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(FinderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py
new file mode 100644
index 0000000..4b8ec15
--- /dev/null
+++ b/Lib/test/test_importlib/frozen/test_loader.py
@@ -0,0 +1,121 @@
+from importlib import machinery
+import imp
+import unittest
+from .. import abc
+from .. import util
+from test.support import captured_stdout
+
+class LoaderTests(abc.LoaderTests):
+
+ def test_module(self):
+ with util.uncache('__hello__'), captured_stdout() as stdout:
+ module = machinery.FrozenImporter.load_module('__hello__')
+ check = {'__name__': '__hello__',
+ '__package__': '',
+ '__loader__': machinery.FrozenImporter,
+ }
+ for attr, value in check.items():
+ self.assertEqual(getattr(module, attr), value)
+ self.assertEqual(stdout.getvalue(), 'Hello world!\n')
+ self.assertFalse(hasattr(module, '__file__'))
+
+ def test_package(self):
+ with util.uncache('__phello__'), captured_stdout() as stdout:
+ module = machinery.FrozenImporter.load_module('__phello__')
+ check = {'__name__': '__phello__',
+ '__package__': '__phello__',
+ '__path__': ['__phello__'],
+ '__loader__': machinery.FrozenImporter,
+ }
+ for attr, value in check.items():
+ attr_value = getattr(module, attr)
+ self.assertEqual(attr_value, value,
+ "for __phello__.%s, %r != %r" %
+ (attr, attr_value, value))
+ self.assertEqual(stdout.getvalue(), 'Hello world!\n')
+ self.assertFalse(hasattr(module, '__file__'))
+
+ def test_lacking_parent(self):
+ with util.uncache('__phello__', '__phello__.spam'), \
+ captured_stdout() as stdout:
+ module = machinery.FrozenImporter.load_module('__phello__.spam')
+ check = {'__name__': '__phello__.spam',
+ '__package__': '__phello__',
+ '__loader__': machinery.FrozenImporter,
+ }
+ for attr, value in check.items():
+ attr_value = getattr(module, attr)
+ self.assertEqual(attr_value, value,
+ "for __phello__.spam.%s, %r != %r" %
+ (attr, attr_value, value))
+ self.assertEqual(stdout.getvalue(), 'Hello world!\n')
+ self.assertFalse(hasattr(module, '__file__'))
+
+ def test_module_reuse(self):
+ with util.uncache('__hello__'), captured_stdout() as stdout:
+ module1 = machinery.FrozenImporter.load_module('__hello__')
+ module2 = machinery.FrozenImporter.load_module('__hello__')
+ self.assertIs(module1, module2)
+ self.assertEqual(stdout.getvalue(),
+ 'Hello world!\nHello world!\n')
+
+ def test_module_repr(self):
+ with util.uncache('__hello__'), captured_stdout():
+ module = machinery.FrozenImporter.load_module('__hello__')
+ self.assertEqual(repr(module),
+ "<module '__hello__' (frozen)>")
+
+ def test_state_after_failure(self):
+ # No way to trigger an error in a frozen module.
+ pass
+
+ def test_unloadable(self):
+ assert machinery.FrozenImporter.find_module('_not_real') is None
+ with self.assertRaises(ImportError) as cm:
+ machinery.FrozenImporter.load_module('_not_real')
+ self.assertEqual(cm.exception.name, '_not_real')
+
+
+class InspectLoaderTests(unittest.TestCase):
+
+ """Tests for the InspectLoader methods for FrozenImporter."""
+
+ def test_get_code(self):
+ # Make sure that the code object is good.
+ name = '__hello__'
+ with captured_stdout() as stdout:
+ code = machinery.FrozenImporter.get_code(name)
+ mod = imp.new_module(name)
+ exec(code, mod.__dict__)
+ self.assertTrue(hasattr(mod, 'initialized'))
+ self.assertEqual(stdout.getvalue(), 'Hello world!\n')
+
+ def test_get_source(self):
+ # Should always return None.
+ result = machinery.FrozenImporter.get_source('__hello__')
+ self.assertIsNone(result)
+
+ def test_is_package(self):
+ # Should be able to tell what is a package.
+ test_for = (('__hello__', False), ('__phello__', True),
+ ('__phello__.spam', False))
+ for name, is_package in test_for:
+ result = machinery.FrozenImporter.is_package(name)
+ self.assertEqual(bool(result), is_package)
+
+ def test_failure(self):
+ # Raise ImportError for modules that are not frozen.
+ for meth_name in ('get_code', 'get_source', 'is_package'):
+ method = getattr(machinery.FrozenImporter, meth_name)
+ with self.assertRaises(ImportError) as cm:
+ method('importlib')
+ self.assertEqual(cm.exception.name, 'importlib')
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(LoaderTests, InspectLoaderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/__init__.py b/Lib/test/test_importlib/import_/__init__.py
new file mode 100644
index 0000000..366e531
--- /dev/null
+++ b/Lib/test/test_importlib/import_/__init__.py
@@ -0,0 +1,13 @@
+from .. import test_suite
+import os.path
+import unittest
+
+
+def test_suite():
+ directory = os.path.dirname(__file__)
+ return test_suite('importlib.test.import_', directory)
+
+
+if __name__ == '__main__':
+ from test.support import run_unittest
+ run_unittest(test_suite())
diff --git a/Lib/test/test_importlib/import_/test___package__.py b/Lib/test/test_importlib/import_/test___package__.py
new file mode 100644
index 0000000..783cde1
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test___package__.py
@@ -0,0 +1,119 @@
+"""PEP 366 ("Main module explicit relative imports") specifies the
+semantics for the __package__ attribute on modules. This attribute is
+used, when available, to detect which package a module belongs to (instead
+of using the typical __path__/__name__ test).
+
+"""
+import unittest
+from .. import util
+from . import util as import_util
+
+
+class Using__package__(unittest.TestCase):
+
+ """Use of __package__ supercedes the use of __name__/__path__ to calculate
+ what package a module belongs to. The basic algorithm is [__package__]::
+
+ def resolve_name(name, package, level):
+ level -= 1
+ base = package.rsplit('.', level)[0]
+ return '{0}.{1}'.format(base, name)
+
+ But since there is no guarantee that __package__ has been set (or not been
+ set to None [None]), there has to be a way to calculate the attribute's value
+ [__name__]::
+
+ def calc_package(caller_name, has___path__):
+ if has__path__:
+ return caller_name
+ else:
+ return caller_name.rsplit('.', 1)[0]
+
+ Then the normal algorithm for relative name imports can proceed as if
+ __package__ had been set.
+
+ """
+
+ def test_using___package__(self):
+ # [__package__]
+ with util.mock_modules('pkg.__init__', 'pkg.fake') as importer:
+ with util.import_state(meta_path=[importer]):
+ import_util.import_('pkg.fake')
+ module = import_util.import_('',
+ globals={'__package__': 'pkg.fake'},
+ fromlist=['attr'], level=2)
+ self.assertEqual(module.__name__, 'pkg')
+
+ def test_using___name__(self, package_as_None=False):
+ # [__name__]
+ globals_ = {'__name__': 'pkg.fake', '__path__': []}
+ if package_as_None:
+ globals_['__package__'] = None
+ with util.mock_modules('pkg.__init__', 'pkg.fake') as importer:
+ with util.import_state(meta_path=[importer]):
+ import_util.import_('pkg.fake')
+ module = import_util.import_('', globals= globals_,
+ fromlist=['attr'], level=2)
+ self.assertEqual(module.__name__, 'pkg')
+
+ def test_None_as___package__(self):
+ # [None]
+ self.test_using___name__(package_as_None=True)
+
+ def test_bad__package__(self):
+ globals = {'__package__': '<not real>'}
+ with self.assertRaises(SystemError):
+ import_util.import_('', globals, {}, ['relimport'], 1)
+
+ def test_bunk__package__(self):
+ globals = {'__package__': 42}
+ with self.assertRaises(TypeError):
+ import_util.import_('', globals, {}, ['relimport'], 1)
+
+
+@import_util.importlib_only
+class Setting__package__(unittest.TestCase):
+
+ """Because __package__ is a new feature, it is not always set by a loader.
+ Import will set it as needed to help with the transition to relying on
+ __package__.
+
+ For a top-level module, __package__ is set to None [top-level]. For a
+ package __name__ is used for __package__ [package]. For submodules the
+ value is __name__.rsplit('.', 1)[0] [submodule].
+
+ """
+
+ # [top-level]
+ def test_top_level(self):
+ with util.mock_modules('top_level') as mock:
+ with util.import_state(meta_path=[mock]):
+ del mock['top_level'].__package__
+ module = import_util.import_('top_level')
+ self.assertEqual(module.__package__, '')
+
+ # [package]
+ def test_package(self):
+ with util.mock_modules('pkg.__init__') as mock:
+ with util.import_state(meta_path=[mock]):
+ del mock['pkg'].__package__
+ module = import_util.import_('pkg')
+ self.assertEqual(module.__package__, 'pkg')
+
+ # [submodule]
+ def test_submodule(self):
+ with util.mock_modules('pkg.__init__', 'pkg.mod') as mock:
+ with util.import_state(meta_path=[mock]):
+ del mock['pkg.mod'].__package__
+ pkg = import_util.import_('pkg.mod')
+ module = getattr(pkg, 'mod')
+ self.assertEqual(module.__package__, 'pkg')
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(Using__package__, Setting__package__)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_api.py b/Lib/test/test_importlib/import_/test_api.py
new file mode 100644
index 0000000..3d4cd94
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_api.py
@@ -0,0 +1,67 @@
+from .. import util as importlib_test_util
+from . import util
+import imp
+import sys
+import unittest
+
+
+class BadLoaderFinder:
+ bad = 'fine.bogus'
+ @classmethod
+ def find_module(cls, fullname, path):
+ if fullname == cls.bad:
+ return cls
+ @classmethod
+ def load_module(cls, fullname):
+ if fullname == cls.bad:
+ raise ImportError('I cannot be loaded!')
+
+
+class APITest(unittest.TestCase):
+
+ """Test API-specific details for __import__ (e.g. raising the right
+ exception when passing in an int for the module name)."""
+
+ def test_name_requires_rparition(self):
+ # Raise TypeError if a non-string is passed in for the module name.
+ with self.assertRaises(TypeError):
+ util.import_(42)
+
+ def test_negative_level(self):
+ # Raise ValueError when a negative level is specified.
+ # PEP 328 did away with sys.module None entries and the ambiguity of
+ # absolute/relative imports.
+ with self.assertRaises(ValueError):
+ util.import_('os', globals(), level=-1)
+
+ def test_nonexistent_fromlist_entry(self):
+ # If something in fromlist doesn't exist, that's okay.
+ # issue15715
+ mod = imp.new_module('fine')
+ mod.__path__ = ['XXX']
+ with importlib_test_util.import_state(meta_path=[BadLoaderFinder]):
+ with importlib_test_util.uncache('fine'):
+ sys.modules['fine'] = mod
+ util.import_('fine', fromlist=['not here'])
+
+ def test_fromlist_load_error_propagates(self):
+ # If something in fromlist triggers an exception not related to not
+ # existing, let that exception propagate.
+ # issue15316
+ mod = imp.new_module('fine')
+ mod.__path__ = ['XXX']
+ with importlib_test_util.import_state(meta_path=[BadLoaderFinder]):
+ with importlib_test_util.uncache('fine'):
+ sys.modules['fine'] = mod
+ with self.assertRaises(ImportError):
+ util.import_('fine', fromlist=['bogus'])
+
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(APITest)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_caching.py b/Lib/test/test_importlib/import_/test_caching.py
new file mode 100644
index 0000000..bf68027
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_caching.py
@@ -0,0 +1,88 @@
+"""Test that sys.modules is used properly by import."""
+from .. import util
+from . import util as import_util
+import sys
+from types import MethodType
+import unittest
+
+
+class UseCache(unittest.TestCase):
+
+ """When it comes to sys.modules, import prefers it over anything else.
+
+ Once a name has been resolved, sys.modules is checked to see if it contains
+ the module desired. If so, then it is returned [use cache]. If it is not
+ found, then the proper steps are taken to perform the import, but
+ sys.modules is still used to return the imported module (e.g., not what a
+ loader returns) [from cache on return]. This also applies to imports of
+ things contained within a package and thus get assigned as an attribute
+ [from cache to attribute] or pulled in thanks to a fromlist import
+ [from cache for fromlist]. But if sys.modules contains None then
+ ImportError is raised [None in cache].
+
+ """
+ def test_using_cache(self):
+ # [use cache]
+ module_to_use = "some module found!"
+ with util.uncache(module_to_use):
+ sys.modules['some_module'] = module_to_use
+ module = import_util.import_('some_module')
+ self.assertEqual(id(module_to_use), id(module))
+
+ def test_None_in_cache(self):
+ #[None in cache]
+ name = 'using_None'
+ with util.uncache(name):
+ sys.modules[name] = None
+ with self.assertRaises(ImportError) as cm:
+ import_util.import_(name)
+ self.assertEqual(cm.exception.name, name)
+
+ def create_mock(self, *names, return_=None):
+ mock = util.mock_modules(*names)
+ original_load = mock.load_module
+ def load_module(self, fullname):
+ original_load(fullname)
+ return return_
+ mock.load_module = MethodType(load_module, mock)
+ return mock
+
+ # __import__ inconsistent between loaders and built-in import when it comes
+ # to when to use the module in sys.modules and when not to.
+ @import_util.importlib_only
+ def test_using_cache_after_loader(self):
+ # [from cache on return]
+ with self.create_mock('module') as mock:
+ with util.import_state(meta_path=[mock]):
+ module = import_util.import_('module')
+ self.assertEqual(id(module), id(sys.modules['module']))
+
+ # See test_using_cache_after_loader() for reasoning.
+ @import_util.importlib_only
+ def test_using_cache_for_assigning_to_attribute(self):
+ # [from cache to attribute]
+ with self.create_mock('pkg.__init__', 'pkg.module') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('pkg.module')
+ self.assertTrue(hasattr(module, 'module'))
+ self.assertEqual(id(module.module),
+ id(sys.modules['pkg.module']))
+
+ # See test_using_cache_after_loader() for reasoning.
+ @import_util.importlib_only
+ def test_using_cache_for_fromlist(self):
+ # [from cache for fromlist]
+ with self.create_mock('pkg.__init__', 'pkg.module') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('pkg', fromlist=['module'])
+ self.assertTrue(hasattr(module, 'module'))
+ self.assertEqual(id(module.module),
+ id(sys.modules['pkg.module']))
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(UseCache)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_fromlist.py b/Lib/test/test_importlib/import_/test_fromlist.py
new file mode 100644
index 0000000..c16c337
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_fromlist.py
@@ -0,0 +1,127 @@
+"""Test that the semantics relating to the 'fromlist' argument are correct."""
+from .. import util
+from . import util as import_util
+import imp
+import unittest
+
+class ReturnValue(unittest.TestCase):
+
+ """The use of fromlist influences what import returns.
+
+ If direct ``import ...`` statement is used, the root module or package is
+ returned [import return]. But if fromlist is set, then the specified module
+ is actually returned (whether it is a relative import or not)
+ [from return].
+
+ """
+
+ def test_return_from_import(self):
+ # [import return]
+ with util.mock_modules('pkg.__init__', 'pkg.module') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('pkg.module')
+ self.assertEqual(module.__name__, 'pkg')
+
+ def test_return_from_from_import(self):
+ # [from return]
+ with util.mock_modules('pkg.__init__', 'pkg.module')as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('pkg.module', fromlist=['attr'])
+ self.assertEqual(module.__name__, 'pkg.module')
+
+
+class HandlingFromlist(unittest.TestCase):
+
+ """Using fromlist triggers different actions based on what is being asked
+ of it.
+
+ If fromlist specifies an object on a module, nothing special happens
+ [object case]. This is even true if the object does not exist [bad object].
+
+ If a package is being imported, then what is listed in fromlist may be
+ treated as a module to be imported [module]. And this extends to what is
+ contained in __all__ when '*' is imported [using *]. And '*' does not need
+ to be the only name in the fromlist [using * with others].
+
+ """
+
+ def test_object(self):
+ # [object case]
+ with util.mock_modules('module') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('module', fromlist=['attr'])
+ self.assertEqual(module.__name__, 'module')
+
+ def test_nonexistent_object(self):
+ # [bad object]
+ with util.mock_modules('module') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('module', fromlist=['non_existent'])
+ self.assertEqual(module.__name__, 'module')
+ self.assertTrue(not hasattr(module, 'non_existent'))
+
+ def test_module_from_package(self):
+ # [module]
+ with util.mock_modules('pkg.__init__', 'pkg.module') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('pkg', fromlist=['module'])
+ self.assertEqual(module.__name__, 'pkg')
+ self.assertTrue(hasattr(module, 'module'))
+ self.assertEqual(module.module.__name__, 'pkg.module')
+
+ def test_module_from_package_triggers_ImportError(self):
+ # If a submodule causes an ImportError because it tries to import
+ # a module which doesn't exist, that should let the ImportError
+ # propagate.
+ def module_code():
+ import i_do_not_exist
+ with util.mock_modules('pkg.__init__', 'pkg.mod',
+ module_code={'pkg.mod': module_code}) as importer:
+ with util.import_state(meta_path=[importer]):
+ with self.assertRaises(ImportError) as exc:
+ import_util.import_('pkg', fromlist=['mod'])
+ self.assertEqual('i_do_not_exist', exc.exception.name)
+
+ def test_empty_string(self):
+ with util.mock_modules('pkg.__init__', 'pkg.mod') as importer:
+ with util.import_state(meta_path=[importer]):
+ module = import_util.import_('pkg.mod', fromlist=[''])
+ self.assertEqual(module.__name__, 'pkg.mod')
+
+ def basic_star_test(self, fromlist=['*']):
+ # [using *]
+ with util.mock_modules('pkg.__init__', 'pkg.module') as mock:
+ with util.import_state(meta_path=[mock]):
+ mock['pkg'].__all__ = ['module']
+ module = import_util.import_('pkg', fromlist=fromlist)
+ self.assertEqual(module.__name__, 'pkg')
+ self.assertTrue(hasattr(module, 'module'))
+ self.assertEqual(module.module.__name__, 'pkg.module')
+
+ def test_using_star(self):
+ # [using *]
+ self.basic_star_test()
+
+ def test_fromlist_as_tuple(self):
+ self.basic_star_test(('*',))
+
+ def test_star_with_others(self):
+ # [using * with others]
+ context = util.mock_modules('pkg.__init__', 'pkg.module1', 'pkg.module2')
+ with context as mock:
+ with util.import_state(meta_path=[mock]):
+ mock['pkg'].__all__ = ['module1']
+ module = import_util.import_('pkg', fromlist=['module2', '*'])
+ self.assertEqual(module.__name__, 'pkg')
+ self.assertTrue(hasattr(module, 'module1'))
+ self.assertTrue(hasattr(module, 'module2'))
+ self.assertEqual(module.module1.__name__, 'pkg.module1')
+ self.assertEqual(module.module2.__name__, 'pkg.module2')
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(ReturnValue, HandlingFromlist)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_meta_path.py b/Lib/test/test_importlib/import_/test_meta_path.py
new file mode 100644
index 0000000..4d85f80
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_meta_path.py
@@ -0,0 +1,115 @@
+from .. import util
+from . import util as import_util
+import importlib._bootstrap
+import sys
+from types import MethodType
+import unittest
+import warnings
+
+
+class CallingOrder(unittest.TestCase):
+
+ """Calls to the importers on sys.meta_path happen in order that they are
+ specified in the sequence, starting with the first importer
+ [first called], and then continuing on down until one is found that doesn't
+ return None [continuing]."""
+
+
+ def test_first_called(self):
+ # [first called]
+ mod = 'top_level'
+ first = util.mock_modules(mod)
+ second = util.mock_modules(mod)
+ with util.mock_modules(mod) as first, util.mock_modules(mod) as second:
+ first.modules[mod] = 42
+ second.modules[mod] = -13
+ with util.import_state(meta_path=[first, second]):
+ self.assertEqual(import_util.import_(mod), 42)
+
+ def test_continuing(self):
+ # [continuing]
+ mod_name = 'for_real'
+ with util.mock_modules('nonexistent') as first, \
+ util.mock_modules(mod_name) as second:
+ first.find_module = lambda self, fullname, path=None: None
+ second.modules[mod_name] = 42
+ with util.import_state(meta_path=[first, second]):
+ self.assertEqual(import_util.import_(mod_name), 42)
+
+ def test_empty(self):
+ # Raise an ImportWarning if sys.meta_path is empty.
+ module_name = 'nothing'
+ try:
+ del sys.modules[module_name]
+ except KeyError:
+ pass
+ with util.import_state(meta_path=[]):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ self.assertIsNone(importlib._bootstrap._find_module('nothing',
+ None))
+ self.assertEqual(len(w), 1)
+ self.assertTrue(issubclass(w[-1].category, ImportWarning))
+
+
+class CallSignature(unittest.TestCase):
+
+ """If there is no __path__ entry on the parent module, then 'path' is None
+ [no path]. Otherwise, the value for __path__ is passed in for the 'path'
+ argument [path set]."""
+
+ def log(self, fxn):
+ log = []
+ def wrapper(self, *args, **kwargs):
+ log.append([args, kwargs])
+ return fxn(*args, **kwargs)
+ return log, wrapper
+
+
+ def test_no_path(self):
+ # [no path]
+ mod_name = 'top_level'
+ assert '.' not in mod_name
+ with util.mock_modules(mod_name) as importer:
+ log, wrapped_call = self.log(importer.find_module)
+ importer.find_module = MethodType(wrapped_call, importer)
+ with util.import_state(meta_path=[importer]):
+ import_util.import_(mod_name)
+ assert len(log) == 1
+ args = log[0][0]
+ kwargs = log[0][1]
+ # Assuming all arguments are positional.
+ self.assertEqual(len(args), 2)
+ self.assertEqual(len(kwargs), 0)
+ self.assertEqual(args[0], mod_name)
+ self.assertIsNone(args[1])
+
+ def test_with_path(self):
+ # [path set]
+ pkg_name = 'pkg'
+ mod_name = pkg_name + '.module'
+ path = [42]
+ assert '.' in mod_name
+ with util.mock_modules(pkg_name+'.__init__', mod_name) as importer:
+ importer.modules[pkg_name].__path__ = path
+ log, wrapped_call = self.log(importer.find_module)
+ importer.find_module = MethodType(wrapped_call, importer)
+ with util.import_state(meta_path=[importer]):
+ import_util.import_(mod_name)
+ assert len(log) == 2
+ args = log[1][0]
+ kwargs = log[1][1]
+ # Assuming all arguments are positional.
+ self.assertTrue(not kwargs)
+ self.assertEqual(args[0], mod_name)
+ self.assertIs(args[1], path)
+
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(CallingOrder, CallSignature)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_packages.py b/Lib/test/test_importlib/import_/test_packages.py
new file mode 100644
index 0000000..bfa18dc
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_packages.py
@@ -0,0 +1,112 @@
+from .. import util
+from . import util as import_util
+import sys
+import unittest
+import importlib
+from test import support
+
+
+class ParentModuleTests(unittest.TestCase):
+
+ """Importing a submodule should import the parent modules."""
+
+ def test_import_parent(self):
+ with util.mock_modules('pkg.__init__', 'pkg.module') as mock:
+ with util.import_state(meta_path=[mock]):
+ module = import_util.import_('pkg.module')
+ self.assertIn('pkg', sys.modules)
+
+ def test_bad_parent(self):
+ with util.mock_modules('pkg.module') as mock:
+ with util.import_state(meta_path=[mock]):
+ with self.assertRaises(ImportError) as cm:
+ import_util.import_('pkg.module')
+ self.assertEqual(cm.exception.name, 'pkg')
+
+ def test_raising_parent_after_importing_child(self):
+ def __init__():
+ import pkg.module
+ 1/0
+ mock = util.mock_modules('pkg.__init__', 'pkg.module',
+ module_code={'pkg': __init__})
+ with mock:
+ with util.import_state(meta_path=[mock]):
+ with self.assertRaises(ZeroDivisionError):
+ import_util.import_('pkg')
+ self.assertNotIn('pkg', sys.modules)
+ self.assertIn('pkg.module', sys.modules)
+ with self.assertRaises(ZeroDivisionError):
+ import_util.import_('pkg.module')
+ self.assertNotIn('pkg', sys.modules)
+ self.assertIn('pkg.module', sys.modules)
+
+ def test_raising_parent_after_relative_importing_child(self):
+ def __init__():
+ from . import module
+ 1/0
+ mock = util.mock_modules('pkg.__init__', 'pkg.module',
+ module_code={'pkg': __init__})
+ with mock:
+ with util.import_state(meta_path=[mock]):
+ with self.assertRaises((ZeroDivisionError, ImportError)):
+ # This raises ImportError on the "from . import module"
+ # line, not sure why.
+ import_util.import_('pkg')
+ self.assertNotIn('pkg', sys.modules)
+ with self.assertRaises((ZeroDivisionError, ImportError)):
+ import_util.import_('pkg.module')
+ self.assertNotIn('pkg', sys.modules)
+ # XXX False
+ #self.assertIn('pkg.module', sys.modules)
+
+ def test_raising_parent_after_double_relative_importing_child(self):
+ def __init__():
+ from ..subpkg import module
+ 1/0
+ mock = util.mock_modules('pkg.__init__', 'pkg.subpkg.__init__',
+ 'pkg.subpkg.module',
+ module_code={'pkg.subpkg': __init__})
+ with mock:
+ with util.import_state(meta_path=[mock]):
+ with self.assertRaises((ZeroDivisionError, ImportError)):
+ # This raises ImportError on the "from ..subpkg import module"
+ # line, not sure why.
+ import_util.import_('pkg.subpkg')
+ self.assertNotIn('pkg.subpkg', sys.modules)
+ with self.assertRaises((ZeroDivisionError, ImportError)):
+ import_util.import_('pkg.subpkg.module')
+ self.assertNotIn('pkg.subpkg', sys.modules)
+ # XXX False
+ #self.assertIn('pkg.subpkg.module', sys.modules)
+
+ def test_module_not_package(self):
+ # Try to import a submodule from a non-package should raise ImportError.
+ assert not hasattr(sys, '__path__')
+ with self.assertRaises(ImportError) as cm:
+ import_util.import_('sys.no_submodules_here')
+ self.assertEqual(cm.exception.name, 'sys.no_submodules_here')
+
+ def test_module_not_package_but_side_effects(self):
+ # If a module injects something into sys.modules as a side-effect, then
+ # pick up on that fact.
+ name = 'mod'
+ subname = name + '.b'
+ def module_injection():
+ sys.modules[subname] = 'total bunk'
+ mock_modules = util.mock_modules('mod',
+ module_code={'mod': module_injection})
+ with mock_modules as mock:
+ with util.import_state(meta_path=[mock]):
+ try:
+ submodule = import_util.import_(subname)
+ finally:
+ support.unload(subname)
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(ParentModuleTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_path.py b/Lib/test/test_importlib/import_/test_path.py
new file mode 100644
index 0000000..d82b7f6
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_path.py
@@ -0,0 +1,120 @@
+from importlib import _bootstrap
+from importlib import machinery
+from importlib import import_module
+from .. import util
+from . import util as import_util
+import os
+import sys
+from types import ModuleType
+import unittest
+import warnings
+import zipimport
+
+
+class FinderTests(unittest.TestCase):
+
+ """Tests for PathFinder."""
+
+ def test_failure(self):
+ # Test None returned upon not finding a suitable finder.
+ module = '<test module>'
+ with util.import_state():
+ self.assertIsNone(machinery.PathFinder.find_module(module))
+
+ def test_sys_path(self):
+ # Test that sys.path is used when 'path' is None.
+ # Implicitly tests that sys.path_importer_cache is used.
+ module = '<test module>'
+ path = '<test path>'
+ importer = util.mock_modules(module)
+ with util.import_state(path_importer_cache={path: importer},
+ path=[path]):
+ loader = machinery.PathFinder.find_module(module)
+ self.assertIs(loader, importer)
+
+ def test_path(self):
+ # Test that 'path' is used when set.
+ # Implicitly tests that sys.path_importer_cache is used.
+ module = '<test module>'
+ path = '<test path>'
+ importer = util.mock_modules(module)
+ with util.import_state(path_importer_cache={path: importer}):
+ loader = machinery.PathFinder.find_module(module, [path])
+ self.assertIs(loader, importer)
+
+ def test_empty_list(self):
+ # An empty list should not count as asking for sys.path.
+ module = 'module'
+ path = '<test path>'
+ importer = util.mock_modules(module)
+ with util.import_state(path_importer_cache={path: importer},
+ path=[path]):
+ self.assertIsNone(machinery.PathFinder.find_module('module', []))
+
+ def test_path_hooks(self):
+ # Test that sys.path_hooks is used.
+ # Test that sys.path_importer_cache is set.
+ module = '<test module>'
+ path = '<test path>'
+ importer = util.mock_modules(module)
+ hook = import_util.mock_path_hook(path, importer=importer)
+ with util.import_state(path_hooks=[hook]):
+ loader = machinery.PathFinder.find_module(module, [path])
+ self.assertIs(loader, importer)
+ self.assertIn(path, sys.path_importer_cache)
+ self.assertIs(sys.path_importer_cache[path], importer)
+
+ def test_empty_path_hooks(self):
+ # Test that if sys.path_hooks is empty a warning is raised,
+ # sys.path_importer_cache gets None set, and PathFinder returns None.
+ path_entry = 'bogus_path'
+ with util.import_state(path_importer_cache={}, path_hooks=[],
+ path=[path_entry]):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ self.assertIsNone(machinery.PathFinder.find_module('os'))
+ self.assertIsNone(sys.path_importer_cache[path_entry])
+ self.assertEqual(len(w), 1)
+ self.assertTrue(issubclass(w[-1].category, ImportWarning))
+
+ def test_path_importer_cache_empty_string(self):
+ # The empty string should create a finder using the cwd.
+ path = ''
+ module = '<test module>'
+ importer = util.mock_modules(module)
+ hook = import_util.mock_path_hook(os.curdir, importer=importer)
+ with util.import_state(path=[path], path_hooks=[hook]):
+ loader = machinery.PathFinder.find_module(module)
+ self.assertIs(loader, importer)
+ self.assertIn(os.curdir, sys.path_importer_cache)
+
+ def test_None_on_sys_path(self):
+ # Putting None in sys.path[0] caused an import regression from Python
+ # 3.2: http://bugs.python.org/issue16514
+ new_path = sys.path[:]
+ new_path.insert(0, None)
+ new_path_importer_cache = sys.path_importer_cache.copy()
+ new_path_importer_cache.pop(None, None)
+ new_path_hooks = [zipimport.zipimporter,
+ _bootstrap.FileFinder.path_hook(
+ *_bootstrap._get_supported_file_loaders())]
+ missing = object()
+ email = sys.modules.pop('email', missing)
+ try:
+ with util.import_state(meta_path=sys.meta_path[:],
+ path=new_path,
+ path_importer_cache=new_path_importer_cache,
+ path_hooks=new_path_hooks):
+ module = import_module('email')
+ self.assertIsInstance(module, ModuleType)
+ finally:
+ if email is not missing:
+ sys.modules['email'] = email
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(FinderTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/test_relative_imports.py b/Lib/test/test_importlib/import_/test_relative_imports.py
new file mode 100644
index 0000000..4569c26
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test_relative_imports.py
@@ -0,0 +1,217 @@
+"""Test relative imports (PEP 328)."""
+from .. import util
+from . import util as import_util
+import sys
+import unittest
+
+class RelativeImports(unittest.TestCase):
+
+ """PEP 328 introduced relative imports. This allows for imports to occur
+ from within a package without having to specify the actual package name.
+
+ A simple example is to import another module within the same package
+ [module from module]::
+
+ # From pkg.mod1 with pkg.mod2 being a module.
+ from . import mod2
+
+ This also works for getting an attribute from a module that is specified
+ in a relative fashion [attr from module]::
+
+ # From pkg.mod1.
+ from .mod2 import attr
+
+ But this is in no way restricted to working between modules; it works
+ from [package to module],::
+
+ # From pkg, importing pkg.module which is a module.
+ from . import module
+
+ [module to package],::
+
+ # Pull attr from pkg, called from pkg.module which is a module.
+ from . import attr
+
+ and [package to package]::
+
+ # From pkg.subpkg1 (both pkg.subpkg[1,2] are packages).
+ from .. import subpkg2
+
+ The number of dots used is in no way restricted [deep import]::
+
+ # Import pkg.attr from pkg.pkg1.pkg2.pkg3.pkg4.pkg5.
+ from ...... import attr
+
+ To prevent someone from accessing code that is outside of a package, one
+ cannot reach the location containing the root package itself::
+
+ # From pkg.__init__ [too high from package]
+ from .. import top_level
+
+ # From pkg.module [too high from module]
+ from .. import top_level
+
+ Relative imports are the only type of import that allow for an empty
+ module name for an import [empty name].
+
+ """
+
+ def relative_import_test(self, create, globals_, callback):
+ """Abstract out boilerplace for setting up for an import test."""
+ uncache_names = []
+ for name in create:
+ if not name.endswith('.__init__'):
+ uncache_names.append(name)
+ else:
+ uncache_names.append(name[:-len('.__init__')])
+ with util.mock_modules(*create) as importer:
+ with util.import_state(meta_path=[importer]):
+ for global_ in globals_:
+ with util.uncache(*uncache_names):
+ callback(global_)
+
+
+ def test_module_from_module(self):
+ # [module from module]
+ create = 'pkg.__init__', 'pkg.mod2'
+ globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.mod1'}
+ def callback(global_):
+ import_util.import_('pkg') # For __import__().
+ module = import_util.import_('', global_, fromlist=['mod2'], level=1)
+ self.assertEqual(module.__name__, 'pkg')
+ self.assertTrue(hasattr(module, 'mod2'))
+ self.assertEqual(module.mod2.attr, 'pkg.mod2')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_attr_from_module(self):
+ # [attr from module]
+ create = 'pkg.__init__', 'pkg.mod2'
+ globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.mod1'}
+ def callback(global_):
+ import_util.import_('pkg') # For __import__().
+ module = import_util.import_('mod2', global_, fromlist=['attr'],
+ level=1)
+ self.assertEqual(module.__name__, 'pkg.mod2')
+ self.assertEqual(module.attr, 'pkg.mod2')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_package_to_module(self):
+ # [package to module]
+ create = 'pkg.__init__', 'pkg.module'
+ globals_ = ({'__package__': 'pkg'},
+ {'__name__': 'pkg', '__path__': ['blah']})
+ def callback(global_):
+ import_util.import_('pkg') # For __import__().
+ module = import_util.import_('', global_, fromlist=['module'],
+ level=1)
+ self.assertEqual(module.__name__, 'pkg')
+ self.assertTrue(hasattr(module, 'module'))
+ self.assertEqual(module.module.attr, 'pkg.module')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_module_to_package(self):
+ # [module to package]
+ create = 'pkg.__init__', 'pkg.module'
+ globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.module'}
+ def callback(global_):
+ import_util.import_('pkg') # For __import__().
+ module = import_util.import_('', global_, fromlist=['attr'], level=1)
+ self.assertEqual(module.__name__, 'pkg')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_package_to_package(self):
+ # [package to package]
+ create = ('pkg.__init__', 'pkg.subpkg1.__init__',
+ 'pkg.subpkg2.__init__')
+ globals_ = ({'__package__': 'pkg.subpkg1'},
+ {'__name__': 'pkg.subpkg1', '__path__': ['blah']})
+ def callback(global_):
+ module = import_util.import_('', global_, fromlist=['subpkg2'],
+ level=2)
+ self.assertEqual(module.__name__, 'pkg')
+ self.assertTrue(hasattr(module, 'subpkg2'))
+ self.assertEqual(module.subpkg2.attr, 'pkg.subpkg2.__init__')
+
+ def test_deep_import(self):
+ # [deep import]
+ create = ['pkg.__init__']
+ for count in range(1,6):
+ create.append('{0}.pkg{1}.__init__'.format(
+ create[-1][:-len('.__init__')], count))
+ globals_ = ({'__package__': 'pkg.pkg1.pkg2.pkg3.pkg4.pkg5'},
+ {'__name__': 'pkg.pkg1.pkg2.pkg3.pkg4.pkg5',
+ '__path__': ['blah']})
+ def callback(global_):
+ import_util.import_(globals_[0]['__package__'])
+ module = import_util.import_('', global_, fromlist=['attr'], level=6)
+ self.assertEqual(module.__name__, 'pkg')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_too_high_from_package(self):
+ # [too high from package]
+ create = ['top_level', 'pkg.__init__']
+ globals_ = ({'__package__': 'pkg'},
+ {'__name__': 'pkg', '__path__': ['blah']})
+ def callback(global_):
+ import_util.import_('pkg')
+ with self.assertRaises(ValueError):
+ import_util.import_('', global_, fromlist=['top_level'],
+ level=2)
+ self.relative_import_test(create, globals_, callback)
+
+ def test_too_high_from_module(self):
+ # [too high from module]
+ create = ['top_level', 'pkg.__init__', 'pkg.module']
+ globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.module'}
+ def callback(global_):
+ import_util.import_('pkg')
+ with self.assertRaises(ValueError):
+ import_util.import_('', global_, fromlist=['top_level'],
+ level=2)
+ self.relative_import_test(create, globals_, callback)
+
+ def test_empty_name_w_level_0(self):
+ # [empty name]
+ with self.assertRaises(ValueError):
+ import_util.import_('')
+
+ def test_import_from_different_package(self):
+ # Test importing from a different package than the caller.
+ # in pkg.subpkg1.mod
+ # from ..subpkg2 import mod
+ create = ['__runpy_pkg__.__init__',
+ '__runpy_pkg__.__runpy_pkg__.__init__',
+ '__runpy_pkg__.uncle.__init__',
+ '__runpy_pkg__.uncle.cousin.__init__',
+ '__runpy_pkg__.uncle.cousin.nephew']
+ globals_ = {'__package__': '__runpy_pkg__.__runpy_pkg__'}
+ def callback(global_):
+ import_util.import_('__runpy_pkg__.__runpy_pkg__')
+ module = import_util.import_('uncle.cousin', globals_, {},
+ fromlist=['nephew'],
+ level=2)
+ self.assertEqual(module.__name__, '__runpy_pkg__.uncle.cousin')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_import_relative_import_no_fromlist(self):
+ # Import a relative module w/ no fromlist.
+ create = ['crash.__init__', 'crash.mod']
+ globals_ = [{'__package__': 'crash', '__name__': 'crash'}]
+ def callback(global_):
+ import_util.import_('crash')
+ mod = import_util.import_('mod', global_, {}, [], 1)
+ self.assertEqual(mod.__name__, 'crash.mod')
+ self.relative_import_test(create, globals_, callback)
+
+ def test_relative_import_no_globals(self):
+ # No globals for a relative import is an error.
+ with self.assertRaises(KeyError):
+ import_util.import_('sys', level=1)
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(RelativeImports)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/import_/util.py b/Lib/test/test_importlib/import_/util.py
new file mode 100644
index 0000000..86ac065
--- /dev/null
+++ b/Lib/test/test_importlib/import_/util.py
@@ -0,0 +1,28 @@
+import functools
+import importlib
+import unittest
+
+
+using___import__ = False
+
+
+def import_(*args, **kwargs):
+ """Delegate to allow for injecting different implementations of import."""
+ if using___import__:
+ return __import__(*args, **kwargs)
+ else:
+ return importlib.__import__(*args, **kwargs)
+
+
+def importlib_only(fxn):
+ """Decorator to skip a test if using __builtins__.__import__."""
+ return unittest.skipIf(using___import__, "importlib-specific test")(fxn)
+
+
+def mock_path_hook(*entries, importer):
+ """A mock sys.path_hooks entry."""
+ def hook(entry):
+ if entry not in entries:
+ raise ImportError
+ return importer
+ return hook
diff --git a/Lib/test/test_importlib/regrtest.py b/Lib/test/test_importlib/regrtest.py
new file mode 100644
index 0000000..a5be11f
--- /dev/null
+++ b/Lib/test/test_importlib/regrtest.py
@@ -0,0 +1,17 @@
+"""Run Python's standard test suite using importlib.__import__.
+
+Tests known to fail because of assumptions that importlib (properly)
+invalidates are automatically skipped if the entire test suite is run.
+Otherwise all command-line options valid for test.regrtest are also valid for
+this script.
+
+"""
+import importlib
+import sys
+from test import regrtest
+
+if __name__ == '__main__':
+ __builtins__.__import__ = importlib.__import__
+ sys.path_importer_cache.clear()
+
+ regrtest.main(quiet=True, verbose2=True)
diff --git a/Lib/test/test_importlib/source/__init__.py b/Lib/test/test_importlib/source/__init__.py
new file mode 100644
index 0000000..3ef97f3
--- /dev/null
+++ b/Lib/test/test_importlib/source/__init__.py
@@ -0,0 +1,13 @@
+from .. import test_suite
+import os.path
+import unittest
+
+
+def test_suite():
+ directory = os.path.dirname(__file__)
+ return test.test_suite('importlib.test.source', directory)
+
+
+if __name__ == '__main__':
+ from test.support import run_unittest
+ run_unittest(test_suite())
diff --git a/Lib/test/test_importlib/source/test_abc_loader.py b/Lib/test/test_importlib/source/test_abc_loader.py
new file mode 100644
index 0000000..0d912b6
--- /dev/null
+++ b/Lib/test/test_importlib/source/test_abc_loader.py
@@ -0,0 +1,906 @@
+import importlib
+from importlib import abc
+
+from .. import abc as testing_abc
+from .. import util
+from . import util as source_util
+
+import imp
+import inspect
+import io
+import marshal
+import os
+import sys
+import types
+import unittest
+import warnings
+
+
+class SourceOnlyLoaderMock(abc.SourceLoader):
+
+ # Globals that should be defined for all modules.
+ source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, "
+ b"repr(__loader__)])")
+
+ def __init__(self, path):
+ self.path = path
+
+ def get_data(self, path):
+ assert self.path == path
+ return self.source
+
+ def get_filename(self, fullname):
+ return self.path
+
+ def module_repr(self, module):
+ return '<module>'
+
+
+class SourceLoaderMock(SourceOnlyLoaderMock):
+
+ source_mtime = 1
+
+ def __init__(self, path, magic=imp.get_magic()):
+ super().__init__(path)
+ self.bytecode_path = imp.cache_from_source(self.path)
+ self.source_size = len(self.source)
+ data = bytearray(magic)
+ data.extend(importlib._w_long(self.source_mtime))
+ data.extend(importlib._w_long(self.source_size))
+ code_object = compile(self.source, self.path, 'exec',
+ dont_inherit=True)
+ data.extend(marshal.dumps(code_object))
+ self.bytecode = bytes(data)
+ self.written = {}
+
+ def get_data(self, path):
+ if path == self.path:
+ return super().get_data(path)
+ elif path == self.bytecode_path:
+ return self.bytecode
+ else:
+ raise IOError
+
+ def path_stats(self, path):
+ assert path == self.path
+ return {'mtime': self.source_mtime, 'size': self.source_size}
+
+ def set_data(self, path, data):
+ self.written[path] = bytes(data)
+ return path == self.bytecode_path
+
+
+class PyLoaderMock(abc.PyLoader):
+
+ # Globals that should be defined for all modules.
+ source = (b"_ = '::'.join([__name__, __file__, __package__, "
+ b"repr(__loader__)])")
+
+ def __init__(self, data):
+ """Take a dict of 'module_name: path' pairings.
+
+ Paths should have no file extension, allowing packages to be denoted by
+ ending in '__init__'.
+
+ """
+ self.module_paths = data
+ self.path_to_module = {val:key for key,val in data.items()}
+
+ def get_data(self, path):
+ if path not in self.path_to_module:
+ raise IOError
+ return self.source
+
+ def is_package(self, name):
+ filename = os.path.basename(self.get_filename(name))
+ return os.path.splitext(filename)[0] == '__init__'
+
+ def source_path(self, name):
+ try:
+ return self.module_paths[name]
+ except KeyError:
+ raise ImportError
+
+ def get_filename(self, name):
+ """Silence deprecation warning."""
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ path = super().get_filename(name)
+ assert len(w) == 1
+ assert issubclass(w[0].category, DeprecationWarning)
+ return path
+
+ def module_repr(self):
+ return '<module>'
+
+
+class PyLoaderCompatMock(PyLoaderMock):
+
+ """Mock that matches what is suggested to have a loader that is compatible
+ from Python 3.1 onwards."""
+
+ def get_filename(self, fullname):
+ try:
+ return self.module_paths[fullname]
+ except KeyError:
+ raise ImportError
+
+ def source_path(self, fullname):
+ try:
+ return self.get_filename(fullname)
+ except ImportError:
+ return None
+
+
+class PyPycLoaderMock(abc.PyPycLoader, PyLoaderMock):
+
+ default_mtime = 1
+
+ def __init__(self, source, bc={}):
+ """Initialize mock.
+
+ 'bc' is a dict keyed on a module's name. The value is dict with
+ possible keys of 'path', 'mtime', 'magic', and 'bc'. Except for 'path',
+ each of those keys control if any part of created bytecode is to
+ deviate from default values.
+
+ """
+ super().__init__(source)
+ self.module_bytecode = {}
+ self.path_to_bytecode = {}
+ self.bytecode_to_path = {}
+ for name, data in bc.items():
+ self.path_to_bytecode[data['path']] = name
+ self.bytecode_to_path[name] = data['path']
+ magic = data.get('magic', imp.get_magic())
+ mtime = importlib._w_long(data.get('mtime', self.default_mtime))
+ source_size = importlib._w_long(len(self.source) & 0xFFFFFFFF)
+ if 'bc' in data:
+ bc = data['bc']
+ else:
+ bc = self.compile_bc(name)
+ self.module_bytecode[name] = magic + mtime + source_size + bc
+
+ def compile_bc(self, name):
+ source_path = self.module_paths.get(name, '<test>') or '<test>'
+ code = compile(self.source, source_path, 'exec')
+ return marshal.dumps(code)
+
+ def source_mtime(self, name):
+ if name in self.module_paths:
+ return self.default_mtime
+ elif name in self.module_bytecode:
+ return None
+ else:
+ raise ImportError
+
+ def bytecode_path(self, name):
+ try:
+ return self.bytecode_to_path[name]
+ except KeyError:
+ if name in self.module_paths:
+ return None
+ else:
+ raise ImportError
+
+ def write_bytecode(self, name, bytecode):
+ self.module_bytecode[name] = bytecode
+ return True
+
+ def get_data(self, path):
+ if path in self.path_to_module:
+ return super().get_data(path)
+ elif path in self.path_to_bytecode:
+ name = self.path_to_bytecode[path]
+ return self.module_bytecode[name]
+ else:
+ raise IOError
+
+ def is_package(self, name):
+ try:
+ return super().is_package(name)
+ except TypeError:
+ return '__init__' in self.bytecode_to_path[name]
+
+ def get_code(self, name):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ code_object = super().get_code(name)
+ assert len(w) == 1
+ assert issubclass(w[0].category, DeprecationWarning)
+ return code_object
+
+class PyLoaderTests(testing_abc.LoaderTests):
+
+ """Tests for importlib.abc.PyLoader."""
+
+ mocker = PyLoaderMock
+
+ def eq_attrs(self, ob, **kwargs):
+ for attr, val in kwargs.items():
+ found = getattr(ob, attr)
+ self.assertEqual(found, val,
+ "{} attribute: {} != {}".format(attr, found, val))
+
+ def test_module(self):
+ name = '<module>'
+ path = os.path.join('', 'path', 'to', 'module')
+ mock = self.mocker({name: path})
+ with util.uncache(name):
+ module = mock.load_module(name)
+ self.assertIn(name, sys.modules)
+ self.eq_attrs(module, __name__=name, __file__=path, __package__='',
+ __loader__=mock)
+ self.assertTrue(not hasattr(module, '__path__'))
+ return mock, name
+
+ def test_package(self):
+ name = '<pkg>'
+ path = os.path.join('path', 'to', name, '__init__')
+ mock = self.mocker({name: path})
+ with util.uncache(name):
+ module = mock.load_module(name)
+ self.assertIn(name, sys.modules)
+ self.eq_attrs(module, __name__=name, __file__=path,
+ __path__=[os.path.dirname(path)], __package__=name,
+ __loader__=mock)
+ return mock, name
+
+ def test_lacking_parent(self):
+ name = 'pkg.mod'
+ path = os.path.join('path', 'to', 'pkg', 'mod')
+ mock = self.mocker({name: path})
+ with util.uncache(name):
+ module = mock.load_module(name)
+ self.assertIn(name, sys.modules)
+ self.eq_attrs(module, __name__=name, __file__=path, __package__='pkg',
+ __loader__=mock)
+ self.assertFalse(hasattr(module, '__path__'))
+ return mock, name
+
+ def test_module_reuse(self):
+ name = 'mod'
+ path = os.path.join('path', 'to', 'mod')
+ module = imp.new_module(name)
+ mock = self.mocker({name: path})
+ with util.uncache(name):
+ sys.modules[name] = module
+ loaded_module = mock.load_module(name)
+ self.assertIs(loaded_module, module)
+ self.assertIs(sys.modules[name], module)
+ return mock, name
+
+ def test_state_after_failure(self):
+ name = "mod"
+ module = imp.new_module(name)
+ module.blah = None
+ mock = self.mocker({name: os.path.join('path', 'to', 'mod')})
+ mock.source = b"1/0"
+ with util.uncache(name):
+ sys.modules[name] = module
+ with self.assertRaises(ZeroDivisionError):
+ mock.load_module(name)
+ self.assertIs(sys.modules[name], module)
+ self.assertTrue(hasattr(module, 'blah'))
+ return mock
+
+ def test_unloadable(self):
+ name = "mod"
+ mock = self.mocker({name: os.path.join('path', 'to', 'mod')})
+ mock.source = b"1/0"
+ with util.uncache(name):
+ with self.assertRaises(ZeroDivisionError):
+ mock.load_module(name)
+ self.assertNotIn(name, sys.modules)
+ return mock
+
+
+class PyLoaderCompatTests(PyLoaderTests):
+
+ """Test that the suggested code to make a loader that is compatible from
+ Python 3.1 forward works."""
+
+ mocker = PyLoaderCompatMock
+
+
+class PyLoaderInterfaceTests(unittest.TestCase):
+
+ """Tests for importlib.abc.PyLoader to make sure that when source_path()
+ doesn't return a path everything works as expected."""
+
+ def test_no_source_path(self):
+ # No source path should lead to ImportError.
+ name = 'mod'
+ mock = PyLoaderMock({})
+ with util.uncache(name), self.assertRaises(ImportError):
+ mock.load_module(name)
+
+ def test_source_path_is_None(self):
+ name = 'mod'
+ mock = PyLoaderMock({name: None})
+ with util.uncache(name), self.assertRaises(ImportError):
+ mock.load_module(name)
+
+ def test_get_filename_with_source_path(self):
+ # get_filename() should return what source_path() returns.
+ name = 'mod'
+ path = os.path.join('path', 'to', 'source')
+ mock = PyLoaderMock({name: path})
+ with util.uncache(name):
+ self.assertEqual(mock.get_filename(name), path)
+
+ def test_get_filename_no_source_path(self):
+ # get_filename() should raise ImportError if source_path returns None.
+ name = 'mod'
+ mock = PyLoaderMock({name: None})
+ with util.uncache(name), self.assertRaises(ImportError):
+ mock.get_filename(name)
+
+
+class PyPycLoaderTests(PyLoaderTests):
+
+ """Tests for importlib.abc.PyPycLoader."""
+
+ mocker = PyPycLoaderMock
+
+ @source_util.writes_bytecode_files
+ def verify_bytecode(self, mock, name):
+ assert name in mock.module_paths
+ self.assertIn(name, mock.module_bytecode)
+ magic = mock.module_bytecode[name][:4]
+ self.assertEqual(magic, imp.get_magic())
+ mtime = importlib._r_long(mock.module_bytecode[name][4:8])
+ self.assertEqual(mtime, 1)
+ source_size = mock.module_bytecode[name][8:12]
+ self.assertEqual(len(mock.source) & 0xFFFFFFFF,
+ importlib._r_long(source_size))
+ bc = mock.module_bytecode[name][12:]
+ self.assertEqual(bc, mock.compile_bc(name))
+
+ def test_module(self):
+ mock, name = super().test_module()
+ self.verify_bytecode(mock, name)
+
+ def test_package(self):
+ mock, name = super().test_package()
+ self.verify_bytecode(mock, name)
+
+ def test_lacking_parent(self):
+ mock, name = super().test_lacking_parent()
+ self.verify_bytecode(mock, name)
+
+ def test_module_reuse(self):
+ mock, name = super().test_module_reuse()
+ self.verify_bytecode(mock, name)
+
+ def test_state_after_failure(self):
+ super().test_state_after_failure()
+
+ def test_unloadable(self):
+ super().test_unloadable()
+
+
+class PyPycLoaderInterfaceTests(unittest.TestCase):
+
+ """Test for the interface of importlib.abc.PyPycLoader."""
+
+ def get_filename_check(self, src_path, bc_path, expect):
+ name = 'mod'
+ mock = PyPycLoaderMock({name: src_path}, {name: {'path': bc_path}})
+ with util.uncache(name):
+ assert mock.source_path(name) == src_path
+ assert mock.bytecode_path(name) == bc_path
+ self.assertEqual(mock.get_filename(name), expect)
+
+ def test_filename_with_source_bc(self):
+ # When source and bytecode paths present, return the source path.
+ self.get_filename_check('source_path', 'bc_path', 'source_path')
+
+ def test_filename_with_source_no_bc(self):
+ # With source but no bc, return source path.
+ self.get_filename_check('source_path', None, 'source_path')
+
+ def test_filename_with_no_source_bc(self):
+ # With not source but bc, return the bc path.
+ self.get_filename_check(None, 'bc_path', 'bc_path')
+
+ def test_filename_with_no_source_or_bc(self):
+ # With no source or bc, raise ImportError.
+ name = 'mod'
+ mock = PyPycLoaderMock({name: None}, {name: {'path': None}})
+ with util.uncache(name), self.assertRaises(ImportError):
+ mock.get_filename(name)
+
+
+class SkipWritingBytecodeTests(unittest.TestCase):
+
+ """Test that bytecode is properly handled based on
+ sys.dont_write_bytecode."""
+
+ @source_util.writes_bytecode_files
+ def run_test(self, dont_write_bytecode):
+ name = 'mod'
+ mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')})
+ sys.dont_write_bytecode = dont_write_bytecode
+ with util.uncache(name):
+ mock.load_module(name)
+ self.assertIsNot(name in mock.module_bytecode, dont_write_bytecode)
+
+ def test_no_bytecode_written(self):
+ self.run_test(True)
+
+ def test_bytecode_written(self):
+ self.run_test(False)
+
+
+class RegeneratedBytecodeTests(unittest.TestCase):
+
+ """Test that bytecode is regenerated as expected."""
+
+ @source_util.writes_bytecode_files
+ def test_different_magic(self):
+ # A different magic number should lead to new bytecode.
+ name = 'mod'
+ bad_magic = b'\x00\x00\x00\x00'
+ assert bad_magic != imp.get_magic()
+ mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')},
+ {name: {'path': os.path.join('path', 'to',
+ 'mod.bytecode'),
+ 'magic': bad_magic}})
+ with util.uncache(name):
+ mock.load_module(name)
+ self.assertIn(name, mock.module_bytecode)
+ magic = mock.module_bytecode[name][:4]
+ self.assertEqual(magic, imp.get_magic())
+
+ @source_util.writes_bytecode_files
+ def test_old_mtime(self):
+ # Bytecode with an older mtime should be regenerated.
+ name = 'mod'
+ old_mtime = PyPycLoaderMock.default_mtime - 1
+ mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')},
+ {name: {'path': 'path/to/mod.bytecode', 'mtime': old_mtime}})
+ with util.uncache(name):
+ mock.load_module(name)
+ self.assertIn(name, mock.module_bytecode)
+ mtime = importlib._r_long(mock.module_bytecode[name][4:8])
+ self.assertEqual(mtime, PyPycLoaderMock.default_mtime)
+
+
+class BadBytecodeFailureTests(unittest.TestCase):
+
+ """Test import failures when there is no source and parts of the bytecode
+ is bad."""
+
+ def test_bad_magic(self):
+ # A bad magic number should lead to an ImportError.
+ name = 'mod'
+ bad_magic = b'\x00\x00\x00\x00'
+ bc = {name:
+ {'path': os.path.join('path', 'to', 'mod'),
+ 'magic': bad_magic}}
+ mock = PyPycLoaderMock({name: None}, bc)
+ with util.uncache(name), self.assertRaises(ImportError) as cm:
+ mock.load_module(name)
+ self.assertEqual(cm.exception.name, name)
+
+ def test_no_bytecode(self):
+ # Missing code object bytecode should lead to an EOFError.
+ name = 'mod'
+ bc = {name: {'path': os.path.join('path', 'to', 'mod'), 'bc': b''}}
+ mock = PyPycLoaderMock({name: None}, bc)
+ with util.uncache(name), self.assertRaises(EOFError):
+ mock.load_module(name)
+
+ def test_bad_bytecode(self):
+ # Malformed code object bytecode should lead to a ValueError.
+ name = 'mod'
+ bc = {name: {'path': os.path.join('path', 'to', 'mod'), 'bc': b'1234'}}
+ mock = PyPycLoaderMock({name: None}, bc)
+ with util.uncache(name), self.assertRaises(ValueError):
+ mock.load_module(name)
+
+
+def raise_ImportError(*args, **kwargs):
+ raise ImportError
+
+class MissingPathsTests(unittest.TestCase):
+
+ """Test what happens when a source or bytecode path does not exist (either
+ from *_path returning None or raising ImportError)."""
+
+ def test_source_path_None(self):
+ # Bytecode should be used when source_path returns None, along with
+ # __file__ being set to the bytecode path.
+ name = 'mod'
+ bytecode_path = 'path/to/mod'
+ mock = PyPycLoaderMock({name: None}, {name: {'path': bytecode_path}})
+ with util.uncache(name):
+ module = mock.load_module(name)
+ self.assertEqual(module.__file__, bytecode_path)
+
+ # Testing for bytecode_path returning None handled by all tests where no
+ # bytecode initially exists.
+
+ def test_all_paths_None(self):
+ # If all *_path methods return None, raise ImportError.
+ name = 'mod'
+ mock = PyPycLoaderMock({name: None})
+ with util.uncache(name), self.assertRaises(ImportError) as cm:
+ mock.load_module(name)
+ self.assertEqual(cm.exception.name, name)
+
+ def test_source_path_ImportError(self):
+ # An ImportError from source_path should trigger an ImportError.
+ name = 'mod'
+ mock = PyPycLoaderMock({}, {name: {'path': os.path.join('path', 'to',
+ 'mod')}})
+ with util.uncache(name), self.assertRaises(ImportError):
+ mock.load_module(name)
+
+ def test_bytecode_path_ImportError(self):
+ # An ImportError from bytecode_path should trigger an ImportError.
+ name = 'mod'
+ mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')})
+ bad_meth = types.MethodType(raise_ImportError, mock)
+ mock.bytecode_path = bad_meth
+ with util.uncache(name), self.assertRaises(ImportError) as cm:
+ mock.load_module(name)
+
+
+class SourceLoaderTestHarness(unittest.TestCase):
+
+ def setUp(self, *, is_package=True, **kwargs):
+ self.package = 'pkg'
+ if is_package:
+ self.path = os.path.join(self.package, '__init__.py')
+ self.name = self.package
+ else:
+ module_name = 'mod'
+ self.path = os.path.join(self.package, '.'.join(['mod', 'py']))
+ self.name = '.'.join([self.package, module_name])
+ self.cached = imp.cache_from_source(self.path)
+ self.loader = self.loader_mock(self.path, **kwargs)
+
+ def verify_module(self, module):
+ self.assertEqual(module.__name__, self.name)
+ self.assertEqual(module.__file__, self.path)
+ self.assertEqual(module.__cached__, self.cached)
+ self.assertEqual(module.__package__, self.package)
+ self.assertEqual(module.__loader__, self.loader)
+ values = module._.split('::')
+ self.assertEqual(values[0], self.name)
+ self.assertEqual(values[1], self.path)
+ self.assertEqual(values[2], self.cached)
+ self.assertEqual(values[3], self.package)
+ self.assertEqual(values[4], repr(self.loader))
+
+ def verify_code(self, code_object):
+ module = imp.new_module(self.name)
+ module.__file__ = self.path
+ module.__cached__ = self.cached
+ module.__package__ = self.package
+ module.__loader__ = self.loader
+ module.__path__ = []
+ exec(code_object, module.__dict__)
+ self.verify_module(module)
+
+
+class SourceOnlyLoaderTests(SourceLoaderTestHarness):
+
+ """Test importlib.abc.SourceLoader for source-only loading.
+
+ Reload testing is subsumed by the tests for
+ importlib.util.module_for_loader.
+
+ """
+
+ loader_mock = SourceOnlyLoaderMock
+
+ def test_get_source(self):
+ # Verify the source code is returned as a string.
+ # If an IOError is raised by get_data then raise ImportError.
+ expected_source = self.loader.source.decode('utf-8')
+ self.assertEqual(self.loader.get_source(self.name), expected_source)
+ def raise_IOError(path):
+ raise IOError
+ self.loader.get_data = raise_IOError
+ with self.assertRaises(ImportError) as cm:
+ self.loader.get_source(self.name)
+ self.assertEqual(cm.exception.name, self.name)
+
+ def test_is_package(self):
+ # Properly detect when loading a package.
+ self.setUp(is_package=False)
+ self.assertFalse(self.loader.is_package(self.name))
+ self.setUp(is_package=True)
+ self.assertTrue(self.loader.is_package(self.name))
+ self.assertFalse(self.loader.is_package(self.name + '.__init__'))
+
+ def test_get_code(self):
+ # Verify the code object is created.
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+
+ def test_load_module(self):
+ # Loading a module should set __name__, __loader__, __package__,
+ # __path__ (for packages), __file__, and __cached__.
+ # The module should also be put into sys.modules.
+ with util.uncache(self.name):
+ module = self.loader.load_module(self.name)
+ self.verify_module(module)
+ self.assertEqual(module.__path__, [os.path.dirname(self.path)])
+ self.assertIn(self.name, sys.modules)
+
+ def test_package_settings(self):
+ # __package__ needs to be set, while __path__ is set on if the module
+ # is a package.
+ # Testing the values for a package are covered by test_load_module.
+ self.setUp(is_package=False)
+ with util.uncache(self.name):
+ module = self.loader.load_module(self.name)
+ self.verify_module(module)
+ self.assertTrue(not hasattr(module, '__path__'))
+
+ def test_get_source_encoding(self):
+ # Source is considered encoded in UTF-8 by default unless otherwise
+ # specified by an encoding line.
+ source = "_ = 'ü'"
+ self.loader.source = source.encode('utf-8')
+ returned_source = self.loader.get_source(self.name)
+ self.assertEqual(returned_source, source)
+ source = "# coding: latin-1\n_ = ü"
+ self.loader.source = source.encode('latin-1')
+ returned_source = self.loader.get_source(self.name)
+ self.assertEqual(returned_source, source)
+
+
+@unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true")
+class SourceLoaderBytecodeTests(SourceLoaderTestHarness):
+
+ """Test importlib.abc.SourceLoader's use of bytecode.
+
+ Source-only testing handled by SourceOnlyLoaderTests.
+
+ """
+
+ loader_mock = SourceLoaderMock
+
+ def verify_code(self, code_object, *, bytecode_written=False):
+ super().verify_code(code_object)
+ if bytecode_written:
+ self.assertIn(self.cached, self.loader.written)
+ data = bytearray(imp.get_magic())
+ data.extend(importlib._w_long(self.loader.source_mtime))
+ data.extend(importlib._w_long(self.loader.source_size))
+ data.extend(marshal.dumps(code_object))
+ self.assertEqual(self.loader.written[self.cached], bytes(data))
+
+ def test_code_with_everything(self):
+ # When everything should work.
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+
+ def test_no_bytecode(self):
+ # If no bytecode exists then move on to the source.
+ self.loader.bytecode_path = "<does not exist>"
+ # Sanity check
+ with self.assertRaises(IOError):
+ bytecode_path = imp.cache_from_source(self.path)
+ self.loader.get_data(bytecode_path)
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object, bytecode_written=True)
+
+ def test_code_bad_timestamp(self):
+ # Bytecode is only used when the timestamp matches the source EXACTLY.
+ for source_mtime in (0, 2):
+ assert source_mtime != self.loader.source_mtime
+ original = self.loader.source_mtime
+ self.loader.source_mtime = source_mtime
+ # If bytecode is used then EOFError would be raised by marshal.
+ self.loader.bytecode = self.loader.bytecode[8:]
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object, bytecode_written=True)
+ self.loader.source_mtime = original
+
+ def test_code_bad_magic(self):
+ # Skip over bytecode with a bad magic number.
+ self.setUp(magic=b'0000')
+ # If bytecode is used then EOFError would be raised by marshal.
+ self.loader.bytecode = self.loader.bytecode[8:]
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object, bytecode_written=True)
+
+ def test_dont_write_bytecode(self):
+ # Bytecode is not written if sys.dont_write_bytecode is true.
+ # Can assume it is false already thanks to the skipIf class decorator.
+ try:
+ sys.dont_write_bytecode = True
+ self.loader.bytecode_path = "<does not exist>"
+ code_object = self.loader.get_code(self.name)
+ self.assertNotIn(self.cached, self.loader.written)
+ finally:
+ sys.dont_write_bytecode = False
+
+ def test_no_set_data(self):
+ # If set_data is not defined, one can still read bytecode.
+ self.setUp(magic=b'0000')
+ original_set_data = self.loader.__class__.set_data
+ try:
+ del self.loader.__class__.set_data
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+ finally:
+ self.loader.__class__.set_data = original_set_data
+
+ def test_set_data_raises_exceptions(self):
+ # Raising NotImplementedError or IOError is okay for set_data.
+ def raise_exception(exc):
+ def closure(*args, **kwargs):
+ raise exc
+ return closure
+
+ self.setUp(magic=b'0000')
+ self.loader.set_data = raise_exception(NotImplementedError)
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+
+
+class SourceLoaderGetSourceTests(unittest.TestCase):
+
+ """Tests for importlib.abc.SourceLoader.get_source()."""
+
+ def test_default_encoding(self):
+ # Should have no problems with UTF-8 text.
+ name = 'mod'
+ mock = SourceOnlyLoaderMock('mod.file')
+ source = 'x = "ü"'
+ mock.source = source.encode('utf-8')
+ returned_source = mock.get_source(name)
+ self.assertEqual(returned_source, source)
+
+ def test_decoded_source(self):
+ # Decoding should work.
+ name = 'mod'
+ mock = SourceOnlyLoaderMock("mod.file")
+ source = "# coding: Latin-1\nx='ü'"
+ assert source.encode('latin-1') != source.encode('utf-8')
+ mock.source = source.encode('latin-1')
+ returned_source = mock.get_source(name)
+ self.assertEqual(returned_source, source)
+
+ def test_universal_newlines(self):
+ # PEP 302 says universal newlines should be used.
+ name = 'mod'
+ mock = SourceOnlyLoaderMock('mod.file')
+ source = "x = 42\r\ny = -13\r\n"
+ mock.source = source.encode('utf-8')
+ expect = io.IncrementalNewlineDecoder(None, True).decode(source)
+ self.assertEqual(mock.get_source(name), expect)
+
+
+class AbstractMethodImplTests(unittest.TestCase):
+
+ """Test the concrete abstractmethod implementations."""
+
+ class MetaPathFinder(abc.MetaPathFinder):
+ def find_module(self, fullname, path):
+ super().find_module(fullname, path)
+
+ class PathEntryFinder(abc.PathEntryFinder):
+ def find_module(self, _):
+ super().find_module(_)
+
+ def find_loader(self, _):
+ super().find_loader(_)
+
+ class Finder(abc.Finder):
+ def find_module(self, fullname, path):
+ super().find_module(fullname, path)
+
+ class Loader(abc.Loader):
+ def load_module(self, fullname):
+ super().load_module(fullname)
+ def module_repr(self, module):
+ super().module_repr(module)
+
+ class ResourceLoader(Loader, abc.ResourceLoader):
+ def get_data(self, _):
+ super().get_data(_)
+
+ class InspectLoader(Loader, abc.InspectLoader):
+ def is_package(self, _):
+ super().is_package(_)
+
+ def get_code(self, _):
+ super().get_code(_)
+
+ def get_source(self, _):
+ super().get_source(_)
+
+ class ExecutionLoader(InspectLoader, abc.ExecutionLoader):
+ def get_filename(self, _):
+ super().get_filename(_)
+
+ class SourceLoader(ResourceLoader, ExecutionLoader, abc.SourceLoader):
+ pass
+
+ class PyLoader(ResourceLoader, InspectLoader, abc.PyLoader):
+ def source_path(self, _):
+ super().source_path(_)
+
+ class PyPycLoader(PyLoader, abc.PyPycLoader):
+ def bytecode_path(self, _):
+ super().bytecode_path(_)
+
+ def source_mtime(self, _):
+ super().source_mtime(_)
+
+ def write_bytecode(self, _, _2):
+ super().write_bytecode(_, _2)
+
+ def raises_NotImplementedError(self, ins, *args):
+ for method_name in args:
+ method = getattr(ins, method_name)
+ arg_count = len(inspect.getfullargspec(method)[0]) - 1
+ args = [''] * arg_count
+ try:
+ method(*args)
+ except NotImplementedError:
+ pass
+ else:
+ msg = "{}.{} did not raise NotImplementedError"
+ self.fail(msg.format(ins.__class__.__name__, method_name))
+
+ def test_Loader(self):
+ self.raises_NotImplementedError(self.Loader(), 'load_module')
+
+ # XXX misplaced; should be somewhere else
+ def test_Finder(self):
+ self.raises_NotImplementedError(self.Finder(), 'find_module')
+
+ def test_ResourceLoader(self):
+ self.raises_NotImplementedError(self.ResourceLoader(), 'load_module',
+ 'get_data')
+
+ def test_InspectLoader(self):
+ self.raises_NotImplementedError(self.InspectLoader(), 'load_module',
+ 'is_package', 'get_code', 'get_source')
+
+ def test_ExecutionLoader(self):
+ self.raises_NotImplementedError(self.ExecutionLoader(), 'load_module',
+ 'is_package', 'get_code', 'get_source',
+ 'get_filename')
+
+ def test_SourceLoader(self):
+ ins = self.SourceLoader()
+ # Required abstractmethods.
+ self.raises_NotImplementedError(ins, 'get_filename', 'get_data')
+ # Optional abstractmethods.
+ self.raises_NotImplementedError(ins,'path_stats', 'set_data')
+
+ def test_PyLoader(self):
+ self.raises_NotImplementedError(self.PyLoader(), 'source_path',
+ 'get_data', 'is_package')
+
+ def test_PyPycLoader(self):
+ self.raises_NotImplementedError(self.PyPycLoader(), 'source_path',
+ 'source_mtime', 'bytecode_path',
+ 'write_bytecode')
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(PyLoaderTests, PyLoaderCompatTests,
+ PyLoaderInterfaceTests,
+ PyPycLoaderTests, PyPycLoaderInterfaceTests,
+ SkipWritingBytecodeTests, RegeneratedBytecodeTests,
+ BadBytecodeFailureTests, MissingPathsTests,
+ SourceOnlyLoaderTests,
+ SourceLoaderBytecodeTests,
+ SourceLoaderGetSourceTests,
+ AbstractMethodImplTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/source/test_case_sensitivity.py b/Lib/test/test_importlib/source/test_case_sensitivity.py
new file mode 100644
index 0000000..241173f
--- /dev/null
+++ b/Lib/test/test_importlib/source/test_case_sensitivity.py
@@ -0,0 +1,70 @@
+"""Test case-sensitivity (PEP 235)."""
+from importlib import _bootstrap
+from importlib import machinery
+from .. import util
+from . import util as source_util
+import imp
+import os
+import sys
+from test import support as test_support
+import unittest
+
+
+@util.case_insensitive_tests
+class CaseSensitivityTest(unittest.TestCase):
+
+ """PEP 235 dictates that on case-preserving, case-insensitive file systems
+ that imports are case-sensitive unless the PYTHONCASEOK environment
+ variable is set."""
+
+ name = 'MoDuLe'
+ assert name != name.lower()
+
+ def find(self, path):
+ finder = machinery.FileFinder(path,
+ (machinery.SourceFileLoader,
+ machinery.SOURCE_SUFFIXES),
+ (machinery.SourcelessFileLoader,
+ machinery.BYTECODE_SUFFIXES))
+ return finder.find_module(self.name)
+
+ def sensitivity_test(self):
+ """Look for a module with matching and non-matching sensitivity."""
+ sensitive_pkg = 'sensitive.{0}'.format(self.name)
+ insensitive_pkg = 'insensitive.{0}'.format(self.name.lower())
+ context = source_util.create_modules(insensitive_pkg, sensitive_pkg)
+ with context as mapping:
+ sensitive_path = os.path.join(mapping['.root'], 'sensitive')
+ insensitive_path = os.path.join(mapping['.root'], 'insensitive')
+ return self.find(sensitive_path), self.find(insensitive_path)
+
+ def test_sensitive(self):
+ with test_support.EnvironmentVarGuard() as env:
+ env.unset('PYTHONCASEOK')
+ if b'PYTHONCASEOK' in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
+ sensitive, insensitive = self.sensitivity_test()
+ self.assertTrue(hasattr(sensitive, 'load_module'))
+ self.assertIn(self.name, sensitive.get_filename(self.name))
+ self.assertIsNone(insensitive)
+
+ def test_insensitive(self):
+ with test_support.EnvironmentVarGuard() as env:
+ env.set('PYTHONCASEOK', '1')
+ if b'PYTHONCASEOK' not in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
+ sensitive, insensitive = self.sensitivity_test()
+ self.assertTrue(hasattr(sensitive, 'load_module'))
+ self.assertIn(self.name, sensitive.get_filename(self.name))
+ self.assertTrue(hasattr(insensitive, 'load_module'))
+ self.assertIn(self.name, insensitive.get_filename(self.name))
+
+
+def test_main():
+ test_support.run_unittest(CaseSensitivityTest)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py
new file mode 100644
index 0000000..90f9d30
--- /dev/null
+++ b/Lib/test/test_importlib/source/test_file_loader.py
@@ -0,0 +1,484 @@
+from importlib import machinery
+import importlib
+import importlib.abc
+from .. import abc
+from .. import util
+from . import util as source_util
+
+import errno
+import imp
+import marshal
+import os
+import py_compile
+import shutil
+import stat
+import sys
+import unittest
+
+from test.support import make_legacy_pyc
+
+
+class SimpleTest(unittest.TestCase):
+
+ """Should have no issue importing a source module [basic]. And if there is
+ a syntax error, it should raise a SyntaxError [syntax error].
+
+ """
+
+ def test_load_module_API(self):
+ # If fullname is not specified that assume self.name is desired.
+ class TesterMixin(importlib.abc.Loader):
+ def load_module(self, fullname): return fullname
+ def module_repr(self, module): return '<module>'
+
+ class Tester(importlib.abc.FileLoader, TesterMixin):
+ def get_code(self, _): pass
+ def get_source(self, _): pass
+ def is_package(self, _): pass
+
+ name = 'mod_name'
+ loader = Tester(name, 'some_path')
+ self.assertEqual(name, loader.load_module())
+ self.assertEqual(name, loader.load_module(None))
+ self.assertEqual(name, loader.load_module(name))
+ with self.assertRaises(ImportError):
+ loader.load_module(loader.name + 'XXX')
+
+ def test_get_filename_API(self):
+ # If fullname is not set then assume self.path is desired.
+ class Tester(importlib.abc.FileLoader):
+ def get_code(self, _): pass
+ def get_source(self, _): pass
+ def is_package(self, _): pass
+ def module_repr(self, _): pass
+
+ path = 'some_path'
+ name = 'some_name'
+ loader = Tester(name, path)
+ self.assertEqual(path, loader.get_filename(name))
+ self.assertEqual(path, loader.get_filename())
+ self.assertEqual(path, loader.get_filename(None))
+ with self.assertRaises(ImportError):
+ loader.get_filename(name + 'XXX')
+
+ # [basic]
+ def test_module(self):
+ with source_util.create_modules('_temp') as mapping:
+ loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ module = loader.load_module('_temp')
+ self.assertIn('_temp', sys.modules)
+ check = {'__name__': '_temp', '__file__': mapping['_temp'],
+ '__package__': ''}
+ for attr, value in check.items():
+ self.assertEqual(getattr(module, attr), value)
+
+ def test_package(self):
+ with source_util.create_modules('_pkg.__init__') as mapping:
+ loader = machinery.SourceFileLoader('_pkg',
+ mapping['_pkg.__init__'])
+ module = loader.load_module('_pkg')
+ self.assertIn('_pkg', sys.modules)
+ check = {'__name__': '_pkg', '__file__': mapping['_pkg.__init__'],
+ '__path__': [os.path.dirname(mapping['_pkg.__init__'])],
+ '__package__': '_pkg'}
+ for attr, value in check.items():
+ self.assertEqual(getattr(module, attr), value)
+
+
+ def test_lacking_parent(self):
+ with source_util.create_modules('_pkg.__init__', '_pkg.mod')as mapping:
+ loader = machinery.SourceFileLoader('_pkg.mod',
+ mapping['_pkg.mod'])
+ module = loader.load_module('_pkg.mod')
+ self.assertIn('_pkg.mod', sys.modules)
+ check = {'__name__': '_pkg.mod', '__file__': mapping['_pkg.mod'],
+ '__package__': '_pkg'}
+ for attr, value in check.items():
+ self.assertEqual(getattr(module, attr), value)
+
+ def fake_mtime(self, fxn):
+ """Fake mtime to always be higher than expected."""
+ return lambda name: fxn(name) + 1
+
+ def test_module_reuse(self):
+ with source_util.create_modules('_temp') as mapping:
+ loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ module = loader.load_module('_temp')
+ module_id = id(module)
+ module_dict_id = id(module.__dict__)
+ with open(mapping['_temp'], 'w') as file:
+ file.write("testing_var = 42\n")
+ module = loader.load_module('_temp')
+ self.assertIn('testing_var', module.__dict__,
+ "'testing_var' not in "
+ "{0}".format(list(module.__dict__.keys())))
+ self.assertEqual(module, sys.modules['_temp'])
+ self.assertEqual(id(module), module_id)
+ self.assertEqual(id(module.__dict__), module_dict_id)
+
+ def test_state_after_failure(self):
+ # A failed reload should leave the original module intact.
+ attributes = ('__file__', '__path__', '__package__')
+ value = '<test>'
+ name = '_temp'
+ with source_util.create_modules(name) as mapping:
+ orig_module = imp.new_module(name)
+ for attr in attributes:
+ setattr(orig_module, attr, value)
+ with open(mapping[name], 'w') as file:
+ file.write('+++ bad syntax +++')
+ loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ with self.assertRaises(SyntaxError):
+ loader.load_module(name)
+ for attr in attributes:
+ self.assertEqual(getattr(orig_module, attr), value)
+
+ # [syntax error]
+ def test_bad_syntax(self):
+ with source_util.create_modules('_temp') as mapping:
+ with open(mapping['_temp'], 'w') as file:
+ file.write('=')
+ loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ with self.assertRaises(SyntaxError):
+ loader.load_module('_temp')
+ self.assertNotIn('_temp', sys.modules)
+
+ def test_file_from_empty_string_dir(self):
+ # Loading a module found from an empty string entry on sys.path should
+ # not only work, but keep all attributes relative.
+ file_path = '_temp.py'
+ with open(file_path, 'w') as file:
+ file.write("# test file for importlib")
+ try:
+ with util.uncache('_temp'):
+ loader = machinery.SourceFileLoader('_temp', file_path)
+ mod = loader.load_module('_temp')
+ self.assertEqual(file_path, mod.__file__)
+ self.assertEqual(imp.cache_from_source(file_path),
+ mod.__cached__)
+ finally:
+ os.unlink(file_path)
+ pycache = os.path.dirname(imp.cache_from_source(file_path))
+ shutil.rmtree(pycache)
+
+ def test_timestamp_overflow(self):
+ # When a modification timestamp is larger than 2**32, it should be
+ # truncated rather than raise an OverflowError.
+ with source_util.create_modules('_temp') as mapping:
+ source = mapping['_temp']
+ compiled = imp.cache_from_source(source)
+ with open(source, 'w') as f:
+ f.write("x = 5")
+ try:
+ os.utime(source, (2 ** 33 - 5, 2 ** 33 - 5))
+ except OverflowError:
+ self.skipTest("cannot set modification time to large integer")
+ except OSError as e:
+ if e.errno != getattr(errno, 'EOVERFLOW', None):
+ raise
+ self.skipTest("cannot set modification time to large integer ({})".format(e))
+ loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ mod = loader.load_module('_temp')
+ # Sanity checks.
+ self.assertEqual(mod.__cached__, compiled)
+ self.assertEqual(mod.x, 5)
+ # The pyc file was created.
+ os.stat(compiled)
+
+
+class BadBytecodeTest(unittest.TestCase):
+
+ def import_(self, file, module_name):
+ loader = self.loader(module_name, file)
+ module = loader.load_module(module_name)
+ self.assertIn(module_name, sys.modules)
+
+ def manipulate_bytecode(self, name, mapping, manipulator, *,
+ del_source=False):
+ """Manipulate the bytecode of a module by passing it into a callable
+ that returns what to use as the new bytecode."""
+ try:
+ del sys.modules['_temp']
+ except KeyError:
+ pass
+ py_compile.compile(mapping[name])
+ if not del_source:
+ bytecode_path = imp.cache_from_source(mapping[name])
+ else:
+ os.unlink(mapping[name])
+ bytecode_path = make_legacy_pyc(mapping[name])
+ if manipulator:
+ with open(bytecode_path, 'rb') as file:
+ bc = file.read()
+ new_bc = manipulator(bc)
+ with open(bytecode_path, 'wb') as file:
+ if new_bc is not None:
+ file.write(new_bc)
+ return bytecode_path
+
+ def _test_empty_file(self, test, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: b'',
+ del_source=del_source)
+ test('_temp', mapping, bc_path)
+
+ @source_util.writes_bytecode_files
+ def _test_partial_magic(self, test, *, del_source=False):
+ # When their are less than 4 bytes to a .pyc, regenerate it if
+ # possible, else raise ImportError.
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:3],
+ del_source=del_source)
+ test('_temp', mapping, bc_path)
+
+ def _test_magic_only(self, test, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:4],
+ del_source=del_source)
+ test('_temp', mapping, bc_path)
+
+ def _test_partial_timestamp(self, test, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:7],
+ del_source=del_source)
+ test('_temp', mapping, bc_path)
+
+ def _test_partial_size(self, test, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:11],
+ del_source=del_source)
+ test('_temp', mapping, bc_path)
+
+ def _test_no_marshal(self, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:12],
+ del_source=del_source)
+ file_path = mapping['_temp'] if not del_source else bc_path
+ with self.assertRaises(EOFError):
+ self.import_(file_path, '_temp')
+
+ def _test_non_code_marshal(self, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bytecode_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:12] + marshal.dumps(b'abcd'),
+ del_source=del_source)
+ file_path = mapping['_temp'] if not del_source else bytecode_path
+ with self.assertRaises(ImportError) as cm:
+ self.import_(file_path, '_temp')
+ self.assertEqual(cm.exception.name, '_temp')
+ self.assertEqual(cm.exception.path, bytecode_path)
+
+ def _test_bad_marshal(self, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bytecode_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:12] + b'<test>',
+ del_source=del_source)
+ file_path = mapping['_temp'] if not del_source else bytecode_path
+ with self.assertRaises(EOFError):
+ self.import_(file_path, '_temp')
+
+ def _test_bad_magic(self, test, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: b'\x00\x00\x00\x00' + bc[4:])
+ test('_temp', mapping, bc_path)
+
+
+class SourceLoaderBadBytecodeTest(BadBytecodeTest):
+
+ loader = machinery.SourceFileLoader
+
+ @source_util.writes_bytecode_files
+ def test_empty_file(self):
+ # When a .pyc is empty, regenerate it if possible, else raise
+ # ImportError.
+ def test(name, mapping, bytecode_path):
+ self.import_(mapping[name], name)
+ with open(bytecode_path, 'rb') as file:
+ self.assertGreater(len(file.read()), 12)
+
+ self._test_empty_file(test)
+
+ def test_partial_magic(self):
+ def test(name, mapping, bytecode_path):
+ self.import_(mapping[name], name)
+ with open(bytecode_path, 'rb') as file:
+ self.assertGreater(len(file.read()), 12)
+
+ self._test_partial_magic(test)
+
+ @source_util.writes_bytecode_files
+ def test_magic_only(self):
+ # When there is only the magic number, regenerate the .pyc if possible,
+ # else raise EOFError.
+ def test(name, mapping, bytecode_path):
+ self.import_(mapping[name], name)
+ with open(bytecode_path, 'rb') as file:
+ self.assertGreater(len(file.read()), 12)
+
+ self._test_magic_only(test)
+
+ @source_util.writes_bytecode_files
+ def test_bad_magic(self):
+ # When the magic number is different, the bytecode should be
+ # regenerated.
+ def test(name, mapping, bytecode_path):
+ self.import_(mapping[name], name)
+ with open(bytecode_path, 'rb') as bytecode_file:
+ self.assertEqual(bytecode_file.read(4), imp.get_magic())
+
+ self._test_bad_magic(test)
+
+ @source_util.writes_bytecode_files
+ def test_partial_timestamp(self):
+ # When the timestamp is partial, regenerate the .pyc, else
+ # raise EOFError.
+ def test(name, mapping, bc_path):
+ self.import_(mapping[name], name)
+ with open(bc_path, 'rb') as file:
+ self.assertGreater(len(file.read()), 12)
+
+ self._test_partial_timestamp(test)
+
+ @source_util.writes_bytecode_files
+ def test_partial_size(self):
+ # When the size is partial, regenerate the .pyc, else
+ # raise EOFError.
+ def test(name, mapping, bc_path):
+ self.import_(mapping[name], name)
+ with open(bc_path, 'rb') as file:
+ self.assertGreater(len(file.read()), 12)
+
+ self._test_partial_size(test)
+
+ @source_util.writes_bytecode_files
+ def test_no_marshal(self):
+ # When there is only the magic number and timestamp, raise EOFError.
+ self._test_no_marshal()
+
+ @source_util.writes_bytecode_files
+ def test_non_code_marshal(self):
+ self._test_non_code_marshal()
+ # XXX ImportError when sourceless
+
+ # [bad marshal]
+ @source_util.writes_bytecode_files
+ def test_bad_marshal(self):
+ # Bad marshal data should raise a ValueError.
+ self._test_bad_marshal()
+
+ # [bad timestamp]
+ @source_util.writes_bytecode_files
+ def test_old_timestamp(self):
+ # When the timestamp is older than the source, bytecode should be
+ # regenerated.
+ zeros = b'\x00\x00\x00\x00'
+ with source_util.create_modules('_temp') as mapping:
+ py_compile.compile(mapping['_temp'])
+ bytecode_path = imp.cache_from_source(mapping['_temp'])
+ with open(bytecode_path, 'r+b') as bytecode_file:
+ bytecode_file.seek(4)
+ bytecode_file.write(zeros)
+ self.import_(mapping['_temp'], '_temp')
+ source_mtime = os.path.getmtime(mapping['_temp'])
+ source_timestamp = importlib._w_long(source_mtime)
+ with open(bytecode_path, 'rb') as bytecode_file:
+ bytecode_file.seek(4)
+ self.assertEqual(bytecode_file.read(4), source_timestamp)
+
+ # [bytecode read-only]
+ @source_util.writes_bytecode_files
+ def test_read_only_bytecode(self):
+ # When bytecode is read-only but should be rewritten, fail silently.
+ with source_util.create_modules('_temp') as mapping:
+ # Create bytecode that will need to be re-created.
+ py_compile.compile(mapping['_temp'])
+ bytecode_path = imp.cache_from_source(mapping['_temp'])
+ with open(bytecode_path, 'r+b') as bytecode_file:
+ bytecode_file.seek(0)
+ bytecode_file.write(b'\x00\x00\x00\x00')
+ # Make the bytecode read-only.
+ os.chmod(bytecode_path,
+ stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
+ try:
+ # Should not raise IOError!
+ self.import_(mapping['_temp'], '_temp')
+ finally:
+ # Make writable for eventual clean-up.
+ os.chmod(bytecode_path, stat.S_IWUSR)
+
+
+class SourcelessLoaderBadBytecodeTest(BadBytecodeTest):
+
+ loader = machinery.SourcelessFileLoader
+
+ def test_empty_file(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(ImportError) as cm:
+ self.import_(bytecode_path, name)
+ self.assertEqual(cm.exception.name, name)
+ self.assertEqual(cm.exception.path, bytecode_path)
+
+ self._test_empty_file(test, del_source=True)
+
+ def test_partial_magic(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(ImportError) as cm:
+ self.import_(bytecode_path, name)
+ self.assertEqual(cm.exception.name, name)
+ self.assertEqual(cm.exception.path, bytecode_path)
+ self._test_partial_magic(test, del_source=True)
+
+ def test_magic_only(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(EOFError):
+ self.import_(bytecode_path, name)
+
+ self._test_magic_only(test, del_source=True)
+
+ def test_bad_magic(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(ImportError) as cm:
+ self.import_(bytecode_path, name)
+ self.assertEqual(cm.exception.name, name)
+ self.assertEqual(cm.exception.path, bytecode_path)
+
+ self._test_bad_magic(test, del_source=True)
+
+ def test_partial_timestamp(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(EOFError):
+ self.import_(bytecode_path, name)
+
+ self._test_partial_timestamp(test, del_source=True)
+
+ def test_partial_size(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(EOFError):
+ self.import_(bytecode_path, name)
+
+ self._test_partial_size(test, del_source=True)
+
+ def test_no_marshal(self):
+ self._test_no_marshal(del_source=True)
+
+ def test_non_code_marshal(self):
+ self._test_non_code_marshal(del_source=True)
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(SimpleTest,
+ SourceLoaderBadBytecodeTest,
+ SourcelessLoaderBadBytecodeTest
+ )
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py
new file mode 100644
index 0000000..8e49868
--- /dev/null
+++ b/Lib/test/test_importlib/source/test_finder.py
@@ -0,0 +1,191 @@
+from .. import abc
+from . import util as source_util
+
+from importlib import machinery
+import errno
+import imp
+import os
+import py_compile
+import stat
+import sys
+import tempfile
+from test.support import make_legacy_pyc
+import unittest
+import warnings
+
+
+class FinderTests(abc.FinderTests):
+
+ """For a top-level module, it should just be found directly in the
+ directory being searched. This is true for a directory with source
+ [top-level source], bytecode [top-level bc], or both [top-level both].
+ There is also the possibility that it is a package [top-level package], in
+ which case there will be a directory with the module name and an
+ __init__.py file. If there is a directory without an __init__.py an
+ ImportWarning is returned [empty dir].
+
+ For sub-modules and sub-packages, the same happens as above but only use
+ the tail end of the name [sub module] [sub package] [sub empty].
+
+ When there is a conflict between a package and module having the same name
+ in the same directory, the package wins out [package over module]. This is
+ so that imports of modules within the package can occur rather than trigger
+ an import error.
+
+ When there is a package and module with the same name, always pick the
+ package over the module [package over module]. This is so that imports from
+ the package have the possibility of succeeding.
+
+ """
+
+ def get_finder(self, root):
+ loader_details = [(machinery.SourceFileLoader,
+ machinery.SOURCE_SUFFIXES),
+ (machinery.SourcelessFileLoader,
+ machinery.BYTECODE_SUFFIXES)]
+ return machinery.FileFinder(root, *loader_details)
+
+ def import_(self, root, module):
+ return self.get_finder(root).find_module(module)
+
+ def run_test(self, test, create=None, *, compile_=None, unlink=None):
+ """Test the finding of 'test' with the creation of modules listed in
+ 'create'.
+
+ Any names listed in 'compile_' are byte-compiled. Modules
+ listed in 'unlink' have their source files deleted.
+
+ """
+ if create is None:
+ create = {test}
+ with source_util.create_modules(*create) as mapping:
+ if compile_:
+ for name in compile_:
+ py_compile.compile(mapping[name])
+ if unlink:
+ for name in unlink:
+ os.unlink(mapping[name])
+ try:
+ make_legacy_pyc(mapping[name])
+ except OSError as error:
+ # Some tests do not set compile_=True so the source
+ # module will not get compiled and there will be no
+ # PEP 3147 pyc file to rename.
+ if error.errno != errno.ENOENT:
+ raise
+ loader = self.import_(mapping['.root'], test)
+ self.assertTrue(hasattr(loader, 'load_module'))
+ return loader
+
+ def test_module(self):
+ # [top-level source]
+ self.run_test('top_level')
+ # [top-level bc]
+ self.run_test('top_level', compile_={'top_level'},
+ unlink={'top_level'})
+ # [top-level both]
+ self.run_test('top_level', compile_={'top_level'})
+
+ # [top-level package]
+ def test_package(self):
+ # Source.
+ self.run_test('pkg', {'pkg.__init__'})
+ # Bytecode.
+ self.run_test('pkg', {'pkg.__init__'}, compile_={'pkg.__init__'},
+ unlink={'pkg.__init__'})
+ # Both.
+ self.run_test('pkg', {'pkg.__init__'}, compile_={'pkg.__init__'})
+
+ # [sub module]
+ def test_module_in_package(self):
+ with source_util.create_modules('pkg.__init__', 'pkg.sub') as mapping:
+ pkg_dir = os.path.dirname(mapping['pkg.__init__'])
+ loader = self.import_(pkg_dir, 'pkg.sub')
+ self.assertTrue(hasattr(loader, 'load_module'))
+
+ # [sub package]
+ def test_package_in_package(self):
+ context = source_util.create_modules('pkg.__init__', 'pkg.sub.__init__')
+ with context as mapping:
+ pkg_dir = os.path.dirname(mapping['pkg.__init__'])
+ loader = self.import_(pkg_dir, 'pkg.sub')
+ self.assertTrue(hasattr(loader, 'load_module'))
+
+ # [package over modules]
+ def test_package_over_module(self):
+ name = '_temp'
+ loader = self.run_test(name, {'{0}.__init__'.format(name), name})
+ self.assertIn('__init__', loader.get_filename(name))
+
+ def test_failure(self):
+ with source_util.create_modules('blah') as mapping:
+ nothing = self.import_(mapping['.root'], 'sdfsadsadf')
+ self.assertIsNone(nothing)
+
+ def test_empty_string_for_dir(self):
+ # The empty string from sys.path means to search in the cwd.
+ finder = machinery.FileFinder('', (machinery.SourceFileLoader,
+ machinery.SOURCE_SUFFIXES))
+ with open('mod.py', 'w') as file:
+ file.write("# test file for importlib")
+ try:
+ loader = finder.find_module('mod')
+ self.assertTrue(hasattr(loader, 'load_module'))
+ finally:
+ os.unlink('mod.py')
+
+ def test_invalidate_caches(self):
+ # invalidate_caches() should reset the mtime.
+ finder = machinery.FileFinder('', (machinery.SourceFileLoader,
+ machinery.SOURCE_SUFFIXES))
+ finder._path_mtime = 42
+ finder.invalidate_caches()
+ self.assertEqual(finder._path_mtime, -1)
+
+ # Regression test for http://bugs.python.org/issue14846
+ def test_dir_removal_handling(self):
+ mod = 'mod'
+ with source_util.create_modules(mod) as mapping:
+ finder = self.get_finder(mapping['.root'])
+ self.assertIsNotNone(finder.find_module(mod))
+ self.assertIsNone(finder.find_module(mod))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ 'os.chmod() does not support the needed arguments under Windows')
+ def test_no_read_directory(self):
+ # Issue #16730
+ tempdir = tempfile.TemporaryDirectory()
+ original_mode = os.stat(tempdir.name).st_mode
+ def cleanup(tempdir):
+ """Cleanup function for the temporary directory.
+
+ Since we muck with the permissions, we want to set them back to
+ their original values to make sure the directory can be properly
+ cleaned up.
+
+ """
+ os.chmod(tempdir.name, original_mode)
+ # If this is not explicitly called then the __del__ method is used,
+ # but since already mucking around might as well explicitly clean
+ # up.
+ tempdir.__exit__(None, None, None)
+ self.addCleanup(cleanup, tempdir)
+ os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR)
+ finder = self.get_finder(tempdir.name)
+ self.assertEqual((None, []), finder.find_loader('doesnotexist'))
+
+ def test_ignore_file(self):
+ # If a directory got changed to a file from underneath us, then don't
+ # worry about looking for submodules.
+ with tempfile.NamedTemporaryFile() as file_obj:
+ finder = self.get_finder(file_obj.name)
+ self.assertEqual((None, []), finder.find_loader('doesnotexist'))
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(FinderTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/source/test_path_hook.py b/Lib/test/test_importlib/source/test_path_hook.py
new file mode 100644
index 0000000..6a78792
--- /dev/null
+++ b/Lib/test/test_importlib/source/test_path_hook.py
@@ -0,0 +1,32 @@
+from . import util as source_util
+
+from importlib import machinery
+import imp
+import unittest
+
+
+class PathHookTest(unittest.TestCase):
+
+ """Test the path hook for source."""
+
+ def path_hook(self):
+ return machinery.FileFinder.path_hook((machinery.SourceFileLoader,
+ machinery.SOURCE_SUFFIXES))
+
+ def test_success(self):
+ with source_util.create_modules('dummy') as mapping:
+ self.assertTrue(hasattr(self.path_hook()(mapping['.root']),
+ 'find_module'))
+
+ def test_empty_string(self):
+ # The empty string represents the cwd.
+ self.assertTrue(hasattr(self.path_hook()(''), 'find_module'))
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(PathHookTest)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/source/test_source_encoding.py b/Lib/test/test_importlib/source/test_source_encoding.py
new file mode 100644
index 0000000..0ca5195
--- /dev/null
+++ b/Lib/test/test_importlib/source/test_source_encoding.py
@@ -0,0 +1,123 @@
+from . import util as source_util
+
+from importlib import _bootstrap
+import codecs
+import re
+import sys
+# Because sys.path gets essentially blanked, need to have unicodedata already
+# imported for the parser to use.
+import unicodedata
+import unittest
+
+
+CODING_RE = re.compile(r'coding[:=]\s*([-\w.]+)')
+
+
+class EncodingTest(unittest.TestCase):
+
+ """PEP 3120 makes UTF-8 the default encoding for source code
+ [default encoding].
+
+ PEP 263 specifies how that can change on a per-file basis. Either the first
+ or second line can contain the encoding line [encoding first line]
+ encoding second line]. If the file has the BOM marker it is considered UTF-8
+ implicitly [BOM]. If any encoding is specified it must be UTF-8, else it is
+ an error [BOM and utf-8][BOM conflict].
+
+ """
+
+ variable = '\u00fc'
+ character = '\u00c9'
+ source_line = "{0} = '{1}'\n".format(variable, character)
+ module_name = '_temp'
+
+ def run_test(self, source):
+ with source_util.create_modules(self.module_name) as mapping:
+ with open(mapping[self.module_name], 'wb') as file:
+ file.write(source)
+ loader = _bootstrap.SourceFileLoader(self.module_name,
+ mapping[self.module_name])
+ return loader.load_module(self.module_name)
+
+ def create_source(self, encoding):
+ encoding_line = "# coding={0}".format(encoding)
+ assert CODING_RE.search(encoding_line)
+ source_lines = [encoding_line.encode('utf-8')]
+ source_lines.append(self.source_line.encode(encoding))
+ return b'\n'.join(source_lines)
+
+ def test_non_obvious_encoding(self):
+ # Make sure that an encoding that has never been a standard one for
+ # Python works.
+ encoding_line = "# coding=koi8-r"
+ assert CODING_RE.search(encoding_line)
+ source = "{0}\na=42\n".format(encoding_line).encode("koi8-r")
+ self.run_test(source)
+
+ # [default encoding]
+ def test_default_encoding(self):
+ self.run_test(self.source_line.encode('utf-8'))
+
+ # [encoding first line]
+ def test_encoding_on_first_line(self):
+ encoding = 'Latin-1'
+ source = self.create_source(encoding)
+ self.run_test(source)
+
+ # [encoding second line]
+ def test_encoding_on_second_line(self):
+ source = b"#/usr/bin/python\n" + self.create_source('Latin-1')
+ self.run_test(source)
+
+ # [BOM]
+ def test_bom(self):
+ self.run_test(codecs.BOM_UTF8 + self.source_line.encode('utf-8'))
+
+ # [BOM and utf-8]
+ def test_bom_and_utf_8(self):
+ source = codecs.BOM_UTF8 + self.create_source('utf-8')
+ self.run_test(source)
+
+ # [BOM conflict]
+ def test_bom_conflict(self):
+ source = codecs.BOM_UTF8 + self.create_source('latin-1')
+ with self.assertRaises(SyntaxError):
+ self.run_test(source)
+
+
+class LineEndingTest(unittest.TestCase):
+
+ r"""Source written with the three types of line endings (\n, \r\n, \r)
+ need to be readable [cr][crlf][lf]."""
+
+ def run_test(self, line_ending):
+ module_name = '_temp'
+ source_lines = [b"a = 42", b"b = -13", b'']
+ source = line_ending.join(source_lines)
+ with source_util.create_modules(module_name) as mapping:
+ with open(mapping[module_name], 'wb') as file:
+ file.write(source)
+ loader = _bootstrap.SourceFileLoader(module_name,
+ mapping[module_name])
+ return loader.load_module(module_name)
+
+ # [cr]
+ def test_cr(self):
+ self.run_test(b'\r')
+
+ # [crlf]
+ def test_crlf(self):
+ self.run_test(b'\r\n')
+
+ # [lf]
+ def test_lf(self):
+ self.run_test(b'\n')
+
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(EncodingTest, LineEndingTest)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/source/util.py b/Lib/test/test_importlib/source/util.py
new file mode 100644
index 0000000..ae65663
--- /dev/null
+++ b/Lib/test/test_importlib/source/util.py
@@ -0,0 +1,97 @@
+from .. import util
+import contextlib
+import errno
+import functools
+import imp
+import os
+import os.path
+import sys
+import tempfile
+from test import support
+
+
+def writes_bytecode_files(fxn):
+ """Decorator to protect sys.dont_write_bytecode from mutation and to skip
+ tests that require it to be set to False."""
+ if sys.dont_write_bytecode:
+ return lambda *args, **kwargs: None
+ @functools.wraps(fxn)
+ def wrapper(*args, **kwargs):
+ original = sys.dont_write_bytecode
+ sys.dont_write_bytecode = False
+ try:
+ to_return = fxn(*args, **kwargs)
+ finally:
+ sys.dont_write_bytecode = original
+ return to_return
+ return wrapper
+
+
+def ensure_bytecode_path(bytecode_path):
+ """Ensure that the __pycache__ directory for PEP 3147 pyc file exists.
+
+ :param bytecode_path: File system path to PEP 3147 pyc file.
+ """
+ try:
+ os.mkdir(os.path.dirname(bytecode_path))
+ except OSError as error:
+ if error.errno != errno.EEXIST:
+ raise
+
+
+@contextlib.contextmanager
+def create_modules(*names):
+ """Temporarily create each named module with an attribute (named 'attr')
+ that contains the name passed into the context manager that caused the
+ creation of the module.
+
+ All files are created in a temporary directory returned by
+ tempfile.mkdtemp(). This directory is inserted at the beginning of
+ sys.path. When the context manager exits all created files (source and
+ bytecode) are explicitly deleted.
+
+ No magic is performed when creating packages! This means that if you create
+ a module within a package you must also create the package's __init__ as
+ well.
+
+ """
+ source = 'attr = {0!r}'
+ created_paths = []
+ mapping = {}
+ state_manager = None
+ uncache_manager = None
+ try:
+ temp_dir = tempfile.mkdtemp()
+ mapping['.root'] = temp_dir
+ import_names = set()
+ for name in names:
+ if not name.endswith('__init__'):
+ import_name = name
+ else:
+ import_name = name[:-len('.__init__')]
+ import_names.add(import_name)
+ if import_name in sys.modules:
+ del sys.modules[import_name]
+ name_parts = name.split('.')
+ file_path = temp_dir
+ for directory in name_parts[:-1]:
+ file_path = os.path.join(file_path, directory)
+ if not os.path.exists(file_path):
+ os.mkdir(file_path)
+ created_paths.append(file_path)
+ file_path = os.path.join(file_path, name_parts[-1] + '.py')
+ with open(file_path, 'w') as file:
+ file.write(source.format(name))
+ created_paths.append(file_path)
+ mapping[name] = file_path
+ uncache_manager = util.uncache(*import_names)
+ uncache_manager.__enter__()
+ state_manager = util.import_state(path=[temp_dir])
+ state_manager.__enter__()
+ yield mapping
+ finally:
+ if state_manager is not None:
+ state_manager.__exit__(None, None, None)
+ if uncache_manager is not None:
+ uncache_manager.__exit__(None, None, None)
+ support.rmtree(temp_dir)
diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py
new file mode 100644
index 0000000..c620c37
--- /dev/null
+++ b/Lib/test/test_importlib/test_abc.py
@@ -0,0 +1,103 @@
+from importlib import abc
+from importlib import machinery
+import inspect
+import unittest
+
+
+class InheritanceTests:
+
+ """Test that the specified class is a subclass/superclass of the expected
+ classes."""
+
+ subclasses = []
+ superclasses = []
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.subclasses or self.superclasses, self.__class__
+ self.__test = getattr(abc, self.__class__.__name__)
+
+ def test_subclasses(self):
+ # Test that the expected subclasses inherit.
+ for subclass in self.subclasses:
+ self.assertTrue(issubclass(subclass, self.__test),
+ "{0} is not a subclass of {1}".format(subclass, self.__test))
+
+ def test_superclasses(self):
+ # Test that the class inherits from the expected superclasses.
+ for superclass in self.superclasses:
+ self.assertTrue(issubclass(self.__test, superclass),
+ "{0} is not a superclass of {1}".format(superclass, self.__test))
+
+
+class MetaPathFinder(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.Finder]
+ subclasses = [machinery.BuiltinImporter, machinery.FrozenImporter,
+ machinery.PathFinder, machinery.WindowsRegistryFinder]
+
+
+class PathEntryFinder(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.Finder]
+ subclasses = [machinery.FileFinder]
+
+
+class Loader(InheritanceTests, unittest.TestCase):
+
+ subclasses = [abc.PyLoader]
+
+
+class ResourceLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.Loader]
+
+
+class InspectLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.Loader]
+ subclasses = [abc.PyLoader, machinery.BuiltinImporter,
+ machinery.FrozenImporter, machinery.ExtensionFileLoader]
+
+
+class ExecutionLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.InspectLoader]
+ subclasses = [abc.PyLoader]
+
+
+class FileLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.ResourceLoader, abc.ExecutionLoader]
+ subclasses = [machinery.SourceFileLoader, machinery.SourcelessFileLoader]
+
+
+class SourceLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.ResourceLoader, abc.ExecutionLoader]
+ subclasses = [machinery.SourceFileLoader]
+
+
+class PyLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.Loader, abc.ResourceLoader, abc.ExecutionLoader]
+
+
+class PyPycLoader(InheritanceTests, unittest.TestCase):
+
+ superclasses = [abc.PyLoader]
+
+
+def test_main():
+ from test.support import run_unittest
+ classes = []
+ for class_ in globals().values():
+ if (inspect.isclass(class_) and
+ issubclass(class_, unittest.TestCase) and
+ issubclass(class_, InheritanceTests)):
+ classes.append(class_)
+ run_unittest(*classes)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py
new file mode 100644
index 0000000..b1a5894
--- /dev/null
+++ b/Lib/test/test_importlib/test_api.py
@@ -0,0 +1,202 @@
+from . import util
+import imp
+import importlib
+from importlib import machinery
+import sys
+from test import support
+import types
+import unittest
+
+
+class ImportModuleTests(unittest.TestCase):
+
+ """Test importlib.import_module."""
+
+ def test_module_import(self):
+ # Test importing a top-level module.
+ with util.mock_modules('top_level') as mock:
+ with util.import_state(meta_path=[mock]):
+ module = importlib.import_module('top_level')
+ self.assertEqual(module.__name__, 'top_level')
+
+ def test_absolute_package_import(self):
+ # Test importing a module from a package with an absolute name.
+ pkg_name = 'pkg'
+ pkg_long_name = '{0}.__init__'.format(pkg_name)
+ name = '{0}.mod'.format(pkg_name)
+ with util.mock_modules(pkg_long_name, name) as mock:
+ with util.import_state(meta_path=[mock]):
+ module = importlib.import_module(name)
+ self.assertEqual(module.__name__, name)
+
+ def test_shallow_relative_package_import(self):
+ # Test importing a module from a package through a relative import.
+ pkg_name = 'pkg'
+ pkg_long_name = '{0}.__init__'.format(pkg_name)
+ module_name = 'mod'
+ absolute_name = '{0}.{1}'.format(pkg_name, module_name)
+ relative_name = '.{0}'.format(module_name)
+ with util.mock_modules(pkg_long_name, absolute_name) as mock:
+ with util.import_state(meta_path=[mock]):
+ importlib.import_module(pkg_name)
+ module = importlib.import_module(relative_name, pkg_name)
+ self.assertEqual(module.__name__, absolute_name)
+
+ def test_deep_relative_package_import(self):
+ modules = ['a.__init__', 'a.b.__init__', 'a.c']
+ with util.mock_modules(*modules) as mock:
+ with util.import_state(meta_path=[mock]):
+ importlib.import_module('a')
+ importlib.import_module('a.b')
+ module = importlib.import_module('..c', 'a.b')
+ self.assertEqual(module.__name__, 'a.c')
+
+ def test_absolute_import_with_package(self):
+ # Test importing a module from a package with an absolute name with
+ # the 'package' argument given.
+ pkg_name = 'pkg'
+ pkg_long_name = '{0}.__init__'.format(pkg_name)
+ name = '{0}.mod'.format(pkg_name)
+ with util.mock_modules(pkg_long_name, name) as mock:
+ with util.import_state(meta_path=[mock]):
+ importlib.import_module(pkg_name)
+ module = importlib.import_module(name, pkg_name)
+ self.assertEqual(module.__name__, name)
+
+ def test_relative_import_wo_package(self):
+ # Relative imports cannot happen without the 'package' argument being
+ # set.
+ with self.assertRaises(TypeError):
+ importlib.import_module('.support')
+
+
+ def test_loaded_once(self):
+ # Issue #13591: Modules should only be loaded once when
+ # initializing the parent package attempts to import the
+ # module currently being imported.
+ b_load_count = 0
+ def load_a():
+ importlib.import_module('a.b')
+ def load_b():
+ nonlocal b_load_count
+ b_load_count += 1
+ code = {'a': load_a, 'a.b': load_b}
+ modules = ['a.__init__', 'a.b']
+ with util.mock_modules(*modules, module_code=code) as mock:
+ with util.import_state(meta_path=[mock]):
+ importlib.import_module('a.b')
+ self.assertEqual(b_load_count, 1)
+
+
+class FindLoaderTests(unittest.TestCase):
+
+ class FakeMetaFinder:
+ @staticmethod
+ def find_module(name, path=None): return name, path
+
+ def test_sys_modules(self):
+ # If a module with __loader__ is in sys.modules, then return it.
+ name = 'some_mod'
+ with util.uncache(name):
+ module = imp.new_module(name)
+ loader = 'a loader!'
+ module.__loader__ = loader
+ sys.modules[name] = module
+ found = importlib.find_loader(name)
+ self.assertEqual(loader, found)
+
+ def test_sys_modules_loader_is_None(self):
+ # If sys.modules[name].__loader__ is None, raise ValueError.
+ name = 'some_mod'
+ with util.uncache(name):
+ module = imp.new_module(name)
+ module.__loader__ = None
+ sys.modules[name] = module
+ with self.assertRaises(ValueError):
+ importlib.find_loader(name)
+
+ def test_success(self):
+ # Return the loader found on sys.meta_path.
+ name = 'some_mod'
+ with util.uncache(name):
+ with util.import_state(meta_path=[self.FakeMetaFinder]):
+ self.assertEqual((name, None), importlib.find_loader(name))
+
+ def test_success_path(self):
+ # Searching on a path should work.
+ name = 'some_mod'
+ path = 'path to some place'
+ with util.uncache(name):
+ with util.import_state(meta_path=[self.FakeMetaFinder]):
+ self.assertEqual((name, path),
+ importlib.find_loader(name, path))
+
+ def test_nothing(self):
+ # None is returned upon failure to find a loader.
+ self.assertIsNone(importlib.find_loader('nevergoingtofindthismodule'))
+
+
+class InvalidateCacheTests(unittest.TestCase):
+
+ def test_method_called(self):
+ # If defined the method should be called.
+ class InvalidatingNullFinder:
+ def __init__(self, *ignored):
+ self.called = False
+ def find_module(self, *args):
+ return None
+ def invalidate_caches(self):
+ self.called = True
+
+ key = 'gobledeegook'
+ meta_ins = InvalidatingNullFinder()
+ path_ins = InvalidatingNullFinder()
+ sys.meta_path.insert(0, meta_ins)
+ self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key))
+ sys.path_importer_cache[key] = path_ins
+ self.addCleanup(lambda: sys.meta_path.remove(meta_ins))
+ importlib.invalidate_caches()
+ self.assertTrue(meta_ins.called)
+ self.assertTrue(path_ins.called)
+
+ def test_method_lacking(self):
+ # There should be no issues if the method is not defined.
+ key = 'gobbledeegook'
+ sys.path_importer_cache[key] = imp.NullImporter('abc')
+ self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key))
+ importlib.invalidate_caches() # Shouldn't trigger an exception.
+
+
+class FrozenImportlibTests(unittest.TestCase):
+
+ def test_no_frozen_importlib(self):
+ # Should be able to import w/o _frozen_importlib being defined.
+ module = support.import_fresh_module('importlib', blocked=['_frozen_importlib'])
+ self.assertFalse(isinstance(module.__loader__,
+ machinery.FrozenImporter))
+
+
+class StartupTests(unittest.TestCase):
+
+ def test_everyone_has___loader__(self):
+ # Issue #17098: all modules should have __loader__ defined.
+ for name, module in sys.modules.items():
+ if isinstance(module, types.ModuleType):
+ if name in sys.builtin_module_names:
+ self.assertEqual(importlib.machinery.BuiltinImporter,
+ module.__loader__)
+ elif imp.is_frozen(name):
+ self.assertEqual(importlib.machinery.FrozenImporter,
+ module.__loader__)
+
+def test_main():
+ from test.support import run_unittest
+ run_unittest(ImportModuleTests,
+ FindLoaderTests,
+ InvalidateCacheTests,
+ FrozenImportlibTests,
+ StartupTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py
new file mode 100644
index 0000000..c373b11
--- /dev/null
+++ b/Lib/test/test_importlib/test_locks.py
@@ -0,0 +1,129 @@
+from importlib import _bootstrap
+import sys
+import time
+import unittest
+import weakref
+
+from test import support
+
+try:
+ import threading
+except ImportError:
+ threading = None
+else:
+ from test import lock_tests
+
+
+LockType = _bootstrap._ModuleLock
+DeadlockError = _bootstrap._DeadlockError
+
+
+if threading is not None:
+ class ModuleLockAsRLockTests(lock_tests.RLockTests):
+ locktype = staticmethod(lambda: LockType("some_lock"))
+
+ # _is_owned() unsupported
+ test__is_owned = None
+ # acquire(blocking=False) unsupported
+ test_try_acquire = None
+ test_try_acquire_contended = None
+ # `with` unsupported
+ test_with = None
+ # acquire(timeout=...) unsupported
+ test_timeout = None
+ # _release_save() unsupported
+ test_release_save_unacquired = None
+
+else:
+ class ModuleLockAsRLockTests(unittest.TestCase):
+ pass
+
+
+@unittest.skipUnless(threading, "threads needed for this test")
+class DeadlockAvoidanceTests(unittest.TestCase):
+
+ def setUp(self):
+ try:
+ self.old_switchinterval = sys.getswitchinterval()
+ sys.setswitchinterval(0.000001)
+ except AttributeError:
+ self.old_switchinterval = None
+
+ def tearDown(self):
+ if self.old_switchinterval is not None:
+ sys.setswitchinterval(self.old_switchinterval)
+
+ def run_deadlock_avoidance_test(self, create_deadlock):
+ NLOCKS = 10
+ locks = [LockType(str(i)) for i in range(NLOCKS)]
+ pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)]
+ if create_deadlock:
+ NTHREADS = NLOCKS
+ else:
+ NTHREADS = NLOCKS - 1
+ barrier = threading.Barrier(NTHREADS)
+ results = []
+ def _acquire(lock):
+ """Try to acquire the lock. Return True on success, False on deadlock."""
+ try:
+ lock.acquire()
+ except DeadlockError:
+ return False
+ else:
+ return True
+ def f():
+ a, b = pairs.pop()
+ ra = _acquire(a)
+ barrier.wait()
+ rb = _acquire(b)
+ results.append((ra, rb))
+ if rb:
+ b.release()
+ if ra:
+ a.release()
+ lock_tests.Bunch(f, NTHREADS).wait_for_finished()
+ self.assertEqual(len(results), NTHREADS)
+ return results
+
+ def test_deadlock(self):
+ results = self.run_deadlock_avoidance_test(True)
+ # At least one of the threads detected a potential deadlock on its
+ # second acquire() call. It may be several of them, because the
+ # deadlock avoidance mechanism is conservative.
+ nb_deadlocks = results.count((True, False))
+ self.assertGreaterEqual(nb_deadlocks, 1)
+ self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks)
+
+ def test_no_deadlock(self):
+ results = self.run_deadlock_avoidance_test(False)
+ self.assertEqual(results.count((True, False)), 0)
+ self.assertEqual(results.count((True, True)), len(results))
+
+
+class LifetimeTests(unittest.TestCase):
+
+ def test_lock_lifetime(self):
+ name = "xyzzy"
+ self.assertNotIn(name, _bootstrap._module_locks)
+ lock = _bootstrap._get_module_lock(name)
+ self.assertIn(name, _bootstrap._module_locks)
+ wr = weakref.ref(lock)
+ del lock
+ support.gc_collect()
+ self.assertNotIn(name, _bootstrap._module_locks)
+ self.assertIsNone(wr())
+
+ def test_all_locks(self):
+ support.gc_collect()
+ self.assertEqual(0, len(_bootstrap._module_locks), _bootstrap._module_locks)
+
+
+@support.reap_threads
+def test_main():
+ support.run_unittest(ModuleLockAsRLockTests,
+ DeadlockAvoidanceTests,
+ LifetimeTests)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py
new file mode 100644
index 0000000..efc8977
--- /dev/null
+++ b/Lib/test/test_importlib/test_util.py
@@ -0,0 +1,208 @@
+from importlib import util
+from . import util as test_util
+import imp
+import sys
+import types
+import unittest
+
+
+class ModuleForLoaderTests(unittest.TestCase):
+
+ """Tests for importlib.util.module_for_loader."""
+
+ def return_module(self, name):
+ fxn = util.module_for_loader(lambda self, module: module)
+ return fxn(self, name)
+
+ def raise_exception(self, name):
+ def to_wrap(self, module):
+ raise ImportError
+ fxn = util.module_for_loader(to_wrap)
+ try:
+ fxn(self, name)
+ except ImportError:
+ pass
+
+ def test_new_module(self):
+ # Test that when no module exists in sys.modules a new module is
+ # created.
+ module_name = 'a.b.c'
+ with test_util.uncache(module_name):
+ module = self.return_module(module_name)
+ self.assertIn(module_name, sys.modules)
+ self.assertIsInstance(module, types.ModuleType)
+ self.assertEqual(module.__name__, module_name)
+
+ def test_reload(self):
+ # Test that a module is reused if already in sys.modules.
+ name = 'a.b.c'
+ module = imp.new_module('a.b.c')
+ with test_util.uncache(name):
+ sys.modules[name] = module
+ returned_module = self.return_module(name)
+ self.assertIs(returned_module, sys.modules[name])
+
+ def test_new_module_failure(self):
+ # Test that a module is removed from sys.modules if added but an
+ # exception is raised.
+ name = 'a.b.c'
+ with test_util.uncache(name):
+ self.raise_exception(name)
+ self.assertNotIn(name, sys.modules)
+
+ def test_reload_failure(self):
+ # Test that a failure on reload leaves the module in-place.
+ name = 'a.b.c'
+ module = imp.new_module(name)
+ with test_util.uncache(name):
+ sys.modules[name] = module
+ self.raise_exception(name)
+ self.assertIs(module, sys.modules[name])
+
+ def test_decorator_attrs(self):
+ def fxn(self, module): pass
+ wrapped = util.module_for_loader(fxn)
+ self.assertEqual(wrapped.__name__, fxn.__name__)
+ self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
+
+ def test_false_module(self):
+ # If for some odd reason a module is considered false, still return it
+ # from sys.modules.
+ class FalseModule(types.ModuleType):
+ def __bool__(self): return False
+
+ name = 'mod'
+ module = FalseModule(name)
+ with test_util.uncache(name):
+ self.assertFalse(module)
+ sys.modules[name] = module
+ given = self.return_module(name)
+ self.assertIs(given, module)
+
+ def test_attributes_set(self):
+ # __name__, __loader__, and __package__ should be set (when
+ # is_package() is defined; undefined implicitly tested elsewhere).
+ class FakeLoader:
+ def __init__(self, is_package):
+ self._pkg = is_package
+ def is_package(self, name):
+ return self._pkg
+ @util.module_for_loader
+ def load_module(self, module):
+ return module
+
+ name = 'pkg.mod'
+ with test_util.uncache(name):
+ loader = FakeLoader(False)
+ module = loader.load_module(name)
+ self.assertEqual(module.__name__, name)
+ self.assertIs(module.__loader__, loader)
+ self.assertEqual(module.__package__, 'pkg')
+
+ name = 'pkg.sub'
+ with test_util.uncache(name):
+ loader = FakeLoader(True)
+ module = loader.load_module(name)
+ self.assertEqual(module.__name__, name)
+ self.assertIs(module.__loader__, loader)
+ self.assertEqual(module.__package__, name)
+
+
+class SetPackageTests(unittest.TestCase):
+
+ """Tests for importlib.util.set_package."""
+
+ def verify(self, module, expect):
+ """Verify the module has the expected value for __package__ after
+ passing through set_package."""
+ fxn = lambda: module
+ wrapped = util.set_package(fxn)
+ wrapped()
+ self.assertTrue(hasattr(module, '__package__'))
+ self.assertEqual(expect, module.__package__)
+
+ def test_top_level(self):
+ # __package__ should be set to the empty string if a top-level module.
+ # Implicitly tests when package is set to None.
+ module = imp.new_module('module')
+ module.__package__ = None
+ self.verify(module, '')
+
+ def test_package(self):
+ # Test setting __package__ for a package.
+ module = imp.new_module('pkg')
+ module.__path__ = ['<path>']
+ module.__package__ = None
+ self.verify(module, 'pkg')
+
+ def test_submodule(self):
+ # Test __package__ for a module in a package.
+ module = imp.new_module('pkg.mod')
+ module.__package__ = None
+ self.verify(module, 'pkg')
+
+ def test_setting_if_missing(self):
+ # __package__ should be set if it is missing.
+ module = imp.new_module('mod')
+ if hasattr(module, '__package__'):
+ delattr(module, '__package__')
+ self.verify(module, '')
+
+ def test_leaving_alone(self):
+ # If __package__ is set and not None then leave it alone.
+ for value in (True, False):
+ module = imp.new_module('mod')
+ module.__package__ = value
+ self.verify(module, value)
+
+ def test_decorator_attrs(self):
+ def fxn(module): pass
+ wrapped = util.set_package(fxn)
+ self.assertEqual(wrapped.__name__, fxn.__name__)
+ self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
+
+
+class ResolveNameTests(unittest.TestCase):
+
+ """Tests importlib.util.resolve_name()."""
+
+ def test_absolute(self):
+ # bacon
+ self.assertEqual('bacon', util.resolve_name('bacon', None))
+
+ def test_aboslute_within_package(self):
+ # bacon in spam
+ self.assertEqual('bacon', util.resolve_name('bacon', 'spam'))
+
+ def test_no_package(self):
+ # .bacon in ''
+ with self.assertRaises(ValueError):
+ util.resolve_name('.bacon', '')
+
+ def test_in_package(self):
+ # .bacon in spam
+ self.assertEqual('spam.eggs.bacon',
+ util.resolve_name('.bacon', 'spam.eggs'))
+
+ def test_other_package(self):
+ # ..bacon in spam.bacon
+ self.assertEqual('spam.bacon',
+ util.resolve_name('..bacon', 'spam.eggs'))
+
+ def test_escape(self):
+ # ..bacon in spam
+ with self.assertRaises(ValueError):
+ util.resolve_name('..bacon', 'spam')
+
+
+def test_main():
+ from test import support
+ support.run_unittest(
+ ModuleForLoaderTests,
+ SetPackageTests,
+ ResolveNameTests
+ )
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py
new file mode 100644
index 0000000..ef32f7d
--- /dev/null
+++ b/Lib/test/test_importlib/util.py
@@ -0,0 +1,140 @@
+from contextlib import contextmanager
+import imp
+import os.path
+from test import support
+import unittest
+import sys
+
+
+CASE_INSENSITIVE_FS = True
+# Windows is the only OS that is *always* case-insensitive
+# (OS X *can* be case-sensitive).
+if sys.platform not in ('win32', 'cygwin'):
+ changed_name = __file__.upper()
+ if changed_name == __file__:
+ changed_name = __file__.lower()
+ if not os.path.exists(changed_name):
+ CASE_INSENSITIVE_FS = False
+
+
+def case_insensitive_tests(test):
+ """Class decorator that nullifies tests requiring a case-insensitive
+ file system."""
+ return unittest.skipIf(not CASE_INSENSITIVE_FS,
+ "requires a case-insensitive filesystem")(test)
+
+
+@contextmanager
+def uncache(*names):
+ """Uncache a module from sys.modules.
+
+ A basic sanity check is performed to prevent uncaching modules that either
+ cannot/shouldn't be uncached.
+
+ """
+ for name in names:
+ if name in ('sys', 'marshal', 'imp'):
+ raise ValueError(
+ "cannot uncache {0}".format(name))
+ try:
+ del sys.modules[name]
+ except KeyError:
+ pass
+ try:
+ yield
+ finally:
+ for name in names:
+ try:
+ del sys.modules[name]
+ except KeyError:
+ pass
+
+@contextmanager
+def import_state(**kwargs):
+ """Context manager to manage the various importers and stored state in the
+ sys module.
+
+ The 'modules' attribute is not supported as the interpreter state stores a
+ pointer to the dict that the interpreter uses internally;
+ reassigning to sys.modules does not have the desired effect.
+
+ """
+ originals = {}
+ try:
+ for attr, default in (('meta_path', []), ('path', []),
+ ('path_hooks', []),
+ ('path_importer_cache', {})):
+ originals[attr] = getattr(sys, attr)
+ if attr in kwargs:
+ new_value = kwargs[attr]
+ del kwargs[attr]
+ else:
+ new_value = default
+ setattr(sys, attr, new_value)
+ if len(kwargs):
+ raise ValueError(
+ 'unrecognized arguments: {0}'.format(kwargs.keys()))
+ yield
+ finally:
+ for attr, value in originals.items():
+ setattr(sys, attr, value)
+
+
+class mock_modules:
+
+ """A mock importer/loader."""
+
+ def __init__(self, *names, module_code={}):
+ self.modules = {}
+ self.module_code = {}
+ for name in names:
+ if not name.endswith('.__init__'):
+ import_name = name
+ else:
+ import_name = name[:-len('.__init__')]
+ if '.' not in name:
+ package = None
+ elif import_name == name:
+ package = name.rsplit('.', 1)[0]
+ else:
+ package = import_name
+ module = imp.new_module(import_name)
+ module.__loader__ = self
+ module.__file__ = '<mock __file__>'
+ module.__package__ = package
+ module.attr = name
+ if import_name != name:
+ module.__path__ = ['<mock __path__>']
+ self.modules[import_name] = module
+ if import_name in module_code:
+ self.module_code[import_name] = module_code[import_name]
+
+ def __getitem__(self, name):
+ return self.modules[name]
+
+ def find_module(self, fullname, path=None):
+ if fullname not in self.modules:
+ return None
+ else:
+ return self
+
+ def load_module(self, fullname):
+ if fullname not in self.modules:
+ raise ImportError
+ else:
+ sys.modules[fullname] = self.modules[fullname]
+ if fullname in self.module_code:
+ try:
+ self.module_code[fullname]()
+ except Exception:
+ del sys.modules[fullname]
+ raise
+ return self.modules[fullname]
+
+ def __enter__(self):
+ self._uncache = uncache(*self.modules.keys())
+ self._uncache.__enter__()
+ return self
+
+ def __exit__(self, *exc_info):
+ self._uncache.__exit__(None, None, None)
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 4b7ee4e..6e3f04e 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -41,11 +41,6 @@ def revise(filename, *args):
import builtins
-try:
- 1/0
-except:
- tb = sys.exc_info()[2]
-
git = mod.StupidGit()
class IsTestBase(unittest.TestCase):
@@ -79,23 +74,31 @@ class TestPredicates(IsTestBase):
def test_excluding_predicates(self):
+ global tb
self.istest(inspect.isbuiltin, 'sys.exit')
self.istest(inspect.isbuiltin, '[].append')
self.istest(inspect.iscode, 'mod.spam.__code__')
- self.istest(inspect.isframe, 'tb.tb_frame')
+ try:
+ 1/0
+ except:
+ tb = sys.exc_info()[2]
+ self.istest(inspect.isframe, 'tb.tb_frame')
+ self.istest(inspect.istraceback, 'tb')
+ if hasattr(types, 'GetSetDescriptorType'):
+ self.istest(inspect.isgetsetdescriptor,
+ 'type(tb.tb_frame).f_locals')
+ else:
+ self.assertFalse(inspect.isgetsetdescriptor(type(tb.tb_frame).f_locals))
+ finally:
+ # Clear traceback and all the frames and local variables hanging to it.
+ tb = None
self.istest(inspect.isfunction, 'mod.spam')
self.istest(inspect.isfunction, 'mod.StupidGit.abuse')
self.istest(inspect.ismethod, 'git.argue')
self.istest(inspect.ismodule, 'mod')
- self.istest(inspect.istraceback, 'tb')
self.istest(inspect.isdatadescriptor, 'collections.defaultdict.default_factory')
self.istest(inspect.isgenerator, '(x for x in range(2))')
self.istest(inspect.isgeneratorfunction, 'generator_function_example')
- if hasattr(types, 'GetSetDescriptorType'):
- self.istest(inspect.isgetsetdescriptor,
- 'type(tb.tb_frame).f_locals')
- else:
- self.assertFalse(inspect.isgetsetdescriptor(type(tb.tb_frame).f_locals))
if hasattr(types, 'MemberDescriptorType'):
self.istest(inspect.ismemberdescriptor, 'datetime.timedelta.days')
else:
@@ -282,7 +285,10 @@ class TestRetrievingSourceCode(GetSourceBase):
co = compile("None", fn, "exec")
self.assertEqual(inspect.getsourcefile(co), None)
linecache.cache[co.co_filename] = (1, None, "None", co.co_filename)
- self.assertEqual(normcase(inspect.getsourcefile(co)), fn)
+ try:
+ self.assertEqual(normcase(inspect.getsourcefile(co)), fn)
+ finally:
+ del linecache.cache[co.co_filename]
def test_getfile(self):
self.assertEqual(inspect.getfile(mod.StupidGit), mod.__file__)
@@ -304,7 +310,7 @@ class TestRetrievingSourceCode(GetSourceBase):
getlines = linecache.getlines
def monkey(filename, module_globals=None):
if filename == fn:
- return source.splitlines(True)
+ return source.splitlines(keepends=True)
else:
return getlines(filename, module_globals)
linecache.getlines = monkey
@@ -404,8 +410,11 @@ class TestBuggyCases(GetSourceBase):
self.assertRaises(IOError, inspect.findsource, co)
self.assertRaises(IOError, inspect.getsource, co)
linecache.cache[co.co_filename] = (1, None, lines, co.co_filename)
- self.assertEqual(inspect.findsource(co), (lines,0))
- self.assertEqual(inspect.getsource(co), lines[0])
+ try:
+ self.assertEqual(inspect.findsource(co), (lines,0))
+ self.assertEqual(inspect.getsource(co), lines[0])
+ finally:
+ del linecache.cache[co.co_filename]
class TestNoEOL(GetSourceBase):
def __init__(self, *args, **kwargs):
@@ -662,6 +671,129 @@ class TestClassesAndFunctions(unittest.TestCase):
self.assertIn(('f', b.f), inspect.getmembers(b, inspect.ismethod))
+_global_ref = object()
+class TestGetClosureVars(unittest.TestCase):
+
+ def test_name_resolution(self):
+ # Basic test of the 4 different resolution mechanisms
+ def f(nonlocal_ref):
+ def g(local_ref):
+ print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+ return g
+ _arg = object()
+ nonlocal_vars = {"nonlocal_ref": _arg}
+ global_vars = {"_global_ref": _global_ref}
+ builtin_vars = {"print": print}
+ unbound_names = {"unbound_ref"}
+ expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+ builtin_vars, unbound_names)
+ self.assertEqual(inspect.getclosurevars(f(_arg)), expected)
+
+ def test_generator_closure(self):
+ def f(nonlocal_ref):
+ def g(local_ref):
+ print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+ yield
+ return g
+ _arg = object()
+ nonlocal_vars = {"nonlocal_ref": _arg}
+ global_vars = {"_global_ref": _global_ref}
+ builtin_vars = {"print": print}
+ unbound_names = {"unbound_ref"}
+ expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+ builtin_vars, unbound_names)
+ self.assertEqual(inspect.getclosurevars(f(_arg)), expected)
+
+ def test_method_closure(self):
+ class C:
+ def f(self, nonlocal_ref):
+ def g(local_ref):
+ print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+ return g
+ _arg = object()
+ nonlocal_vars = {"nonlocal_ref": _arg}
+ global_vars = {"_global_ref": _global_ref}
+ builtin_vars = {"print": print}
+ unbound_names = {"unbound_ref"}
+ expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+ builtin_vars, unbound_names)
+ self.assertEqual(inspect.getclosurevars(C().f(_arg)), expected)
+
+ def test_nonlocal_vars(self):
+ # More complex tests of nonlocal resolution
+ def _nonlocal_vars(f):
+ return inspect.getclosurevars(f).nonlocals
+
+ def make_adder(x):
+ def add(y):
+ return x + y
+ return add
+
+ def curry(func, arg1):
+ return lambda arg2: func(arg1, arg2)
+
+ def less_than(a, b):
+ return a < b
+
+ # The infamous Y combinator.
+ def Y(le):
+ def g(f):
+ return le(lambda x: f(f)(x))
+ Y.g_ref = g
+ return g(g)
+
+ def check_y_combinator(func):
+ self.assertEqual(_nonlocal_vars(func), {'f': Y.g_ref})
+
+ inc = make_adder(1)
+ add_two = make_adder(2)
+ greater_than_five = curry(less_than, 5)
+
+ self.assertEqual(_nonlocal_vars(inc), {'x': 1})
+ self.assertEqual(_nonlocal_vars(add_two), {'x': 2})
+ self.assertEqual(_nonlocal_vars(greater_than_five),
+ {'arg1': 5, 'func': less_than})
+ self.assertEqual(_nonlocal_vars((lambda x: lambda y: x + y)(3)),
+ {'x': 3})
+ Y(check_y_combinator)
+
+ def test_getclosurevars_empty(self):
+ def foo(): pass
+ _empty = inspect.ClosureVars({}, {}, {}, set())
+ self.assertEqual(inspect.getclosurevars(lambda: True), _empty)
+ self.assertEqual(inspect.getclosurevars(foo), _empty)
+
+ def test_getclosurevars_error(self):
+ class T: pass
+ self.assertRaises(TypeError, inspect.getclosurevars, 1)
+ self.assertRaises(TypeError, inspect.getclosurevars, list)
+ self.assertRaises(TypeError, inspect.getclosurevars, {})
+
+ def _private_globals(self):
+ code = """def f(): print(path)"""
+ ns = {}
+ exec(code, ns)
+ return ns["f"], ns
+
+ def test_builtins_fallback(self):
+ f, ns = self._private_globals()
+ ns.pop("__builtins__", None)
+ expected = inspect.ClosureVars({}, {}, {"print":print}, {"path"})
+ self.assertEqual(inspect.getclosurevars(f), expected)
+
+ def test_builtins_as_dict(self):
+ f, ns = self._private_globals()
+ ns["__builtins__"] = {"path":1}
+ expected = inspect.ClosureVars({}, {}, {"path":1}, {"print"})
+ self.assertEqual(inspect.getclosurevars(f), expected)
+
+ def test_builtins_as_module(self):
+ f, ns = self._private_globals()
+ ns["__builtins__"] = os
+ expected = inspect.ClosureVars({}, {}, {"path":os.path}, {"print"})
+ self.assertEqual(inspect.getclosurevars(f), expected)
+
+
class TestGetcallargsFunctions(unittest.TestCase):
def assertEqualCallArgs(self, func, call_params_string, locs=None):
@@ -1169,6 +1301,982 @@ class TestGetGeneratorState(unittest.TestCase):
self.assertIn(name, repr(state))
self.assertIn(name, str(state))
+ def test_getgeneratorlocals(self):
+ def each(lst, a=None):
+ b=(1, 2, 3)
+ for v in lst:
+ if v == 3:
+ c = 12
+ yield v
+
+ numbers = each([1, 2, 3])
+ self.assertEqual(inspect.getgeneratorlocals(numbers),
+ {'a': None, 'lst': [1, 2, 3]})
+ next(numbers)
+ self.assertEqual(inspect.getgeneratorlocals(numbers),
+ {'a': None, 'lst': [1, 2, 3], 'v': 1,
+ 'b': (1, 2, 3)})
+ next(numbers)
+ self.assertEqual(inspect.getgeneratorlocals(numbers),
+ {'a': None, 'lst': [1, 2, 3], 'v': 2,
+ 'b': (1, 2, 3)})
+ next(numbers)
+ self.assertEqual(inspect.getgeneratorlocals(numbers),
+ {'a': None, 'lst': [1, 2, 3], 'v': 3,
+ 'b': (1, 2, 3), 'c': 12})
+ try:
+ next(numbers)
+ except StopIteration:
+ pass
+ self.assertEqual(inspect.getgeneratorlocals(numbers), {})
+
+ def test_getgeneratorlocals_empty(self):
+ def yield_one():
+ yield 1
+ one = yield_one()
+ self.assertEqual(inspect.getgeneratorlocals(one), {})
+ try:
+ next(one)
+ except StopIteration:
+ pass
+ self.assertEqual(inspect.getgeneratorlocals(one), {})
+
+ def test_getgeneratorlocals_error(self):
+ self.assertRaises(TypeError, inspect.getgeneratorlocals, 1)
+ self.assertRaises(TypeError, inspect.getgeneratorlocals, lambda x: True)
+ self.assertRaises(TypeError, inspect.getgeneratorlocals, set)
+ self.assertRaises(TypeError, inspect.getgeneratorlocals, (2,3))
+
+
+class TestSignatureObject(unittest.TestCase):
+ @staticmethod
+ def signature(func):
+ sig = inspect.signature(func)
+ return (tuple((param.name,
+ (... if param.default is param.empty else param.default),
+ (... if param.annotation is param.empty
+ else param.annotation),
+ str(param.kind).lower())
+ for param in sig.parameters.values()),
+ (... if sig.return_annotation is sig.empty
+ else sig.return_annotation))
+
+ def test_signature_object(self):
+ S = inspect.Signature
+ P = inspect.Parameter
+
+ self.assertEqual(str(S()), '()')
+
+ def test(po, pk, *args, ko, **kwargs):
+ pass
+ sig = inspect.signature(test)
+ po = sig.parameters['po'].replace(kind=P.POSITIONAL_ONLY)
+ pk = sig.parameters['pk']
+ args = sig.parameters['args']
+ ko = sig.parameters['ko']
+ kwargs = sig.parameters['kwargs']
+
+ S((po, pk, args, ko, kwargs))
+
+ with self.assertRaisesRegex(ValueError, 'wrong parameter order'):
+ S((pk, po, args, ko, kwargs))
+
+ with self.assertRaisesRegex(ValueError, 'wrong parameter order'):
+ S((po, args, pk, ko, kwargs))
+
+ with self.assertRaisesRegex(ValueError, 'wrong parameter order'):
+ S((args, po, pk, ko, kwargs))
+
+ with self.assertRaisesRegex(ValueError, 'wrong parameter order'):
+ S((po, pk, args, kwargs, ko))
+
+ kwargs2 = kwargs.replace(name='args')
+ with self.assertRaisesRegex(ValueError, 'duplicate parameter name'):
+ S((po, pk, args, kwargs2, ko))
+
+ def test_signature_immutability(self):
+ def test(a):
+ pass
+ sig = inspect.signature(test)
+
+ with self.assertRaises(AttributeError):
+ sig.foo = 'bar'
+
+ with self.assertRaises(TypeError):
+ sig.parameters['a'] = None
+
+ def test_signature_on_noarg(self):
+ def test():
+ pass
+ self.assertEqual(self.signature(test), ((), ...))
+
+ def test_signature_on_wargs(self):
+ def test(a, b:'foo') -> 123:
+ pass
+ self.assertEqual(self.signature(test),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., 'foo', "positional_or_keyword")),
+ 123))
+
+ def test_signature_on_wkwonly(self):
+ def test(*, a:float, b:str) -> int:
+ pass
+ self.assertEqual(self.signature(test),
+ ((('a', ..., float, "keyword_only"),
+ ('b', ..., str, "keyword_only")),
+ int))
+
+ def test_signature_on_complex_args(self):
+ def test(a, b:'foo'=10, *args:'bar', spam:'baz', ham=123, **kwargs:int):
+ pass
+ self.assertEqual(self.signature(test),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', 10, 'foo', "positional_or_keyword"),
+ ('args', ..., 'bar', "var_positional"),
+ ('spam', ..., 'baz', "keyword_only"),
+ ('ham', 123, ..., "keyword_only"),
+ ('kwargs', ..., int, "var_keyword")),
+ ...))
+
+ def test_signature_on_builtin_function(self):
+ with self.assertRaisesRegex(ValueError, 'not supported by signature'):
+ inspect.signature(type)
+ with self.assertRaisesRegex(ValueError, 'not supported by signature'):
+ # support for 'wrapper_descriptor'
+ inspect.signature(type.__call__)
+ with self.assertRaisesRegex(ValueError, 'not supported by signature'):
+ # support for 'method-wrapper'
+ inspect.signature(min.__call__)
+ with self.assertRaisesRegex(ValueError,
+ 'no signature found for builtin function'):
+ # support for 'method-wrapper'
+ inspect.signature(min)
+
+ def test_signature_on_non_function(self):
+ with self.assertRaisesRegex(TypeError, 'is not a callable object'):
+ inspect.signature(42)
+
+ with self.assertRaisesRegex(TypeError, 'is not a Python function'):
+ inspect.Signature.from_function(42)
+
+ def test_signature_on_method(self):
+ class Test:
+ def foo(self, arg1, arg2=1) -> int:
+ pass
+
+ meth = Test().foo
+
+ self.assertEqual(self.signature(meth),
+ ((('arg1', ..., ..., "positional_or_keyword"),
+ ('arg2', 1, ..., "positional_or_keyword")),
+ int))
+
+ def test_signature_on_classmethod(self):
+ class Test:
+ @classmethod
+ def foo(cls, arg1, *, arg2=1):
+ pass
+
+ meth = Test().foo
+ self.assertEqual(self.signature(meth),
+ ((('arg1', ..., ..., "positional_or_keyword"),
+ ('arg2', 1, ..., "keyword_only")),
+ ...))
+
+ meth = Test.foo
+ self.assertEqual(self.signature(meth),
+ ((('arg1', ..., ..., "positional_or_keyword"),
+ ('arg2', 1, ..., "keyword_only")),
+ ...))
+
+ def test_signature_on_staticmethod(self):
+ class Test:
+ @staticmethod
+ def foo(cls, *, arg):
+ pass
+
+ meth = Test().foo
+ self.assertEqual(self.signature(meth),
+ ((('cls', ..., ..., "positional_or_keyword"),
+ ('arg', ..., ..., "keyword_only")),
+ ...))
+
+ meth = Test.foo
+ self.assertEqual(self.signature(meth),
+ ((('cls', ..., ..., "positional_or_keyword"),
+ ('arg', ..., ..., "keyword_only")),
+ ...))
+
+ def test_signature_on_partial(self):
+ from functools import partial
+
+ def test():
+ pass
+
+ self.assertEqual(self.signature(partial(test)), ((), ...))
+
+ with self.assertRaisesRegex(ValueError, "has incorrect arguments"):
+ inspect.signature(partial(test, 1))
+
+ with self.assertRaisesRegex(ValueError, "has incorrect arguments"):
+ inspect.signature(partial(test, a=1))
+
+ def test(a, b, *, c, d):
+ pass
+
+ self.assertEqual(self.signature(partial(test)),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword"),
+ ('c', ..., ..., "keyword_only"),
+ ('d', ..., ..., "keyword_only")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, 1)),
+ ((('b', ..., ..., "positional_or_keyword"),
+ ('c', ..., ..., "keyword_only"),
+ ('d', ..., ..., "keyword_only")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, 1, c=2)),
+ ((('b', ..., ..., "positional_or_keyword"),
+ ('c', 2, ..., "keyword_only"),
+ ('d', ..., ..., "keyword_only")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, b=1, c=2)),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', 1, ..., "positional_or_keyword"),
+ ('c', 2, ..., "keyword_only"),
+ ('d', ..., ..., "keyword_only")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, 0, b=1, c=2)),
+ ((('b', 1, ..., "positional_or_keyword"),
+ ('c', 2, ..., "keyword_only"),
+ ('d', ..., ..., "keyword_only"),),
+ ...))
+
+ def test(a, *args, b, **kwargs):
+ pass
+
+ self.assertEqual(self.signature(partial(test, 1)),
+ ((('args', ..., ..., "var_positional"),
+ ('b', ..., ..., "keyword_only"),
+ ('kwargs', ..., ..., "var_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, 1, 2, 3)),
+ ((('args', ..., ..., "var_positional"),
+ ('b', ..., ..., "keyword_only"),
+ ('kwargs', ..., ..., "var_keyword")),
+ ...))
+
+
+ self.assertEqual(self.signature(partial(test, 1, 2, 3, test=True)),
+ ((('args', ..., ..., "var_positional"),
+ ('b', ..., ..., "keyword_only"),
+ ('kwargs', ..., ..., "var_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, 1, 2, 3, test=1, b=0)),
+ ((('args', ..., ..., "var_positional"),
+ ('b', 0, ..., "keyword_only"),
+ ('kwargs', ..., ..., "var_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, b=0)),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('args', ..., ..., "var_positional"),
+ ('b', 0, ..., "keyword_only"),
+ ('kwargs', ..., ..., "var_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(partial(test, b=0, test=1)),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('args', ..., ..., "var_positional"),
+ ('b', 0, ..., "keyword_only"),
+ ('kwargs', ..., ..., "var_keyword")),
+ ...))
+
+ def test(a, b, c:int) -> 42:
+ pass
+
+ sig = test.__signature__ = inspect.signature(test)
+
+ self.assertEqual(self.signature(partial(partial(test, 1))),
+ ((('b', ..., ..., "positional_or_keyword"),
+ ('c', ..., int, "positional_or_keyword")),
+ 42))
+
+ self.assertEqual(self.signature(partial(partial(test, 1), 2)),
+ ((('c', ..., int, "positional_or_keyword"),),
+ 42))
+
+ psig = inspect.signature(partial(partial(test, 1), 2))
+
+ def foo(a):
+ return a
+ _foo = partial(partial(foo, a=10), a=20)
+ self.assertEqual(self.signature(_foo),
+ ((('a', 20, ..., "positional_or_keyword"),),
+ ...))
+ # check that we don't have any side-effects in signature(),
+ # and the partial object is still functioning
+ self.assertEqual(_foo(), 20)
+
+ def foo(a, b, c):
+ return a, b, c
+ _foo = partial(partial(foo, 1, b=20), b=30)
+ self.assertEqual(self.signature(_foo),
+ ((('b', 30, ..., "positional_or_keyword"),
+ ('c', ..., ..., "positional_or_keyword")),
+ ...))
+ self.assertEqual(_foo(c=10), (1, 30, 10))
+ _foo = partial(_foo, 2) # now 'b' has two values -
+ # positional and keyword
+ with self.assertRaisesRegex(ValueError, "has incorrect arguments"):
+ inspect.signature(_foo)
+
+ def foo(a, b, c, *, d):
+ return a, b, c, d
+ _foo = partial(partial(foo, d=20, c=20), b=10, d=30)
+ self.assertEqual(self.signature(_foo),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', 10, ..., "positional_or_keyword"),
+ ('c', 20, ..., "positional_or_keyword"),
+ ('d', 30, ..., "keyword_only")),
+ ...))
+ ba = inspect.signature(_foo).bind(a=200, b=11)
+ self.assertEqual(_foo(*ba.args, **ba.kwargs), (200, 11, 20, 30))
+
+ def foo(a=1, b=2, c=3):
+ return a, b, c
+ _foo = partial(foo, a=10, c=13)
+ ba = inspect.signature(_foo).bind(11)
+ self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 2, 13))
+ ba = inspect.signature(_foo).bind(11, 12)
+ self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 12, 13))
+ ba = inspect.signature(_foo).bind(11, b=12)
+ self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 12, 13))
+ ba = inspect.signature(_foo).bind(b=12)
+ self.assertEqual(_foo(*ba.args, **ba.kwargs), (10, 12, 13))
+ _foo = partial(_foo, b=10)
+ ba = inspect.signature(_foo).bind(12, 14)
+ self.assertEqual(_foo(*ba.args, **ba.kwargs), (12, 14, 13))
+
+ def test_signature_on_decorated(self):
+ import functools
+
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs) -> int:
+ return func(*args, **kwargs)
+ return wrapper
+
+ class Foo:
+ @decorator
+ def bar(self, a, b):
+ pass
+
+ self.assertEqual(self.signature(Foo.bar),
+ ((('self', ..., ..., "positional_or_keyword"),
+ ('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(Foo().bar),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword")),
+ ...))
+
+ # Test that we handle method wrappers correctly
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs) -> int:
+ return func(42, *args, **kwargs)
+ sig = inspect.signature(func)
+ new_params = tuple(sig.parameters.values())[1:]
+ wrapper.__signature__ = sig.replace(parameters=new_params)
+ return wrapper
+
+ class Foo:
+ @decorator
+ def __call__(self, a, b):
+ pass
+
+ self.assertEqual(self.signature(Foo.__call__),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(Foo().__call__),
+ ((('b', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ def test_signature_on_class(self):
+ class C:
+ def __init__(self, a):
+ pass
+
+ self.assertEqual(self.signature(C),
+ ((('a', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ class CM(type):
+ def __call__(cls, a):
+ pass
+ class C(metaclass=CM):
+ def __init__(self, b):
+ pass
+
+ self.assertEqual(self.signature(C),
+ ((('a', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ class CM(type):
+ def __new__(mcls, name, bases, dct, *, foo=1):
+ return super().__new__(mcls, name, bases, dct)
+ class C(metaclass=CM):
+ def __init__(self, b):
+ pass
+
+ self.assertEqual(self.signature(C),
+ ((('b', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ self.assertEqual(self.signature(CM),
+ ((('name', ..., ..., "positional_or_keyword"),
+ ('bases', ..., ..., "positional_or_keyword"),
+ ('dct', ..., ..., "positional_or_keyword"),
+ ('foo', 1, ..., "keyword_only")),
+ ...))
+
+ class CMM(type):
+ def __new__(mcls, name, bases, dct, *, foo=1):
+ return super().__new__(mcls, name, bases, dct)
+ def __call__(cls, nm, bs, dt):
+ return type(nm, bs, dt)
+ class CM(type, metaclass=CMM):
+ def __new__(mcls, name, bases, dct, *, bar=2):
+ return super().__new__(mcls, name, bases, dct)
+ class C(metaclass=CM):
+ def __init__(self, b):
+ pass
+
+ self.assertEqual(self.signature(CMM),
+ ((('name', ..., ..., "positional_or_keyword"),
+ ('bases', ..., ..., "positional_or_keyword"),
+ ('dct', ..., ..., "positional_or_keyword"),
+ ('foo', 1, ..., "keyword_only")),
+ ...))
+
+ self.assertEqual(self.signature(CM),
+ ((('nm', ..., ..., "positional_or_keyword"),
+ ('bs', ..., ..., "positional_or_keyword"),
+ ('dt', ..., ..., "positional_or_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(C),
+ ((('b', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ class CM(type):
+ def __init__(cls, name, bases, dct, *, bar=2):
+ return super().__init__(name, bases, dct)
+ class C(metaclass=CM):
+ def __init__(self, b):
+ pass
+
+ self.assertEqual(self.signature(CM),
+ ((('name', ..., ..., "positional_or_keyword"),
+ ('bases', ..., ..., "positional_or_keyword"),
+ ('dct', ..., ..., "positional_or_keyword"),
+ ('bar', 2, ..., "keyword_only")),
+ ...))
+
+ def test_signature_on_callable_objects(self):
+ class Foo:
+ def __call__(self, a):
+ pass
+
+ self.assertEqual(self.signature(Foo()),
+ ((('a', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ class Spam:
+ pass
+ with self.assertRaisesRegex(TypeError, "is not a callable object"):
+ inspect.signature(Spam())
+
+ class Bar(Spam, Foo):
+ pass
+
+ self.assertEqual(self.signature(Bar()),
+ ((('a', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ class ToFail:
+ __call__ = type
+ with self.assertRaisesRegex(ValueError, "not supported by signature"):
+ inspect.signature(ToFail())
+
+
+ class Wrapped:
+ pass
+ Wrapped.__wrapped__ = lambda a: None
+ self.assertEqual(self.signature(Wrapped),
+ ((('a', ..., ..., "positional_or_keyword"),),
+ ...))
+
+ def test_signature_on_lambdas(self):
+ self.assertEqual(self.signature((lambda a=10: a)),
+ ((('a', 10, ..., "positional_or_keyword"),),
+ ...))
+
+ def test_signature_equality(self):
+ def foo(a, *, b:int) -> float: pass
+ self.assertNotEqual(inspect.signature(foo), 42)
+
+ def bar(a, *, b:int) -> float: pass
+ self.assertEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def bar(a, *, b:int) -> int: pass
+ self.assertNotEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def bar(a, *, b:int): pass
+ self.assertNotEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def bar(a, *, b:int=42) -> float: pass
+ self.assertNotEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def bar(a, *, c) -> float: pass
+ self.assertNotEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def bar(a, b:int) -> float: pass
+ self.assertNotEqual(inspect.signature(foo), inspect.signature(bar))
+ def spam(b:int, a) -> float: pass
+ self.assertNotEqual(inspect.signature(spam), inspect.signature(bar))
+
+ def foo(*, a, b, c): pass
+ def bar(*, c, b, a): pass
+ self.assertEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def foo(*, a=1, b, c): pass
+ def bar(*, c, b, a=1): pass
+ self.assertEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def foo(pos, *, a=1, b, c): pass
+ def bar(pos, *, c, b, a=1): pass
+ self.assertEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def foo(pos, *, a, b, c): pass
+ def bar(pos, *, c, b, a=1): pass
+ self.assertNotEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def foo(pos, *args, a=42, b, c, **kwargs:int): pass
+ def bar(pos, *args, c, b, a=42, **kwargs:int): pass
+ self.assertEqual(inspect.signature(foo), inspect.signature(bar))
+
+ def test_signature_unhashable(self):
+ def foo(a): pass
+ sig = inspect.signature(foo)
+ with self.assertRaisesRegex(TypeError, 'unhashable type'):
+ hash(sig)
+
+ def test_signature_str(self):
+ def foo(a:int=1, *, b, c=None, **kwargs) -> 42:
+ pass
+ self.assertEqual(str(inspect.signature(foo)),
+ '(a:int=1, *, b, c=None, **kwargs) -> 42')
+
+ def foo(a:int=1, *args, b, c=None, **kwargs) -> 42:
+ pass
+ self.assertEqual(str(inspect.signature(foo)),
+ '(a:int=1, *args, b, c=None, **kwargs) -> 42')
+
+ def foo():
+ pass
+ self.assertEqual(str(inspect.signature(foo)), '()')
+
+ def test_signature_str_positional_only(self):
+ P = inspect.Parameter
+
+ def test(a_po, *, b, **kwargs):
+ return a_po, kwargs
+
+ sig = inspect.signature(test)
+ new_params = list(sig.parameters.values())
+ new_params[0] = new_params[0].replace(kind=P.POSITIONAL_ONLY)
+ test.__signature__ = sig.replace(parameters=new_params)
+
+ self.assertEqual(str(inspect.signature(test)),
+ '(<a_po>, *, b, **kwargs)')
+
+ sig = inspect.signature(test)
+ new_params = list(sig.parameters.values())
+ new_params[0] = new_params[0].replace(name=None)
+ test.__signature__ = sig.replace(parameters=new_params)
+ self.assertEqual(str(inspect.signature(test)),
+ '(<0>, *, b, **kwargs)')
+
+ def test_signature_replace_anno(self):
+ def test() -> 42:
+ pass
+
+ sig = inspect.signature(test)
+ sig = sig.replace(return_annotation=None)
+ self.assertIs(sig.return_annotation, None)
+ sig = sig.replace(return_annotation=sig.empty)
+ self.assertIs(sig.return_annotation, sig.empty)
+ sig = sig.replace(return_annotation=42)
+ self.assertEqual(sig.return_annotation, 42)
+ self.assertEqual(sig, inspect.signature(test))
+
+
+class TestParameterObject(unittest.TestCase):
+ def test_signature_parameter_kinds(self):
+ P = inspect.Parameter
+ self.assertTrue(P.POSITIONAL_ONLY < P.POSITIONAL_OR_KEYWORD < \
+ P.VAR_POSITIONAL < P.KEYWORD_ONLY < P.VAR_KEYWORD)
+
+ self.assertEqual(str(P.POSITIONAL_ONLY), 'POSITIONAL_ONLY')
+ self.assertTrue('POSITIONAL_ONLY' in repr(P.POSITIONAL_ONLY))
+
+ def test_signature_parameter_object(self):
+ p = inspect.Parameter('foo', default=10,
+ kind=inspect.Parameter.POSITIONAL_ONLY)
+ self.assertEqual(p.name, 'foo')
+ self.assertEqual(p.default, 10)
+ self.assertIs(p.annotation, p.empty)
+ self.assertEqual(p.kind, inspect.Parameter.POSITIONAL_ONLY)
+
+ with self.assertRaisesRegex(ValueError, 'invalid value'):
+ inspect.Parameter('foo', default=10, kind='123')
+
+ with self.assertRaisesRegex(ValueError, 'not a valid parameter name'):
+ inspect.Parameter('1', kind=inspect.Parameter.VAR_KEYWORD)
+
+ with self.assertRaisesRegex(ValueError,
+ 'non-positional-only parameter'):
+ inspect.Parameter(None, kind=inspect.Parameter.VAR_KEYWORD)
+
+ with self.assertRaisesRegex(ValueError, 'cannot have default values'):
+ inspect.Parameter('a', default=42,
+ kind=inspect.Parameter.VAR_KEYWORD)
+
+ with self.assertRaisesRegex(ValueError, 'cannot have default values'):
+ inspect.Parameter('a', default=42,
+ kind=inspect.Parameter.VAR_POSITIONAL)
+
+ p = inspect.Parameter('a', default=42,
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ with self.assertRaisesRegex(ValueError, 'cannot have default values'):
+ p.replace(kind=inspect.Parameter.VAR_POSITIONAL)
+
+ self.assertTrue(repr(p).startswith('<Parameter'))
+
+ def test_signature_parameter_equality(self):
+ P = inspect.Parameter
+ p = P('foo', default=42, kind=inspect.Parameter.KEYWORD_ONLY)
+
+ self.assertEqual(p, p)
+ self.assertNotEqual(p, 42)
+
+ self.assertEqual(p, P('foo', default=42,
+ kind=inspect.Parameter.KEYWORD_ONLY))
+
+ def test_signature_parameter_unhashable(self):
+ p = inspect.Parameter('foo', default=42,
+ kind=inspect.Parameter.KEYWORD_ONLY)
+
+ with self.assertRaisesRegex(TypeError, 'unhashable type'):
+ hash(p)
+
+ def test_signature_parameter_replace(self):
+ p = inspect.Parameter('foo', default=42,
+ kind=inspect.Parameter.KEYWORD_ONLY)
+
+ self.assertIsNot(p, p.replace())
+ self.assertEqual(p, p.replace())
+
+ p2 = p.replace(annotation=1)
+ self.assertEqual(p2.annotation, 1)
+ p2 = p2.replace(annotation=p2.empty)
+ self.assertEqual(p, p2)
+
+ p2 = p2.replace(name='bar')
+ self.assertEqual(p2.name, 'bar')
+ self.assertNotEqual(p2, p)
+
+ with self.assertRaisesRegex(ValueError, 'not a valid parameter name'):
+ p2 = p2.replace(name=p2.empty)
+
+ p2 = p2.replace(name='foo', default=None)
+ self.assertIs(p2.default, None)
+ self.assertNotEqual(p2, p)
+
+ p2 = p2.replace(name='foo', default=p2.empty)
+ self.assertIs(p2.default, p2.empty)
+
+
+ p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD)
+ self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD)
+ self.assertNotEqual(p2, p)
+
+ with self.assertRaisesRegex(ValueError, 'invalid value for'):
+ p2 = p2.replace(kind=p2.empty)
+
+ p2 = p2.replace(kind=p2.KEYWORD_ONLY)
+ self.assertEqual(p2, p)
+
+ def test_signature_parameter_positional_only(self):
+ p = inspect.Parameter(None, kind=inspect.Parameter.POSITIONAL_ONLY)
+ self.assertEqual(str(p), '<>')
+
+ p = p.replace(name='1')
+ self.assertEqual(str(p), '<1>')
+
+ def test_signature_parameter_immutability(self):
+ p = inspect.Parameter(None, kind=inspect.Parameter.POSITIONAL_ONLY)
+
+ with self.assertRaises(AttributeError):
+ p.foo = 'bar'
+
+ with self.assertRaises(AttributeError):
+ p.kind = 123
+
+
+class TestSignatureBind(unittest.TestCase):
+ @staticmethod
+ def call(func, *args, **kwargs):
+ sig = inspect.signature(func)
+ ba = sig.bind(*args, **kwargs)
+ return func(*ba.args, **ba.kwargs)
+
+ def test_signature_bind_empty(self):
+ def test():
+ return 42
+
+ self.assertEqual(self.call(test), 42)
+ with self.assertRaisesRegex(TypeError, 'too many positional arguments'):
+ self.call(test, 1)
+ with self.assertRaisesRegex(TypeError, 'too many positional arguments'):
+ self.call(test, 1, spam=10)
+ with self.assertRaisesRegex(TypeError, 'too many keyword arguments'):
+ self.call(test, spam=1)
+
+ def test_signature_bind_var(self):
+ def test(*args, **kwargs):
+ return args, kwargs
+
+ self.assertEqual(self.call(test), ((), {}))
+ self.assertEqual(self.call(test, 1), ((1,), {}))
+ self.assertEqual(self.call(test, 1, 2), ((1, 2), {}))
+ self.assertEqual(self.call(test, foo='bar'), ((), {'foo': 'bar'}))
+ self.assertEqual(self.call(test, 1, foo='bar'), ((1,), {'foo': 'bar'}))
+ self.assertEqual(self.call(test, args=10), ((), {'args': 10}))
+ self.assertEqual(self.call(test, 1, 2, foo='bar'),
+ ((1, 2), {'foo': 'bar'}))
+
+ def test_signature_bind_just_args(self):
+ def test(a, b, c):
+ return a, b, c
+
+ self.assertEqual(self.call(test, 1, 2, 3), (1, 2, 3))
+
+ with self.assertRaisesRegex(TypeError, 'too many positional arguments'):
+ self.call(test, 1, 2, 3, 4)
+
+ with self.assertRaisesRegex(TypeError, "'b' parameter lacking default"):
+ self.call(test, 1)
+
+ with self.assertRaisesRegex(TypeError, "'a' parameter lacking default"):
+ self.call(test)
+
+ def test(a, b, c=10):
+ return a, b, c
+ self.assertEqual(self.call(test, 1, 2, 3), (1, 2, 3))
+ self.assertEqual(self.call(test, 1, 2), (1, 2, 10))
+
+ def test(a=1, b=2, c=3):
+ return a, b, c
+ self.assertEqual(self.call(test, a=10, c=13), (10, 2, 13))
+ self.assertEqual(self.call(test, a=10), (10, 2, 3))
+ self.assertEqual(self.call(test, b=10), (1, 10, 3))
+
+ def test_signature_bind_varargs_order(self):
+ def test(*args):
+ return args
+
+ self.assertEqual(self.call(test), ())
+ self.assertEqual(self.call(test, 1, 2, 3), (1, 2, 3))
+
+ def test_signature_bind_args_and_varargs(self):
+ def test(a, b, c=3, *args):
+ return a, b, c, args
+
+ self.assertEqual(self.call(test, 1, 2, 3, 4, 5), (1, 2, 3, (4, 5)))
+ self.assertEqual(self.call(test, 1, 2), (1, 2, 3, ()))
+ self.assertEqual(self.call(test, b=1, a=2), (2, 1, 3, ()))
+ self.assertEqual(self.call(test, 1, b=2), (1, 2, 3, ()))
+
+ with self.assertRaisesRegex(TypeError,
+ "multiple values for argument 'c'"):
+ self.call(test, 1, 2, 3, c=4)
+
+ def test_signature_bind_just_kwargs(self):
+ def test(**kwargs):
+ return kwargs
+
+ self.assertEqual(self.call(test), {})
+ self.assertEqual(self.call(test, foo='bar', spam='ham'),
+ {'foo': 'bar', 'spam': 'ham'})
+
+ def test_signature_bind_args_and_kwargs(self):
+ def test(a, b, c=3, **kwargs):
+ return a, b, c, kwargs
+
+ self.assertEqual(self.call(test, 1, 2), (1, 2, 3, {}))
+ self.assertEqual(self.call(test, 1, 2, foo='bar', spam='ham'),
+ (1, 2, 3, {'foo': 'bar', 'spam': 'ham'}))
+ self.assertEqual(self.call(test, b=2, a=1, foo='bar', spam='ham'),
+ (1, 2, 3, {'foo': 'bar', 'spam': 'ham'}))
+ self.assertEqual(self.call(test, a=1, b=2, foo='bar', spam='ham'),
+ (1, 2, 3, {'foo': 'bar', 'spam': 'ham'}))
+ self.assertEqual(self.call(test, 1, b=2, foo='bar', spam='ham'),
+ (1, 2, 3, {'foo': 'bar', 'spam': 'ham'}))
+ self.assertEqual(self.call(test, 1, b=2, c=4, foo='bar', spam='ham'),
+ (1, 2, 4, {'foo': 'bar', 'spam': 'ham'}))
+ self.assertEqual(self.call(test, 1, 2, 4, foo='bar'),
+ (1, 2, 4, {'foo': 'bar'}))
+ self.assertEqual(self.call(test, c=5, a=4, b=3),
+ (4, 3, 5, {}))
+
+ def test_signature_bind_kwonly(self):
+ def test(*, foo):
+ return foo
+ with self.assertRaisesRegex(TypeError,
+ 'too many positional arguments'):
+ self.call(test, 1)
+ self.assertEqual(self.call(test, foo=1), 1)
+
+ def test(a, *, foo=1, bar):
+ return foo
+ with self.assertRaisesRegex(TypeError,
+ "'bar' parameter lacking default value"):
+ self.call(test, 1)
+
+ def test(foo, *, bar):
+ return foo, bar
+ self.assertEqual(self.call(test, 1, bar=2), (1, 2))
+ self.assertEqual(self.call(test, bar=2, foo=1), (1, 2))
+
+ with self.assertRaisesRegex(TypeError,
+ 'too many keyword arguments'):
+ self.call(test, bar=2, foo=1, spam=10)
+
+ with self.assertRaisesRegex(TypeError,
+ 'too many positional arguments'):
+ self.call(test, 1, 2)
+
+ with self.assertRaisesRegex(TypeError,
+ 'too many positional arguments'):
+ self.call(test, 1, 2, bar=2)
+
+ with self.assertRaisesRegex(TypeError,
+ 'too many keyword arguments'):
+ self.call(test, 1, bar=2, spam='ham')
+
+ with self.assertRaisesRegex(TypeError,
+ "'bar' parameter lacking default value"):
+ self.call(test, 1)
+
+ def test(foo, *, bar, **bin):
+ return foo, bar, bin
+ self.assertEqual(self.call(test, 1, bar=2), (1, 2, {}))
+ self.assertEqual(self.call(test, foo=1, bar=2), (1, 2, {}))
+ self.assertEqual(self.call(test, 1, bar=2, spam='ham'),
+ (1, 2, {'spam': 'ham'}))
+ self.assertEqual(self.call(test, spam='ham', foo=1, bar=2),
+ (1, 2, {'spam': 'ham'}))
+ with self.assertRaisesRegex(TypeError,
+ "'foo' parameter lacking default value"):
+ self.call(test, spam='ham', bar=2)
+ self.assertEqual(self.call(test, 1, bar=2, bin=1, spam=10),
+ (1, 2, {'bin': 1, 'spam': 10}))
+
+ def test_signature_bind_arguments(self):
+ def test(a, *args, b, z=100, **kwargs):
+ pass
+ sig = inspect.signature(test)
+ ba = sig.bind(10, 20, b=30, c=40, args=50, kwargs=60)
+ # we won't have 'z' argument in the bound arguments object, as we didn't
+ # pass it to the 'bind'
+ self.assertEqual(tuple(ba.arguments.items()),
+ (('a', 10), ('args', (20,)), ('b', 30),
+ ('kwargs', {'c': 40, 'args': 50, 'kwargs': 60})))
+ self.assertEqual(ba.kwargs,
+ {'b': 30, 'c': 40, 'args': 50, 'kwargs': 60})
+ self.assertEqual(ba.args, (10, 20))
+
+ def test_signature_bind_positional_only(self):
+ P = inspect.Parameter
+
+ def test(a_po, b_po, c_po=3, foo=42, *, bar=50, **kwargs):
+ return a_po, b_po, c_po, foo, bar, kwargs
+
+ sig = inspect.signature(test)
+ new_params = collections.OrderedDict(tuple(sig.parameters.items()))
+ for name in ('a_po', 'b_po', 'c_po'):
+ new_params[name] = new_params[name].replace(kind=P.POSITIONAL_ONLY)
+ new_sig = sig.replace(parameters=new_params.values())
+ test.__signature__ = new_sig
+
+ self.assertEqual(self.call(test, 1, 2, 4, 5, bar=6),
+ (1, 2, 4, 5, 6, {}))
+
+ with self.assertRaisesRegex(TypeError, "parameter is positional only"):
+ self.call(test, 1, 2, c_po=4)
+
+ with self.assertRaisesRegex(TypeError, "parameter is positional only"):
+ self.call(test, a_po=1, b_po=2)
+
+ def test_signature_bind_with_self_arg(self):
+ # Issue #17071: one of the parameters is named "self
+ def test(a, self, b):
+ pass
+ sig = inspect.signature(test)
+ ba = sig.bind(1, 2, 3)
+ self.assertEqual(ba.args, (1, 2, 3))
+ ba = sig.bind(1, self=2, b=3)
+ self.assertEqual(ba.args, (1, 2, 3))
+
+
+class TestBoundArguments(unittest.TestCase):
+ def test_signature_bound_arguments_unhashable(self):
+ def foo(a): pass
+ ba = inspect.signature(foo).bind(1)
+
+ with self.assertRaisesRegex(TypeError, 'unhashable type'):
+ hash(ba)
+
+ def test_signature_bound_arguments_equality(self):
+ def foo(a): pass
+ ba = inspect.signature(foo).bind(1)
+ self.assertEqual(ba, ba)
+
+ ba2 = inspect.signature(foo).bind(1)
+ self.assertEqual(ba, ba2)
+
+ ba3 = inspect.signature(foo).bind(2)
+ self.assertNotEqual(ba, ba3)
+ ba3.arguments['a'] = 1
+ self.assertEqual(ba, ba3)
+
+ def bar(b): pass
+ ba4 = inspect.signature(bar).bind(1)
+ self.assertNotEqual(ba, ba4)
+
def test_main():
run_unittest(
@@ -1176,7 +2284,8 @@ def test_main():
TestInterpreterStack, TestClassesAndFunctions, TestPredicates,
TestGetcallargsFunctions, TestGetcallargsMethods,
TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState,
- TestNoEOL
+ TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject,
+ TestBoundArguments, TestGetClosureVars
)
if __name__ == "__main__":
diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py
index 671b20a..c35a42f 100644
--- a/Lib/test/test_int.py
+++ b/Lib/test/test_int.py
@@ -2,7 +2,6 @@ import sys
import unittest
from test import support
-from test.support import run_unittest
L = [
('0', 0),
@@ -226,6 +225,9 @@ class IntTestCases(unittest.TestCase):
self.assertIs(int(b'10'), 10)
self.assertIs(int(b'-1'), -1)
+ def test_no_args(self):
+ self.assertEqual(int(), 0)
+
def test_keyword_args(self):
# Test invoking int() using keyword arguments.
self.assertEqual(int(x=1.2), 1)
@@ -234,6 +236,27 @@ class IntTestCases(unittest.TestCase):
self.assertRaises(TypeError, int, base=10)
self.assertRaises(TypeError, int, base=0)
+ def test_non_numeric_input_types(self):
+ # Test possible non-numeric types for the argument x, including
+ # subclasses of the explicitly documented accepted types.
+ class CustomStr(str): pass
+ class CustomBytes(bytes): pass
+ class CustomByteArray(bytearray): pass
+
+ values = [b'100',
+ bytearray(b'100'),
+ CustomStr('100'),
+ CustomBytes(b'100'),
+ CustomByteArray(b'100')]
+
+ for x in values:
+ msg = 'x has type %s' % type(x).__name__
+ self.assertEqual(int(x), 100, msg=msg)
+ self.assertEqual(int(x, 2), 4, msg=msg)
+
+ def test_string_float(self):
+ self.assertRaises(ValueError, int, '1.2')
+
def test_intconversion(self):
# Test __int__()
class ClassicMissingMethods:
@@ -318,6 +341,18 @@ class IntTestCases(unittest.TestCase):
self.fail("Failed to raise TypeError with %s" %
((base, trunc_result_base),))
+ # Regression test for bugs.python.org/issue16060.
+ class BadInt(trunc_result_base):
+ def __int__(self):
+ return 42.0
+
+ class TruncReturnsBadInt(base):
+ def __trunc__(self):
+ return BadInt()
+
+ with self.assertRaises(TypeError):
+ int(TruncReturnsBadInt())
+
def test_error_message(self):
testlist = ('\xbd', '123\xbd', ' 123 456 ')
for s in testlist:
@@ -329,7 +364,7 @@ class IntTestCases(unittest.TestCase):
self.fail("Expected int(%r) to raise a ValueError", s)
def test_main():
- run_unittest(IntTestCases)
+ support.run_unittest(IntTestCases)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index a7e596a..57b22c6 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -19,21 +19,22 @@
# test both implementations. This file has lots of examples.
################################################################################
+import abc
+import array
+import errno
+import locale
import os
+import pickle
+import random
+import signal
import sys
import time
-import array
-import random
import unittest
-import weakref
-import abc
-import signal
-import errno
import warnings
-import pickle
+import weakref
import _testcapi
-from itertools import cycle, count
from collections import deque, UserList
+from itertools import cycle, count
from test import support
import codecs
@@ -50,7 +51,7 @@ except ImportError:
def _default_chunk_size():
"""Get the default TextIOWrapper chunk size"""
- with open(__file__, "r", encoding="latin1") as f:
+ with open(__file__, "r", encoding="latin-1") as f:
return f._CHUNK_SIZE
@@ -603,6 +604,7 @@ class IOTest(unittest.TestCase):
raise IOError()
f.flush = bad_flush
self.assertRaises(IOError, f.close) # exception not swallowed
+ self.assertTrue(f.closed)
def test_multi_close(self):
f = self.open(support.TESTFN, "wb", buffering=0)
@@ -635,6 +637,15 @@ class IOTest(unittest.TestCase):
for obj in test:
self.assertTrue(hasattr(obj, "__dict__"))
+ def test_opener(self):
+ with self.open(support.TESTFN, "w") as f:
+ f.write("egg\n")
+ fd = os.open(support.TESTFN, os.O_RDONLY)
+ def opener(path, flags):
+ return fd
+ with self.open("non-existent", "r", opener=opener) as f:
+ self.assertEqual(f.read(), "egg\n")
+
def test_fileio_closefd(self):
# Issue #4841
with self.open(__file__, 'rb') as f1, \
@@ -697,7 +708,7 @@ class CommonBufferedTests:
bufio = self.tp(rawio)
# Invalid whence
self.assertRaises(ValueError, bufio.seek, 0, -1)
- self.assertRaises(ValueError, bufio.seek, 0, 3)
+ self.assertRaises(ValueError, bufio.seek, 0, 9)
def test_override_destructor(self):
tp = self.tp
@@ -771,6 +782,22 @@ class CommonBufferedTests:
raw.flush = bad_flush
b = self.tp(raw)
self.assertRaises(IOError, b.close) # exception not swallowed
+ self.assertTrue(b.closed)
+
+ def test_close_error_on_close(self):
+ raw = self.MockRawIO()
+ def bad_flush():
+ raise IOError('flush')
+ def bad_close():
+ raise IOError('close')
+ raw.close = bad_close
+ b = self.tp(raw)
+ b.flush = bad_flush
+ with self.assertRaises(IOError) as err: # exception not swallowed
+ b.close()
+ self.assertEqual(err.exception.args, ('close',))
+ self.assertEqual(err.exception.__context__.args, ('flush',))
+ self.assertFalse(b.closed)
def test_multi_close(self):
raw = self.MockRawIO()
@@ -863,6 +890,12 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests):
self.assertEqual(b, b"gf")
self.assertEqual(bufio.readinto(b), 0)
self.assertEqual(b, b"gf")
+ rawio = self.MockRawIO((b"abc", None))
+ bufio = self.tp(rawio)
+ self.assertEqual(bufio.readinto(b), 2)
+ self.assertEqual(b, b"ab")
+ self.assertEqual(bufio.readinto(b), 1)
+ self.assertEqual(b, b"cb")
def test_readlines(self):
def bufio():
@@ -1277,11 +1310,20 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
self.assertRaises(IOError, bufio.tell)
self.assertRaises(IOError, bufio.write, b"abcdef")
- def test_max_buffer_size_deprecation(self):
- with support.check_warnings(("max_buffer_size is deprecated",
- DeprecationWarning)):
+ def test_max_buffer_size_removal(self):
+ with self.assertRaises(TypeError):
self.tp(self.MockRawIO(), 8, 12)
+ def test_write_error_on_close(self):
+ raw = self.MockRawIO()
+ def bad_write(b):
+ raise IOError()
+ raw.write = bad_write
+ b = self.tp(raw)
+ b.write(b'spam')
+ self.assertRaises(IOError, b.close) # exception not swallowed
+ self.assertTrue(b.closed)
+
class CBufferedWriterTest(BufferedWriterTest, SizeofTest):
tp = io.BufferedWriter
@@ -1335,9 +1377,8 @@ class BufferedRWPairTest(unittest.TestCase):
pair = self.tp(self.MockRawIO(), self.MockRawIO())
self.assertRaises(self.UnsupportedOperation, pair.detach)
- def test_constructor_max_buffer_size_deprecation(self):
- with support.check_warnings(("max_buffer_size is deprecated",
- DeprecationWarning)):
+ def test_constructor_max_buffer_size_removal(self):
+ with self.assertRaises(TypeError):
self.tp(self.MockRawIO(), self.MockRawIO(), 8, 12)
def test_constructor_with_not_readable(self):
@@ -1853,11 +1894,11 @@ class TextIOWrapperTest(unittest.TestCase):
r = self.BytesIO(b"\xc3\xa9\n\n")
b = self.BufferedReader(r, 1000)
t = self.TextIOWrapper(b)
- t.__init__(b, encoding="latin1", newline="\r\n")
- self.assertEqual(t.encoding, "latin1")
+ t.__init__(b, encoding="latin-1", newline="\r\n")
+ self.assertEqual(t.encoding, "latin-1")
self.assertEqual(t.line_buffering, False)
- t.__init__(b, encoding="utf8", line_buffering=True)
- self.assertEqual(t.encoding, "utf8")
+ t.__init__(b, encoding="utf-8", line_buffering=True)
+ self.assertEqual(t.encoding, "utf-8")
self.assertEqual(t.line_buffering, True)
self.assertEqual("\xe9\n", t.readline())
self.assertRaises(TypeError, t.__init__, b, newline=42)
@@ -1904,6 +1945,24 @@ class TextIOWrapperTest(unittest.TestCase):
t.write("A\rB")
self.assertEqual(r.getvalue(), b"XY\nZA\rB")
+ def test_default_encoding(self):
+ old_environ = dict(os.environ)
+ try:
+ # try to get a user preferred encoding different than the current
+ # locale encoding to check that TextIOWrapper() uses the current
+ # locale encoding and not the user preferred encoding
+ for key in ('LC_ALL', 'LANG', 'LC_CTYPE'):
+ if key in os.environ:
+ del os.environ[key]
+
+ current_locale_encoding = locale.getpreferredencoding(False)
+ b = self.BytesIO()
+ t = self.TextIOWrapper(b)
+ self.assertEqual(t.encoding, current_locale_encoding)
+ finally:
+ os.environ.clear()
+ os.environ.update(old_environ)
+
# Issue 15989
def test_device_encoding(self):
b = self.BytesIO()
@@ -1915,8 +1974,8 @@ class TextIOWrapperTest(unittest.TestCase):
def test_encoding(self):
# Check the encoding attribute is always set, and valid
b = self.BytesIO()
- t = self.TextIOWrapper(b, encoding="utf8")
- self.assertEqual(t.encoding, "utf8")
+ t = self.TextIOWrapper(b, encoding="utf-8")
+ self.assertEqual(t.encoding, "utf-8")
t = self.TextIOWrapper(b)
self.assertTrue(t.encoding is not None)
codecs.lookup(t.encoding)
@@ -2009,8 +2068,8 @@ class TextIOWrapperTest(unittest.TestCase):
testdata = b"AAA\nBB\x00B\nCCC\rDDD\rEEE\r\nFFF\r\nGGG"
normalized = testdata.replace(b"\r\n", b"\n").replace(b"\r", b"\n")
for newline, expected in [
- (None, normalized.decode("ascii").splitlines(True)),
- ("", testdata.decode("ascii").splitlines(True)),
+ (None, normalized.decode("ascii").splitlines(keepends=True)),
+ ("", testdata.decode("ascii").splitlines(keepends=True)),
("\n", ["AAA\n", "BB\x00B\n", "CCC\rDDD\rEEE\r\n", "FFF\r\n", "GGG"]),
("\r\n", ["AAA\nBB\x00B\nCCC\rDDD\rEEE\r\n", "FFF\r\n", "GGG"]),
("\r", ["AAA\nBB\x00B\nCCC\r", "DDD\r", "EEE\r", "\nFFF\r", "\nGGG"]),
@@ -2095,7 +2154,7 @@ class TextIOWrapperTest(unittest.TestCase):
def test_basic_io(self):
for chunksize in (1, 2, 3, 4, 5, 15, 16, 17, 31, 32, 33, 63, 64, 65):
- for enc in "ascii", "latin1", "utf8" :# , "utf-16-be", "utf-16-le":
+ for enc in "ascii", "latin-1", "utf-8" :# , "utf-16-be", "utf-16-le":
f = self.open(support.TESTFN, "w+", encoding=enc)
f._CHUNK_SIZE = chunksize
self.assertEqual(f.write("abc"), 3)
@@ -2145,7 +2204,7 @@ class TextIOWrapperTest(unittest.TestCase):
self.assertEqual(rlines, wlines)
def test_telling(self):
- f = self.open(support.TESTFN, "w+", encoding="utf8")
+ f = self.open(support.TESTFN, "w+", encoding="utf-8")
p0 = f.tell()
f.write("\xff\n")
p1 = f.tell()
@@ -2413,6 +2472,7 @@ class TextIOWrapperTest(unittest.TestCase):
with self.open(support.TESTFN, "w", errors="replace") as f:
self.assertEqual(f.errors, "replace")
+ @support.no_tracing
@unittest.skipUnless(threading, 'Threading required for this test.')
def test_threads_write(self):
# Issue6750: concurrent writes could duplicate data
@@ -2441,6 +2501,7 @@ class TextIOWrapperTest(unittest.TestCase):
raise IOError()
txt.flush = bad_flush
self.assertRaises(IOError, txt.close) # exception not swallowed
+ self.assertTrue(txt.closed)
def test_multi_close(self):
txt = self.TextIOWrapper(self.BytesIO(self.testdata), encoding="ascii")
@@ -2754,12 +2815,6 @@ class MiscIOTest(unittest.TestCase):
def test_blockingioerror(self):
# Various BlockingIOError issues
- self.assertRaises(TypeError, self.BlockingIOError)
- self.assertRaises(TypeError, self.BlockingIOError, 1)
- self.assertRaises(TypeError, self.BlockingIOError, 1, 2, 3, 4)
- self.assertRaises(TypeError, self.BlockingIOError, 1, "", None)
- b = self.BlockingIOError(1, "")
- self.assertEqual(b.characters_written, 0)
class C(str):
pass
c = C("")
@@ -2901,6 +2956,7 @@ class MiscIOTest(unittest.TestCase):
except self.BlockingIOError as e:
self.assertEqual(e.args[0], errno.EAGAIN)
+ self.assertEqual(e.args[2], e.characters_written)
sent[-1] = sent[-1][:e.characters_written]
received.append(rf.read())
msg = b'BLOCKED'
@@ -2913,6 +2969,7 @@ class MiscIOTest(unittest.TestCase):
break
except self.BlockingIOError as e:
self.assertEqual(e.args[0], errno.EAGAIN)
+ self.assertEqual(e.args[2], e.characters_written)
self.assertEqual(e.characters_written, 0)
received.append(rf.read())
@@ -2923,6 +2980,24 @@ class MiscIOTest(unittest.TestCase):
self.assertTrue(wf.closed)
self.assertTrue(rf.closed)
+ def test_create_fail(self):
+ # 'x' mode fails if file is existing
+ with self.open(support.TESTFN, 'w'):
+ pass
+ self.assertRaises(FileExistsError, self.open, support.TESTFN, 'x')
+
+ def test_create_writes(self):
+ # 'x' mode opens for writing
+ with self.open(support.TESTFN, 'xb') as f:
+ f.write(b"spam")
+ with self.open(support.TESTFN, 'rb') as f:
+ self.assertEqual(b"spam", f.read())
+
+ def test_open_allargs(self):
+ # there used to be a buffer overflow in the parser for rawmode
+ self.assertRaises(ValueError, self.open, support.TESTFN, 'rwax+')
+
+
class CMiscIOTest(MiscIOTest):
io = io
@@ -2943,14 +3018,14 @@ class SignalsTest(unittest.TestCase):
1/0
@unittest.skipUnless(threading, 'Threading required for this test.')
- @unittest.skipIf(sys.platform in ('freebsd5', 'freebsd6', 'freebsd7'),
- 'issue #12429: skip test on FreeBSD <= 7')
def check_interrupted_write(self, item, bytes, **fdopen_kwargs):
"""Check that a partial write, when it gets interrupted, properly
invokes the signal handler, and bubbles up the exception raised
in the latter."""
read_results = []
def _read():
+ if hasattr(signal, 'pthread_sigmask'):
+ signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGALRM])
s = os.read(r, 1)
read_results.append(s)
t = threading.Thread(target=_read)
@@ -2968,7 +3043,7 @@ class SignalsTest(unittest.TestCase):
# The buffered IO layer must check for pending signal
# handlers, which in this case will invoke alarm_interrupt().
self.assertRaises(ZeroDivisionError,
- wio.write, item * (1024 * 1024))
+ wio.write, item * (support.PIPE_MAX_SIZE // len(item)))
t.join()
# We got one byte, get another one and check that it isn't a
# repeat of the first one.
@@ -2995,6 +3070,7 @@ class SignalsTest(unittest.TestCase):
def test_interrupted_write_text(self):
self.check_interrupted_write("xy", b"xy", mode="w", encoding="ascii")
+ @support.no_tracing
def check_reentrant_write(self, data, **fdopen_kwargs):
def on_alarm(*args):
# Will be called reentrantly from the same thread
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
new file mode 100644
index 0000000..99c54f1
--- /dev/null
+++ b/Lib/test/test_ipaddress.py
@@ -0,0 +1,1649 @@
+# Copyright 2007 Google Inc.
+# Licensed to PSF under a Contributor Agreement.
+
+"""Unittest for ipaddress module."""
+
+
+import unittest
+import re
+import contextlib
+import operator
+import ipaddress
+
+class BaseTestCase(unittest.TestCase):
+ # One big change in ipaddress over the original ipaddr module is
+ # error reporting that tries to assume users *don't know the rules*
+ # for what constitutes an RFC compliant IP address
+
+ # Ensuring these errors are emitted correctly in all relevant cases
+ # meant moving to a more systematic test structure that allows the
+ # test structure to map more directly to the module structure
+
+ # Note that if the constructors are refactored so that addresses with
+ # multiple problems get classified differently, that's OK - just
+ # move the affected examples to the newly appropriate test case.
+
+ # There is some duplication between the original relatively ad hoc
+ # test suite and the new systematic tests. While some redundancy in
+ # testing is considered preferable to accidentally deleting a valid
+ # test, the original test suite will likely be reduced over time as
+ # redundant tests are identified.
+
+ @property
+ def factory(self):
+ raise NotImplementedError
+
+ @contextlib.contextmanager
+ def assertCleanError(self, exc_type, details, *args):
+ """
+ Ensure exception does not display a context by default
+
+ Wraps unittest.TestCase.assertRaisesRegex
+ """
+ if args:
+ details = details % args
+ cm = self.assertRaisesRegex(exc_type, details)
+ with cm as exc:
+ yield exc
+ # Ensure we produce clean tracebacks on failure
+ if exc.exception.__context__ is not None:
+ self.assertTrue(exc.exception.__suppress_context__)
+
+ def assertAddressError(self, details, *args):
+ """Ensure a clean AddressValueError"""
+ return self.assertCleanError(ipaddress.AddressValueError,
+ details, *args)
+
+ def assertNetmaskError(self, details, *args):
+ """Ensure a clean NetmaskValueError"""
+ return self.assertCleanError(ipaddress.NetmaskValueError,
+ details, *args)
+
+ def assertInstancesEqual(self, lhs, rhs):
+ """Check constructor arguments produce equivalent instances"""
+ self.assertEqual(self.factory(lhs), self.factory(rhs))
+
+class CommonTestMixin:
+
+ def test_empty_address(self):
+ with self.assertAddressError("Address cannot be empty"):
+ self.factory("")
+
+ def test_floats_rejected(self):
+ with self.assertAddressError(re.escape(repr("1.0"))):
+ self.factory(1.0)
+
+ def test_not_an_index_issue15559(self):
+ # Implementing __index__ makes for a very nasty interaction with the
+ # bytes constructor. Thus, we disallow implicit use as an integer
+ self.assertRaises(TypeError, operator.index, self.factory(1))
+ self.assertRaises(TypeError, hex, self.factory(1))
+ self.assertRaises(TypeError, bytes, self.factory(1))
+
+
+class CommonTestMixin_v4(CommonTestMixin):
+
+ def test_leading_zeros(self):
+ self.assertInstancesEqual("000.000.000.000", "0.0.0.0")
+ self.assertInstancesEqual("192.168.000.001", "192.168.0.1")
+
+ def test_int(self):
+ self.assertInstancesEqual(0, "0.0.0.0")
+ self.assertInstancesEqual(3232235521, "192.168.0.1")
+
+ def test_packed(self):
+ self.assertInstancesEqual(bytes.fromhex("00000000"), "0.0.0.0")
+ self.assertInstancesEqual(bytes.fromhex("c0a80001"), "192.168.0.1")
+
+ def test_negative_ints_rejected(self):
+ msg = "-1 (< 0) is not permitted as an IPv4 address"
+ with self.assertAddressError(re.escape(msg)):
+ self.factory(-1)
+
+ def test_large_ints_rejected(self):
+ msg = "%d (>= 2**32) is not permitted as an IPv4 address"
+ with self.assertAddressError(re.escape(msg % 2**32)):
+ self.factory(2**32)
+
+ def test_bad_packed_length(self):
+ def assertBadLength(length):
+ addr = bytes(length)
+ msg = "%r (len %d != 4) is not permitted as an IPv4 address"
+ with self.assertAddressError(re.escape(msg % (addr, length))):
+ self.factory(addr)
+
+ assertBadLength(3)
+ assertBadLength(5)
+
+class CommonTestMixin_v6(CommonTestMixin):
+
+ def test_leading_zeros(self):
+ self.assertInstancesEqual("0000::0000", "::")
+ self.assertInstancesEqual("000::c0a8:0001", "::c0a8:1")
+
+ def test_int(self):
+ self.assertInstancesEqual(0, "::")
+ self.assertInstancesEqual(3232235521, "::c0a8:1")
+
+ def test_packed(self):
+ addr = bytes(12) + bytes.fromhex("00000000")
+ self.assertInstancesEqual(addr, "::")
+ addr = bytes(12) + bytes.fromhex("c0a80001")
+ self.assertInstancesEqual(addr, "::c0a8:1")
+ addr = bytes.fromhex("c0a80001") + bytes(12)
+ self.assertInstancesEqual(addr, "c0a8:1::")
+
+ def test_negative_ints_rejected(self):
+ msg = "-1 (< 0) is not permitted as an IPv6 address"
+ with self.assertAddressError(re.escape(msg)):
+ self.factory(-1)
+
+ def test_large_ints_rejected(self):
+ msg = "%d (>= 2**128) is not permitted as an IPv6 address"
+ with self.assertAddressError(re.escape(msg % 2**128)):
+ self.factory(2**128)
+
+ def test_bad_packed_length(self):
+ def assertBadLength(length):
+ addr = bytes(length)
+ msg = "%r (len %d != 16) is not permitted as an IPv6 address"
+ with self.assertAddressError(re.escape(msg % (addr, length))):
+ self.factory(addr)
+ self.factory(addr)
+
+ assertBadLength(15)
+ assertBadLength(17)
+
+
+class AddressTestCase_v4(BaseTestCase, CommonTestMixin_v4):
+ factory = ipaddress.IPv4Address
+
+ def test_network_passed_as_address(self):
+ addr = "127.0.0.1/24"
+ with self.assertAddressError("Unexpected '/' in %r", addr):
+ ipaddress.IPv4Address(addr)
+
+ def test_bad_address_split(self):
+ def assertBadSplit(addr):
+ with self.assertAddressError("Expected 4 octets in %r", addr):
+ ipaddress.IPv4Address(addr)
+
+ assertBadSplit("127.0.1")
+ assertBadSplit("42.42.42.42.42")
+ assertBadSplit("42.42.42")
+ assertBadSplit("42.42")
+ assertBadSplit("42")
+ assertBadSplit("42..42.42.42")
+ assertBadSplit("42.42.42.42.")
+ assertBadSplit("42.42.42.42...")
+ assertBadSplit(".42.42.42.42")
+ assertBadSplit("...42.42.42.42")
+ assertBadSplit("016.016.016")
+ assertBadSplit("016.016")
+ assertBadSplit("016")
+ assertBadSplit("000")
+ assertBadSplit("0x0a.0x0a.0x0a")
+ assertBadSplit("0x0a.0x0a")
+ assertBadSplit("0x0a")
+ assertBadSplit(".")
+ assertBadSplit("bogus")
+ assertBadSplit("bogus.com")
+ assertBadSplit("1000")
+ assertBadSplit("1000000000000000")
+ assertBadSplit("192.168.0.1.com")
+
+ def test_empty_octet(self):
+ def assertBadOctet(addr):
+ with self.assertAddressError("Empty octet not permitted in %r",
+ addr):
+ ipaddress.IPv4Address(addr)
+
+ assertBadOctet("42..42.42")
+ assertBadOctet("...")
+
+ def test_invalid_characters(self):
+ def assertBadOctet(addr, octet):
+ msg = "Only decimal digits permitted in %r in %r" % (octet, addr)
+ with self.assertAddressError(re.escape(msg)):
+ ipaddress.IPv4Address(addr)
+
+ assertBadOctet("0x0a.0x0a.0x0a.0x0a", "0x0a")
+ assertBadOctet("0xa.0x0a.0x0a.0x0a", "0xa")
+ assertBadOctet("42.42.42.-0", "-0")
+ assertBadOctet("42.42.42.+0", "+0")
+ assertBadOctet("42.42.42.-42", "-42")
+ assertBadOctet("+1.+2.+3.4", "+1")
+ assertBadOctet("1.2.3.4e0", "4e0")
+ assertBadOctet("1.2.3.4::", "4::")
+ assertBadOctet("1.a.2.3", "a")
+
+ def test_octal_decimal_ambiguity(self):
+ def assertBadOctet(addr, octet):
+ msg = "Ambiguous (octal/decimal) value in %r not permitted in %r"
+ with self.assertAddressError(re.escape(msg % (octet, addr))):
+ ipaddress.IPv4Address(addr)
+
+ assertBadOctet("016.016.016.016", "016")
+ assertBadOctet("001.000.008.016", "008")
+
+ def test_octet_length(self):
+ def assertBadOctet(addr, octet):
+ msg = "At most 3 characters permitted in %r in %r"
+ with self.assertAddressError(re.escape(msg % (octet, addr))):
+ ipaddress.IPv4Address(addr)
+
+ assertBadOctet("0000.000.000.000", "0000")
+ assertBadOctet("12345.67899.-54321.-98765", "12345")
+
+ def test_octet_limit(self):
+ def assertBadOctet(addr, octet):
+ msg = "Octet %d (> 255) not permitted in %r" % (octet, addr)
+ with self.assertAddressError(re.escape(msg)):
+ ipaddress.IPv4Address(addr)
+
+ assertBadOctet("257.0.0.0", 257)
+ assertBadOctet("192.168.0.999", 999)
+
+
+class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6):
+ factory = ipaddress.IPv6Address
+
+ def test_network_passed_as_address(self):
+ addr = "::1/24"
+ with self.assertAddressError("Unexpected '/' in %r", addr):
+ ipaddress.IPv6Address(addr)
+
+ def test_bad_address_split_v6_not_enough_parts(self):
+ def assertBadSplit(addr):
+ msg = "At least 3 parts expected in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit(":")
+ assertBadSplit(":1")
+ assertBadSplit("FEDC:9878")
+
+ def test_bad_address_split_v6_too_many_colons(self):
+ def assertBadSplit(addr):
+ msg = "At most 8 colons permitted in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit("9:8:7:6:5:4:3::2:1")
+ assertBadSplit("10:9:8:7:6:5:4:3:2:1")
+ assertBadSplit("::8:7:6:5:4:3:2:1")
+ assertBadSplit("8:7:6:5:4:3:2:1::")
+ # A trailing IPv4 address is two parts
+ assertBadSplit("10:9:8:7:6:5:4:3:42.42.42.42")
+
+ def test_bad_address_split_v6_too_many_parts(self):
+ def assertBadSplit(addr):
+ msg = "Exactly 8 parts expected without '::' in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit("3ffe:0:0:0:0:0:0:0:1")
+ assertBadSplit("9:8:7:6:5:4:3:2:1")
+ assertBadSplit("7:6:5:4:3:2:1")
+ # A trailing IPv4 address is two parts
+ assertBadSplit("9:8:7:6:5:4:3:42.42.42.42")
+ assertBadSplit("7:6:5:4:3:42.42.42.42")
+
+ def test_bad_address_split_v6_too_many_parts_with_double_colon(self):
+ def assertBadSplit(addr):
+ msg = "Expected at most 7 other parts with '::' in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit("1:2:3:4::5:6:7:8")
+
+ def test_bad_address_split_v6_repeated_double_colon(self):
+ def assertBadSplit(addr):
+ msg = "At most one '::' permitted in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit("3ffe::1::1")
+ assertBadSplit("1::2::3::4:5")
+ assertBadSplit("2001::db:::1")
+ assertBadSplit("3ffe::1::")
+ assertBadSplit("::3ffe::1")
+ assertBadSplit(":3ffe::1::1")
+ assertBadSplit("3ffe::1::1:")
+ assertBadSplit(":3ffe::1::1:")
+ assertBadSplit(":::")
+ assertBadSplit('2001:db8:::1')
+
+ def test_bad_address_split_v6_leading_colon(self):
+ def assertBadSplit(addr):
+ msg = "Leading ':' only permitted as part of '::' in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit(":2001:db8::1")
+ assertBadSplit(":1:2:3:4:5:6:7")
+ assertBadSplit(":1:2:3:4:5:6:")
+ assertBadSplit(":6:5:4:3:2:1::")
+
+ def test_bad_address_split_v6_trailing_colon(self):
+ def assertBadSplit(addr):
+ msg = "Trailing ':' only permitted as part of '::' in %r"
+ with self.assertAddressError(msg, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadSplit("2001:db8::1:")
+ assertBadSplit("1:2:3:4:5:6:7:")
+ assertBadSplit("::1.2.3.4:")
+ assertBadSplit("::7:6:5:4:3:2:")
+
+ def test_bad_v4_part_in(self):
+ def assertBadAddressPart(addr, v4_error):
+ with self.assertAddressError("%s in %r", v4_error, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadAddressPart("3ffe::1.net", "Expected 4 octets in '1.net'")
+ assertBadAddressPart("3ffe::127.0.1",
+ "Expected 4 octets in '127.0.1'")
+ assertBadAddressPart("::1.2.3",
+ "Expected 4 octets in '1.2.3'")
+ assertBadAddressPart("::1.2.3.4.5",
+ "Expected 4 octets in '1.2.3.4.5'")
+ assertBadAddressPart("3ffe::1.1.1.net",
+ "Only decimal digits permitted in 'net' "
+ "in '1.1.1.net'")
+
+ def test_invalid_characters(self):
+ def assertBadPart(addr, part):
+ msg = "Only hex digits permitted in %r in %r" % (part, addr)
+ with self.assertAddressError(re.escape(msg)):
+ ipaddress.IPv6Address(addr)
+
+ assertBadPart("3ffe::goog", "goog")
+ assertBadPart("3ffe::-0", "-0")
+ assertBadPart("3ffe::+0", "+0")
+ assertBadPart("3ffe::-1", "-1")
+ assertBadPart("1.2.3.4::", "1.2.3.4")
+ assertBadPart('1234:axy::b', "axy")
+
+ def test_part_length(self):
+ def assertBadPart(addr, part):
+ msg = "At most 4 characters permitted in %r in %r"
+ with self.assertAddressError(msg, part, addr):
+ ipaddress.IPv6Address(addr)
+
+ assertBadPart("::00000", "00000")
+ assertBadPart("3ffe::10000", "10000")
+ assertBadPart("02001:db8::", "02001")
+ assertBadPart('2001:888888::1', "888888")
+
+
+class NetmaskTestMixin_v4(CommonTestMixin_v4):
+ """Input validation on interfaces and networks is very similar"""
+
+ def test_split_netmask(self):
+ addr = "1.2.3.4/32/24"
+ with self.assertAddressError("Only one '/' permitted in %r" % addr):
+ self.factory(addr)
+
+ def test_address_errors(self):
+ def assertBadAddress(addr, details):
+ with self.assertAddressError(details):
+ self.factory(addr)
+
+ assertBadAddress("/", "Address cannot be empty")
+ assertBadAddress("/8", "Address cannot be empty")
+ assertBadAddress("bogus", "Expected 4 octets")
+ assertBadAddress("google.com", "Expected 4 octets")
+ assertBadAddress("10/8", "Expected 4 octets")
+ assertBadAddress("::1.2.3.4", "Only decimal digits")
+ assertBadAddress("1.2.3.256", re.escape("256 (> 255)"))
+
+ def test_netmask_errors(self):
+ def assertBadNetmask(addr, netmask):
+ msg = "%r is not a valid netmask"
+ with self.assertNetmaskError(msg % netmask):
+ self.factory("%s/%s" % (addr, netmask))
+
+ assertBadNetmask("1.2.3.4", "")
+ assertBadNetmask("1.2.3.4", "33")
+ assertBadNetmask("1.2.3.4", "254.254.255.256")
+ assertBadNetmask("1.1.1.1", "254.xyz.2.3")
+ assertBadNetmask("1.1.1.1", "240.255.0.0")
+ assertBadNetmask("1.1.1.1", "pudding")
+
+class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4):
+ factory = ipaddress.IPv4Interface
+
+class NetworkTestCase_v4(BaseTestCase, NetmaskTestMixin_v4):
+ factory = ipaddress.IPv4Network
+
+
+class NetmaskTestMixin_v6(CommonTestMixin_v6):
+ """Input validation on interfaces and networks is very similar"""
+
+ def test_split_netmask(self):
+ addr = "cafe:cafe::/128/190"
+ with self.assertAddressError("Only one '/' permitted in %r" % addr):
+ self.factory(addr)
+
+ def test_address_errors(self):
+ def assertBadAddress(addr, details):
+ with self.assertAddressError(details):
+ self.factory(addr)
+
+ assertBadAddress("/", "Address cannot be empty")
+ assertBadAddress("/8", "Address cannot be empty")
+ assertBadAddress("google.com", "At least 3 parts")
+ assertBadAddress("1.2.3.4", "At least 3 parts")
+ assertBadAddress("10/8", "At least 3 parts")
+ assertBadAddress("1234:axy::b", "Only hex digits")
+
+ def test_netmask_errors(self):
+ def assertBadNetmask(addr, netmask):
+ msg = "%r is not a valid netmask"
+ with self.assertNetmaskError(msg % netmask):
+ self.factory("%s/%s" % (addr, netmask))
+
+ assertBadNetmask("::1", "")
+ assertBadNetmask("::1", "::1")
+ assertBadNetmask("::1", "1::")
+ assertBadNetmask("::1", "129")
+ assertBadNetmask("::1", "pudding")
+
+class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6):
+ factory = ipaddress.IPv6Interface
+
+class NetworkTestCase_v6(BaseTestCase, NetmaskTestMixin_v6):
+ factory = ipaddress.IPv6Network
+
+
+class FactoryFunctionErrors(BaseTestCase):
+
+ def assertFactoryError(self, factory, kind):
+ """Ensure a clean ValueError with the expected message"""
+ addr = "camelot"
+ msg = '%r does not appear to be an IPv4 or IPv6 %s'
+ with self.assertCleanError(ValueError, msg, addr, kind):
+ factory(addr)
+
+ def test_ip_address(self):
+ self.assertFactoryError(ipaddress.ip_address, "address")
+
+ def test_ip_interface(self):
+ self.assertFactoryError(ipaddress.ip_interface, "interface")
+
+ def test_ip_network(self):
+ self.assertFactoryError(ipaddress.ip_network, "network")
+
+
+class ComparisonTests(unittest.TestCase):
+
+ v4addr = ipaddress.IPv4Address(1)
+ v4net = ipaddress.IPv4Network(1)
+ v4intf = ipaddress.IPv4Interface(1)
+ v6addr = ipaddress.IPv6Address(1)
+ v6net = ipaddress.IPv6Network(1)
+ v6intf = ipaddress.IPv6Interface(1)
+
+ v4_addresses = [v4addr, v4intf]
+ v4_objects = v4_addresses + [v4net]
+ v6_addresses = [v6addr, v6intf]
+ v6_objects = v6_addresses + [v6net]
+ objects = v4_objects + v6_objects
+
+ def test_foreign_type_equality(self):
+ # __eq__ should never raise TypeError directly
+ other = object()
+ for obj in self.objects:
+ self.assertNotEqual(obj, other)
+ self.assertFalse(obj == other)
+ self.assertEqual(obj.__eq__(other), NotImplemented)
+ self.assertEqual(obj.__ne__(other), NotImplemented)
+
+ def test_mixed_type_equality(self):
+ # Ensure none of the internal objects accidentally
+ # expose the right set of attributes to become "equal"
+ for lhs in self.objects:
+ for rhs in self.objects:
+ if lhs is rhs:
+ continue
+ self.assertNotEqual(lhs, rhs)
+
+ def test_containment(self):
+ for obj in self.v4_addresses:
+ self.assertIn(obj, self.v4net)
+ for obj in self.v6_addresses:
+ self.assertIn(obj, self.v6net)
+ for obj in self.v4_objects + [self.v6net]:
+ self.assertNotIn(obj, self.v6net)
+ for obj in self.v6_objects + [self.v4net]:
+ self.assertNotIn(obj, self.v4net)
+
+ def test_mixed_type_ordering(self):
+ for lhs in self.objects:
+ for rhs in self.objects:
+ if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)):
+ continue
+ self.assertRaises(TypeError, lambda: lhs < rhs)
+ self.assertRaises(TypeError, lambda: lhs > rhs)
+ self.assertRaises(TypeError, lambda: lhs <= rhs)
+ self.assertRaises(TypeError, lambda: lhs >= rhs)
+
+ def test_mixed_type_key(self):
+ # with get_mixed_type_key, you can sort addresses and network.
+ v4_ordered = [self.v4addr, self.v4net, self.v4intf]
+ v6_ordered = [self.v6addr, self.v6net, self.v6intf]
+ self.assertEqual(v4_ordered,
+ sorted(self.v4_objects,
+ key=ipaddress.get_mixed_type_key))
+ self.assertEqual(v6_ordered,
+ sorted(self.v6_objects,
+ key=ipaddress.get_mixed_type_key))
+ self.assertEqual(v4_ordered + v6_ordered,
+ sorted(self.objects,
+ key=ipaddress.get_mixed_type_key))
+ self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
+
+ def test_incompatible_versions(self):
+ # These should always raise TypeError
+ v4addr = ipaddress.ip_address('1.1.1.1')
+ v4net = ipaddress.ip_network('1.1.1.1')
+ v6addr = ipaddress.ip_address('::1')
+ v6net = ipaddress.ip_address('::1')
+
+ self.assertRaises(TypeError, v4addr.__lt__, v6addr)
+ self.assertRaises(TypeError, v4addr.__gt__, v6addr)
+ self.assertRaises(TypeError, v4net.__lt__, v6net)
+ self.assertRaises(TypeError, v4net.__gt__, v6net)
+
+ self.assertRaises(TypeError, v6addr.__lt__, v4addr)
+ self.assertRaises(TypeError, v6addr.__gt__, v4addr)
+ self.assertRaises(TypeError, v6net.__lt__, v4net)
+ self.assertRaises(TypeError, v6net.__gt__, v4net)
+
+
+
+class IpaddrUnitTest(unittest.TestCase):
+
+ def setUp(self):
+ self.ipv4_address = ipaddress.IPv4Address('1.2.3.4')
+ self.ipv4_interface = ipaddress.IPv4Interface('1.2.3.4/24')
+ self.ipv4_network = ipaddress.IPv4Network('1.2.3.0/24')
+ #self.ipv4_hostmask = ipaddress.IPv4Interface('10.0.0.1/0.255.255.255')
+ self.ipv6_address = ipaddress.IPv6Interface(
+ '2001:658:22a:cafe:200:0:0:1')
+ self.ipv6_interface = ipaddress.IPv6Interface(
+ '2001:658:22a:cafe:200:0:0:1/64')
+ self.ipv6_network = ipaddress.IPv6Network('2001:658:22a:cafe::/64')
+
+ def testRepr(self):
+ self.assertEqual("IPv4Interface('1.2.3.4/32')",
+ repr(ipaddress.IPv4Interface('1.2.3.4')))
+ self.assertEqual("IPv6Interface('::1/128')",
+ repr(ipaddress.IPv6Interface('::1')))
+
+ # issue57
+ def testAddressIntMath(self):
+ self.assertEqual(ipaddress.IPv4Address('1.1.1.1') + 255,
+ ipaddress.IPv4Address('1.1.2.0'))
+ self.assertEqual(ipaddress.IPv4Address('1.1.1.1') - 256,
+ ipaddress.IPv4Address('1.1.0.1'))
+ self.assertEqual(ipaddress.IPv6Address('::1') + (2**16 - 2),
+ ipaddress.IPv6Address('::ffff'))
+ self.assertEqual(ipaddress.IPv6Address('::ffff') - (2**16 - 2),
+ ipaddress.IPv6Address('::1'))
+
+ def testInvalidIntToBytes(self):
+ self.assertRaises(ValueError, ipaddress.v4_int_to_packed, -1)
+ self.assertRaises(ValueError, ipaddress.v4_int_to_packed,
+ 2 ** ipaddress.IPV4LENGTH)
+ self.assertRaises(ValueError, ipaddress.v6_int_to_packed, -1)
+ self.assertRaises(ValueError, ipaddress.v6_int_to_packed,
+ 2 ** ipaddress.IPV6LENGTH)
+
+ def testInternals(self):
+ first, last = ipaddress._find_address_range([
+ ipaddress.IPv4Address('10.10.10.10'),
+ ipaddress.IPv4Address('10.10.10.12')])
+ self.assertEqual(first, last)
+ self.assertEqual(128, ipaddress._count_righthand_zero_bits(0, 128))
+ self.assertEqual("IPv4Network('1.2.3.0/24')", repr(self.ipv4_network))
+
+ def testMissingAddressVersion(self):
+ class Broken(ipaddress._BaseAddress):
+ pass
+ broken = Broken('127.0.0.1')
+ with self.assertRaisesRegex(NotImplementedError, "Broken.*version"):
+ broken.version
+
+ def testMissingNetworkVersion(self):
+ class Broken(ipaddress._BaseNetwork):
+ pass
+ broken = Broken('127.0.0.1')
+ with self.assertRaisesRegex(NotImplementedError, "Broken.*version"):
+ broken.version
+
+ def testMissingAddressClass(self):
+ class Broken(ipaddress._BaseNetwork):
+ pass
+ broken = Broken('127.0.0.1')
+ with self.assertRaisesRegex(NotImplementedError, "Broken.*address"):
+ broken._address_class
+
+ def testGetNetwork(self):
+ self.assertEqual(int(self.ipv4_network.network_address), 16909056)
+ self.assertEqual(str(self.ipv4_network.network_address), '1.2.3.0')
+
+ self.assertEqual(int(self.ipv6_network.network_address),
+ 42540616829182469433403647294022090752)
+ self.assertEqual(str(self.ipv6_network.network_address),
+ '2001:658:22a:cafe::')
+ self.assertEqual(str(self.ipv6_network.hostmask),
+ '::ffff:ffff:ffff:ffff')
+
+ def testIpFromInt(self):
+ self.assertEqual(self.ipv4_interface._ip,
+ ipaddress.IPv4Interface(16909060)._ip)
+
+ ipv4 = ipaddress.ip_network('1.2.3.4')
+ ipv6 = ipaddress.ip_network('2001:658:22a:cafe:200:0:0:1')
+ self.assertEqual(ipv4, ipaddress.ip_network(int(ipv4.network_address)))
+ self.assertEqual(ipv6, ipaddress.ip_network(int(ipv6.network_address)))
+
+ v6_int = 42540616829182469433547762482097946625
+ self.assertEqual(self.ipv6_interface._ip,
+ ipaddress.IPv6Interface(v6_int)._ip)
+
+ self.assertEqual(ipaddress.ip_network(self.ipv4_address._ip).version,
+ 4)
+ self.assertEqual(ipaddress.ip_network(self.ipv6_address._ip).version,
+ 6)
+
+ def testIpFromPacked(self):
+ address = ipaddress.ip_address
+ self.assertEqual(self.ipv4_interface._ip,
+ ipaddress.ip_interface(b'\x01\x02\x03\x04')._ip)
+ self.assertEqual(address('255.254.253.252'),
+ address(b'\xff\xfe\xfd\xfc'))
+ self.assertEqual(self.ipv6_interface.ip,
+ ipaddress.ip_interface(
+ b'\x20\x01\x06\x58\x02\x2a\xca\xfe'
+ b'\x02\x00\x00\x00\x00\x00\x00\x01').ip)
+ self.assertEqual(address('ffff:2:3:4:ffff::'),
+ address(b'\xff\xff\x00\x02\x00\x03\x00\x04' +
+ b'\xff\xff' + b'\x00' * 6))
+ self.assertEqual(address('::'),
+ address(b'\x00' * 16))
+
+ def testGetIp(self):
+ self.assertEqual(int(self.ipv4_interface.ip), 16909060)
+ self.assertEqual(str(self.ipv4_interface.ip), '1.2.3.4')
+
+ self.assertEqual(int(self.ipv6_interface.ip),
+ 42540616829182469433547762482097946625)
+ self.assertEqual(str(self.ipv6_interface.ip),
+ '2001:658:22a:cafe:200::1')
+
+ def testGetNetmask(self):
+ self.assertEqual(int(self.ipv4_network.netmask), 4294967040)
+ self.assertEqual(str(self.ipv4_network.netmask), '255.255.255.0')
+ self.assertEqual(int(self.ipv6_network.netmask),
+ 340282366920938463444927863358058659840)
+ self.assertEqual(self.ipv6_network.prefixlen, 64)
+
+ def testZeroNetmask(self):
+ ipv4_zero_netmask = ipaddress.IPv4Interface('1.2.3.4/0')
+ self.assertEqual(int(ipv4_zero_netmask.network.netmask), 0)
+ self.assertTrue(ipv4_zero_netmask.network._is_valid_netmask(
+ str(0)))
+ self.assertTrue(ipv4_zero_netmask._is_valid_netmask('0'))
+ self.assertTrue(ipv4_zero_netmask._is_valid_netmask('0.0.0.0'))
+ self.assertFalse(ipv4_zero_netmask._is_valid_netmask('invalid'))
+
+ ipv6_zero_netmask = ipaddress.IPv6Interface('::1/0')
+ self.assertEqual(int(ipv6_zero_netmask.network.netmask), 0)
+ self.assertTrue(ipv6_zero_netmask.network._is_valid_netmask(
+ str(0)))
+
+ def testIPv4NetAndHostmasks(self):
+ net = self.ipv4_network
+ self.assertFalse(net._is_valid_netmask('invalid'))
+ self.assertTrue(net._is_valid_netmask('128.128.128.128'))
+ self.assertFalse(net._is_valid_netmask('128.128.128.127'))
+ self.assertFalse(net._is_valid_netmask('128.128.128.255'))
+ self.assertTrue(net._is_valid_netmask('255.128.128.128'))
+
+ self.assertFalse(net._is_hostmask('invalid'))
+ self.assertTrue(net._is_hostmask('128.255.255.255'))
+ self.assertFalse(net._is_hostmask('255.255.255.255'))
+ self.assertFalse(net._is_hostmask('1.2.3.4'))
+
+ net = ipaddress.IPv4Network('127.0.0.0/0.0.0.255')
+ self.assertEqual(24, net.prefixlen)
+
+ def testGetBroadcast(self):
+ self.assertEqual(int(self.ipv4_network.broadcast_address), 16909311)
+ self.assertEqual(str(self.ipv4_network.broadcast_address), '1.2.3.255')
+
+ self.assertEqual(int(self.ipv6_network.broadcast_address),
+ 42540616829182469451850391367731642367)
+ self.assertEqual(str(self.ipv6_network.broadcast_address),
+ '2001:658:22a:cafe:ffff:ffff:ffff:ffff')
+
+ def testGetPrefixlen(self):
+ self.assertEqual(self.ipv4_interface.network.prefixlen, 24)
+ self.assertEqual(self.ipv6_interface.network.prefixlen, 64)
+
+ def testGetSupernet(self):
+ self.assertEqual(self.ipv4_network.supernet().prefixlen, 23)
+ self.assertEqual(str(self.ipv4_network.supernet().network_address),
+ '1.2.2.0')
+ self.assertEqual(
+ ipaddress.IPv4Interface('0.0.0.0/0').network.supernet(),
+ ipaddress.IPv4Network('0.0.0.0/0'))
+
+ self.assertEqual(self.ipv6_network.supernet().prefixlen, 63)
+ self.assertEqual(str(self.ipv6_network.supernet().network_address),
+ '2001:658:22a:cafe::')
+ self.assertEqual(ipaddress.IPv6Interface('::0/0').network.supernet(),
+ ipaddress.IPv6Network('::0/0'))
+
+ def testGetSupernet3(self):
+ self.assertEqual(self.ipv4_network.supernet(3).prefixlen, 21)
+ self.assertEqual(str(self.ipv4_network.supernet(3).network_address),
+ '1.2.0.0')
+
+ self.assertEqual(self.ipv6_network.supernet(3).prefixlen, 61)
+ self.assertEqual(str(self.ipv6_network.supernet(3).network_address),
+ '2001:658:22a:caf8::')
+
+ def testGetSupernet4(self):
+ self.assertRaises(ValueError, self.ipv4_network.supernet,
+ prefixlen_diff=2, new_prefix=1)
+ self.assertRaises(ValueError, self.ipv4_network.supernet,
+ new_prefix=25)
+ self.assertEqual(self.ipv4_network.supernet(prefixlen_diff=2),
+ self.ipv4_network.supernet(new_prefix=22))
+
+ self.assertRaises(ValueError, self.ipv6_network.supernet,
+ prefixlen_diff=2, new_prefix=1)
+ self.assertRaises(ValueError, self.ipv6_network.supernet,
+ new_prefix=65)
+ self.assertEqual(self.ipv6_network.supernet(prefixlen_diff=2),
+ self.ipv6_network.supernet(new_prefix=62))
+
+ def testHosts(self):
+ hosts = list(self.ipv4_network.hosts())
+ self.assertEqual(254, len(hosts))
+ self.assertEqual(ipaddress.IPv4Address('1.2.3.1'), hosts[0])
+ self.assertEqual(ipaddress.IPv4Address('1.2.3.254'), hosts[-1])
+
+ # special case where only 1 bit is left for address
+ self.assertEqual([ipaddress.IPv4Address('2.0.0.0'),
+ ipaddress.IPv4Address('2.0.0.1')],
+ list(ipaddress.ip_network('2.0.0.0/31').hosts()))
+
+ def testFancySubnetting(self):
+ self.assertEqual(sorted(self.ipv4_network.subnets(prefixlen_diff=3)),
+ sorted(self.ipv4_network.subnets(new_prefix=27)))
+ self.assertRaises(ValueError, list,
+ self.ipv4_network.subnets(new_prefix=23))
+ self.assertRaises(ValueError, list,
+ self.ipv4_network.subnets(prefixlen_diff=3,
+ new_prefix=27))
+ self.assertEqual(sorted(self.ipv6_network.subnets(prefixlen_diff=4)),
+ sorted(self.ipv6_network.subnets(new_prefix=68)))
+ self.assertRaises(ValueError, list,
+ self.ipv6_network.subnets(new_prefix=63))
+ self.assertRaises(ValueError, list,
+ self.ipv6_network.subnets(prefixlen_diff=4,
+ new_prefix=68))
+
+ def testGetSubnets(self):
+ self.assertEqual(list(self.ipv4_network.subnets())[0].prefixlen, 25)
+ self.assertEqual(str(list(
+ self.ipv4_network.subnets())[0].network_address),
+ '1.2.3.0')
+ self.assertEqual(str(list(
+ self.ipv4_network.subnets())[1].network_address),
+ '1.2.3.128')
+
+ self.assertEqual(list(self.ipv6_network.subnets())[0].prefixlen, 65)
+
+ def testGetSubnetForSingle32(self):
+ ip = ipaddress.IPv4Network('1.2.3.4/32')
+ subnets1 = [str(x) for x in ip.subnets()]
+ subnets2 = [str(x) for x in ip.subnets(2)]
+ self.assertEqual(subnets1, ['1.2.3.4/32'])
+ self.assertEqual(subnets1, subnets2)
+
+ def testGetSubnetForSingle128(self):
+ ip = ipaddress.IPv6Network('::1/128')
+ subnets1 = [str(x) for x in ip.subnets()]
+ subnets2 = [str(x) for x in ip.subnets(2)]
+ self.assertEqual(subnets1, ['::1/128'])
+ self.assertEqual(subnets1, subnets2)
+
+ def testSubnet2(self):
+ ips = [str(x) for x in self.ipv4_network.subnets(2)]
+ self.assertEqual(
+ ips,
+ ['1.2.3.0/26', '1.2.3.64/26', '1.2.3.128/26', '1.2.3.192/26'])
+
+ ipsv6 = [str(x) for x in self.ipv6_network.subnets(2)]
+ self.assertEqual(
+ ipsv6,
+ ['2001:658:22a:cafe::/66',
+ '2001:658:22a:cafe:4000::/66',
+ '2001:658:22a:cafe:8000::/66',
+ '2001:658:22a:cafe:c000::/66'])
+
+ def testSubnetFailsForLargeCidrDiff(self):
+ self.assertRaises(ValueError, list,
+ self.ipv4_interface.network.subnets(9))
+ self.assertRaises(ValueError, list,
+ self.ipv4_network.subnets(9))
+ self.assertRaises(ValueError, list,
+ self.ipv6_interface.network.subnets(65))
+ self.assertRaises(ValueError, list,
+ self.ipv6_network.subnets(65))
+
+ def testSupernetFailsForLargeCidrDiff(self):
+ self.assertRaises(ValueError,
+ self.ipv4_interface.network.supernet, 25)
+ self.assertRaises(ValueError,
+ self.ipv6_interface.network.supernet, 65)
+
+ def testSubnetFailsForNegativeCidrDiff(self):
+ self.assertRaises(ValueError, list,
+ self.ipv4_interface.network.subnets(-1))
+ self.assertRaises(ValueError, list,
+ self.ipv4_network.subnets(-1))
+ self.assertRaises(ValueError, list,
+ self.ipv6_interface.network.subnets(-1))
+ self.assertRaises(ValueError, list,
+ self.ipv6_network.subnets(-1))
+
+ def testGetNum_Addresses(self):
+ self.assertEqual(self.ipv4_network.num_addresses, 256)
+ self.assertEqual(list(self.ipv4_network.subnets())[0].num_addresses,
+ 128)
+ self.assertEqual(self.ipv4_network.supernet().num_addresses, 512)
+
+ self.assertEqual(self.ipv6_network.num_addresses, 18446744073709551616)
+ self.assertEqual(list(self.ipv6_network.subnets())[0].num_addresses,
+ 9223372036854775808)
+ self.assertEqual(self.ipv6_network.supernet().num_addresses,
+ 36893488147419103232)
+
+ def testContains(self):
+ self.assertTrue(ipaddress.IPv4Interface('1.2.3.128/25') in
+ self.ipv4_network)
+ self.assertFalse(ipaddress.IPv4Interface('1.2.4.1/24') in
+ self.ipv4_network)
+ # We can test addresses and string as well.
+ addr1 = ipaddress.IPv4Address('1.2.3.37')
+ self.assertTrue(addr1 in self.ipv4_network)
+ # issue 61, bad network comparison on like-ip'd network objects
+ # with identical broadcast addresses.
+ self.assertFalse(ipaddress.IPv4Network('1.1.0.0/16').__contains__(
+ ipaddress.IPv4Network('1.0.0.0/15')))
+
+ def testNth(self):
+ self.assertEqual(str(self.ipv4_network[5]), '1.2.3.5')
+ self.assertRaises(IndexError, self.ipv4_network.__getitem__, 256)
+
+ self.assertEqual(str(self.ipv6_network[5]),
+ '2001:658:22a:cafe::5')
+
+ def testGetitem(self):
+ # http://code.google.com/p/ipaddr-py/issues/detail?id=15
+ addr = ipaddress.IPv4Network('172.31.255.128/255.255.255.240')
+ self.assertEqual(28, addr.prefixlen)
+ addr_list = list(addr)
+ self.assertEqual('172.31.255.128', str(addr_list[0]))
+ self.assertEqual('172.31.255.128', str(addr[0]))
+ self.assertEqual('172.31.255.143', str(addr_list[-1]))
+ self.assertEqual('172.31.255.143', str(addr[-1]))
+ self.assertEqual(addr_list[-1], addr[-1])
+
+ def testEqual(self):
+ self.assertTrue(self.ipv4_interface ==
+ ipaddress.IPv4Interface('1.2.3.4/24'))
+ self.assertFalse(self.ipv4_interface ==
+ ipaddress.IPv4Interface('1.2.3.4/23'))
+ self.assertFalse(self.ipv4_interface ==
+ ipaddress.IPv6Interface('::1.2.3.4/24'))
+ self.assertFalse(self.ipv4_interface == '')
+ self.assertFalse(self.ipv4_interface == [])
+ self.assertFalse(self.ipv4_interface == 2)
+
+ self.assertTrue(self.ipv6_interface ==
+ ipaddress.IPv6Interface('2001:658:22a:cafe:200::1/64'))
+ self.assertFalse(self.ipv6_interface ==
+ ipaddress.IPv6Interface('2001:658:22a:cafe:200::1/63'))
+ self.assertFalse(self.ipv6_interface ==
+ ipaddress.IPv4Interface('1.2.3.4/23'))
+ self.assertFalse(self.ipv6_interface == '')
+ self.assertFalse(self.ipv6_interface == [])
+ self.assertFalse(self.ipv6_interface == 2)
+
+ def testNotEqual(self):
+ self.assertFalse(self.ipv4_interface !=
+ ipaddress.IPv4Interface('1.2.3.4/24'))
+ self.assertTrue(self.ipv4_interface !=
+ ipaddress.IPv4Interface('1.2.3.4/23'))
+ self.assertTrue(self.ipv4_interface !=
+ ipaddress.IPv6Interface('::1.2.3.4/24'))
+ self.assertTrue(self.ipv4_interface != '')
+ self.assertTrue(self.ipv4_interface != [])
+ self.assertTrue(self.ipv4_interface != 2)
+
+ self.assertTrue(self.ipv4_address !=
+ ipaddress.IPv4Address('1.2.3.5'))
+ self.assertTrue(self.ipv4_address != '')
+ self.assertTrue(self.ipv4_address != [])
+ self.assertTrue(self.ipv4_address != 2)
+
+ self.assertFalse(self.ipv6_interface !=
+ ipaddress.IPv6Interface('2001:658:22a:cafe:200::1/64'))
+ self.assertTrue(self.ipv6_interface !=
+ ipaddress.IPv6Interface('2001:658:22a:cafe:200::1/63'))
+ self.assertTrue(self.ipv6_interface !=
+ ipaddress.IPv4Interface('1.2.3.4/23'))
+ self.assertTrue(self.ipv6_interface != '')
+ self.assertTrue(self.ipv6_interface != [])
+ self.assertTrue(self.ipv6_interface != 2)
+
+ self.assertTrue(self.ipv6_address !=
+ ipaddress.IPv4Address('1.2.3.4'))
+ self.assertTrue(self.ipv6_address != '')
+ self.assertTrue(self.ipv6_address != [])
+ self.assertTrue(self.ipv6_address != 2)
+
+ def testSlash32Constructor(self):
+ self.assertEqual(str(ipaddress.IPv4Interface(
+ '1.2.3.4/255.255.255.255')), '1.2.3.4/32')
+
+ def testSlash128Constructor(self):
+ self.assertEqual(str(ipaddress.IPv6Interface('::1/128')),
+ '::1/128')
+
+ def testSlash0Constructor(self):
+ self.assertEqual(str(ipaddress.IPv4Interface('1.2.3.4/0.0.0.0')),
+ '1.2.3.4/0')
+
+ def testCollapsing(self):
+ # test only IP addresses including some duplicates
+ ip1 = ipaddress.IPv4Address('1.1.1.0')
+ ip2 = ipaddress.IPv4Address('1.1.1.1')
+ ip3 = ipaddress.IPv4Address('1.1.1.2')
+ ip4 = ipaddress.IPv4Address('1.1.1.3')
+ ip5 = ipaddress.IPv4Address('1.1.1.4')
+ ip6 = ipaddress.IPv4Address('1.1.1.0')
+ # check that addreses are subsumed properly.
+ collapsed = ipaddress.collapse_addresses(
+ [ip1, ip2, ip3, ip4, ip5, ip6])
+ self.assertEqual(list(collapsed),
+ [ipaddress.IPv4Network('1.1.1.0/30'),
+ ipaddress.IPv4Network('1.1.1.4/32')])
+
+ # test a mix of IP addresses and networks including some duplicates
+ ip1 = ipaddress.IPv4Address('1.1.1.0')
+ ip2 = ipaddress.IPv4Address('1.1.1.1')
+ ip3 = ipaddress.IPv4Address('1.1.1.2')
+ ip4 = ipaddress.IPv4Address('1.1.1.3')
+ #ip5 = ipaddress.IPv4Interface('1.1.1.4/30')
+ #ip6 = ipaddress.IPv4Interface('1.1.1.4/30')
+ # check that addreses are subsumed properly.
+ collapsed = ipaddress.collapse_addresses([ip1, ip2, ip3, ip4])
+ self.assertEqual(list(collapsed),
+ [ipaddress.IPv4Network('1.1.1.0/30')])
+
+ # test only IP networks
+ ip1 = ipaddress.IPv4Network('1.1.0.0/24')
+ ip2 = ipaddress.IPv4Network('1.1.1.0/24')
+ ip3 = ipaddress.IPv4Network('1.1.2.0/24')
+ ip4 = ipaddress.IPv4Network('1.1.3.0/24')
+ ip5 = ipaddress.IPv4Network('1.1.4.0/24')
+ # stored in no particular order b/c we want CollapseAddr to call
+ # [].sort
+ ip6 = ipaddress.IPv4Network('1.1.0.0/22')
+ # check that addreses are subsumed properly.
+ collapsed = ipaddress.collapse_addresses([ip1, ip2, ip3, ip4, ip5,
+ ip6])
+ self.assertEqual(list(collapsed),
+ [ipaddress.IPv4Network('1.1.0.0/22'),
+ ipaddress.IPv4Network('1.1.4.0/24')])
+
+ # test that two addresses are supernet'ed properly
+ collapsed = ipaddress.collapse_addresses([ip1, ip2])
+ self.assertEqual(list(collapsed),
+ [ipaddress.IPv4Network('1.1.0.0/23')])
+
+ # test same IP networks
+ ip_same1 = ip_same2 = ipaddress.IPv4Network('1.1.1.1/32')
+ self.assertEqual(list(ipaddress.collapse_addresses(
+ [ip_same1, ip_same2])),
+ [ip_same1])
+
+ # test same IP addresses
+ ip_same1 = ip_same2 = ipaddress.IPv4Address('1.1.1.1')
+ self.assertEqual(list(ipaddress.collapse_addresses(
+ [ip_same1, ip_same2])),
+ [ipaddress.ip_network('1.1.1.1/32')])
+ ip1 = ipaddress.IPv6Network('2001::/100')
+ ip2 = ipaddress.IPv6Network('2001::/120')
+ ip3 = ipaddress.IPv6Network('2001::/96')
+ # test that ipv6 addresses are subsumed properly.
+ collapsed = ipaddress.collapse_addresses([ip1, ip2, ip3])
+ self.assertEqual(list(collapsed), [ip3])
+
+ # the toejam test
+ addr_tuples = [
+ (ipaddress.ip_address('1.1.1.1'),
+ ipaddress.ip_address('::1')),
+ (ipaddress.IPv4Network('1.1.0.0/24'),
+ ipaddress.IPv6Network('2001::/120')),
+ (ipaddress.IPv4Network('1.1.0.0/32'),
+ ipaddress.IPv6Network('2001::/128')),
+ ]
+ for ip1, ip2 in addr_tuples:
+ self.assertRaises(TypeError, ipaddress.collapse_addresses,
+ [ip1, ip2])
+
+ def testSummarizing(self):
+ #ip = ipaddress.ip_address
+ #ipnet = ipaddress.ip_network
+ summarize = ipaddress.summarize_address_range
+ ip1 = ipaddress.ip_address('1.1.1.0')
+ ip2 = ipaddress.ip_address('1.1.1.255')
+
+ # summarize works only for IPv4 & IPv6
+ class IPv7Address(ipaddress.IPv6Address):
+ @property
+ def version(self):
+ return 7
+ ip_invalid1 = IPv7Address('::1')
+ ip_invalid2 = IPv7Address('::1')
+ self.assertRaises(ValueError, list,
+ summarize(ip_invalid1, ip_invalid2))
+ # test that a summary over ip4 & ip6 fails
+ self.assertRaises(TypeError, list,
+ summarize(ip1, ipaddress.IPv6Address('::1')))
+ # test a /24 is summarized properly
+ self.assertEqual(list(summarize(ip1, ip2))[0],
+ ipaddress.ip_network('1.1.1.0/24'))
+ # test an IPv4 range that isn't on a network byte boundary
+ ip2 = ipaddress.ip_address('1.1.1.8')
+ self.assertEqual(list(summarize(ip1, ip2)),
+ [ipaddress.ip_network('1.1.1.0/29'),
+ ipaddress.ip_network('1.1.1.8')])
+ # all!
+ ip1 = ipaddress.IPv4Address(0)
+ ip2 = ipaddress.IPv4Address(ipaddress.IPv4Address._ALL_ONES)
+ self.assertEqual([ipaddress.IPv4Network('0.0.0.0/0')],
+ list(summarize(ip1, ip2)))
+
+ ip1 = ipaddress.ip_address('1::')
+ ip2 = ipaddress.ip_address('1:ffff:ffff:ffff:ffff:ffff:ffff:ffff')
+ # test a IPv6 is sumamrized properly
+ self.assertEqual(list(summarize(ip1, ip2))[0],
+ ipaddress.ip_network('1::/16'))
+ # test an IPv6 range that isn't on a network byte boundary
+ ip2 = ipaddress.ip_address('2::')
+ self.assertEqual(list(summarize(ip1, ip2)),
+ [ipaddress.ip_network('1::/16'),
+ ipaddress.ip_network('2::/128')])
+
+ # test exception raised when first is greater than last
+ self.assertRaises(ValueError, list,
+ summarize(ipaddress.ip_address('1.1.1.0'),
+ ipaddress.ip_address('1.1.0.0')))
+ # test exception raised when first and last aren't IP addresses
+ self.assertRaises(TypeError, list,
+ summarize(ipaddress.ip_network('1.1.1.0'),
+ ipaddress.ip_network('1.1.0.0')))
+ self.assertRaises(TypeError, list,
+ summarize(ipaddress.ip_network('1.1.1.0'),
+ ipaddress.ip_network('1.1.0.0')))
+ # test exception raised when first and last are not same version
+ self.assertRaises(TypeError, list,
+ summarize(ipaddress.ip_address('::'),
+ ipaddress.ip_network('1.1.0.0')))
+
+ def testAddressComparison(self):
+ self.assertTrue(ipaddress.ip_address('1.1.1.1') <=
+ ipaddress.ip_address('1.1.1.1'))
+ self.assertTrue(ipaddress.ip_address('1.1.1.1') <=
+ ipaddress.ip_address('1.1.1.2'))
+ self.assertTrue(ipaddress.ip_address('::1') <=
+ ipaddress.ip_address('::1'))
+ self.assertTrue(ipaddress.ip_address('::1') <=
+ ipaddress.ip_address('::2'))
+
+ def testInterfaceComparison(self):
+ self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
+ ipaddress.ip_interface('1.1.1.1'))
+ self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
+ ipaddress.ip_interface('1.1.1.2'))
+ self.assertTrue(ipaddress.ip_interface('::1') <=
+ ipaddress.ip_interface('::1'))
+ self.assertTrue(ipaddress.ip_interface('::1') <=
+ ipaddress.ip_interface('::2'))
+
+ def testNetworkComparison(self):
+ # ip1 and ip2 have the same network address
+ ip1 = ipaddress.IPv4Network('1.1.1.0/24')
+ ip2 = ipaddress.IPv4Network('1.1.1.0/32')
+ ip3 = ipaddress.IPv4Network('1.1.2.0/24')
+
+ self.assertTrue(ip1 < ip3)
+ self.assertTrue(ip3 > ip2)
+
+ self.assertEqual(ip1.compare_networks(ip1), 0)
+
+ # if addresses are the same, sort by netmask
+ self.assertEqual(ip1.compare_networks(ip2), -1)
+ self.assertEqual(ip2.compare_networks(ip1), 1)
+
+ self.assertEqual(ip1.compare_networks(ip3), -1)
+ self.assertEqual(ip3.compare_networks(ip1), 1)
+ self.assertTrue(ip1._get_networks_key() < ip3._get_networks_key())
+
+ ip1 = ipaddress.IPv6Network('2001:2000::/96')
+ ip2 = ipaddress.IPv6Network('2001:2001::/96')
+ ip3 = ipaddress.IPv6Network('2001:ffff:2000::/96')
+
+ self.assertTrue(ip1 < ip3)
+ self.assertTrue(ip3 > ip2)
+ self.assertEqual(ip1.compare_networks(ip3), -1)
+ self.assertTrue(ip1._get_networks_key() < ip3._get_networks_key())
+
+ # Test comparing different protocols.
+ # Should always raise a TypeError.
+ self.assertRaises(TypeError,
+ self.ipv4_network.compare_networks,
+ self.ipv6_network)
+ ipv6 = ipaddress.IPv6Interface('::/0')
+ ipv4 = ipaddress.IPv4Interface('0.0.0.0/0')
+ self.assertRaises(TypeError, ipv4.__lt__, ipv6)
+ self.assertRaises(TypeError, ipv4.__gt__, ipv6)
+ self.assertRaises(TypeError, ipv6.__lt__, ipv4)
+ self.assertRaises(TypeError, ipv6.__gt__, ipv4)
+
+ # Regression test for issue 19.
+ ip1 = ipaddress.ip_network('10.1.2.128/25')
+ self.assertFalse(ip1 < ip1)
+ self.assertFalse(ip1 > ip1)
+ ip2 = ipaddress.ip_network('10.1.3.0/24')
+ self.assertTrue(ip1 < ip2)
+ self.assertFalse(ip2 < ip1)
+ self.assertFalse(ip1 > ip2)
+ self.assertTrue(ip2 > ip1)
+ ip3 = ipaddress.ip_network('10.1.3.0/25')
+ self.assertTrue(ip2 < ip3)
+ self.assertFalse(ip3 < ip2)
+ self.assertFalse(ip2 > ip3)
+ self.assertTrue(ip3 > ip2)
+
+ # Regression test for issue 28.
+ ip1 = ipaddress.ip_network('10.10.10.0/31')
+ ip2 = ipaddress.ip_network('10.10.10.0')
+ ip3 = ipaddress.ip_network('10.10.10.2/31')
+ ip4 = ipaddress.ip_network('10.10.10.2')
+ sorted = [ip1, ip2, ip3, ip4]
+ unsorted = [ip2, ip4, ip1, ip3]
+ unsorted.sort()
+ self.assertEqual(sorted, unsorted)
+ unsorted = [ip4, ip1, ip3, ip2]
+ unsorted.sort()
+ self.assertEqual(sorted, unsorted)
+ self.assertRaises(TypeError, ip1.__lt__,
+ ipaddress.ip_address('10.10.10.0'))
+ self.assertRaises(TypeError, ip2.__lt__,
+ ipaddress.ip_address('10.10.10.0'))
+
+ # <=, >=
+ self.assertTrue(ipaddress.ip_network('1.1.1.1') <=
+ ipaddress.ip_network('1.1.1.1'))
+ self.assertTrue(ipaddress.ip_network('1.1.1.1') <=
+ ipaddress.ip_network('1.1.1.2'))
+ self.assertFalse(ipaddress.ip_network('1.1.1.2') <=
+ ipaddress.ip_network('1.1.1.1'))
+ self.assertTrue(ipaddress.ip_network('::1') <=
+ ipaddress.ip_network('::1'))
+ self.assertTrue(ipaddress.ip_network('::1') <=
+ ipaddress.ip_network('::2'))
+ self.assertFalse(ipaddress.ip_network('::2') <=
+ ipaddress.ip_network('::1'))
+
+ def testStrictNetworks(self):
+ self.assertRaises(ValueError, ipaddress.ip_network, '192.168.1.1/24')
+ self.assertRaises(ValueError, ipaddress.ip_network, '::1/120')
+
+ def testOverlaps(self):
+ other = ipaddress.IPv4Network('1.2.3.0/30')
+ other2 = ipaddress.IPv4Network('1.2.2.0/24')
+ other3 = ipaddress.IPv4Network('1.2.2.64/26')
+ self.assertTrue(self.ipv4_network.overlaps(other))
+ self.assertFalse(self.ipv4_network.overlaps(other2))
+ self.assertTrue(other2.overlaps(other3))
+
+ def testEmbeddedIpv4(self):
+ ipv4_string = '192.168.0.1'
+ ipv4 = ipaddress.IPv4Interface(ipv4_string)
+ v4compat_ipv6 = ipaddress.IPv6Interface('::%s' % ipv4_string)
+ self.assertEqual(int(v4compat_ipv6.ip), int(ipv4.ip))
+ v4mapped_ipv6 = ipaddress.IPv6Interface('::ffff:%s' % ipv4_string)
+ self.assertNotEqual(v4mapped_ipv6.ip, ipv4.ip)
+ self.assertRaises(ipaddress.AddressValueError, ipaddress.IPv6Interface,
+ '2001:1.1.1.1:1.1.1.1')
+
+ # Issue 67: IPv6 with embedded IPv4 address not recognized.
+ def testIPv6AddressTooLarge(self):
+ # RFC4291 2.5.5.2
+ self.assertEqual(ipaddress.ip_address('::FFFF:192.0.2.1'),
+ ipaddress.ip_address('::FFFF:c000:201'))
+ # RFC4291 2.2 (part 3) x::d.d.d.d
+ self.assertEqual(ipaddress.ip_address('FFFF::192.0.2.1'),
+ ipaddress.ip_address('FFFF::c000:201'))
+
+ def testIPVersion(self):
+ self.assertEqual(self.ipv4_address.version, 4)
+ self.assertEqual(self.ipv6_address.version, 6)
+
+ def testMaxPrefixLength(self):
+ self.assertEqual(self.ipv4_interface.max_prefixlen, 32)
+ self.assertEqual(self.ipv6_interface.max_prefixlen, 128)
+
+ def testPacked(self):
+ self.assertEqual(self.ipv4_address.packed,
+ b'\x01\x02\x03\x04')
+ self.assertEqual(ipaddress.IPv4Interface('255.254.253.252').packed,
+ b'\xff\xfe\xfd\xfc')
+ self.assertEqual(self.ipv6_address.packed,
+ b'\x20\x01\x06\x58\x02\x2a\xca\xfe'
+ b'\x02\x00\x00\x00\x00\x00\x00\x01')
+ self.assertEqual(ipaddress.IPv6Interface('ffff:2:3:4:ffff::').packed,
+ b'\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff'
+ + b'\x00' * 6)
+ self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed,
+ b'\x00' * 6 + b'\x00\x01' + b'\x00' * 8)
+
+ def testIpStrFromPrefixlen(self):
+ ipv4 = ipaddress.IPv4Interface('1.2.3.4/24')
+ self.assertEqual(ipv4._ip_string_from_prefix(), '255.255.255.0')
+ self.assertEqual(ipv4._ip_string_from_prefix(28), '255.255.255.240')
+
+ def testIpType(self):
+ ipv4net = ipaddress.ip_network('1.2.3.4')
+ ipv4addr = ipaddress.ip_address('1.2.3.4')
+ ipv6net = ipaddress.ip_network('::1.2.3.4')
+ ipv6addr = ipaddress.ip_address('::1.2.3.4')
+ self.assertEqual(ipaddress.IPv4Network, type(ipv4net))
+ self.assertEqual(ipaddress.IPv4Address, type(ipv4addr))
+ self.assertEqual(ipaddress.IPv6Network, type(ipv6net))
+ self.assertEqual(ipaddress.IPv6Address, type(ipv6addr))
+
+ def testReservedIpv4(self):
+ # test networks
+ self.assertEqual(True, ipaddress.ip_interface(
+ '224.1.1.1/31').is_multicast)
+ self.assertEqual(False, ipaddress.ip_network('240.0.0.0').is_multicast)
+ self.assertEqual(True, ipaddress.ip_network('240.0.0.0').is_reserved)
+
+ self.assertEqual(True, ipaddress.ip_interface(
+ '192.168.1.1/17').is_private)
+ self.assertEqual(False, ipaddress.ip_network('192.169.0.0').is_private)
+ self.assertEqual(True, ipaddress.ip_network(
+ '10.255.255.255').is_private)
+ self.assertEqual(False, ipaddress.ip_network('11.0.0.0').is_private)
+ self.assertEqual(False, ipaddress.ip_network('11.0.0.0').is_reserved)
+ self.assertEqual(True, ipaddress.ip_network(
+ '172.31.255.255').is_private)
+ self.assertEqual(False, ipaddress.ip_network('172.32.0.0').is_private)
+ self.assertEqual(True,
+ ipaddress.ip_network('169.254.1.0/24').is_link_local)
+
+ self.assertEqual(True,
+ ipaddress.ip_interface(
+ '169.254.100.200/24').is_link_local)
+ self.assertEqual(False,
+ ipaddress.ip_interface(
+ '169.255.100.200/24').is_link_local)
+
+ self.assertEqual(True,
+ ipaddress.ip_network(
+ '127.100.200.254/32').is_loopback)
+ self.assertEqual(True, ipaddress.ip_network(
+ '127.42.0.0/16').is_loopback)
+ self.assertEqual(False, ipaddress.ip_network('128.0.0.0').is_loopback)
+
+ # test addresses
+ self.assertEqual(True, ipaddress.ip_address('0.0.0.0').is_unspecified)
+ self.assertEqual(True, ipaddress.ip_address('224.1.1.1').is_multicast)
+ self.assertEqual(False, ipaddress.ip_address('240.0.0.0').is_multicast)
+ self.assertEqual(True, ipaddress.ip_address('240.0.0.1').is_reserved)
+ self.assertEqual(False,
+ ipaddress.ip_address('239.255.255.255').is_reserved)
+
+ self.assertEqual(True, ipaddress.ip_address('192.168.1.1').is_private)
+ self.assertEqual(False, ipaddress.ip_address('192.169.0.0').is_private)
+ self.assertEqual(True, ipaddress.ip_address(
+ '10.255.255.255').is_private)
+ self.assertEqual(False, ipaddress.ip_address('11.0.0.0').is_private)
+ self.assertEqual(True, ipaddress.ip_address(
+ '172.31.255.255').is_private)
+ self.assertEqual(False, ipaddress.ip_address('172.32.0.0').is_private)
+
+ self.assertEqual(True,
+ ipaddress.ip_address('169.254.100.200').is_link_local)
+ self.assertEqual(False,
+ ipaddress.ip_address('169.255.100.200').is_link_local)
+
+ self.assertEqual(True,
+ ipaddress.ip_address('127.100.200.254').is_loopback)
+ self.assertEqual(True, ipaddress.ip_address('127.42.0.0').is_loopback)
+ self.assertEqual(False, ipaddress.ip_address('128.0.0.0').is_loopback)
+ self.assertEqual(True, ipaddress.ip_network('0.0.0.0').is_unspecified)
+
+ def testReservedIpv6(self):
+
+ self.assertEqual(True, ipaddress.ip_network('ffff::').is_multicast)
+ self.assertEqual(True, ipaddress.ip_network(2**128 - 1).is_multicast)
+ self.assertEqual(True, ipaddress.ip_network('ff00::').is_multicast)
+ self.assertEqual(False, ipaddress.ip_network('fdff::').is_multicast)
+
+ self.assertEqual(True, ipaddress.ip_network('fecf::').is_site_local)
+ self.assertEqual(True, ipaddress.ip_network(
+ 'feff:ffff:ffff:ffff::').is_site_local)
+ self.assertEqual(False, ipaddress.ip_network(
+ 'fbf:ffff::').is_site_local)
+ self.assertEqual(False, ipaddress.ip_network('ff00::').is_site_local)
+
+ self.assertEqual(True, ipaddress.ip_network('fc00::').is_private)
+ self.assertEqual(True, ipaddress.ip_network(
+ 'fc00:ffff:ffff:ffff::').is_private)
+ self.assertEqual(False, ipaddress.ip_network('fbff:ffff::').is_private)
+ self.assertEqual(False, ipaddress.ip_network('fe00::').is_private)
+
+ self.assertEqual(True, ipaddress.ip_network('fea0::').is_link_local)
+ self.assertEqual(True, ipaddress.ip_network(
+ 'febf:ffff::').is_link_local)
+ self.assertEqual(False, ipaddress.ip_network(
+ 'fe7f:ffff::').is_link_local)
+ self.assertEqual(False, ipaddress.ip_network('fec0::').is_link_local)
+
+ self.assertEqual(True, ipaddress.ip_interface('0:0::0:01').is_loopback)
+ self.assertEqual(False, ipaddress.ip_interface('::1/127').is_loopback)
+ self.assertEqual(False, ipaddress.ip_network('::').is_loopback)
+ self.assertEqual(False, ipaddress.ip_network('::2').is_loopback)
+
+ self.assertEqual(True, ipaddress.ip_network('0::0').is_unspecified)
+ self.assertEqual(False, ipaddress.ip_network('::1').is_unspecified)
+ self.assertEqual(False, ipaddress.ip_network('::/127').is_unspecified)
+
+ # test addresses
+ self.assertEqual(True, ipaddress.ip_address('ffff::').is_multicast)
+ self.assertEqual(True, ipaddress.ip_address(2**128 - 1).is_multicast)
+ self.assertEqual(True, ipaddress.ip_address('ff00::').is_multicast)
+ self.assertEqual(False, ipaddress.ip_address('fdff::').is_multicast)
+
+ self.assertEqual(True, ipaddress.ip_address('fecf::').is_site_local)
+ self.assertEqual(True, ipaddress.ip_address(
+ 'feff:ffff:ffff:ffff::').is_site_local)
+ self.assertEqual(False, ipaddress.ip_address(
+ 'fbf:ffff::').is_site_local)
+ self.assertEqual(False, ipaddress.ip_address('ff00::').is_site_local)
+
+ self.assertEqual(True, ipaddress.ip_address('fc00::').is_private)
+ self.assertEqual(True, ipaddress.ip_address(
+ 'fc00:ffff:ffff:ffff::').is_private)
+ self.assertEqual(False, ipaddress.ip_address('fbff:ffff::').is_private)
+ self.assertEqual(False, ipaddress.ip_address('fe00::').is_private)
+
+ self.assertEqual(True, ipaddress.ip_address('fea0::').is_link_local)
+ self.assertEqual(True, ipaddress.ip_address(
+ 'febf:ffff::').is_link_local)
+ self.assertEqual(False, ipaddress.ip_address(
+ 'fe7f:ffff::').is_link_local)
+ self.assertEqual(False, ipaddress.ip_address('fec0::').is_link_local)
+
+ self.assertEqual(True, ipaddress.ip_address('0:0::0:01').is_loopback)
+ self.assertEqual(True, ipaddress.ip_address('::1').is_loopback)
+ self.assertEqual(False, ipaddress.ip_address('::2').is_loopback)
+
+ self.assertEqual(True, ipaddress.ip_address('0::0').is_unspecified)
+ self.assertEqual(False, ipaddress.ip_address('::1').is_unspecified)
+
+ # some generic IETF reserved addresses
+ self.assertEqual(True, ipaddress.ip_address('100::').is_reserved)
+ self.assertEqual(True, ipaddress.ip_network('4000::1/128').is_reserved)
+
+ def testIpv4Mapped(self):
+ self.assertEqual(
+ ipaddress.ip_address('::ffff:192.168.1.1').ipv4_mapped,
+ ipaddress.ip_address('192.168.1.1'))
+ self.assertEqual(ipaddress.ip_address('::c0a8:101').ipv4_mapped, None)
+ self.assertEqual(ipaddress.ip_address('::ffff:c0a8:101').ipv4_mapped,
+ ipaddress.ip_address('192.168.1.1'))
+
+ def testAddrExclude(self):
+ addr1 = ipaddress.ip_network('10.1.1.0/24')
+ addr2 = ipaddress.ip_network('10.1.1.0/26')
+ addr3 = ipaddress.ip_network('10.2.1.0/24')
+ addr4 = ipaddress.ip_address('10.1.1.0')
+ addr5 = ipaddress.ip_network('2001:db8::0/32')
+ self.assertEqual(sorted(list(addr1.address_exclude(addr2))),
+ [ipaddress.ip_network('10.1.1.64/26'),
+ ipaddress.ip_network('10.1.1.128/25')])
+ self.assertRaises(ValueError, list, addr1.address_exclude(addr3))
+ self.assertRaises(TypeError, list, addr1.address_exclude(addr4))
+ self.assertRaises(TypeError, list, addr1.address_exclude(addr5))
+ self.assertEqual(list(addr1.address_exclude(addr1)), [])
+
+ def testHash(self):
+ self.assertEqual(hash(ipaddress.ip_interface('10.1.1.0/24')),
+ hash(ipaddress.ip_interface('10.1.1.0/24')))
+ self.assertEqual(hash(ipaddress.ip_network('10.1.1.0/24')),
+ hash(ipaddress.ip_network('10.1.1.0/24')))
+ self.assertEqual(hash(ipaddress.ip_address('10.1.1.0')),
+ hash(ipaddress.ip_address('10.1.1.0')))
+ # i70
+ self.assertEqual(hash(ipaddress.ip_address('1.2.3.4')),
+ hash(ipaddress.ip_address(
+ int(ipaddress.ip_address('1.2.3.4')._ip))))
+ ip1 = ipaddress.ip_address('10.1.1.0')
+ ip2 = ipaddress.ip_address('1::')
+ dummy = {}
+ dummy[self.ipv4_address] = None
+ dummy[self.ipv6_address] = None
+ dummy[ip1] = None
+ dummy[ip2] = None
+ self.assertTrue(self.ipv4_address in dummy)
+ self.assertTrue(ip2 in dummy)
+
+ def testIPBases(self):
+ net = self.ipv4_network
+ self.assertEqual('1.2.3.0/24', net.compressed)
+ self.assertEqual(
+ net._ip_int_from_prefix(24),
+ net._ip_int_from_prefix(None))
+ net = self.ipv6_network
+ self.assertRaises(ValueError, net._string_from_ip_int, 2**128 + 1)
+ self.assertEqual(
+ self.ipv6_address._string_from_ip_int(self.ipv6_address._ip),
+ self.ipv6_address._string_from_ip_int(None))
+
+ def testIPv6NetworkHelpers(self):
+ net = self.ipv6_network
+ self.assertEqual('2001:658:22a:cafe::/64', net.with_prefixlen)
+ self.assertEqual('2001:658:22a:cafe::/ffff:ffff:ffff:ffff::',
+ net.with_netmask)
+ self.assertEqual('2001:658:22a:cafe::/::ffff:ffff:ffff:ffff',
+ net.with_hostmask)
+ self.assertEqual('2001:658:22a:cafe::/64', str(net))
+
+ def testIPv4NetworkHelpers(self):
+ net = self.ipv4_network
+ self.assertEqual('1.2.3.0/24', net.with_prefixlen)
+ self.assertEqual('1.2.3.0/255.255.255.0', net.with_netmask)
+ self.assertEqual('1.2.3.0/0.0.0.255', net.with_hostmask)
+ self.assertEqual('1.2.3.0/24', str(net))
+
+ def testCopyConstructor(self):
+ addr1 = ipaddress.ip_network('10.1.1.0/24')
+ addr2 = ipaddress.ip_network(addr1)
+ addr3 = ipaddress.ip_interface('2001:658:22a:cafe:200::1/64')
+ addr4 = ipaddress.ip_interface(addr3)
+ addr5 = ipaddress.IPv4Address('1.1.1.1')
+ addr6 = ipaddress.IPv6Address('2001:658:22a:cafe:200::1')
+
+ self.assertEqual(addr1, addr2)
+ self.assertEqual(addr3, addr4)
+ self.assertEqual(addr5, ipaddress.IPv4Address(addr5))
+ self.assertEqual(addr6, ipaddress.IPv6Address(addr6))
+
+ def testCompressIPv6Address(self):
+ test_addresses = {
+ '1:2:3:4:5:6:7:8': '1:2:3:4:5:6:7:8/128',
+ '2001:0:0:4:0:0:0:8': '2001:0:0:4::8/128',
+ '2001:0:0:4:5:6:7:8': '2001::4:5:6:7:8/128',
+ '2001:0:3:4:5:6:7:8': '2001:0:3:4:5:6:7:8/128',
+ '2001:0:3:4:5:6:7:8': '2001:0:3:4:5:6:7:8/128',
+ '0:0:3:0:0:0:0:ffff': '0:0:3::ffff/128',
+ '0:0:0:4:0:0:0:ffff': '::4:0:0:0:ffff/128',
+ '0:0:0:0:5:0:0:ffff': '::5:0:0:ffff/128',
+ '1:0:0:4:0:0:7:8': '1::4:0:0:7:8/128',
+ '0:0:0:0:0:0:0:0': '::/128',
+ '0:0:0:0:0:0:0:0/0': '::/0',
+ '0:0:0:0:0:0:0:1': '::1/128',
+ '2001:0658:022a:cafe:0000:0000:0000:0000/66':
+ '2001:658:22a:cafe::/66',
+ '::1.2.3.4': '::102:304/128',
+ '1:2:3:4:5:ffff:1.2.3.4': '1:2:3:4:5:ffff:102:304/128',
+ '::7:6:5:4:3:2:1': '0:7:6:5:4:3:2:1/128',
+ '::7:6:5:4:3:2:0': '0:7:6:5:4:3:2:0/128',
+ '7:6:5:4:3:2:1::': '7:6:5:4:3:2:1:0/128',
+ '0:6:5:4:3:2:1::': '0:6:5:4:3:2:1:0/128',
+ }
+ for uncompressed, compressed in list(test_addresses.items()):
+ self.assertEqual(compressed, str(ipaddress.IPv6Interface(
+ uncompressed)))
+
+ def testExplodeShortHandIpStr(self):
+ addr1 = ipaddress.IPv6Interface('2001::1')
+ addr2 = ipaddress.IPv6Address('2001:0:5ef5:79fd:0:59d:a0e5:ba1')
+ addr3 = ipaddress.IPv6Network('2001::/96')
+ addr4 = ipaddress.IPv4Address('192.168.178.1')
+ self.assertEqual('2001:0000:0000:0000:0000:0000:0000:0001/128',
+ addr1.exploded)
+ self.assertEqual('0000:0000:0000:0000:0000:0000:0000:0001/128',
+ ipaddress.IPv6Interface('::1/128').exploded)
+ # issue 77
+ self.assertEqual('2001:0000:5ef5:79fd:0000:059d:a0e5:0ba1',
+ addr2.exploded)
+ self.assertEqual('2001:0000:0000:0000:0000:0000:0000:0000/96',
+ addr3.exploded)
+ self.assertEqual('192.168.178.1', addr4.exploded)
+
+ def testIntRepresentation(self):
+ self.assertEqual(16909060, int(self.ipv4_address))
+ self.assertEqual(42540616829182469433547762482097946625,
+ int(self.ipv6_address))
+
+ def testForceVersion(self):
+ self.assertEqual(ipaddress.ip_network(1).version, 4)
+ self.assertEqual(ipaddress.IPv6Network(1).version, 6)
+
+ def testWithStar(self):
+ self.assertEqual(self.ipv4_interface.with_prefixlen, "1.2.3.4/24")
+ self.assertEqual(self.ipv4_interface.with_netmask,
+ "1.2.3.4/255.255.255.0")
+ self.assertEqual(self.ipv4_interface.with_hostmask,
+ "1.2.3.4/0.0.0.255")
+
+ self.assertEqual(self.ipv6_interface.with_prefixlen,
+ '2001:658:22a:cafe:200::1/64')
+ self.assertEqual(self.ipv6_interface.with_netmask,
+ '2001:658:22a:cafe:200::1/ffff:ffff:ffff:ffff::')
+ # this probably don't make much sense, but it's included for
+ # compatibility with ipv4
+ self.assertEqual(self.ipv6_interface.with_hostmask,
+ '2001:658:22a:cafe:200::1/::ffff:ffff:ffff:ffff')
+
+ def testNetworkElementCaching(self):
+ # V4 - make sure we're empty
+ self.assertFalse('network_address' in self.ipv4_network._cache)
+ self.assertFalse('broadcast_address' in self.ipv4_network._cache)
+ self.assertFalse('hostmask' in self.ipv4_network._cache)
+
+ # V4 - populate and test
+ self.assertEqual(self.ipv4_network.network_address,
+ ipaddress.IPv4Address('1.2.3.0'))
+ self.assertEqual(self.ipv4_network.broadcast_address,
+ ipaddress.IPv4Address('1.2.3.255'))
+ self.assertEqual(self.ipv4_network.hostmask,
+ ipaddress.IPv4Address('0.0.0.255'))
+
+ # V4 - check we're cached
+ self.assertTrue('broadcast_address' in self.ipv4_network._cache)
+ self.assertTrue('hostmask' in self.ipv4_network._cache)
+
+ # V6 - make sure we're empty
+ self.assertFalse('broadcast_address' in self.ipv6_network._cache)
+ self.assertFalse('hostmask' in self.ipv6_network._cache)
+
+ # V6 - populate and test
+ self.assertEqual(self.ipv6_network.network_address,
+ ipaddress.IPv6Address('2001:658:22a:cafe::'))
+ self.assertEqual(self.ipv6_interface.network.network_address,
+ ipaddress.IPv6Address('2001:658:22a:cafe::'))
+
+ self.assertEqual(
+ self.ipv6_network.broadcast_address,
+ ipaddress.IPv6Address('2001:658:22a:cafe:ffff:ffff:ffff:ffff'))
+ self.assertEqual(self.ipv6_network.hostmask,
+ ipaddress.IPv6Address('::ffff:ffff:ffff:ffff'))
+ self.assertEqual(
+ self.ipv6_interface.network.broadcast_address,
+ ipaddress.IPv6Address('2001:658:22a:cafe:ffff:ffff:ffff:ffff'))
+ self.assertEqual(self.ipv6_interface.network.hostmask,
+ ipaddress.IPv6Address('::ffff:ffff:ffff:ffff'))
+
+ # V6 - check we're cached
+ self.assertTrue('broadcast_address' in self.ipv6_network._cache)
+ self.assertTrue('hostmask' in self.ipv6_network._cache)
+ self.assertTrue(
+ 'broadcast_address' in self.ipv6_interface.network._cache)
+ self.assertTrue('hostmask' in self.ipv6_interface.network._cache)
+
+ def testTeredo(self):
+ # stolen from wikipedia
+ server = ipaddress.IPv4Address('65.54.227.120')
+ client = ipaddress.IPv4Address('192.0.2.45')
+ teredo_addr = '2001:0000:4136:e378:8000:63bf:3fff:fdd2'
+ self.assertEqual((server, client),
+ ipaddress.ip_address(teredo_addr).teredo)
+ bad_addr = '2000::4136:e378:8000:63bf:3fff:fdd2'
+ self.assertFalse(ipaddress.ip_address(bad_addr).teredo)
+ bad_addr = '2001:0001:4136:e378:8000:63bf:3fff:fdd2'
+ self.assertFalse(ipaddress.ip_address(bad_addr).teredo)
+
+ # i77
+ teredo_addr = ipaddress.IPv6Address('2001:0:5ef5:79fd:0:59d:a0e5:ba1')
+ self.assertEqual((ipaddress.IPv4Address('94.245.121.253'),
+ ipaddress.IPv4Address('95.26.244.94')),
+ teredo_addr.teredo)
+
+ def testsixtofour(self):
+ sixtofouraddr = ipaddress.ip_address('2002:ac1d:2d64::1')
+ bad_addr = ipaddress.ip_address('2000:ac1d:2d64::1')
+ self.assertEqual(ipaddress.IPv4Address('172.29.45.100'),
+ sixtofouraddr.sixtofour)
+ self.assertFalse(bad_addr.sixtofour)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_isinstance.py b/Lib/test/test_isinstance.py
index dc2d074..7a6730e 100644
--- a/Lib/test/test_isinstance.py
+++ b/Lib/test/test_isinstance.py
@@ -15,7 +15,7 @@ class TestIsInstanceExceptions(unittest.TestCase):
# (leading to an "undetected error" in the debug build). Set up is,
# isinstance(inst, cls) where:
#
- # - cls isn't a a type, or a tuple
+ # - cls isn't a type, or a tuple
# - cls has a __bases__ attribute
# - inst has a __class__ attribute
# - inst.__class__ as no __bases__ attribute
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
index 4d97183..43f8e15 100644
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -2,6 +2,8 @@
import unittest
from test.support import run_unittest, TESTFN, unlink, cpython_only
+import pickle
+import collections.abc
# Test result of triple loop (too big to inline)
TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
@@ -28,6 +30,8 @@ class BasicIterClass:
raise StopIteration
self.i = res + 1
return res
+ def __iter__(self):
+ return self
class IteratingSequenceClass:
def __init__(self, n):
@@ -49,7 +53,9 @@ class SequenceClass:
class TestCase(unittest.TestCase):
# Helper to check that an iterator returns a given sequence
- def check_iterator(self, it, seq):
+ def check_iterator(self, it, seq, pickle=True):
+ if pickle:
+ self.check_pickle(it, seq)
res = []
while 1:
try:
@@ -60,12 +66,33 @@ class TestCase(unittest.TestCase):
self.assertEqual(res, seq)
# Helper to check that a for loop generates a given sequence
- def check_for_loop(self, expr, seq):
+ def check_for_loop(self, expr, seq, pickle=True):
+ if pickle:
+ self.check_pickle(iter(expr), seq)
res = []
for val in expr:
res.append(val)
self.assertEqual(res, seq)
+ # Helper to check picklability
+ def check_pickle(self, itorg, seq):
+ d = pickle.dumps(itorg)
+ it = pickle.loads(d)
+ # Cannot assert type equality because dict iterators unpickle as list
+ # iterators.
+ # self.assertEqual(type(itorg), type(it))
+ self.assertTrue(isinstance(it, collections.abc.Iterator))
+ self.assertEqual(list(it), seq)
+
+ it = pickle.loads(d)
+ try:
+ next(it)
+ except StopIteration:
+ return
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(list(it), seq[1:])
+
# Test basic use of iter() function
def test_iter_basic(self):
self.check_iterator(iter(range(10)), list(range(10)))
@@ -138,7 +165,7 @@ class TestCase(unittest.TestCase):
if i > 100:
raise IndexError # Emergency stop
return i
- self.check_iterator(iter(C(), 10), list(range(10)))
+ self.check_iterator(iter(C(), 10), list(range(10)), pickle=False)
# Test two-argument iter() with function
def test_iter_function(self):
@@ -146,7 +173,7 @@ class TestCase(unittest.TestCase):
i = state[0]
state[0] = i+1
return i
- self.check_iterator(iter(spam, 10), list(range(10)))
+ self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
# Test two-argument iter() with function that raises StopIteration
def test_iter_function_stop(self):
@@ -156,7 +183,7 @@ class TestCase(unittest.TestCase):
raise StopIteration
state[0] = i+1
return i
- self.check_iterator(iter(spam, 20), list(range(10)))
+ self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
# Test exception propagation through function iterator
def test_exception_function(self):
@@ -198,7 +225,7 @@ class TestCase(unittest.TestCase):
if i == 10:
raise StopIteration
return SequenceClass.__getitem__(self, i)
- self.check_for_loop(MySequenceClass(20), list(range(10)))
+ self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
# Test a big range
def test_iter_big_range(self):
@@ -237,8 +264,8 @@ class TestCase(unittest.TestCase):
f.close()
f = open(TESTFN, "r")
try:
- self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
- self.check_for_loop(f, [])
+ self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
+ self.check_for_loop(f, [], pickle=False)
finally:
f.close()
try:
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index dfa371e..53926a9 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -37,6 +37,13 @@ def isOdd(x):
'Test predicate'
return x%2==1
+def tupleize(*args):
+ return args
+
+def irange(n):
+ for i in range(n):
+ yield i
+
class StopNow:
'Class emulating an empty iterable.'
def __iter__(self):
@@ -55,8 +62,59 @@ def fact(n):
'Factorial'
return prod(range(1, n+1))
+# root level methods for pickling ability
+def testR(r):
+ return r[0]
+
+def testR2(r):
+ return r[2]
+
+def underten(x):
+ return x<10
+
class TestBasicOps(unittest.TestCase):
+ def pickletest(self, it, stop=4, take=1, compare=None):
+ """Test that an iterator is the same after pickling, also when part-consumed"""
+ def expand(it, i=0):
+ # Recursively expand iterables, within sensible bounds
+ if i > 10:
+ raise RuntimeError("infinite recursion encountered")
+ if isinstance(it, str):
+ return it
+ try:
+ l = list(islice(it, stop))
+ except TypeError:
+ return it # can't expand it
+ return [expand(e, i+1) for e in l]
+
+ # Test the initial copy against the original
+ dump = pickle.dumps(it)
+ i2 = pickle.loads(dump)
+ self.assertEqual(type(it), type(i2))
+ a, b = expand(it), expand(i2)
+ self.assertEqual(a, b)
+ if compare:
+ c = expand(compare)
+ self.assertEqual(a, c)
+
+ # Take from the copy, and create another copy and compare them.
+ i3 = pickle.loads(dump)
+ took = 0
+ try:
+ for i in range(take):
+ next(i3)
+ took += 1
+ except StopIteration:
+ pass #in case there is less data than 'take'
+ dump = pickle.dumps(i3)
+ i4 = pickle.loads(dump)
+ a, b = expand(i3), expand(i4)
+ self.assertEqual(a, b)
+ if compare:
+ c = expand(compare[took:])
+ self.assertEqual(a, c);
+
def test_accumulate(self):
self.assertEqual(list(accumulate(range(10))), # one positional arg
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
@@ -69,11 +127,22 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(accumulate('abc')), ['a', 'ab', 'abc']) # works with non-numeric
self.assertEqual(list(accumulate([])), []) # empty iterable
self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
- self.assertRaises(TypeError, accumulate, range(10), 5) # too many args
+ self.assertRaises(TypeError, accumulate, range(10), 5, 6) # too many args
self.assertRaises(TypeError, accumulate) # too few args
self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
+ s = [2, 8, 9, 5, 7, 0, 3, 4, 1, 6]
+ self.assertEqual(list(accumulate(s, min)),
+ [2, 2, 2, 2, 2, 0, 0, 0, 0, 0])
+ self.assertEqual(list(accumulate(s, max)),
+ [2, 8, 9, 9, 9, 9, 9, 9, 9, 9])
+ self.assertEqual(list(accumulate(s, operator.mul)),
+ [2, 16, 144, 720, 5040, 0, 0, 0, 0, 0])
+ with self.assertRaises(TypeError):
+ list(accumulate(s, chr)) # unary-operation
+ self.pickletest(accumulate(range(10))) # test pickling
+
def test_chain(self):
def chain2(*iterables):
@@ -96,14 +165,43 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd'))
self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
+ def test_chain_reducible(self):
+ operators = [copy.deepcopy,
+ lambda s: pickle.loads(pickle.dumps(s))]
+ for oper in operators:
+ it = chain('abc', 'def')
+ self.assertEqual(list(oper(it)), list('abcdef'))
+ self.assertEqual(next(it), 'a')
+ self.assertEqual(list(oper(it)), list('bcdef'))
+
+ self.assertEqual(list(oper(chain(''))), [])
+ self.assertEqual(take(4, oper(chain('abc', 'def'))), list('abcd'))
+ self.assertRaises(TypeError, list, oper(chain(2, 3)))
+ self.pickletest(chain('abc', 'def'), compare=list('abcdef'))
+
def test_combinations(self):
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, combinations, None) # pool is not iterable
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
- self.assertEqual(list(combinations('abc', 32)), []) # r > n
- self.assertEqual(list(combinations(range(4), 3)),
- [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
+
+ for op in (lambda a:a, lambda a:pickle.loads(pickle.dumps(a))):
+ self.assertEqual(list(op(combinations('abc', 32))), []) # r > n
+
+ self.assertEqual(list(op(combinations('ABCD', 2))),
+ [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')])
+ testIntermediate = combinations('ABCD', 2)
+ next(testIntermediate)
+ self.assertEqual(list(op(testIntermediate)),
+ [('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')])
+
+ self.assertEqual(list(op(combinations(range(4), 3))),
+ [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
+ testIntermediate = combinations(range(4), 3)
+ next(testIntermediate)
+ self.assertEqual(list(op(testIntermediate)),
+ [(0,1,3), (0,2,3), (1,2,3)])
+
def combinations1(iterable, r):
'Pure python version shown in the docs'
@@ -158,7 +256,11 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
self.assertEqual(result, list(combinations3(values, r))) # matches second pure python version
+ self.pickletest(combinations(values, r)) # test pickling
+
# Test implementation detail: tuple re-use
+ @support.impl_detail("tuple reuse is specific to CPython")
+ def test_combinations_tuple_reuse(self):
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
@@ -168,8 +270,15 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, cwr, None) # pool is not iterable
self.assertRaises(ValueError, cwr, 'abc', -2) # r is negative
- self.assertEqual(list(cwr('ABC', 2)),
- [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+
+ for op in (lambda a:a, lambda a:pickle.loads(pickle.dumps(a))):
+ self.assertEqual(list(op(cwr('ABC', 2))),
+ [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+ testIntermediate = cwr('ABC', 2)
+ next(testIntermediate)
+ self.assertEqual(list(op(testIntermediate)),
+ [('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+
def cwr1(iterable, r):
'Pure python version shown in the docs'
@@ -228,7 +337,13 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(cwr1(values, r))) # matches first pure python version
self.assertEqual(result, list(cwr2(values, r))) # matches second pure python version
+ self.pickletest(cwr(values,r)) # test pickling
+
# Test implementation detail: tuple re-use
+
+ @support.impl_detail("tuple reuse is specific to CPython")
+ def test_combinations_with_replacement_tuple_reuse(self):
+ cwr = combinations_with_replacement
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
@@ -292,7 +407,10 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r
- # Test implementation detail: tuple re-use
+ self.pickletest(permutations(values, r)) # test pickling
+
+ @support.impl_detail("tuple resuse is CPython specific")
+ def test_permutations_tuple_reuse(self):
self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
@@ -345,6 +463,24 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, compress, range(6)) # too few args
self.assertRaises(TypeError, compress, range(6), None) # too many args
+ # check copy, deepcopy, pickle
+ for op in (lambda a:copy.copy(a), lambda a:copy.deepcopy(a), lambda a:pickle.loads(pickle.dumps(a))):
+ for data, selectors, result1, result2 in [
+ ('ABCDEF', [1,0,1,0,1,1], 'ACEF', 'CEF'),
+ ('ABCDEF', [0,0,0,0,0,0], '', ''),
+ ('ABCDEF', [1,1,1,1,1,1], 'ABCDEF', 'BCDEF'),
+ ('ABCDEF', [1,0,1], 'AC', 'C'),
+ ('ABC', [0,1,1,1,1,1], 'BC', 'C'),
+ ]:
+
+ self.assertEqual(list(op(compress(data=data, selectors=selectors))), list(result1))
+ self.assertEqual(list(op(compress(data, selectors))), list(result1))
+ testIntermediate = compress(data, selectors)
+ if result1:
+ next(testIntermediate)
+ self.assertEqual(list(op(testIntermediate)), list(result2))
+
+
def test_count(self):
self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
self.assertEqual(lzip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
@@ -379,7 +515,7 @@ class TestBasicOps(unittest.TestCase):
c = count(value)
self.assertEqual(next(copy.copy(c)), value)
self.assertEqual(next(copy.deepcopy(c)), value)
- self.assertEqual(next(pickle.loads(pickle.dumps(c))), value)
+ self.pickletest(count(value))
#check proper internal error handling for large "step' sizes
count(1, maxsize+5); sys.exc_info()
@@ -426,6 +562,7 @@ class TestBasicOps(unittest.TestCase):
else:
r2 = ('count(%r, %r)' % (i, j)).replace('L', '')
self.assertEqual(r1, r2)
+ self.pickletest(count(i, j))
def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
@@ -434,6 +571,18 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, cycle, 5)
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
+ # check copy, deepcopy, pickle
+ c = cycle('abc')
+ self.assertEqual(next(c), 'a')
+ #simple copy currently not supported, because __reduce__ returns
+ #an internal iterator
+ #self.assertEqual(take(10, copy.copy(c)), list('bcabcabcab'))
+ self.assertEqual(take(10, copy.deepcopy(c)), list('bcabcabcab'))
+ self.assertEqual(take(10, pickle.loads(pickle.dumps(c))), list('bcabcabcab'))
+ next(c)
+ self.assertEqual(take(10, pickle.loads(pickle.dumps(c))), list('cabcabcabc'))
+ self.pickletest(cycle('abc'))
+
def test_groupby(self):
# Check whether it accepts arguments correctly
self.assertEqual([], list(groupby([])))
@@ -452,18 +601,37 @@ class TestBasicOps(unittest.TestCase):
dup.append(elem)
self.assertEqual(s, dup)
+ # Check normal pickled
+ dup = []
+ for k, g in pickle.loads(pickle.dumps(groupby(s, testR))):
+ for elem in g:
+ self.assertEqual(k, elem[0])
+ dup.append(elem)
+ self.assertEqual(s, dup)
+
# Check nested case
dup = []
- for k, g in groupby(s, lambda r:r[0]):
- for ik, ig in groupby(g, lambda r:r[2]):
+ for k, g in groupby(s, testR):
+ for ik, ig in groupby(g, testR2):
+ for elem in ig:
+ self.assertEqual(k, elem[0])
+ self.assertEqual(ik, elem[2])
+ dup.append(elem)
+ self.assertEqual(s, dup)
+
+ # Check nested and pickled
+ dup = []
+ for k, g in pickle.loads(pickle.dumps(groupby(s, testR))):
+ for ik, ig in pickle.loads(pickle.dumps(groupby(g, testR2))):
for elem in ig:
self.assertEqual(k, elem[0])
self.assertEqual(ik, elem[2])
dup.append(elem)
self.assertEqual(s, dup)
+
# Check case where inner iterator is not used
- keys = [k for k, g in groupby(s, lambda r:r[0])]
+ keys = [k for k, g in groupby(s, testR)]
expectedkeys = set([r[0] for r in s])
self.assertEqual(set(keys), expectedkeys)
self.assertEqual(len(keys), len(expectedkeys))
@@ -534,6 +702,20 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, filter, isEven, 3)
self.assertRaises(TypeError, next, filter(range(6), range(6)))
+ # check copy, deepcopy, pickle
+ ans = [0,2,4]
+
+ c = filter(isEven, range(6))
+ self.assertEqual(list(copy.copy(c)), ans)
+ c = filter(isEven, range(6))
+ self.assertEqual(list(copy.deepcopy(c)), ans)
+ c = filter(isEven, range(6))
+ self.assertEqual(list(pickle.loads(pickle.dumps(c))), ans)
+ next(c)
+ self.assertEqual(list(pickle.loads(pickle.dumps(c))), ans[1:])
+ c = filter(isEven, range(6))
+ self.pickletest(c)
+
def test_filterfalse(self):
self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5])
self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0])
@@ -544,6 +726,7 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, filterfalse, lambda x:x, range(6), 7)
self.assertRaises(TypeError, filterfalse, isEven, 3)
self.assertRaises(TypeError, next, filterfalse(range(6), range(6)))
+ self.pickletest(filterfalse(isEven, range(6)))
def test_zip(self):
# XXX This is rather silly now that builtin zip() calls zip()...
@@ -556,16 +739,35 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(zip()), lzip())
self.assertRaises(TypeError, zip, 3)
self.assertRaises(TypeError, zip, range(3), 3)
- # Check tuple re-use (implementation detail)
self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')],
lzip('abc', 'def'))
self.assertEqual([pair for pair in zip('abc', 'def')],
lzip('abc', 'def'))
+
+ @support.impl_detail("tuple reuse is specific to CPython")
+ def test_zip_tuple_reuse(self):
ids = list(map(id, zip('abc', 'def')))
self.assertEqual(min(ids), max(ids))
ids = list(map(id, list(zip('abc', 'def'))))
self.assertEqual(len(dict.fromkeys(ids)), len(ids))
+ # check copy, deepcopy, pickle
+ ans = [(x,y) for x, y in copy.copy(zip('abc',count()))]
+ self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)])
+
+ ans = [(x,y) for x, y in copy.deepcopy(zip('abc',count()))]
+ self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)])
+
+ ans = [(x,y) for x, y in pickle.loads(pickle.dumps(zip('abc',count())))]
+ self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)])
+
+ testIntermediate = zip('abc',count())
+ next(testIntermediate)
+ ans = [(x,y) for x, y in pickle.loads(pickle.dumps(testIntermediate))]
+ self.assertEqual(ans, [('b', 1), ('c', 2)])
+
+ self.pickletest(zip('abc', count()))
+
def test_ziplongest(self):
for args in [
['abc', range(6)],
@@ -603,16 +805,24 @@ class TestBasicOps(unittest.TestCase):
else:
self.fail('Did not raise Type in: ' + stmt)
- # Check tuple re-use (implementation detail)
self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')],
list(zip('abc', 'def')))
self.assertEqual([pair for pair in zip_longest('abc', 'def')],
list(zip('abc', 'def')))
+
+ @support.impl_detail("tuple reuse is specific to CPython")
+ def test_zip_longest_tuple_reuse(self):
ids = list(map(id, zip_longest('abc', 'def')))
self.assertEqual(min(ids), max(ids))
ids = list(map(id, list(zip_longest('abc', 'def'))))
self.assertEqual(len(dict.fromkeys(ids)), len(ids))
+ def test_zip_longest_pickling(self):
+ self.pickletest(zip_longest("abc", "def"))
+ self.pickletest(zip_longest("abc", "defgh"))
+ self.pickletest(zip_longest("abc", "defgh", fillvalue=1))
+ self.pickletest(zip_longest("", "defgh"))
+
def test_bug_7244(self):
class Repeater:
@@ -711,10 +921,25 @@ class TestBasicOps(unittest.TestCase):
args = map(iter, args)
self.assertEqual(len(list(product(*args))), expected_len)
- # Test implementation detail: tuple re-use
+ @support.impl_detail("tuple reuse is specific to CPython")
+ def test_product_tuple_reuse(self):
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)
+ def test_product_pickling(self):
+ # check copy, deepcopy, pickle
+ for args, result in [
+ ([], [()]), # zero iterables
+ (['ab'], [('a',), ('b',)]), # one iterable
+ ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
+ ([range(0), range(2), range(3)], []), # first iterable with zero length
+ ([range(2), range(0), range(3)], []), # middle iterable with zero length
+ ([range(2), range(3), range(0)], []), # last iterable with zero length
+ ]:
+ self.assertEqual(list(copy.copy(product(*args))), result)
+ self.assertEqual(list(copy.deepcopy(product(*args))), result)
+ self.pickletest(product(*args))
+
def test_repeat(self):
self.assertEqual(list(repeat(object='a', times=3)), ['a', 'a', 'a'])
self.assertEqual(lzip(range(3),repeat('a')),
@@ -733,11 +958,16 @@ class TestBasicOps(unittest.TestCase):
list(r)
self.assertEqual(repr(r), 'repeat((1+0j), 0)')
+ # check copy, deepcopy, pickle
+ c = repeat(object='a', times=10)
+ self.assertEqual(next(c), 'a')
+ self.assertEqual(take(2, copy.copy(c)), list('a' * 2))
+ self.assertEqual(take(2, copy.deepcopy(c)), list('a' * 2))
+ self.pickletest(repeat(object='a', times=10))
+
def test_map(self):
self.assertEqual(list(map(operator.pow, range(3), range(1,7))),
[0**1, 1**2, 2**3])
- def tupleize(*args):
- return args
self.assertEqual(list(map(tupleize, 'abc', range(5))),
[('a',0),('b',1),('c',2)])
self.assertEqual(list(map(tupleize, 'abc', count())),
@@ -752,6 +982,18 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(ValueError, next, map(errfunc, [4], [5]))
self.assertRaises(TypeError, next, map(onearg, [4], [5]))
+ # check copy, deepcopy, pickle
+ ans = [('a',0),('b',1),('c',2)]
+
+ c = map(tupleize, 'abc', count())
+ self.assertEqual(list(copy.copy(c)), ans)
+
+ c = map(tupleize, 'abc', count())
+ self.assertEqual(list(copy.deepcopy(c)), ans)
+
+ c = map(tupleize, 'abc', count())
+ self.pickletest(c)
+
def test_starmap(self):
self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))),
[0**1, 1**2, 2**3])
@@ -766,6 +1008,18 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(ValueError, next, starmap(errfunc, [(4,5)]))
self.assertRaises(TypeError, next, starmap(onearg, [(4,5)]))
+ # check copy, deepcopy, pickle
+ ans = [0**1, 1**2, 2**3]
+
+ c = starmap(operator.pow, zip(range(3), range(1,7)))
+ self.assertEqual(list(copy.copy(c)), ans)
+
+ c = starmap(operator.pow, zip(range(3), range(1,7)))
+ self.assertEqual(list(copy.deepcopy(c)), ans)
+
+ c = starmap(operator.pow, zip(range(3), range(1,7)))
+ self.pickletest(c)
+
def test_islice(self):
for args in [ # islice(args) should agree with range(args)
(10, 20, 3),
@@ -798,17 +1052,18 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(it), list(range(3, 10)))
# Test invalid arguments
- self.assertRaises(TypeError, islice, range(10))
- self.assertRaises(TypeError, islice, range(10), 1, 2, 3, 4)
- self.assertRaises(ValueError, islice, range(10), -5, 10, 1)
- self.assertRaises(ValueError, islice, range(10), 1, -5, -1)
- self.assertRaises(ValueError, islice, range(10), 1, 10, -1)
- self.assertRaises(ValueError, islice, range(10), 1, 10, 0)
- self.assertRaises(ValueError, islice, range(10), 'a')
- self.assertRaises(ValueError, islice, range(10), 'a', 1)
- self.assertRaises(ValueError, islice, range(10), 1, 'a')
- self.assertRaises(ValueError, islice, range(10), 'a', 1, 1)
- self.assertRaises(ValueError, islice, range(10), 1, 'a', 1)
+ ra = range(10)
+ self.assertRaises(TypeError, islice, ra)
+ self.assertRaises(TypeError, islice, ra, 1, 2, 3, 4)
+ self.assertRaises(ValueError, islice, ra, -5, 10, 1)
+ self.assertRaises(ValueError, islice, ra, 1, -5, -1)
+ self.assertRaises(ValueError, islice, ra, 1, 10, -1)
+ self.assertRaises(ValueError, islice, ra, 1, 10, 0)
+ self.assertRaises(ValueError, islice, ra, 'a')
+ self.assertRaises(ValueError, islice, ra, 'a', 1)
+ self.assertRaises(ValueError, islice, ra, 1, 'a')
+ self.assertRaises(ValueError, islice, ra, 'a', 1, 1)
+ self.assertRaises(ValueError, islice, ra, 1, 'a', 1)
self.assertEqual(len(list(islice(count(), 1, 10, maxsize))), 1)
# Issue #10323: Less islice in a predictable state
@@ -816,9 +1071,22 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(islice(c, 1, 3, 50)), [1])
self.assertEqual(next(c), 3)
+ # check copy, deepcopy, pickle
+ for args in [ # islice(args) should agree with range(args)
+ (10, 20, 3),
+ (10, 3, 20),
+ (10, 20),
+ (10, 3),
+ (20,)
+ ]:
+ self.assertEqual(list(copy.copy(islice(range(100), *args))),
+ list(range(*args)))
+ self.assertEqual(list(copy.deepcopy(islice(range(100), *args))),
+ list(range(*args)))
+ self.pickletest(islice(range(100), *args))
+
def test_takewhile(self):
data = [1, 3, 5, 20, 2, 4, 6, 8]
- underten = lambda x: x<10
self.assertEqual(list(takewhile(underten, data)), [1, 3, 5])
self.assertEqual(list(takewhile(underten, [])), [])
self.assertRaises(TypeError, takewhile)
@@ -830,9 +1098,14 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(t), [1, 1, 1])
self.assertRaises(StopIteration, next, t)
+ # check copy, deepcopy, pickle
+ self.assertEqual(list(copy.copy(takewhile(underten, data))), [1, 3, 5])
+ self.assertEqual(list(copy.deepcopy(takewhile(underten, data))),
+ [1, 3, 5])
+ self.pickletest(takewhile(underten, data))
+
def test_dropwhile(self):
data = [1, 3, 5, 20, 2, 4, 6, 8]
- underten = lambda x: x<10
self.assertEqual(list(dropwhile(underten, data)), [20, 2, 4, 6, 8])
self.assertEqual(list(dropwhile(underten, [])), [])
self.assertRaises(TypeError, dropwhile)
@@ -841,11 +1114,14 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, next, dropwhile(10, [(4,5)]))
self.assertRaises(ValueError, next, dropwhile(errfunc, [(4,5)]))
+ # check copy, deepcopy, pickle
+ self.assertEqual(list(copy.copy(dropwhile(underten, data))), [20, 2, 4, 6, 8])
+ self.assertEqual(list(copy.deepcopy(dropwhile(underten, data))),
+ [20, 2, 4, 6, 8])
+ self.pickletest(dropwhile(underten, data))
+
def test_tee(self):
n = 200
- def irange(n):
- for i in range(n):
- yield i
a, b = tee([]) # test empty iterator
self.assertEqual(list(a), [])
@@ -930,6 +1206,67 @@ class TestBasicOps(unittest.TestCase):
del a
self.assertRaises(ReferenceError, getattr, p, '__class__')
+ ans = list('abc')
+ long_ans = list(range(10000))
+
+ # check copy
+ a, b = tee('abc')
+ self.assertEqual(list(copy.copy(a)), ans)
+ self.assertEqual(list(copy.copy(b)), ans)
+ a, b = tee(list(range(10000)))
+ self.assertEqual(list(copy.copy(a)), long_ans)
+ self.assertEqual(list(copy.copy(b)), long_ans)
+
+ # check partially consumed copy
+ a, b = tee('abc')
+ take(2, a)
+ take(1, b)
+ self.assertEqual(list(copy.copy(a)), ans[2:])
+ self.assertEqual(list(copy.copy(b)), ans[1:])
+ self.assertEqual(list(a), ans[2:])
+ self.assertEqual(list(b), ans[1:])
+ a, b = tee(range(10000))
+ take(100, a)
+ take(60, b)
+ self.assertEqual(list(copy.copy(a)), long_ans[100:])
+ self.assertEqual(list(copy.copy(b)), long_ans[60:])
+ self.assertEqual(list(a), long_ans[100:])
+ self.assertEqual(list(b), long_ans[60:])
+
+ # check deepcopy
+ a, b = tee('abc')
+ self.assertEqual(list(copy.deepcopy(a)), ans)
+ self.assertEqual(list(copy.deepcopy(b)), ans)
+ self.assertEqual(list(a), ans)
+ self.assertEqual(list(b), ans)
+ a, b = tee(range(10000))
+ self.assertEqual(list(copy.deepcopy(a)), long_ans)
+ self.assertEqual(list(copy.deepcopy(b)), long_ans)
+ self.assertEqual(list(a), long_ans)
+ self.assertEqual(list(b), long_ans)
+
+ # check partially consumed deepcopy
+ a, b = tee('abc')
+ take(2, a)
+ take(1, b)
+ self.assertEqual(list(copy.deepcopy(a)), ans[2:])
+ self.assertEqual(list(copy.deepcopy(b)), ans[1:])
+ self.assertEqual(list(a), ans[2:])
+ self.assertEqual(list(b), ans[1:])
+ a, b = tee(range(10000))
+ take(100, a)
+ take(60, b)
+ self.assertEqual(list(copy.deepcopy(a)), long_ans[100:])
+ self.assertEqual(list(copy.deepcopy(b)), long_ans[60:])
+ self.assertEqual(list(a), long_ans[100:])
+ self.assertEqual(list(b), long_ans[60:])
+
+ # check pickle
+ self.pickletest(iter(tee('abc')))
+ a, b = tee('abc')
+ self.pickletest(a, compare=ans)
+ self.pickletest(b, compare=ans)
+
# Issue 13454: Crash when deleting backward iterator from tee()
def test_tee_del_backward(self):
forward, backward = tee(repeat(None, 20000000))
@@ -961,9 +1298,21 @@ class TestBasicOps(unittest.TestCase):
class TestExamples(unittest.TestCase):
- def test_accumlate(self):
+ def test_accumulate(self):
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
+ def test_accumulate_reducible(self):
+ # check copy, deepcopy, pickle
+ data = [1, 2, 3, 4, 5]
+ accumulated = [1, 3, 6, 10, 15]
+ it = accumulate(data)
+
+ self.assertEqual(list(pickle.loads(pickle.dumps(it))), accumulated[:])
+ self.assertEqual(next(it), 1)
+ self.assertEqual(list(pickle.loads(pickle.dumps(it))), accumulated[1:])
+ self.assertEqual(list(copy.deepcopy(it)), accumulated[1:])
+ self.assertEqual(list(copy.copy(it)), accumulated[1:])
+
def test_chain(self):
self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
diff --git a/Lib/test/test_keywordonlyarg.py b/Lib/test/test_keywordonlyarg.py
index 108ed18..0503a7f 100644
--- a/Lib/test/test_keywordonlyarg.py
+++ b/Lib/test/test_keywordonlyarg.py
@@ -78,7 +78,7 @@ class KeywordOnlyArgTestCase(unittest.TestCase):
pass
with self.assertRaises(TypeError) as exc:
f(1, 2, 3)
- expected = "f() takes at most 2 positional arguments (3 given)"
+ expected = "f() takes from 1 to 2 positional arguments but 3 were given"
self.assertEqual(str(exc.exception), expected)
def testSyntaxErrorForFunctionCall(self):
diff --git a/Lib/test/test_lib2to3.py b/Lib/test/test_lib2to3.py
index 1afaf70..df4c37b 100644
--- a/Lib/test/test_lib2to3.py
+++ b/Lib/test/test_lib2to3.py
@@ -9,8 +9,8 @@ from test.support import run_unittest
def suite():
tests = unittest.TestSuite()
loader = unittest.TestLoader()
- for m in (test_fixers, test_pytree,test_util, test_refactor,
- test_parser, test_main_):
+ for m in (test_fixers, test_pytree, test_util, test_refactor, test_parser,
+ test_main_):
tests.addTests(loader.loadTestsFromModule(m))
return tests
diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py
index 84b8afe..5df27d3 100644
--- a/Lib/test/test_list.py
+++ b/Lib/test/test_list.py
@@ -1,5 +1,6 @@
import sys
from test import support, list_tests
+import pickle
class ListTest(list_tests.CommonTest):
type2test = list
@@ -69,6 +70,33 @@ class ListTest(list_tests.CommonTest):
check(10) # check our checking code
check(1000000)
+ def test_iterator_pickle(self):
+ # Userlist iterators don't support pickling yet since
+ # they are based on generators.
+ data = self.type2test([4, 5, 6, 7])
+ it = itorg = iter(data)
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(self.type2test(it), self.type2test(data))
+
+ it = pickle.loads(d)
+ next(it)
+ d = pickle.dumps(it)
+ self.assertEqual(self.type2test(it), self.type2test(data)[1:])
+
+ def test_reversed_pickle(self):
+ data = self.type2test([4, 5, 6, 7])
+ it = itorg = reversed(data)
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(self.type2test(it), self.type2test(reversed(data)))
+
+ it = pickle.loads(d)
+ next(it)
+ d = pickle.dumps(it)
+ self.assertEqual(self.type2test(it), self.type2test(reversed(data))[1:])
def test_no_comdat_folding(self):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py
index 8f7574b..51a7bca 100644
--- a/Lib/test/test_locale.py
+++ b/Lib/test/test_locale.py
@@ -11,7 +11,7 @@ def get_enUS_locale():
if sys.platform == 'darwin':
import os
tlocs = ("en_US.UTF-8", "en_US.ISO8859-1", "en_US")
- if int(os.uname()[2].split('.')[0]) < 10:
+ if int(os.uname().release.split('.')[0]) < 10:
# The locale test work fine on OSX 10.6, I (ronaldoussoren)
# haven't had time yet to verify if tests work on OSX 10.5
# (10.4 is known to be bad)
@@ -401,6 +401,8 @@ class TestMiscellaneous(unittest.TestCase):
# Unsupported locale on this system
self.skipTest('test needs Turkish locale')
loc = locale.getlocale(locale.LC_CTYPE)
+ if verbose:
+ print('got locale %a' % (loc,))
locale.setlocale(locale.LC_CTYPE, loc)
self.assertEqual(loc, locale.getlocale(locale.LC_CTYPE))
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index 2a3c780..cb908fb 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -37,12 +37,11 @@ import random
import re
import select
import socket
-from socketserver import ThreadingTCPServer, StreamRequestHandler
import struct
import sys
import tempfile
-from test.support import captured_stdout, run_with_locale, run_unittest, patch
-from test.support import TestHandler, Matcher
+from test.support import (captured_stdout, run_with_locale, run_unittest,
+ patch, requires_zlib, TestHandler, Matcher)
import textwrap
import time
import unittest
@@ -50,8 +49,31 @@ import warnings
import weakref
try:
import threading
+ # The following imports are needed only for tests which
+ # require threading
+ import asynchat
+ import asyncore
+ import errno
+ from http.server import HTTPServer, BaseHTTPRequestHandler
+ import smtpd
+ from urllib.parse import urlparse, parse_qs
+ from socketserver import (ThreadingUDPServer, DatagramRequestHandler,
+ ThreadingTCPServer, StreamRequestHandler)
except ImportError:
threading = None
+try:
+ import win32evtlog
+except ImportError:
+ win32evtlog = None
+try:
+ import win32evtlogutil
+except ImportError:
+ win32evtlogutil = None
+ win32evtlog = None
+try:
+ import zlib
+except ImportError:
+ pass
class BaseTest(unittest.TestCase):
@@ -79,9 +101,7 @@ class BaseTest(unittest.TestCase):
finally:
logging._releaseLock()
- # Set two unused loggers: one non-ASCII and one Unicode.
- # This is to test correct operation when sorting existing
- # loggers in the configuration code. See issue 8201.
+ # Set two unused loggers
self.logger1 = logging.getLogger("\xab\xd7\xbb")
self.logger2 = logging.getLogger("\u013f\u00d6\u0047")
@@ -142,8 +162,7 @@ class BaseTest(unittest.TestCase):
except AttributeError:
# StringIO.StringIO lacks a reset() method.
actual_lines = stream.getvalue().splitlines()
- self.assertEqual(len(actual_lines), len(expected_values),
- '%s vs. %s' % (actual_lines, expected_values))
+ self.assertEqual(len(actual_lines), len(expected_values))
for actual, expected in zip(actual_lines, expected_values):
match = pat.search(actual)
if not match:
@@ -181,17 +200,17 @@ class BuiltinLevelsTest(BaseTest):
INF.log(logging.CRITICAL, m())
INF.error(m())
- INF.warn(m())
+ INF.warning(m())
INF.info(m())
DEB.log(logging.CRITICAL, m())
DEB.error(m())
- DEB.warn (m())
- DEB.info (m())
+ DEB.warning(m())
+ DEB.info(m())
DEB.debug(m())
# These should not log.
- ERR.warn(m())
+ ERR.warning(m())
ERR.info(m())
ERR.debug(m())
@@ -225,7 +244,7 @@ class BuiltinLevelsTest(BaseTest):
INF_ERR.error(m())
# These should not log.
- INF_ERR.warn(m())
+ INF_ERR.warning(m())
INF_ERR.info(m())
INF_ERR.debug(m())
@@ -249,14 +268,14 @@ class BuiltinLevelsTest(BaseTest):
# These should log.
INF_UNDEF.log(logging.CRITICAL, m())
INF_UNDEF.error(m())
- INF_UNDEF.warn(m())
+ INF_UNDEF.warning(m())
INF_UNDEF.info(m())
INF_ERR_UNDEF.log(logging.CRITICAL, m())
INF_ERR_UNDEF.error(m())
# These should not log.
INF_UNDEF.debug(m())
- INF_ERR_UNDEF.warn(m())
+ INF_ERR_UNDEF.warning(m())
INF_ERR_UNDEF.info(m())
INF_ERR_UNDEF.debug(m())
@@ -295,8 +314,6 @@ class BuiltinLevelsTest(BaseTest):
('INF.BADPARENT', 'INFO', '4'),
])
- def test_invalid_name(self):
- self.assertRaises(TypeError, logging.getLogger, any)
class BasicFilterTest(BaseTest):
@@ -355,6 +372,10 @@ class BasicFilterTest(BaseTest):
finally:
handler.removeFilter(filterfunc)
+ def test_empty_filter(self):
+ f = logging.Filter()
+ r = logging.makeLogRecord({'name': 'spam.eggs'})
+ self.assertTrue(f.filter(r))
#
# First, we define our levels. There can be as many as you want - the only
@@ -498,6 +519,478 @@ class CustomLevelsAndFiltersTest(BaseTest):
handler.removeFilter(garr)
+class HandlerTest(BaseTest):
+ def test_name(self):
+ h = logging.Handler()
+ h.name = 'generic'
+ self.assertEqual(h.name, 'generic')
+ h.name = 'anothergeneric'
+ self.assertEqual(h.name, 'anothergeneric')
+ self.assertRaises(NotImplementedError, h.emit, None)
+
+ def test_builtin_handlers(self):
+ # We can't actually *use* too many handlers in the tests,
+ # but we can try instantiating them with various options
+ if sys.platform in ('linux', 'darwin'):
+ for existing in (True, False):
+ fd, fn = tempfile.mkstemp()
+ os.close(fd)
+ if not existing:
+ os.unlink(fn)
+ h = logging.handlers.WatchedFileHandler(fn, delay=True)
+ if existing:
+ dev, ino = h.dev, h.ino
+ self.assertEqual(dev, -1)
+ self.assertEqual(ino, -1)
+ r = logging.makeLogRecord({'msg': 'Test'})
+ h.handle(r)
+ # Now remove the file.
+ os.unlink(fn)
+ self.assertFalse(os.path.exists(fn))
+ # The next call should recreate the file.
+ h.handle(r)
+ self.assertTrue(os.path.exists(fn))
+ else:
+ self.assertEqual(h.dev, -1)
+ self.assertEqual(h.ino, -1)
+ h.close()
+ if existing:
+ os.unlink(fn)
+ if sys.platform == 'darwin':
+ sockname = '/var/run/syslog'
+ else:
+ sockname = '/dev/log'
+ try:
+ h = logging.handlers.SysLogHandler(sockname)
+ self.assertEqual(h.facility, h.LOG_USER)
+ self.assertTrue(h.unixsocket)
+ h.close()
+ except socket.error: # syslogd might not be available
+ pass
+ for method in ('GET', 'POST', 'PUT'):
+ if method == 'PUT':
+ self.assertRaises(ValueError, logging.handlers.HTTPHandler,
+ 'localhost', '/log', method)
+ else:
+ h = logging.handlers.HTTPHandler('localhost', '/log', method)
+ h.close()
+ h = logging.handlers.BufferingHandler(0)
+ r = logging.makeLogRecord({})
+ self.assertTrue(h.shouldFlush(r))
+ h.close()
+ h = logging.handlers.BufferingHandler(1)
+ self.assertFalse(h.shouldFlush(r))
+ h.close()
+
+ @unittest.skipIf(os.name == 'nt', 'WatchedFileHandler not appropriate for Windows.')
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_race(self):
+ # Issue #14632 refers.
+ def remove_loop(fname, tries):
+ for _ in range(tries):
+ try:
+ os.unlink(fname)
+ except OSError:
+ pass
+ time.sleep(0.004 * random.randint(0, 4))
+
+ del_count = 500
+ log_count = 500
+
+ for delay in (False, True):
+ fd, fn = tempfile.mkstemp('.log', 'test_logging-3-')
+ os.close(fd)
+ remover = threading.Thread(target=remove_loop, args=(fn, del_count))
+ remover.daemon = True
+ remover.start()
+ h = logging.handlers.WatchedFileHandler(fn, delay=delay)
+ f = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s')
+ h.setFormatter(f)
+ try:
+ for _ in range(log_count):
+ time.sleep(0.005)
+ r = logging.makeLogRecord({'msg': 'testing' })
+ h.handle(r)
+ finally:
+ remover.join()
+ h.close()
+ if os.path.exists(fn):
+ os.unlink(fn)
+
+
+class BadStream(object):
+ def write(self, data):
+ raise RuntimeError('deliberate mistake')
+
+class TestStreamHandler(logging.StreamHandler):
+ def handleError(self, record):
+ self.error_record = record
+
+class StreamHandlerTest(BaseTest):
+ def test_error_handling(self):
+ h = TestStreamHandler(BadStream())
+ r = logging.makeLogRecord({})
+ old_raise = logging.raiseExceptions
+ old_stderr = sys.stderr
+ try:
+ h.handle(r)
+ self.assertIs(h.error_record, r)
+ h = logging.StreamHandler(BadStream())
+ sys.stderr = sio = io.StringIO()
+ h.handle(r)
+ self.assertIn('\nRuntimeError: deliberate mistake\n',
+ sio.getvalue())
+ logging.raiseExceptions = False
+ sys.stderr = sio = io.StringIO()
+ h.handle(r)
+ self.assertEqual('', sio.getvalue())
+ finally:
+ logging.raiseExceptions = old_raise
+ sys.stderr = old_stderr
+
+# -- The following section could be moved into a server_helper.py module
+# -- if it proves to be of wider utility than just test_logging
+
+if threading:
+ class TestSMTPChannel(smtpd.SMTPChannel):
+ """
+ This derived class has had to be created because smtpd does not
+ support use of custom channel maps, although they are allowed by
+ asyncore's design. Issue #11959 has been raised to address this,
+ and if resolved satisfactorily, some of this code can be removed.
+ """
+ def __init__(self, server, conn, addr, sockmap):
+ asynchat.async_chat.__init__(self, conn, sockmap)
+ self.smtp_server = server
+ self.conn = conn
+ self.addr = addr
+ self.data_size_limit = None
+ self.received_lines = []
+ self.smtp_state = self.COMMAND
+ self.seen_greeting = ''
+ self.mailfrom = None
+ self.rcpttos = []
+ self.received_data = ''
+ self.fqdn = socket.getfqdn()
+ self.num_bytes = 0
+ try:
+ self.peer = conn.getpeername()
+ except socket.error as err:
+ # a race condition may occur if the other end is closing
+ # before we can get the peername
+ self.close()
+ if err.args[0] != errno.ENOTCONN:
+ raise
+ return
+ self.push('220 %s %s' % (self.fqdn, smtpd.__version__))
+ self.set_terminator(b'\r\n')
+ self.extended_smtp = False
+
+
+ class TestSMTPServer(smtpd.SMTPServer):
+ """
+ This class implements a test SMTP server.
+
+ :param addr: A (host, port) tuple which the server listens on.
+ You can specify a port value of zero: the server's
+ *port* attribute will hold the actual port number
+ used, which can be used in client connections.
+ :param handler: A callable which will be called to process
+ incoming messages. The handler will be passed
+ the client address tuple, who the message is from,
+ a list of recipients and the message data.
+ :param poll_interval: The interval, in seconds, used in the underlying
+ :func:`select` or :func:`poll` call by
+ :func:`asyncore.loop`.
+ :param sockmap: A dictionary which will be used to hold
+ :class:`asyncore.dispatcher` instances used by
+ :func:`asyncore.loop`. This avoids changing the
+ :mod:`asyncore` module's global state.
+ """
+ channel_class = TestSMTPChannel
+
+ def __init__(self, addr, handler, poll_interval, sockmap):
+ self._localaddr = addr
+ self._remoteaddr = None
+ self.data_size_limit = None
+ self.sockmap = sockmap
+ asyncore.dispatcher.__init__(self, map=sockmap)
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setblocking(0)
+ self.set_socket(sock, map=sockmap)
+ # try to re-use a server port if possible
+ self.set_reuse_addr()
+ self.bind(addr)
+ self.port = sock.getsockname()[1]
+ self.listen(5)
+ except:
+ self.close()
+ raise
+ self._handler = handler
+ self._thread = None
+ self.poll_interval = poll_interval
+
+ def handle_accepted(self, conn, addr):
+ """
+ Redefined only because the base class does not pass in a
+ map, forcing use of a global in :mod:`asyncore`.
+ """
+ channel = self.channel_class(self, conn, addr, self.sockmap)
+
+ def process_message(self, peer, mailfrom, rcpttos, data):
+ """
+ Delegates to the handler passed in to the server's constructor.
+
+ Typically, this will be a test case method.
+ :param peer: The client (host, port) tuple.
+ :param mailfrom: The address of the sender.
+ :param rcpttos: The addresses of the recipients.
+ :param data: The message.
+ """
+ self._handler(peer, mailfrom, rcpttos, data)
+
+ def start(self):
+ """
+ Start the server running on a separate daemon thread.
+ """
+ self._thread = t = threading.Thread(target=self.serve_forever,
+ args=(self.poll_interval,))
+ t.setDaemon(True)
+ t.start()
+
+ def serve_forever(self, poll_interval):
+ """
+ Run the :mod:`asyncore` loop until normal termination
+ conditions arise.
+ :param poll_interval: The interval, in seconds, used in the underlying
+ :func:`select` or :func:`poll` call by
+ :func:`asyncore.loop`.
+ """
+ try:
+ asyncore.loop(poll_interval, map=self.sockmap)
+ except select.error:
+ # On FreeBSD 8, closing the server repeatably
+ # raises this error. We swallow it if the
+ # server has been closed.
+ if self.connected or self.accepting:
+ raise
+
+ def stop(self, timeout=None):
+ """
+ Stop the thread by closing the server instance.
+ Wait for the server thread to terminate.
+
+ :param timeout: How long to wait for the server thread
+ to terminate.
+ """
+ self.close()
+ self._thread.join(timeout)
+ self._thread = None
+
+ class ControlMixin(object):
+ """
+ This mixin is used to start a server on a separate thread, and
+ shut it down programmatically. Request handling is simplified - instead
+ of needing to derive a suitable RequestHandler subclass, you just
+ provide a callable which will be passed each received request to be
+ processed.
+
+ :param handler: A handler callable which will be called with a
+ single parameter - the request - in order to
+ process the request. This handler is called on the
+ server thread, effectively meaning that requests are
+ processed serially. While not quite Web scale ;-),
+ this should be fine for testing applications.
+ :param poll_interval: The polling interval in seconds.
+ """
+ def __init__(self, handler, poll_interval):
+ self._thread = None
+ self.poll_interval = poll_interval
+ self._handler = handler
+ self.ready = threading.Event()
+
+ def start(self):
+ """
+ Create a daemon thread to run the server, and start it.
+ """
+ self._thread = t = threading.Thread(target=self.serve_forever,
+ args=(self.poll_interval,))
+ t.setDaemon(True)
+ t.start()
+
+ def serve_forever(self, poll_interval):
+ """
+ Run the server. Set the ready flag before entering the
+ service loop.
+ """
+ self.ready.set()
+ super(ControlMixin, self).serve_forever(poll_interval)
+
+ def stop(self, timeout=None):
+ """
+ Tell the server thread to stop, and wait for it to do so.
+
+ :param timeout: How long to wait for the server thread
+ to terminate.
+ """
+ self.shutdown()
+ if self._thread is not None:
+ self._thread.join(timeout)
+ self._thread = None
+ self.server_close()
+ self.ready.clear()
+
+ class TestHTTPServer(ControlMixin, HTTPServer):
+ """
+ An HTTP server which is controllable using :class:`ControlMixin`.
+
+ :param addr: A tuple with the IP address and port to listen on.
+ :param handler: A handler callable which will be called with a
+ single parameter - the request - in order to
+ process the request.
+ :param poll_interval: The polling interval in seconds.
+ :param log: Pass ``True`` to enable log messages.
+ """
+ def __init__(self, addr, handler, poll_interval=0.5,
+ log=False, sslctx=None):
+ class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler):
+ def __getattr__(self, name, default=None):
+ if name.startswith('do_'):
+ return self.process_request
+ raise AttributeError(name)
+
+ def process_request(self):
+ self.server._handler(self)
+
+ def log_message(self, format, *args):
+ if log:
+ super(DelegatingHTTPRequestHandler,
+ self).log_message(format, *args)
+ HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler)
+ ControlMixin.__init__(self, handler, poll_interval)
+ self.sslctx = sslctx
+
+ def get_request(self):
+ try:
+ sock, addr = self.socket.accept()
+ if self.sslctx:
+ sock = self.sslctx.wrap_socket(sock, server_side=True)
+ except socket.error as e:
+ # socket errors are silenced by the caller, print them here
+ sys.stderr.write("Got an error:\n%s\n" % e)
+ raise
+ return sock, addr
+
+ class TestTCPServer(ControlMixin, ThreadingTCPServer):
+ """
+ A TCP server which is controllable using :class:`ControlMixin`.
+
+ :param addr: A tuple with the IP address and port to listen on.
+ :param handler: A handler callable which will be called with a single
+ parameter - the request - in order to process the request.
+ :param poll_interval: The polling interval in seconds.
+ :bind_and_activate: If True (the default), binds the server and starts it
+ listening. If False, you need to call
+ :meth:`server_bind` and :meth:`server_activate` at
+ some later time before calling :meth:`start`, so that
+ the server will set up the socket and listen on it.
+ """
+
+ allow_reuse_address = True
+
+ def __init__(self, addr, handler, poll_interval=0.5,
+ bind_and_activate=True):
+ class DelegatingTCPRequestHandler(StreamRequestHandler):
+
+ def handle(self):
+ self.server._handler(self)
+ ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler,
+ bind_and_activate)
+ ControlMixin.__init__(self, handler, poll_interval)
+
+ def server_bind(self):
+ super(TestTCPServer, self).server_bind()
+ self.port = self.socket.getsockname()[1]
+
+ class TestUDPServer(ControlMixin, ThreadingUDPServer):
+ """
+ A UDP server which is controllable using :class:`ControlMixin`.
+
+ :param addr: A tuple with the IP address and port to listen on.
+ :param handler: A handler callable which will be called with a
+ single parameter - the request - in order to
+ process the request.
+ :param poll_interval: The polling interval for shutdown requests,
+ in seconds.
+ :bind_and_activate: If True (the default), binds the server and
+ starts it listening. If False, you need to
+ call :meth:`server_bind` and
+ :meth:`server_activate` at some later time
+ before calling :meth:`start`, so that the server will
+ set up the socket and listen on it.
+ """
+ def __init__(self, addr, handler, poll_interval=0.5,
+ bind_and_activate=True):
+ class DelegatingUDPRequestHandler(DatagramRequestHandler):
+
+ def handle(self):
+ self.server._handler(self)
+
+ def finish(self):
+ data = self.wfile.getvalue()
+ if data:
+ try:
+ super(DelegatingUDPRequestHandler, self).finish()
+ except socket.error:
+ if not self.server._closed:
+ raise
+
+ ThreadingUDPServer.__init__(self, addr,
+ DelegatingUDPRequestHandler,
+ bind_and_activate)
+ ControlMixin.__init__(self, handler, poll_interval)
+ self._closed = False
+
+ def server_bind(self):
+ super(TestUDPServer, self).server_bind()
+ self.port = self.socket.getsockname()[1]
+
+ def server_close(self):
+ super(TestUDPServer, self).server_close()
+ self._closed = True
+
+# - end of server_helper section
+
+@unittest.skipUnless(threading, 'Threading required for this test.')
+class SMTPHandlerTest(BaseTest):
+ def test_basic(self):
+ sockmap = {}
+ server = TestSMTPServer(('localhost', 0), self.process_message, 0.001,
+ sockmap)
+ server.start()
+ addr = ('localhost', server.port)
+ h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log', timeout=5.0)
+ self.assertEqual(h.toaddrs, ['you'])
+ self.messages = []
+ r = logging.makeLogRecord({'msg': 'Hello'})
+ self.handled = threading.Event()
+ h.handle(r)
+ self.handled.wait(5.0) # 14314: don't wait forever
+ server.stop()
+ self.assertTrue(self.handled.is_set())
+ self.assertEqual(len(self.messages), 1)
+ peer, mailfrom, rcpttos, data = self.messages[0]
+ self.assertEqual(mailfrom, 'me')
+ self.assertEqual(rcpttos, ['you'])
+ self.assertIn('\nSubject: Log\n', data)
+ self.assertTrue(data.endswith('\n\nHello'))
+ h.close()
+
+ def process_message(self, *args):
+ self.messages.append(args)
+ self.handled.set()
+
class MemoryHandlerTest(BaseTest):
"""Tests for the MemoryHandler."""
@@ -525,7 +1018,7 @@ class MemoryHandlerTest(BaseTest):
self.mem_logger.info(self.next_message())
self.assert_log_lines([])
# This will flush because the level is >= logging.WARNING
- self.mem_logger.warn(self.next_message())
+ self.mem_logger.warning(self.next_message())
lines = [
('DEBUG', '1'),
('INFO', '2'),
@@ -870,116 +1363,280 @@ class ConfigFileTest(BaseTest):
# Original logger output is empty.
self.assert_log_lines([])
-class LogRecordStreamHandler(StreamRequestHandler):
- """Handler for a streaming logging request. It saves the log message in the
- TCP server's 'log_output' attribute."""
+@unittest.skipUnless(threading, 'Threading required for this test.')
+class SocketHandlerTest(BaseTest):
+
+ """Test for SocketHandler objects."""
+
+ def setUp(self):
+ """Set up a TCP server to receive log messages, and a SocketHandler
+ pointing to that server's address and port."""
+ BaseTest.setUp(self)
+ addr = ('localhost', 0)
+ self.server = server = TestTCPServer(addr, self.handle_socket,
+ 0.01)
+ server.start()
+ server.ready.wait()
+ self.sock_hdlr = logging.handlers.SocketHandler('localhost',
+ server.port)
+ self.log_output = ''
+ self.root_logger.removeHandler(self.root_logger.handlers[0])
+ self.root_logger.addHandler(self.sock_hdlr)
+ self.handled = threading.Semaphore(0)
- TCP_LOG_END = "!!!END!!!"
+ def tearDown(self):
+ """Shutdown the TCP server."""
+ try:
+ self.server.stop(2.0)
+ self.root_logger.removeHandler(self.sock_hdlr)
+ self.sock_hdlr.close()
+ finally:
+ BaseTest.tearDown(self)
- def handle(self):
- """Handle multiple requests - each expected to be of 4-byte length,
- followed by the LogRecord in pickle format. Logs the record
- according to whatever policy is configured locally."""
+ def handle_socket(self, request):
+ conn = request.connection
while True:
- chunk = self.connection.recv(4)
+ chunk = conn.recv(4)
if len(chunk) < 4:
break
slen = struct.unpack(">L", chunk)[0]
- chunk = self.connection.recv(slen)
+ chunk = conn.recv(slen)
while len(chunk) < slen:
- chunk = chunk + self.connection.recv(slen - len(chunk))
- obj = self.unpickle(chunk)
+ chunk = chunk + conn.recv(slen - len(chunk))
+ obj = pickle.loads(chunk)
record = logging.makeLogRecord(obj)
- self.handle_log_record(record)
+ self.log_output += record.msg + '\n'
+ self.handled.release()
+
+ def test_output(self):
+ # The log message sent to the SocketHandler is properly received.
+ logger = logging.getLogger("tcp")
+ logger.error("spam")
+ self.handled.acquire()
+ logger.debug("eggs")
+ self.handled.acquire()
+ self.assertEqual(self.log_output, "spam\neggs\n")
- def unpickle(self, data):
- return pickle.loads(data)
+ def test_noserver(self):
+ # Kill the server
+ self.server.stop(2.0)
+ #The logging call should try to connect, which should fail
+ try:
+ raise RuntimeError('Deliberate mistake')
+ except RuntimeError:
+ self.root_logger.exception('Never sent')
+ self.root_logger.error('Never sent, either')
+ now = time.time()
+ self.assertTrue(self.sock_hdlr.retryTime > now)
+ time.sleep(self.sock_hdlr.retryTime - now + 0.001)
+ self.root_logger.error('Nor this')
- def handle_log_record(self, record):
- # If the end-of-messages sentinel is seen, tell the server to
- # terminate.
- if self.TCP_LOG_END in record.msg:
- self.server.abort = 1
- return
- self.server.log_output += record.msg + "\n"
+@unittest.skipUnless(threading, 'Threading required for this test.')
+class DatagramHandlerTest(BaseTest):
-class LogRecordSocketReceiver(ThreadingTCPServer):
+ """Test for DatagramHandler."""
- """A simple-minded TCP socket-based logging receiver suitable for test
- purposes."""
+ def setUp(self):
+ """Set up a UDP server to receive log messages, and a DatagramHandler
+ pointing to that server's address and port."""
+ BaseTest.setUp(self)
+ addr = ('localhost', 0)
+ self.server = server = TestUDPServer(addr, self.handle_datagram, 0.01)
+ server.start()
+ server.ready.wait()
+ self.sock_hdlr = logging.handlers.DatagramHandler('localhost',
+ server.port)
+ self.log_output = ''
+ self.root_logger.removeHandler(self.root_logger.handlers[0])
+ self.root_logger.addHandler(self.sock_hdlr)
+ self.handled = threading.Event()
- allow_reuse_address = 1
- log_output = ""
+ def tearDown(self):
+ """Shutdown the UDP server."""
+ try:
+ self.server.stop(2.0)
+ self.root_logger.removeHandler(self.sock_hdlr)
+ self.sock_hdlr.close()
+ finally:
+ BaseTest.tearDown(self)
- def __init__(self, host='localhost',
- port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
- handler=LogRecordStreamHandler):
- ThreadingTCPServer.__init__(self, (host, port), handler)
- self.abort = False
- self.timeout = 0.1
- self.finished = threading.Event()
+ def handle_datagram(self, request):
+ slen = struct.pack('>L', 0) # length of prefix
+ packet = request.packet[len(slen):]
+ obj = pickle.loads(packet)
+ record = logging.makeLogRecord(obj)
+ self.log_output += record.msg + '\n'
+ self.handled.set()
- def serve_until_stopped(self):
- while not self.abort:
- rd, wr, ex = select.select([self.socket.fileno()], [], [],
- self.timeout)
- if rd:
- self.handle_request()
- # Notify the main thread that we're about to exit
- self.finished.set()
- # close the listen socket
- self.server_close()
+ def test_output(self):
+ # The log message sent to the DatagramHandler is properly received.
+ logger = logging.getLogger("udp")
+ logger.error("spam")
+ self.handled.wait()
+ self.handled.clear()
+ logger.error("eggs")
+ self.handled.wait()
+ self.assertEqual(self.log_output, "spam\neggs\n")
@unittest.skipUnless(threading, 'Threading required for this test.')
-class SocketHandlerTest(BaseTest):
+class SysLogHandlerTest(BaseTest):
- """Test for SocketHandler objects."""
+ """Test for SysLogHandler using UDP."""
def setUp(self):
- """Set up a TCP server to receive log messages, and a SocketHandler
+ """Set up a UDP server to receive log messages, and a SysLogHandler
pointing to that server's address and port."""
BaseTest.setUp(self)
- self.tcpserver = LogRecordSocketReceiver(port=0)
- self.port = self.tcpserver.socket.getsockname()[1]
- self.threads = [
- threading.Thread(target=self.tcpserver.serve_until_stopped)]
- for thread in self.threads:
- thread.start()
-
- self.sock_hdlr = logging.handlers.SocketHandler('localhost', self.port)
- self.sock_hdlr.setFormatter(self.root_formatter)
+ addr = ('localhost', 0)
+ self.server = server = TestUDPServer(addr, self.handle_datagram,
+ 0.01)
+ server.start()
+ server.ready.wait()
+ self.sl_hdlr = logging.handlers.SysLogHandler(('localhost',
+ server.port))
+ self.log_output = ''
self.root_logger.removeHandler(self.root_logger.handlers[0])
- self.root_logger.addHandler(self.sock_hdlr)
+ self.root_logger.addHandler(self.sl_hdlr)
+ self.handled = threading.Event()
def tearDown(self):
- """Shutdown the TCP server."""
+ """Shutdown the UDP server."""
try:
- self.tcpserver.abort = True
- del self.tcpserver
- self.root_logger.removeHandler(self.sock_hdlr)
- self.sock_hdlr.close()
- for thread in self.threads:
- thread.join(2.0)
+ self.server.stop(2.0)
+ self.root_logger.removeHandler(self.sl_hdlr)
+ self.sl_hdlr.close()
finally:
BaseTest.tearDown(self)
- def get_output(self):
- """Get the log output as received by the TCP server."""
- # Signal the TCP receiver and wait for it to terminate.
- self.root_logger.critical(LogRecordStreamHandler.TCP_LOG_END)
- self.tcpserver.finished.wait(2.0)
- return self.tcpserver.log_output
+ def handle_datagram(self, request):
+ self.log_output = request.packet
+ self.handled.set()
def test_output(self):
- # The log message sent to the SocketHandler is properly received.
- logger = logging.getLogger("tcp")
- logger.error("spam")
- logger.debug("eggs")
- self.assertEqual(self.get_output(), "spam\neggs\n")
+ # The log message sent to the SysLogHandler is properly received.
+ logger = logging.getLogger("slh")
+ logger.error("sp\xe4m")
+ self.handled.wait()
+ self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m\x00')
+ self.handled.clear()
+ self.sl_hdlr.append_nul = False
+ logger.error("sp\xe4m")
+ self.handled.wait()
+ self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m')
+ self.handled.clear()
+ self.sl_hdlr.ident = "h\xe4m-"
+ logger.error("sp\xe4m")
+ self.handled.wait()
+ self.assertEqual(self.log_output, b'<11>h\xc3\xa4m-sp\xc3\xa4m')
+
+
+@unittest.skipUnless(threading, 'Threading required for this test.')
+class HTTPHandlerTest(BaseTest):
+ """Test for HTTPHandler."""
+
+ PEMFILE = """-----BEGIN RSA PRIVATE KEY-----
+MIICXQIBAAKBgQDGT4xS5r91rbLJQK2nUDenBhBG6qFk+bVOjuAGC/LSHlAoBnvG
+zQG3agOG+e7c5z2XT8m2ktORLqG3E4mYmbxgyhDrzP6ei2Anc+pszmnxPoK3Puh5
+aXV+XKt0bU0C1m2+ACmGGJ0t3P408art82nOxBw8ZHgIg9Dtp6xIUCyOqwIDAQAB
+AoGBAJFTnFboaKh5eUrIzjmNrKsG44jEyy+vWvHN/FgSC4l103HxhmWiuL5Lv3f7
+0tMp1tX7D6xvHwIG9VWvyKb/Cq9rJsDibmDVIOslnOWeQhG+XwJyitR0pq/KlJIB
+5LjORcBw795oKWOAi6RcOb1ON59tysEFYhAGQO9k6VL621gRAkEA/Gb+YXULLpbs
+piXN3q4zcHzeaVANo69tUZ6TjaQqMeTxE4tOYM0G0ZoSeHEdaP59AOZGKXXNGSQy
+2z/MddcYGQJBAMkjLSYIpOLJY11ja8OwwswFG2hEzHe0cS9bzo++R/jc1bHA5R0Y
+i6vA5iPi+wopPFvpytdBol7UuEBe5xZrxWMCQQCWxELRHiP2yWpEeLJ3gGDzoXMN
+PydWjhRju7Bx3AzkTtf+D6lawz1+eGTuEss5i0JKBkMEwvwnN2s1ce+EuF4JAkBb
+E96h1lAzkVW5OAfYOPY8RCPA90ZO/hoyg7PpSxR0ECuDrgERR8gXIeYUYfejBkEa
+rab4CfRoVJKKM28Yq/xZAkBvuq670JRCwOgfUTdww7WpdOQBYPkzQccsKNCslQW8
+/DyW6y06oQusSENUvynT6dr3LJxt/NgZPhZX2+k1eYDV
+-----END RSA PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICGzCCAYSgAwIBAgIJAIq84a2Q/OvlMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV
+BAMTCWxvY2FsaG9zdDAeFw0xMTA1MjExMDIzMzNaFw03NTAzMjEwMzU1MTdaMBQx
+EjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA
+xk+MUua/da2yyUCtp1A3pwYQRuqhZPm1To7gBgvy0h5QKAZ7xs0Bt2oDhvnu3Oc9
+l0/JtpLTkS6htxOJmJm8YMoQ68z+notgJ3PqbM5p8T6Ctz7oeWl1flyrdG1NAtZt
+vgAphhidLdz+NPGq7fNpzsQcPGR4CIPQ7aesSFAsjqsCAwEAAaN1MHMwHQYDVR0O
+BBYEFLWaUPO6N7efGiuoS9i3DVYcUwn0MEQGA1UdIwQ9MDuAFLWaUPO6N7efGiuo
+S9i3DVYcUwn0oRikFjAUMRIwEAYDVQQDEwlsb2NhbGhvc3SCCQCKvOGtkPzr5TAM
+BgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAMK5whPjLNQK1Ivvk88oqJqq
+4f889OwikGP0eUhOBhbFlsZs+jq5YZC2UzHz+evzKBlgAP1u4lP/cB85CnjvWqM+
+1c/lywFHQ6HOdDeQ1L72tSYMrNOG4XNmLn0h7rx6GoTU7dcFRfseahBCq8mv0IDt
+IRbTpvlHWPjsSvHz0ZOH
+-----END CERTIFICATE-----"""
+ def setUp(self):
+ """Set up an HTTP server to receive log messages, and a HTTPHandler
+ pointing to that server's address and port."""
+ BaseTest.setUp(self)
+ self.handled = threading.Event()
+
+ def handle_request(self, request):
+ self.command = request.command
+ self.log_data = urlparse(request.path)
+ if self.command == 'POST':
+ try:
+ rlen = int(request.headers['Content-Length'])
+ self.post_data = request.rfile.read(rlen)
+ except:
+ self.post_data = None
+ request.send_response(200)
+ request.end_headers()
+ self.handled.set()
+
+ def test_output(self):
+ # The log message sent to the HTTPHandler is properly received.
+ logger = logging.getLogger("http")
+ root_logger = self.root_logger
+ root_logger.removeHandler(self.root_logger.handlers[0])
+ for secure in (False, True):
+ addr = ('localhost', 0)
+ if secure:
+ try:
+ import ssl
+ fd, fn = tempfile.mkstemp()
+ os.close(fd)
+ with open(fn, 'w') as f:
+ f.write(self.PEMFILE)
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslctx.load_cert_chain(fn)
+ os.unlink(fn)
+ except ImportError:
+ sslctx = None
+ else:
+ sslctx = None
+ self.server = server = TestHTTPServer(addr, self.handle_request,
+ 0.01, sslctx=sslctx)
+ server.start()
+ server.ready.wait()
+ host = 'localhost:%d' % server.server_port
+ secure_client = secure and sslctx
+ self.h_hdlr = logging.handlers.HTTPHandler(host, '/frob',
+ secure=secure_client)
+ self.log_data = None
+ root_logger.addHandler(self.h_hdlr)
+
+ for method in ('GET', 'POST'):
+ self.h_hdlr.method = method
+ self.handled.clear()
+ msg = "sp\xe4m"
+ logger.error(msg)
+ self.handled.wait()
+ self.assertEqual(self.log_data.path, '/frob')
+ self.assertEqual(self.command, method)
+ if method == 'GET':
+ d = parse_qs(self.log_data.query)
+ else:
+ d = parse_qs(self.post_data.decode('utf-8'))
+ self.assertEqual(d['name'], ['http'])
+ self.assertEqual(d['funcName'], ['test_output'])
+ self.assertEqual(d['msg'], [msg])
+
+ self.server.stop(2.0)
+ self.root_logger.removeHandler(self.h_hdlr)
+ self.h_hdlr.close()
class MemoryTest(BaseTest):
@@ -1087,28 +1744,39 @@ class WarningsTest(BaseTest):
def test_warnings(self):
with warnings.catch_warnings():
logging.captureWarnings(True)
- try:
- warnings.filterwarnings("always", category=UserWarning)
- file = io.StringIO()
- h = logging.StreamHandler(file)
- logger = logging.getLogger("py.warnings")
- logger.addHandler(h)
- warnings.warn("I'm warning you...")
- logger.removeHandler(h)
- s = file.getvalue()
- h.close()
- self.assertTrue(s.find("UserWarning: I'm warning you...\n") > 0)
-
- #See if an explicit file uses the original implementation
- file = io.StringIO()
- warnings.showwarning("Explicit", UserWarning, "dummy.py", 42,
- file, "Dummy line")
- s = file.getvalue()
- file.close()
- self.assertEqual(s,
- "dummy.py:42: UserWarning: Explicit\n Dummy line\n")
- finally:
- logging.captureWarnings(False)
+ self.addCleanup(logging.captureWarnings, False)
+ warnings.filterwarnings("always", category=UserWarning)
+ stream = io.StringIO()
+ h = logging.StreamHandler(stream)
+ logger = logging.getLogger("py.warnings")
+ logger.addHandler(h)
+ warnings.warn("I'm warning you...")
+ logger.removeHandler(h)
+ s = stream.getvalue()
+ h.close()
+ self.assertTrue(s.find("UserWarning: I'm warning you...\n") > 0)
+
+ #See if an explicit file uses the original implementation
+ a_file = io.StringIO()
+ warnings.showwarning("Explicit", UserWarning, "dummy.py", 42,
+ a_file, "Dummy line")
+ s = a_file.getvalue()
+ a_file.close()
+ self.assertEqual(s,
+ "dummy.py:42: UserWarning: Explicit\n Dummy line\n")
+
+ def test_warnings_no_handlers(self):
+ with warnings.catch_warnings():
+ logging.captureWarnings(True)
+ self.addCleanup(logging.captureWarnings, False)
+
+ # confirm our assumption: no loggers are set
+ logger = logging.getLogger("py.warnings")
+ self.assertEqual(logger.handlers, [])
+
+ warnings.showwarning("Explicit", UserWarning, "dummy.py", 42)
+ self.assertEqual(len(logger.handlers), 1)
+ self.assertIsInstance(logger.handlers[0], logging.NullHandler)
def formatFunc(format, datefmt=None):
@@ -1961,6 +2629,7 @@ class ConfigDictTest(BaseTest):
logging.config.stopListening()
t.join(2.0)
+ @unittest.skipUnless(threading, 'Threading required for this test.')
def test_listen_config_10_ok(self):
with captured_stdout() as output:
self.setup_via_listener(json.dumps(self.config10))
@@ -1980,6 +2649,7 @@ class ConfigDictTest(BaseTest):
('ERROR', '4'),
], stream=output)
+ @unittest.skipUnless(threading, 'Threading required for this test.')
def test_listen_config_1_ok(self):
with captured_stdout() as output:
self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1))
@@ -1994,6 +2664,27 @@ class ConfigDictTest(BaseTest):
# Original logger output is empty.
self.assert_log_lines([])
+ def test_baseconfig(self):
+ d = {
+ 'atuple': (1, 2, 3),
+ 'alist': ['a', 'b', 'c'],
+ 'adict': {'d': 'e', 'f': 3 },
+ 'nest1': ('g', ('h', 'i'), 'j'),
+ 'nest2': ['k', ['l', 'm'], 'n'],
+ 'nest3': ['o', 'cfg://alist', 'p'],
+ }
+ bc = logging.config.BaseConfigurator(d)
+ self.assertEqual(bc.convert('cfg://atuple[1]'), 2)
+ self.assertEqual(bc.convert('cfg://alist[1]'), 'b')
+ self.assertEqual(bc.convert('cfg://nest1[1][0]'), 'h')
+ self.assertEqual(bc.convert('cfg://nest2[1][1]'), 'm')
+ self.assertEqual(bc.convert('cfg://adict.d'), 'e')
+ self.assertEqual(bc.convert('cfg://adict[f]'), 3)
+ v = bc.convert('cfg://nest3')
+ self.assertEqual(v.pop(1), ['a', 'b', 'c'])
+ self.assertRaises(KeyError, bc.convert, 'cfg://nosuch')
+ self.assertRaises(ValueError, bc.convert, 'cfg://!')
+ self.assertRaises(KeyError, bc.convert, 'cfg://adict[2]')
class ManagerTest(BaseTest):
def test_manager_loggerclass(self):
@@ -2012,6 +2703,11 @@ class ManagerTest(BaseTest):
self.assertEqual(logged, ['should appear in logged'])
+ def test_set_log_record_factory(self):
+ man = logging.Manager(None)
+ expected = object()
+ man.setLogRecordFactory(expected)
+ self.assertEqual(man.logRecordFactory, expected)
class ChildLoggerTest(BaseTest):
def test_child_loggers(self):
@@ -2113,6 +2809,18 @@ class QueueHandlerTest(BaseTest):
self.assertTrue(handler.matches(levelno=logging.ERROR, message='2'))
self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='3'))
+ZERO = datetime.timedelta(0)
+
+class UTC(datetime.tzinfo):
+ def utcoffset(self, dt):
+ return ZERO
+
+ dst = utcoffset
+
+ def tzname(self, dt):
+ return 'UTC'
+
+utc = UTC()
class FormatterTest(unittest.TestCase):
def setUp(self):
@@ -2186,6 +2894,71 @@ class FormatterTest(unittest.TestCase):
f = logging.Formatter('asctime', style='$')
self.assertFalse(f.usesTime())
+ def test_invalid_style(self):
+ self.assertRaises(ValueError, logging.Formatter, None, None, 'x')
+
+ def test_time(self):
+ r = self.get_record()
+ dt = datetime.datetime(1993, 4, 21, 8, 3, 0, 0, utc)
+ # We use None to indicate we want the local timezone
+ # We're essentially converting a UTC time to local time
+ r.created = time.mktime(dt.astimezone(None).timetuple())
+ r.msecs = 123
+ f = logging.Formatter('%(asctime)s %(message)s')
+ f.converter = time.gmtime
+ self.assertEqual(f.formatTime(r), '1993-04-21 08:03:00,123')
+ self.assertEqual(f.formatTime(r, '%Y:%d'), '1993:21')
+ f.format(r)
+ self.assertEqual(r.asctime, '1993-04-21 08:03:00,123')
+
+class TestBufferingFormatter(logging.BufferingFormatter):
+ def formatHeader(self, records):
+ return '[(%d)' % len(records)
+
+ def formatFooter(self, records):
+ return '(%d)]' % len(records)
+
+class BufferingFormatterTest(unittest.TestCase):
+ def setUp(self):
+ self.records = [
+ logging.makeLogRecord({'msg': 'one'}),
+ logging.makeLogRecord({'msg': 'two'}),
+ ]
+
+ def test_default(self):
+ f = logging.BufferingFormatter()
+ self.assertEqual('', f.format([]))
+ self.assertEqual('onetwo', f.format(self.records))
+
+ def test_custom(self):
+ f = TestBufferingFormatter()
+ self.assertEqual('[(2)onetwo(2)]', f.format(self.records))
+ lf = logging.Formatter('<%(message)s>')
+ f = TestBufferingFormatter(lf)
+ self.assertEqual('[(2)<one><two>(2)]', f.format(self.records))
+
+class ExceptionTest(BaseTest):
+ def test_formatting(self):
+ r = self.root_logger
+ h = RecordingHandler()
+ r.addHandler(h)
+ try:
+ raise RuntimeError('deliberate mistake')
+ except:
+ logging.exception('failed', stack_info=True)
+ r.removeHandler(h)
+ h.close()
+ r = h.records[0]
+ self.assertTrue(r.exc_text.startswith('Traceback (most recent '
+ 'call last):\n'))
+ self.assertTrue(r.exc_text.endswith('\nRuntimeError: '
+ 'deliberate mistake'))
+ self.assertTrue(r.stack_info.startswith('Stack (most recent '
+ 'call last):\n'))
+ self.assertTrue(r.stack_info.endswith('logging.exception(\'failed\', '
+ 'stack_info=True)'))
+
+
class LastResortTest(BaseTest):
def test_last_resort(self):
# Test the last resort handler
@@ -2196,6 +2969,8 @@ class LastResortTest(BaseTest):
old_raise_exceptions = logging.raiseExceptions
try:
sys.stderr = sio = io.StringIO()
+ root.debug('This should not appear')
+ self.assertEqual(sio.getvalue(), '')
root.warning('This is your final chance!')
self.assertEqual(sio.getvalue(), 'This is your final chance!\n')
#No handlers and no last resort, so 'No handlers' message
@@ -2220,152 +2995,236 @@ class LastResortTest(BaseTest):
logging.raiseExceptions = old_raise_exceptions
-class BaseFileTest(BaseTest):
- "Base class for handler tests that write log files"
+class FakeHandler:
+
+ def __init__(self, identifier, called):
+ for method in ('acquire', 'flush', 'close', 'release'):
+ setattr(self, method, self.record_call(identifier, method, called))
+
+ def record_call(self, identifier, method_name, called):
+ def inner():
+ called.append('{} - {}'.format(identifier, method_name))
+ return inner
+
+
+class RecordingHandler(logging.NullHandler):
+
+ def __init__(self, *args, **kwargs):
+ super(RecordingHandler, self).__init__(*args, **kwargs)
+ self.records = []
+
+ def handle(self, record):
+ """Keep track of all the emitted records."""
+ self.records.append(record)
+
+
+class ShutdownTest(BaseTest):
+
+ """Test suite for the shutdown method."""
def setUp(self):
- BaseTest.setUp(self)
- fd, self.fn = tempfile.mkstemp(".log", "test_logging-2-")
- os.close(fd)
- self.rmfiles = []
+ super(ShutdownTest, self).setUp()
+ self.called = []
- def tearDown(self):
- for fn in self.rmfiles:
- os.unlink(fn)
- if os.path.exists(self.fn):
- os.unlink(self.fn)
- BaseTest.tearDown(self)
+ raise_exceptions = logging.raiseExceptions
+ self.addCleanup(setattr, logging, 'raiseExceptions', raise_exceptions)
- def assertLogFile(self, filename):
- "Assert a log file is there and register it for deletion"
- self.assertTrue(os.path.exists(filename),
- msg="Log file %r does not exist")
- self.rmfiles.append(filename)
+ def raise_error(self, error):
+ def inner():
+ raise error()
+ return inner
+ def test_no_failure(self):
+ # create some fake handlers
+ handler0 = FakeHandler(0, self.called)
+ handler1 = FakeHandler(1, self.called)
+ handler2 = FakeHandler(2, self.called)
-class RotatingFileHandlerTest(BaseFileTest):
- def next_rec(self):
- return logging.LogRecord('n', logging.DEBUG, 'p', 1,
- self.next_message(), None, None, None)
+ # create live weakref to those handlers
+ handlers = map(logging.weakref.ref, [handler0, handler1, handler2])
- def test_should_not_rollover(self):
- # If maxbytes is zero rollover never occurs
- rh = logging.handlers.RotatingFileHandler(self.fn, maxBytes=0)
- self.assertFalse(rh.shouldRollover(None))
- rh.close()
+ logging.shutdown(handlerList=list(handlers))
- def test_should_rollover(self):
- rh = logging.handlers.RotatingFileHandler(self.fn, maxBytes=1)
- self.assertTrue(rh.shouldRollover(self.next_rec()))
- rh.close()
+ expected = ['2 - acquire', '2 - flush', '2 - close', '2 - release',
+ '1 - acquire', '1 - flush', '1 - close', '1 - release',
+ '0 - acquire', '0 - flush', '0 - close', '0 - release']
+ self.assertEqual(expected, self.called)
- def test_file_created(self):
- # checks that the file is created and assumes it was created
- # by us
- rh = logging.handlers.RotatingFileHandler(self.fn)
- rh.emit(self.next_rec())
- self.assertLogFile(self.fn)
- rh.close()
+ def _test_with_failure_in_method(self, method, error):
+ handler = FakeHandler(0, self.called)
+ setattr(handler, method, self.raise_error(error))
+ handlers = [logging.weakref.ref(handler)]
- def test_rollover_filenames(self):
- rh = logging.handlers.RotatingFileHandler(
- self.fn, backupCount=2, maxBytes=1)
- rh.emit(self.next_rec())
- self.assertLogFile(self.fn)
- rh.emit(self.next_rec())
- self.assertLogFile(self.fn + ".1")
- rh.emit(self.next_rec())
- self.assertLogFile(self.fn + ".2")
- self.assertFalse(os.path.exists(self.fn + ".3"))
- rh.close()
+ logging.shutdown(handlerList=list(handlers))
-class TimedRotatingFileHandlerTest(BaseFileTest):
- # test methods added below
- pass
+ self.assertEqual('0 - release', self.called[-1])
-def secs(**kw):
- return datetime.timedelta(**kw) // datetime.timedelta(seconds=1)
+ def test_with_ioerror_in_acquire(self):
+ self._test_with_failure_in_method('acquire', IOError)
-for when, exp in (('S', 1),
- ('M', 60),
- ('H', 60 * 60),
- ('D', 60 * 60 * 24),
- ('MIDNIGHT', 60 * 60 * 24),
- # current time (epoch start) is a Thursday, W0 means Monday
- ('W0', secs(days=4, hours=24)),
- ):
- def test_compute_rollover(self, when=when, exp=exp):
- rh = logging.handlers.TimedRotatingFileHandler(
- self.fn, when=when, interval=1, backupCount=0, utc=True)
- currentTime = 0.0
- actual = rh.computeRollover(currentTime)
- if exp != actual:
- # Failures occur on some systems for MIDNIGHT and W0.
- # Print detailed calculation for MIDNIGHT so we can try to see
- # what's going on
- if when == 'MIDNIGHT':
- try:
- if rh.utc:
- t = time.gmtime(currentTime)
- else:
- t = time.localtime(currentTime)
- currentHour = t[3]
- currentMinute = t[4]
- currentSecond = t[5]
- # r is the number of seconds left between now and midnight
- r = logging.handlers._MIDNIGHT - ((currentHour * 60 +
- currentMinute) * 60 +
- currentSecond)
- result = currentTime + r
- print('t: %s (%s)' % (t, rh.utc), file=sys.stderr)
- print('currentHour: %s' % currentHour, file=sys.stderr)
- print('currentMinute: %s' % currentMinute, file=sys.stderr)
- print('currentSecond: %s' % currentSecond, file=sys.stderr)
- print('r: %s' % r, file=sys.stderr)
- print('result: %s' % result, file=sys.stderr)
- except Exception:
- print('exception in diagnostic code: %s' % sys.exc_info()[1], file=sys.stderr)
- self.assertEqual(exp, actual)
- rh.close()
- setattr(TimedRotatingFileHandlerTest, "test_compute_rollover_%s" % when, test_compute_rollover)
+ def test_with_ioerror_in_flush(self):
+ self._test_with_failure_in_method('flush', IOError)
-class HandlerTest(BaseTest):
+ def test_with_ioerror_in_close(self):
+ self._test_with_failure_in_method('close', IOError)
- @unittest.skipIf(os.name == 'nt', 'WatchedFileHandler not appropriate for Windows.')
- @unittest.skipUnless(threading, 'Threading required for this test.')
- def test_race(self):
- # Issue #14632 refers.
- def remove_loop(fname, tries):
- for _ in range(tries):
- try:
- os.unlink(fname)
- except OSError:
- pass
- time.sleep(0.004 * random.randint(0, 4))
+ def test_with_valueerror_in_acquire(self):
+ self._test_with_failure_in_method('acquire', ValueError)
- del_count = 500
- log_count = 500
+ def test_with_valueerror_in_flush(self):
+ self._test_with_failure_in_method('flush', ValueError)
- for delay in (False, True):
- fd, fn = tempfile.mkstemp('.log', 'test_logging-3-')
- os.close(fd)
- remover = threading.Thread(target=remove_loop, args=(fn, del_count))
- remover.daemon = True
- remover.start()
- h = logging.handlers.WatchedFileHandler(fn, delay=delay)
- f = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s')
- h.setFormatter(f)
- try:
- for _ in range(log_count):
- time.sleep(0.005)
- r = logging.makeLogRecord({'msg': 'testing' })
- h.handle(r)
- finally:
- remover.join()
- h.close()
- if os.path.exists(fn):
- os.unlink(fn)
+ def test_with_valueerror_in_close(self):
+ self._test_with_failure_in_method('close', ValueError)
+
+ def test_with_other_error_in_acquire_without_raise(self):
+ logging.raiseExceptions = False
+ self._test_with_failure_in_method('acquire', IndexError)
+
+ def test_with_other_error_in_flush_without_raise(self):
+ logging.raiseExceptions = False
+ self._test_with_failure_in_method('flush', IndexError)
+
+ def test_with_other_error_in_close_without_raise(self):
+ logging.raiseExceptions = False
+ self._test_with_failure_in_method('close', IndexError)
+
+ def test_with_other_error_in_acquire_with_raise(self):
+ logging.raiseExceptions = True
+ self.assertRaises(IndexError, self._test_with_failure_in_method,
+ 'acquire', IndexError)
+ def test_with_other_error_in_flush_with_raise(self):
+ logging.raiseExceptions = True
+ self.assertRaises(IndexError, self._test_with_failure_in_method,
+ 'flush', IndexError)
+
+ def test_with_other_error_in_close_with_raise(self):
+ logging.raiseExceptions = True
+ self.assertRaises(IndexError, self._test_with_failure_in_method,
+ 'close', IndexError)
+
+
+class ModuleLevelMiscTest(BaseTest):
+
+ """Test suite for some module level methods."""
+
+ def test_disable(self):
+ old_disable = logging.root.manager.disable
+ # confirm our assumptions are correct
+ self.assertEqual(old_disable, 0)
+ self.addCleanup(logging.disable, old_disable)
+
+ logging.disable(83)
+ self.assertEqual(logging.root.manager.disable, 83)
+
+ def _test_log(self, method, level=None):
+ called = []
+ patch(self, logging, 'basicConfig',
+ lambda *a, **kw: called.append((a, kw)))
+
+ recording = RecordingHandler()
+ logging.root.addHandler(recording)
+
+ log_method = getattr(logging, method)
+ if level is not None:
+ log_method(level, "test me: %r", recording)
+ else:
+ log_method("test me: %r", recording)
+
+ self.assertEqual(len(recording.records), 1)
+ record = recording.records[0]
+ self.assertEqual(record.getMessage(), "test me: %r" % recording)
+
+ expected_level = level if level is not None else getattr(logging, method.upper())
+ self.assertEqual(record.levelno, expected_level)
+
+ # basicConfig was not called!
+ self.assertEqual(called, [])
+
+ def test_log(self):
+ self._test_log('log', logging.ERROR)
+
+ def test_debug(self):
+ self._test_log('debug')
+
+ def test_info(self):
+ self._test_log('info')
+
+ def test_warning(self):
+ self._test_log('warning')
+
+ def test_error(self):
+ self._test_log('error')
+
+ def test_critical(self):
+ self._test_log('critical')
+
+ def test_set_logger_class(self):
+ self.assertRaises(TypeError, logging.setLoggerClass, object)
+
+ class MyLogger(logging.Logger):
+ pass
+
+ logging.setLoggerClass(MyLogger)
+ self.assertEqual(logging.getLoggerClass(), MyLogger)
+
+ logging.setLoggerClass(logging.Logger)
+ self.assertEqual(logging.getLoggerClass(), logging.Logger)
+
+class LogRecordTest(BaseTest):
+ def test_str_rep(self):
+ r = logging.makeLogRecord({})
+ s = str(r)
+ self.assertTrue(s.startswith('<LogRecord: '))
+ self.assertTrue(s.endswith('>'))
+
+ def test_dict_arg(self):
+ h = RecordingHandler()
+ r = logging.getLogger()
+ r.addHandler(h)
+ d = {'less' : 'more' }
+ logging.warning('less is %(less)s', d)
+ self.assertIs(h.records[0].args, d)
+ self.assertEqual(h.records[0].message, 'less is more')
+ r.removeHandler(h)
+ h.close()
+
+ def test_multiprocessing(self):
+ r = logging.makeLogRecord({})
+ self.assertEqual(r.processName, 'MainProcess')
+ try:
+ import multiprocessing as mp
+ r = logging.makeLogRecord({})
+ self.assertEqual(r.processName, mp.current_process().name)
+ except ImportError:
+ pass
+
+ def test_optional(self):
+ r = logging.makeLogRecord({})
+ NOT_NONE = self.assertIsNotNone
+ if threading:
+ NOT_NONE(r.thread)
+ NOT_NONE(r.threadName)
+ NOT_NONE(r.process)
+ NOT_NONE(r.processName)
+ log_threads = logging.logThreads
+ log_processes = logging.logProcesses
+ log_multiprocessing = logging.logMultiprocessing
+ try:
+ logging.logThreads = False
+ logging.logProcesses = False
+ logging.logMultiprocessing = False
+ r = logging.makeLogRecord({})
+ NONE = self.assertIsNone
+ NONE(r.thread)
+ NONE(r.threadName)
+ NONE(r.process)
+ NONE(r.processName)
+ finally:
+ logging.logThreads = log_threads
+ logging.logProcesses = log_processes
+ logging.logMultiprocessing = log_multiprocessing
class BasicConfigTest(unittest.TestCase):
@@ -2471,6 +3330,17 @@ class BasicConfigTest(unittest.TestCase):
logging.basicConfig(level=58)
self.assertEqual(logging.root.level, 57)
+ def test_incompatible(self):
+ assertRaises = self.assertRaises
+ handlers = [logging.StreamHandler()]
+ stream = sys.stderr
+ assertRaises(ValueError, logging.basicConfig, filename='test.log',
+ stream=stream)
+ assertRaises(ValueError, logging.basicConfig, filename='test.log',
+ handlers=handlers)
+ assertRaises(ValueError, logging.basicConfig, stream=stream,
+ handlers=handlers)
+
def test_handlers(self):
handlers = [
logging.StreamHandler(),
@@ -2479,8 +3349,14 @@ class BasicConfigTest(unittest.TestCase):
]
f = logging.Formatter()
handlers[2].setFormatter(f)
- self.assertRaises(ValueError, logging.basicConfig, level=logging.DEBUG,
- format='%(asctime)s %(message)s', handlers=handlers)
+ logging.basicConfig(handlers=handlers)
+ self.assertIs(handlers[0], logging.root.handlers[0])
+ self.assertIs(handlers[1], logging.root.handlers[1])
+ self.assertIs(handlers[2], logging.root.handlers[2])
+ self.assertIsNotNone(handlers[0].formatter)
+ self.assertIsNotNone(handlers[1].formatter)
+ self.assertIs(handlers[2].formatter, f)
+ self.assertIs(handlers[0].formatter, handlers[1].formatter)
def _test_log(self, method, level=None):
# logging.root has no handlers so basicConfig should be called
@@ -2523,20 +3399,445 @@ class BasicConfigTest(unittest.TestCase):
def test_critical(self):
self._test_log('critical')
+
+class LoggerAdapterTest(unittest.TestCase):
+
+ def setUp(self):
+ super(LoggerAdapterTest, self).setUp()
+ old_handler_list = logging._handlerList[:]
+
+ self.recording = RecordingHandler()
+ self.logger = logging.root
+ self.logger.addHandler(self.recording)
+ self.addCleanup(self.logger.removeHandler, self.recording)
+ self.addCleanup(self.recording.close)
+
+ def cleanup():
+ logging._handlerList[:] = old_handler_list
+
+ self.addCleanup(cleanup)
+ self.addCleanup(logging.shutdown)
+ self.adapter = logging.LoggerAdapter(logger=self.logger, extra=None)
+
+ def test_exception(self):
+ msg = 'testing exception: %r'
+ exc = None
+ try:
+ 1 / 0
+ except ZeroDivisionError as e:
+ exc = e
+ self.adapter.exception(msg, self.recording)
+
+ self.assertEqual(len(self.recording.records), 1)
+ record = self.recording.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+ self.assertEqual(record.msg, msg)
+ self.assertEqual(record.args, (self.recording,))
+ self.assertEqual(record.exc_info,
+ (exc.__class__, exc, exc.__traceback__))
+
+ def test_critical(self):
+ msg = 'critical test! %r'
+ self.adapter.critical(msg, self.recording)
+
+ self.assertEqual(len(self.recording.records), 1)
+ record = self.recording.records[0]
+ self.assertEqual(record.levelno, logging.CRITICAL)
+ self.assertEqual(record.msg, msg)
+ self.assertEqual(record.args, (self.recording,))
+
+ def test_is_enabled_for(self):
+ old_disable = self.adapter.logger.manager.disable
+ self.adapter.logger.manager.disable = 33
+ self.addCleanup(setattr, self.adapter.logger.manager, 'disable',
+ old_disable)
+ self.assertFalse(self.adapter.isEnabledFor(32))
+
+ def test_has_handlers(self):
+ self.assertTrue(self.adapter.hasHandlers())
+
+ for handler in self.logger.handlers:
+ self.logger.removeHandler(handler)
+
+ self.assertFalse(self.logger.hasHandlers())
+ self.assertFalse(self.adapter.hasHandlers())
+
+
+class LoggerTest(BaseTest):
+
+ def setUp(self):
+ super(LoggerTest, self).setUp()
+ self.recording = RecordingHandler()
+ self.logger = logging.Logger(name='blah')
+ self.logger.addHandler(self.recording)
+ self.addCleanup(self.logger.removeHandler, self.recording)
+ self.addCleanup(self.recording.close)
+ self.addCleanup(logging.shutdown)
+
+ def test_set_invalid_level(self):
+ self.assertRaises(TypeError, self.logger.setLevel, object())
+
+ def test_exception(self):
+ msg = 'testing exception: %r'
+ exc = None
+ try:
+ 1 / 0
+ except ZeroDivisionError as e:
+ exc = e
+ self.logger.exception(msg, self.recording)
+
+ self.assertEqual(len(self.recording.records), 1)
+ record = self.recording.records[0]
+ self.assertEqual(record.levelno, logging.ERROR)
+ self.assertEqual(record.msg, msg)
+ self.assertEqual(record.args, (self.recording,))
+ self.assertEqual(record.exc_info,
+ (exc.__class__, exc, exc.__traceback__))
+
+ def test_log_invalid_level_with_raise(self):
+ old_raise = logging.raiseExceptions
+ self.addCleanup(setattr, logging, 'raiseExecptions', old_raise)
+
+ logging.raiseExceptions = True
+ self.assertRaises(TypeError, self.logger.log, '10', 'test message')
+
+ def test_log_invalid_level_no_raise(self):
+ old_raise = logging.raiseExceptions
+ self.addCleanup(setattr, logging, 'raiseExecptions', old_raise)
+
+ logging.raiseExceptions = False
+ self.logger.log('10', 'test message') # no exception happens
+
+ def test_find_caller_with_stack_info(self):
+ called = []
+ patch(self, logging.traceback, 'print_stack',
+ lambda f, file: called.append(file.getvalue()))
+
+ self.logger.findCaller(stack_info=True)
+
+ self.assertEqual(len(called), 1)
+ self.assertEqual('Stack (most recent call last):\n', called[0])
+
+ def test_make_record_with_extra_overwrite(self):
+ name = 'my record'
+ level = 13
+ fn = lno = msg = args = exc_info = func = sinfo = None
+ rv = logging._logRecordFactory(name, level, fn, lno, msg, args,
+ exc_info, func, sinfo)
+
+ for key in ('message', 'asctime') + tuple(rv.__dict__.keys()):
+ extra = {key: 'some value'}
+ self.assertRaises(KeyError, self.logger.makeRecord, name, level,
+ fn, lno, msg, args, exc_info,
+ extra=extra, sinfo=sinfo)
+
+ def test_make_record_with_extra_no_overwrite(self):
+ name = 'my record'
+ level = 13
+ fn = lno = msg = args = exc_info = func = sinfo = None
+ extra = {'valid_key': 'some value'}
+ result = self.logger.makeRecord(name, level, fn, lno, msg, args,
+ exc_info, extra=extra, sinfo=sinfo)
+ self.assertIn('valid_key', result.__dict__)
+
+ def test_has_handlers(self):
+ self.assertTrue(self.logger.hasHandlers())
+
+ for handler in self.logger.handlers:
+ self.logger.removeHandler(handler)
+ self.assertFalse(self.logger.hasHandlers())
+
+ def test_has_handlers_no_propagate(self):
+ child_logger = logging.getLogger('blah.child')
+ child_logger.propagate = False
+ self.assertFalse(child_logger.hasHandlers())
+
+ def test_is_enabled_for(self):
+ old_disable = self.logger.manager.disable
+ self.logger.manager.disable = 23
+ self.addCleanup(setattr, self.logger.manager, 'disable', old_disable)
+ self.assertFalse(self.logger.isEnabledFor(22))
+
+ def test_root_logger_aliases(self):
+ root = logging.getLogger()
+ self.assertIs(root, logging.root)
+ self.assertIs(root, logging.getLogger(None))
+ self.assertIs(root, logging.getLogger(''))
+ self.assertIs(root, logging.getLogger('foo').root)
+ self.assertIs(root, logging.getLogger('foo.bar').root)
+ self.assertIs(root, logging.getLogger('foo').parent)
+
+ self.assertIsNot(root, logging.getLogger('\0'))
+ self.assertIsNot(root, logging.getLogger('foo.bar').parent)
+
+ def test_invalid_names(self):
+ self.assertRaises(TypeError, logging.getLogger, any)
+ self.assertRaises(TypeError, logging.getLogger, b'foo')
+
+
+class BaseFileTest(BaseTest):
+ "Base class for handler tests that write log files"
+
+ def setUp(self):
+ BaseTest.setUp(self)
+ fd, self.fn = tempfile.mkstemp(".log", "test_logging-2-")
+ os.close(fd)
+ self.rmfiles = []
+
+ def tearDown(self):
+ for fn in self.rmfiles:
+ os.unlink(fn)
+ if os.path.exists(self.fn):
+ os.unlink(self.fn)
+ BaseTest.tearDown(self)
+
+ def assertLogFile(self, filename):
+ "Assert a log file is there and register it for deletion"
+ self.assertTrue(os.path.exists(filename),
+ msg="Log file %r does not exist" % filename)
+ self.rmfiles.append(filename)
+
+
+class FileHandlerTest(BaseFileTest):
+ def test_delay(self):
+ os.unlink(self.fn)
+ fh = logging.FileHandler(self.fn, delay=True)
+ self.assertIsNone(fh.stream)
+ self.assertFalse(os.path.exists(self.fn))
+ fh.handle(logging.makeLogRecord({}))
+ self.assertIsNotNone(fh.stream)
+ self.assertTrue(os.path.exists(self.fn))
+ fh.close()
+
+class RotatingFileHandlerTest(BaseFileTest):
+ def next_rec(self):
+ return logging.LogRecord('n', logging.DEBUG, 'p', 1,
+ self.next_message(), None, None, None)
+
+ def test_should_not_rollover(self):
+ # If maxbytes is zero rollover never occurs
+ rh = logging.handlers.RotatingFileHandler(self.fn, maxBytes=0)
+ self.assertFalse(rh.shouldRollover(None))
+ rh.close()
+
+ def test_should_rollover(self):
+ rh = logging.handlers.RotatingFileHandler(self.fn, maxBytes=1)
+ self.assertTrue(rh.shouldRollover(self.next_rec()))
+ rh.close()
+
+ def test_file_created(self):
+ # checks that the file is created and assumes it was created
+ # by us
+ rh = logging.handlers.RotatingFileHandler(self.fn)
+ rh.emit(self.next_rec())
+ self.assertLogFile(self.fn)
+ rh.close()
+
+ def test_rollover_filenames(self):
+ def namer(name):
+ return name + ".test"
+ rh = logging.handlers.RotatingFileHandler(
+ self.fn, backupCount=2, maxBytes=1)
+ rh.namer = namer
+ rh.emit(self.next_rec())
+ self.assertLogFile(self.fn)
+ rh.emit(self.next_rec())
+ self.assertLogFile(namer(self.fn + ".1"))
+ rh.emit(self.next_rec())
+ self.assertLogFile(namer(self.fn + ".2"))
+ self.assertFalse(os.path.exists(namer(self.fn + ".3")))
+ rh.close()
+
+ @requires_zlib
+ def test_rotator(self):
+ def namer(name):
+ return name + ".gz"
+
+ def rotator(source, dest):
+ with open(source, "rb") as sf:
+ data = sf.read()
+ compressed = zlib.compress(data, 9)
+ with open(dest, "wb") as df:
+ df.write(compressed)
+ os.remove(source)
+
+ rh = logging.handlers.RotatingFileHandler(
+ self.fn, backupCount=2, maxBytes=1)
+ rh.rotator = rotator
+ rh.namer = namer
+ m1 = self.next_rec()
+ rh.emit(m1)
+ self.assertLogFile(self.fn)
+ m2 = self.next_rec()
+ rh.emit(m2)
+ fn = namer(self.fn + ".1")
+ self.assertLogFile(fn)
+ newline = os.linesep
+ with open(fn, "rb") as f:
+ compressed = f.read()
+ data = zlib.decompress(compressed)
+ self.assertEqual(data.decode("ascii"), m1.msg + newline)
+ rh.emit(self.next_rec())
+ fn = namer(self.fn + ".2")
+ self.assertLogFile(fn)
+ with open(fn, "rb") as f:
+ compressed = f.read()
+ data = zlib.decompress(compressed)
+ self.assertEqual(data.decode("ascii"), m1.msg + newline)
+ rh.emit(self.next_rec())
+ fn = namer(self.fn + ".2")
+ with open(fn, "rb") as f:
+ compressed = f.read()
+ data = zlib.decompress(compressed)
+ self.assertEqual(data.decode("ascii"), m2.msg + newline)
+ self.assertFalse(os.path.exists(namer(self.fn + ".3")))
+ rh.close()
+
+class TimedRotatingFileHandlerTest(BaseFileTest):
+ # other test methods added below
+ def test_rollover(self):
+ fh = logging.handlers.TimedRotatingFileHandler(self.fn, 'S',
+ backupCount=1)
+ fmt = logging.Formatter('%(asctime)s %(message)s')
+ fh.setFormatter(fmt)
+ r1 = logging.makeLogRecord({'msg': 'testing - initial'})
+ fh.emit(r1)
+ self.assertLogFile(self.fn)
+ time.sleep(1.1) # a little over a second ...
+ r2 = logging.makeLogRecord({'msg': 'testing - after delay'})
+ fh.emit(r2)
+ fh.close()
+ # At this point, we should have a recent rotated file which we
+ # can test for the existence of. However, in practice, on some
+ # machines which run really slowly, we don't know how far back
+ # in time to go to look for the log file. So, we go back a fair
+ # bit, and stop as soon as we see a rotated file. In theory this
+ # could of course still fail, but the chances are lower.
+ found = False
+ now = datetime.datetime.now()
+ GO_BACK = 5 * 60 # seconds
+ for secs in range(GO_BACK):
+ prev = now - datetime.timedelta(seconds=secs)
+ fn = self.fn + prev.strftime(".%Y-%m-%d_%H-%M-%S")
+ found = os.path.exists(fn)
+ if found:
+ self.rmfiles.append(fn)
+ break
+ msg = 'No rotated files found, went back %d seconds' % GO_BACK
+ if not found:
+ #print additional diagnostics
+ dn, fn = os.path.split(self.fn)
+ files = [f for f in os.listdir(dn) if f.startswith(fn)]
+ print('Test time: %s' % now.strftime("%Y-%m-%d %H-%M-%S"), file=sys.stderr)
+ print('The only matching files are: %s' % files, file=sys.stderr)
+ for f in files:
+ print('Contents of %s:' % f)
+ path = os.path.join(dn, f)
+ with open(path, 'r') as tf:
+ print(tf.read())
+ self.assertTrue(found, msg=msg)
+
+ def test_invalid(self):
+ assertRaises = self.assertRaises
+ assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler,
+ self.fn, 'X', delay=True)
+ assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler,
+ self.fn, 'W', delay=True)
+ assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler,
+ self.fn, 'W7', delay=True)
+
+def secs(**kw):
+ return datetime.timedelta(**kw) // datetime.timedelta(seconds=1)
+
+for when, exp in (('S', 1),
+ ('M', 60),
+ ('H', 60 * 60),
+ ('D', 60 * 60 * 24),
+ ('MIDNIGHT', 60 * 60 * 24),
+ # current time (epoch start) is a Thursday, W0 means Monday
+ ('W0', secs(days=4, hours=24)),
+ ):
+ def test_compute_rollover(self, when=when, exp=exp):
+ rh = logging.handlers.TimedRotatingFileHandler(
+ self.fn, when=when, interval=1, backupCount=0, utc=True)
+ currentTime = 0.0
+ actual = rh.computeRollover(currentTime)
+ if exp != actual:
+ # Failures occur on some systems for MIDNIGHT and W0.
+ # Print detailed calculation for MIDNIGHT so we can try to see
+ # what's going on
+ if when == 'MIDNIGHT':
+ try:
+ if rh.utc:
+ t = time.gmtime(currentTime)
+ else:
+ t = time.localtime(currentTime)
+ currentHour = t[3]
+ currentMinute = t[4]
+ currentSecond = t[5]
+ # r is the number of seconds left between now and midnight
+ r = logging.handlers._MIDNIGHT - ((currentHour * 60 +
+ currentMinute) * 60 +
+ currentSecond)
+ result = currentTime + r
+ print('t: %s (%s)' % (t, rh.utc), file=sys.stderr)
+ print('currentHour: %s' % currentHour, file=sys.stderr)
+ print('currentMinute: %s' % currentMinute, file=sys.stderr)
+ print('currentSecond: %s' % currentSecond, file=sys.stderr)
+ print('r: %s' % r, file=sys.stderr)
+ print('result: %s' % result, file=sys.stderr)
+ except Exception:
+ print('exception in diagnostic code: %s' % sys.exc_info()[1], file=sys.stderr)
+ self.assertEqual(exp, actual)
+ rh.close()
+ setattr(TimedRotatingFileHandlerTest, "test_compute_rollover_%s" % when, test_compute_rollover)
+
+
+@unittest.skipUnless(win32evtlog, 'win32evtlog/win32evtlogutil required for this test.')
+class NTEventLogHandlerTest(BaseTest):
+ def test_basic(self):
+ logtype = 'Application'
+ elh = win32evtlog.OpenEventLog(None, logtype)
+ num_recs = win32evtlog.GetNumberOfEventLogRecords(elh)
+ h = logging.handlers.NTEventLogHandler('test_logging')
+ r = logging.makeLogRecord({'msg': 'Test Log Message'})
+ h.handle(r)
+ h.close()
+ # Now see if the event is recorded
+ self.assertTrue(num_recs < win32evtlog.GetNumberOfEventLogRecords(elh))
+ flags = win32evtlog.EVENTLOG_BACKWARDS_READ | \
+ win32evtlog.EVENTLOG_SEQUENTIAL_READ
+ found = False
+ GO_BACK = 100
+ events = win32evtlog.ReadEventLog(elh, flags, GO_BACK)
+ for e in events:
+ if e.SourceName != 'test_logging':
+ continue
+ msg = win32evtlogutil.SafeFormatMessage(e, logtype)
+ if msg != 'Test Log Message\r\n':
+ continue
+ found = True
+ break
+ msg = 'Record not found in event log, went back %d records' % GO_BACK
+ self.assertTrue(found, msg=msg)
+
# Set the locale to the platform-dependent default. I have no idea
# why the test does this, but in any case we save the current locale
# first and restore it at the end.
@run_with_locale('LC_ALL', '')
def test_main():
run_unittest(BuiltinLevelsTest, BasicFilterTest,
- CustomLevelsAndFiltersTest, MemoryHandlerTest,
- ConfigFileTest, SocketHandlerTest, MemoryTest,
- EncodingTest, WarningsTest, ConfigDictTest, ManagerTest,
- FormatterTest,
- LogRecordFactoryTest, ChildLoggerTest, QueueHandlerTest,
- RotatingFileHandlerTest,
- LastResortTest,
- TimedRotatingFileHandlerTest, HandlerTest, BasicConfigTest,
+ CustomLevelsAndFiltersTest, HandlerTest, MemoryHandlerTest,
+ ConfigFileTest, SocketHandlerTest, DatagramHandlerTest,
+ MemoryTest, EncodingTest, WarningsTest, ConfigDictTest,
+ ManagerTest, FormatterTest, BufferingFormatterTest,
+ StreamHandlerTest, LogRecordFactoryTest, ChildLoggerTest,
+ QueueHandlerTest, ShutdownTest, ModuleLevelMiscTest,
+ BasicConfigTest, LoggerAdapterTest, LoggerTest,
+ SMTPHandlerTest, FileHandlerTest, RotatingFileHandlerTest,
+ LastResortTest, LogRecordTest, ExceptionTest,
+ SysLogHandlerTest, HTTPHandlerTest, NTEventLogHandlerTest,
+ TimedRotatingFileHandlerTest
)
if __name__ == "__main__":
diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py
index 3ebc1f7..b417bea 100644
--- a/Lib/test/test_long.py
+++ b/Lib/test/test_long.py
@@ -43,6 +43,53 @@ DBL_MIN_EXP = sys.float_info.min_exp
DBL_MANT_DIG = sys.float_info.mant_dig
DBL_MIN_OVERFLOW = 2**DBL_MAX_EXP - 2**(DBL_MAX_EXP - DBL_MANT_DIG - 1)
+
+# Pure Python version of correctly-rounded integer-to-float conversion.
+def int_to_float(n):
+ """
+ Correctly-rounded integer-to-float conversion.
+
+ """
+ # Constants, depending only on the floating-point format in use.
+ # We use an extra 2 bits of precision for rounding purposes.
+ PRECISION = sys.float_info.mant_dig + 2
+ SHIFT_MAX = sys.float_info.max_exp - PRECISION
+ Q_MAX = 1 << PRECISION
+ ROUND_HALF_TO_EVEN_CORRECTION = [0, -1, -2, 1, 0, -1, 2, 1]
+
+ # Reduce to the case where n is positive.
+ if n == 0:
+ return 0.0
+ elif n < 0:
+ return -int_to_float(-n)
+
+ # Convert n to a 'floating-point' number q * 2**shift, where q is an
+ # integer with 'PRECISION' significant bits. When shifting n to create q,
+ # the least significant bit of q is treated as 'sticky'. That is, the
+ # least significant bit of q is set if either the corresponding bit of n
+ # was already set, or any one of the bits of n lost in the shift was set.
+ shift = n.bit_length() - PRECISION
+ q = n << -shift if shift < 0 else (n >> shift) | bool(n & ~(-1 << shift))
+
+ # Round half to even (actually rounds to the nearest multiple of 4,
+ # rounding ties to a multiple of 8).
+ q += ROUND_HALF_TO_EVEN_CORRECTION[q & 7]
+
+ # Detect overflow.
+ if shift + (q == Q_MAX) > SHIFT_MAX:
+ raise OverflowError("integer too large to convert to float")
+
+ # Checks: q is exactly representable, and q**2**shift doesn't overflow.
+ assert q % 4 == 0 and q // 4 <= 2**(sys.float_info.mant_dig)
+ assert q * 2**shift <= sys.float_info.max
+
+ # Some circularity here, since float(q) is doing an int-to-float
+ # conversion. But here q is of bounded size, and is exactly representable
+ # as a float. In a low-level C-like language, this operation would be a
+ # simple cast (e.g., from unsigned long long to double).
+ return math.ldexp(float(q), shift)
+
+
# pure Python version of correctly-rounded true division
def truediv(a, b):
"""Correctly-rounded true division for integers."""
@@ -367,6 +414,23 @@ class LongTest(unittest.TestCase):
return 1729
self.assertEqual(int(LongTrunc()), 1729)
+ def check_float_conversion(self, n):
+ # Check that int -> float conversion behaviour matches
+ # that of the pure Python version above.
+ try:
+ actual = float(n)
+ except OverflowError:
+ actual = 'overflow'
+
+ try:
+ expected = int_to_float(n)
+ except OverflowError:
+ expected = 'overflow'
+
+ msg = ("Error in conversion of integer {} to float. "
+ "Got {}, expected {}.".format(n, actual, expected))
+ self.assertEqual(actual, expected, msg)
+
@support.requires_IEEE_754
def test_float_conversion(self):
@@ -421,6 +485,22 @@ class LongTest(unittest.TestCase):
y = 2**p * 2**53
self.assertEqual(int(float(x)), y)
+ # Compare builtin float conversion with pure Python int_to_float
+ # function above.
+ test_values = [
+ int_dbl_max-1, int_dbl_max, int_dbl_max+1,
+ halfway-1, halfway, halfway + 1,
+ top_power-1, top_power, top_power+1,
+ 2*top_power-1, 2*top_power, top_power*top_power,
+ ]
+ test_values.extend(exact_values)
+ for p in range(-4, 8):
+ for x in range(-128, 128):
+ test_values.append(2**(p+53) + x)
+ for value in test_values:
+ self.check_float_conversion(value)
+ self.check_float_conversion(-value)
+
def test_float_overflow(self):
for x in -2.0, -1.0, 0.0, 1.0, 2.0:
self.assertEqual(float(int(x)), x)
diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py
new file mode 100644
index 0000000..20d8582
--- /dev/null
+++ b/Lib/test/test_lzma.py
@@ -0,0 +1,1531 @@
+from io import BytesIO, UnsupportedOperation
+import os
+import random
+import unittest
+
+from test.support import (
+ _4G, TESTFN, import_module, bigmemtest, run_unittest, unlink
+)
+
+lzma = import_module("lzma")
+from lzma import LZMACompressor, LZMADecompressor, LZMAError, LZMAFile
+
+
+class CompressorDecompressorTestCase(unittest.TestCase):
+
+ # Test error cases.
+
+ def test_simple_bad_args(self):
+ self.assertRaises(TypeError, LZMACompressor, [])
+ self.assertRaises(TypeError, LZMACompressor, format=3.45)
+ self.assertRaises(TypeError, LZMACompressor, check="")
+ self.assertRaises(TypeError, LZMACompressor, preset="asdf")
+ self.assertRaises(TypeError, LZMACompressor, filters=3)
+ # Can't specify FORMAT_AUTO when compressing.
+ self.assertRaises(ValueError, LZMACompressor, format=lzma.FORMAT_AUTO)
+ # Can't specify a preset and a custom filter chain at the same time.
+ with self.assertRaises(ValueError):
+ LZMACompressor(preset=7, filters=[{"id": lzma.FILTER_LZMA2}])
+
+ self.assertRaises(TypeError, LZMADecompressor, ())
+ self.assertRaises(TypeError, LZMADecompressor, memlimit=b"qw")
+ with self.assertRaises(TypeError):
+ LZMADecompressor(lzma.FORMAT_RAW, filters="zzz")
+ # Cannot specify a memory limit with FILTER_RAW.
+ with self.assertRaises(ValueError):
+ LZMADecompressor(lzma.FORMAT_RAW, memlimit=0x1000000)
+ # Can only specify a custom filter chain with FILTER_RAW.
+ self.assertRaises(ValueError, LZMADecompressor, filters=FILTERS_RAW_1)
+ with self.assertRaises(ValueError):
+ LZMADecompressor(format=lzma.FORMAT_XZ, filters=FILTERS_RAW_1)
+ with self.assertRaises(ValueError):
+ LZMADecompressor(format=lzma.FORMAT_ALONE, filters=FILTERS_RAW_1)
+
+ lzc = LZMACompressor()
+ self.assertRaises(TypeError, lzc.compress)
+ self.assertRaises(TypeError, lzc.compress, b"foo", b"bar")
+ self.assertRaises(TypeError, lzc.flush, b"blah")
+ empty = lzc.flush()
+ self.assertRaises(ValueError, lzc.compress, b"quux")
+ self.assertRaises(ValueError, lzc.flush)
+
+ lzd = LZMADecompressor()
+ self.assertRaises(TypeError, lzd.decompress)
+ self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar")
+ lzd.decompress(empty)
+ self.assertRaises(EOFError, lzd.decompress, b"quux")
+
+ def test_bad_filter_spec(self):
+ self.assertRaises(TypeError, LZMACompressor, filters=[b"wobsite"])
+ self.assertRaises(ValueError, LZMACompressor, filters=[{"xyzzy": 3}])
+ self.assertRaises(ValueError, LZMACompressor, filters=[{"id": 98765}])
+ with self.assertRaises(ValueError):
+ LZMACompressor(filters=[{"id": lzma.FILTER_LZMA2, "foo": 0}])
+ with self.assertRaises(ValueError):
+ LZMACompressor(filters=[{"id": lzma.FILTER_DELTA, "foo": 0}])
+ with self.assertRaises(ValueError):
+ LZMACompressor(filters=[{"id": lzma.FILTER_X86, "foo": 0}])
+
+ def test_decompressor_after_eof(self):
+ lzd = LZMADecompressor()
+ lzd.decompress(COMPRESSED_XZ)
+ self.assertRaises(EOFError, lzd.decompress, b"nyan")
+
+ def test_decompressor_memlimit(self):
+ lzd = LZMADecompressor(memlimit=1024)
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ)
+
+ lzd = LZMADecompressor(lzma.FORMAT_XZ, memlimit=1024)
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ)
+
+ lzd = LZMADecompressor(lzma.FORMAT_ALONE, memlimit=1024)
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_ALONE)
+
+ # Test LZMADecompressor on known-good input data.
+
+ def _test_decompressor(self, lzd, data, check, unused_data=b""):
+ self.assertFalse(lzd.eof)
+ out = lzd.decompress(data)
+ self.assertEqual(out, INPUT)
+ self.assertEqual(lzd.check, check)
+ self.assertTrue(lzd.eof)
+ self.assertEqual(lzd.unused_data, unused_data)
+
+ def test_decompressor_auto(self):
+ lzd = LZMADecompressor()
+ self._test_decompressor(lzd, COMPRESSED_XZ, lzma.CHECK_CRC64)
+
+ lzd = LZMADecompressor()
+ self._test_decompressor(lzd, COMPRESSED_ALONE, lzma.CHECK_NONE)
+
+ def test_decompressor_xz(self):
+ lzd = LZMADecompressor(lzma.FORMAT_XZ)
+ self._test_decompressor(lzd, COMPRESSED_XZ, lzma.CHECK_CRC64)
+
+ def test_decompressor_alone(self):
+ lzd = LZMADecompressor(lzma.FORMAT_ALONE)
+ self._test_decompressor(lzd, COMPRESSED_ALONE, lzma.CHECK_NONE)
+
+ def test_decompressor_raw_1(self):
+ lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_1)
+ self._test_decompressor(lzd, COMPRESSED_RAW_1, lzma.CHECK_NONE)
+
+ def test_decompressor_raw_2(self):
+ lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_2)
+ self._test_decompressor(lzd, COMPRESSED_RAW_2, lzma.CHECK_NONE)
+
+ def test_decompressor_raw_3(self):
+ lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_3)
+ self._test_decompressor(lzd, COMPRESSED_RAW_3, lzma.CHECK_NONE)
+
+ def test_decompressor_raw_4(self):
+ lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+ self._test_decompressor(lzd, COMPRESSED_RAW_4, lzma.CHECK_NONE)
+
+ def test_decompressor_chunks(self):
+ lzd = LZMADecompressor()
+ out = []
+ for i in range(0, len(COMPRESSED_XZ), 10):
+ self.assertFalse(lzd.eof)
+ out.append(lzd.decompress(COMPRESSED_XZ[i:i+10]))
+ out = b"".join(out)
+ self.assertEqual(out, INPUT)
+ self.assertEqual(lzd.check, lzma.CHECK_CRC64)
+ self.assertTrue(lzd.eof)
+ self.assertEqual(lzd.unused_data, b"")
+
+ def test_decompressor_unused_data(self):
+ lzd = LZMADecompressor()
+ extra = b"fooblibar"
+ self._test_decompressor(lzd, COMPRESSED_XZ + extra, lzma.CHECK_CRC64,
+ unused_data=extra)
+
+ def test_decompressor_bad_input(self):
+ lzd = LZMADecompressor()
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_RAW_1)
+
+ lzd = LZMADecompressor(lzma.FORMAT_XZ)
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_ALONE)
+
+ lzd = LZMADecompressor(lzma.FORMAT_ALONE)
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ)
+
+ lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_1)
+ self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ)
+
+ # Test that LZMACompressor->LZMADecompressor preserves the input data.
+
+ def test_roundtrip_xz(self):
+ lzc = LZMACompressor()
+ cdata = lzc.compress(INPUT) + lzc.flush()
+ lzd = LZMADecompressor()
+ self._test_decompressor(lzd, cdata, lzma.CHECK_CRC64)
+
+ def test_roundtrip_alone(self):
+ lzc = LZMACompressor(lzma.FORMAT_ALONE)
+ cdata = lzc.compress(INPUT) + lzc.flush()
+ lzd = LZMADecompressor()
+ self._test_decompressor(lzd, cdata, lzma.CHECK_NONE)
+
+ def test_roundtrip_raw(self):
+ lzc = LZMACompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+ cdata = lzc.compress(INPUT) + lzc.flush()
+ lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+ self._test_decompressor(lzd, cdata, lzma.CHECK_NONE)
+
+ def test_roundtrip_chunks(self):
+ lzc = LZMACompressor()
+ cdata = []
+ for i in range(0, len(INPUT), 10):
+ cdata.append(lzc.compress(INPUT[i:i+10]))
+ cdata.append(lzc.flush())
+ cdata = b"".join(cdata)
+ lzd = LZMADecompressor()
+ self._test_decompressor(lzd, cdata, lzma.CHECK_CRC64)
+
+ # LZMADecompressor intentionally does not handle concatenated streams.
+
+ def test_decompressor_multistream(self):
+ lzd = LZMADecompressor()
+ self._test_decompressor(lzd, COMPRESSED_XZ + COMPRESSED_ALONE,
+ lzma.CHECK_CRC64, unused_data=COMPRESSED_ALONE)
+
+ # Test with inputs larger than 4GiB.
+
+ @bigmemtest(size=_4G + 100, memuse=2)
+ def test_compressor_bigmem(self, size):
+ lzc = LZMACompressor()
+ cdata = lzc.compress(b"x" * size) + lzc.flush()
+ ddata = lzma.decompress(cdata)
+ try:
+ self.assertEqual(len(ddata), size)
+ self.assertEqual(len(ddata.strip(b"x")), 0)
+ finally:
+ ddata = None
+
+ @bigmemtest(size=_4G + 100, memuse=3)
+ def test_decompressor_bigmem(self, size):
+ lzd = LZMADecompressor()
+ blocksize = 10 * 1024 * 1024
+ block = random.getrandbits(blocksize * 8).to_bytes(blocksize, "little")
+ try:
+ input = block * (size // blocksize + 1)
+ cdata = lzma.compress(input)
+ ddata = lzd.decompress(cdata)
+ self.assertEqual(ddata, input)
+ finally:
+ input = cdata = ddata = None
+
+
+class CompressDecompressFunctionTestCase(unittest.TestCase):
+
+ # Test error cases:
+
+ def test_bad_args(self):
+ self.assertRaises(TypeError, lzma.compress)
+ self.assertRaises(TypeError, lzma.compress, [])
+ self.assertRaises(TypeError, lzma.compress, b"", format="xz")
+ self.assertRaises(TypeError, lzma.compress, b"", check="none")
+ self.assertRaises(TypeError, lzma.compress, b"", preset="blah")
+ self.assertRaises(TypeError, lzma.compress, b"", filters=1024)
+ # Can't specify a preset and a custom filter chain at the same time.
+ with self.assertRaises(ValueError):
+ lzma.compress(b"", preset=3, filters=[{"id": lzma.FILTER_LZMA2}])
+
+ self.assertRaises(TypeError, lzma.decompress)
+ self.assertRaises(TypeError, lzma.decompress, [])
+ self.assertRaises(TypeError, lzma.decompress, b"", format="lzma")
+ self.assertRaises(TypeError, lzma.decompress, b"", memlimit=7.3e9)
+ with self.assertRaises(TypeError):
+ lzma.decompress(b"", format=lzma.FORMAT_RAW, filters={})
+ # Cannot specify a memory limit with FILTER_RAW.
+ with self.assertRaises(ValueError):
+ lzma.decompress(b"", format=lzma.FORMAT_RAW, memlimit=0x1000000)
+ # Can only specify a custom filter chain with FILTER_RAW.
+ with self.assertRaises(ValueError):
+ lzma.decompress(b"", filters=FILTERS_RAW_1)
+ with self.assertRaises(ValueError):
+ lzma.decompress(b"", format=lzma.FORMAT_XZ, filters=FILTERS_RAW_1)
+ with self.assertRaises(ValueError):
+ lzma.decompress(
+ b"", format=lzma.FORMAT_ALONE, filters=FILTERS_RAW_1)
+
+ def test_decompress_memlimit(self):
+ with self.assertRaises(LZMAError):
+ lzma.decompress(COMPRESSED_XZ, memlimit=1024)
+ with self.assertRaises(LZMAError):
+ lzma.decompress(
+ COMPRESSED_XZ, format=lzma.FORMAT_XZ, memlimit=1024)
+ with self.assertRaises(LZMAError):
+ lzma.decompress(
+ COMPRESSED_ALONE, format=lzma.FORMAT_ALONE, memlimit=1024)
+
+ # Test LZMADecompressor on known-good input data.
+
+ def test_decompress_good_input(self):
+ ddata = lzma.decompress(COMPRESSED_XZ)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(COMPRESSED_ALONE)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(COMPRESSED_XZ, lzma.FORMAT_XZ)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(COMPRESSED_ALONE, lzma.FORMAT_ALONE)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(
+ COMPRESSED_RAW_1, lzma.FORMAT_RAW, filters=FILTERS_RAW_1)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(
+ COMPRESSED_RAW_2, lzma.FORMAT_RAW, filters=FILTERS_RAW_2)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(
+ COMPRESSED_RAW_3, lzma.FORMAT_RAW, filters=FILTERS_RAW_3)
+ self.assertEqual(ddata, INPUT)
+
+ ddata = lzma.decompress(
+ COMPRESSED_RAW_4, lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+ self.assertEqual(ddata, INPUT)
+
+ def test_decompress_incomplete_input(self):
+ self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_XZ[:128])
+ self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_ALONE[:128])
+ self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_RAW_1[:128],
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_1)
+ self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_RAW_2[:128],
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_2)
+ self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_RAW_3[:128],
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_3)
+ self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_RAW_4[:128],
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+
+ def test_decompress_bad_input(self):
+ with self.assertRaises(LZMAError):
+ lzma.decompress(COMPRESSED_RAW_1)
+ with self.assertRaises(LZMAError):
+ lzma.decompress(COMPRESSED_ALONE, format=lzma.FORMAT_XZ)
+ with self.assertRaises(LZMAError):
+ lzma.decompress(COMPRESSED_XZ, format=lzma.FORMAT_ALONE)
+ with self.assertRaises(LZMAError):
+ lzma.decompress(COMPRESSED_XZ, format=lzma.FORMAT_RAW,
+ filters=FILTERS_RAW_1)
+
+ # Test that compress()->decompress() preserves the input data.
+
+ def test_roundtrip(self):
+ cdata = lzma.compress(INPUT)
+ ddata = lzma.decompress(cdata)
+ self.assertEqual(ddata, INPUT)
+
+ cdata = lzma.compress(INPUT, lzma.FORMAT_XZ)
+ ddata = lzma.decompress(cdata)
+ self.assertEqual(ddata, INPUT)
+
+ cdata = lzma.compress(INPUT, lzma.FORMAT_ALONE)
+ ddata = lzma.decompress(cdata)
+ self.assertEqual(ddata, INPUT)
+
+ cdata = lzma.compress(INPUT, lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+ ddata = lzma.decompress(cdata, lzma.FORMAT_RAW, filters=FILTERS_RAW_4)
+ self.assertEqual(ddata, INPUT)
+
+ # Unlike LZMADecompressor, decompress() *does* handle concatenated streams.
+
+ def test_decompress_multistream(self):
+ ddata = lzma.decompress(COMPRESSED_XZ + COMPRESSED_ALONE)
+ self.assertEqual(ddata, INPUT * 2)
+
+
+class TempFile:
+ """Context manager - creates a file, and deletes it on __exit__."""
+
+ def __init__(self, filename, data=b""):
+ self.filename = filename
+ self.data = data
+
+ def __enter__(self):
+ with open(self.filename, "wb") as f:
+ f.write(self.data)
+
+ def __exit__(self, *args):
+ unlink(self.filename)
+
+
+class FileTestCase(unittest.TestCase):
+
+ def test_init(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ pass
+ with LZMAFile(BytesIO(), "w") as f:
+ pass
+ with LZMAFile(BytesIO(), "a") as f:
+ pass
+
+ def test_init_with_filename(self):
+ with TempFile(TESTFN, COMPRESSED_XZ):
+ with LZMAFile(TESTFN) as f:
+ pass
+ with LZMAFile(TESTFN, "w") as f:
+ pass
+ with LZMAFile(TESTFN, "a") as f:
+ pass
+
+ def test_init_mode(self):
+ with TempFile(TESTFN):
+ with LZMAFile(TESTFN, "r"):
+ pass
+ with LZMAFile(TESTFN, "rb"):
+ pass
+ with LZMAFile(TESTFN, "w"):
+ pass
+ with LZMAFile(TESTFN, "wb"):
+ pass
+ with LZMAFile(TESTFN, "a"):
+ pass
+ with LZMAFile(TESTFN, "ab"):
+ pass
+
+ def test_init_bad_mode(self):
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), (3, "x"))
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "x")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "rt")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "r+")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "wt")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "w+")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "rw")
+
+ def test_init_bad_check(self):
+ with self.assertRaises(TypeError):
+ LZMAFile(BytesIO(), "w", check=b"asd")
+ # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid.
+ with self.assertRaises(LZMAError):
+ LZMAFile(BytesIO(), "w", check=lzma.CHECK_UNKNOWN)
+ with self.assertRaises(LZMAError):
+ LZMAFile(BytesIO(), "w", check=lzma.CHECK_ID_MAX + 3)
+ # Cannot specify a check with mode="r".
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), check=lzma.CHECK_NONE)
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), check=lzma.CHECK_CRC32)
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), check=lzma.CHECK_CRC64)
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), check=lzma.CHECK_SHA256)
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), check=lzma.CHECK_UNKNOWN)
+
+ def test_init_bad_preset(self):
+ with self.assertRaises(TypeError):
+ LZMAFile(BytesIO(), "w", preset=4.39)
+ with self.assertRaises(LZMAError):
+ LZMAFile(BytesIO(), "w", preset=10)
+ with self.assertRaises(LZMAError):
+ LZMAFile(BytesIO(), "w", preset=23)
+ with self.assertRaises(OverflowError):
+ LZMAFile(BytesIO(), "w", preset=-1)
+ with self.assertRaises(OverflowError):
+ LZMAFile(BytesIO(), "w", preset=-7)
+ with self.assertRaises(TypeError):
+ LZMAFile(BytesIO(), "w", preset="foo")
+ # Cannot specify a preset with mode="r".
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), preset=3)
+
+ def test_init_bad_filter_spec(self):
+ with self.assertRaises(TypeError):
+ LZMAFile(BytesIO(), "w", filters=[b"wobsite"])
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(), "w", filters=[{"xyzzy": 3}])
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(), "w", filters=[{"id": 98765}])
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(), "w",
+ filters=[{"id": lzma.FILTER_LZMA2, "foo": 0}])
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(), "w",
+ filters=[{"id": lzma.FILTER_DELTA, "foo": 0}])
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(), "w",
+ filters=[{"id": lzma.FILTER_X86, "foo": 0}])
+
+ def test_init_with_preset_and_filters(self):
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(), "w", format=lzma.FORMAT_RAW,
+ preset=6, filters=FILTERS_RAW_1)
+
+ def test_close(self):
+ with BytesIO(COMPRESSED_XZ) as src:
+ f = LZMAFile(src)
+ f.close()
+ # LZMAFile.close() should not close the underlying file object.
+ self.assertFalse(src.closed)
+ # Try closing an already-closed LZMAFile.
+ f.close()
+ self.assertFalse(src.closed)
+
+ # Test with a real file on disk, opened directly by LZMAFile.
+ with TempFile(TESTFN, COMPRESSED_XZ):
+ f = LZMAFile(TESTFN)
+ fp = f._fp
+ f.close()
+ # Here, LZMAFile.close() *should* close the underlying file object.
+ self.assertTrue(fp.closed)
+ # Try closing an already-closed LZMAFile.
+ f.close()
+
+ def test_closed(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ try:
+ self.assertFalse(f.closed)
+ f.read()
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ f = LZMAFile(BytesIO(), "w")
+ try:
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ def test_fileno(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ try:
+ self.assertRaises(UnsupportedOperation, f.fileno)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+ with TempFile(TESTFN, COMPRESSED_XZ):
+ f = LZMAFile(TESTFN)
+ try:
+ self.assertEqual(f.fileno(), f._fp.fileno())
+ self.assertIsInstance(f.fileno(), int)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ def test_seekable(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ try:
+ self.assertTrue(f.seekable())
+ f.read()
+ self.assertTrue(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ f = LZMAFile(BytesIO(), "w")
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ src = BytesIO(COMPRESSED_XZ)
+ src.seekable = lambda: False
+ f = LZMAFile(src)
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ def test_readable(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ try:
+ self.assertTrue(f.readable())
+ f.read()
+ self.assertTrue(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ f = LZMAFile(BytesIO(), "w")
+ try:
+ self.assertFalse(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ def test_writable(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ try:
+ self.assertFalse(f.writable())
+ f.read()
+ self.assertFalse(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ f = LZMAFile(BytesIO(), "w")
+ try:
+ self.assertTrue(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ def test_read(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+ with LZMAFile(BytesIO(COMPRESSED_ALONE)) as f:
+ self.assertEqual(f.read(), INPUT)
+ with LZMAFile(BytesIO(COMPRESSED_XZ), format=lzma.FORMAT_XZ) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+ with LZMAFile(BytesIO(COMPRESSED_ALONE), format=lzma.FORMAT_ALONE) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+ with LZMAFile(BytesIO(COMPRESSED_RAW_1),
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_1) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+ with LZMAFile(BytesIO(COMPRESSED_RAW_2),
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_2) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+ with LZMAFile(BytesIO(COMPRESSED_RAW_3),
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_3) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+ with LZMAFile(BytesIO(COMPRESSED_RAW_4),
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_4) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+
+ def test_read_0(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertEqual(f.read(0), b"")
+ with LZMAFile(BytesIO(COMPRESSED_ALONE)) as f:
+ self.assertEqual(f.read(0), b"")
+ with LZMAFile(BytesIO(COMPRESSED_XZ), format=lzma.FORMAT_XZ) as f:
+ self.assertEqual(f.read(0), b"")
+ with LZMAFile(BytesIO(COMPRESSED_ALONE), format=lzma.FORMAT_ALONE) as f:
+ self.assertEqual(f.read(0), b"")
+
+ def test_read_10(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ chunks = []
+ while True:
+ result = f.read(10)
+ if not result:
+ break
+ self.assertLessEqual(len(result), 10)
+ chunks.append(result)
+ self.assertEqual(b"".join(chunks), INPUT)
+
+ def test_read_multistream(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f:
+ self.assertEqual(f.read(), INPUT * 5)
+ with LZMAFile(BytesIO(COMPRESSED_XZ + COMPRESSED_ALONE)) as f:
+ self.assertEqual(f.read(), INPUT * 2)
+ with LZMAFile(BytesIO(COMPRESSED_RAW_3 * 4),
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_3) as f:
+ self.assertEqual(f.read(), INPUT * 4)
+
+ def test_read_multistream_buffer_size_aligned(self):
+ # Test the case where a stream boundary coincides with the end
+ # of the raw read buffer.
+ saved_buffer_size = lzma._BUFFER_SIZE
+ lzma._BUFFER_SIZE = len(COMPRESSED_XZ)
+ try:
+ with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f:
+ self.assertEqual(f.read(), INPUT * 5)
+ finally:
+ lzma._BUFFER_SIZE = saved_buffer_size
+
+ def test_read_from_file(self):
+ with TempFile(TESTFN, COMPRESSED_XZ):
+ with LZMAFile(TESTFN) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+
+ def test_read_from_file_with_bytes_filename(self):
+ try:
+ bytes_filename = TESTFN.encode("ascii")
+ except UnicodeEncodeError:
+ self.skipTest("Temporary file name needs to be ASCII")
+ with TempFile(TESTFN, COMPRESSED_XZ):
+ with LZMAFile(bytes_filename) as f:
+ self.assertEqual(f.read(), INPUT)
+ self.assertEqual(f.read(), b"")
+
+ def test_read_incomplete(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ[:128])) as f:
+ self.assertRaises(EOFError, f.read)
+
+ def test_read_truncated(self):
+ # Drop stream footer: CRC (4 bytes), index size (4 bytes),
+ # flags (2 bytes) and magic number (2 bytes).
+ truncated = COMPRESSED_XZ[:-12]
+ with LZMAFile(BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+ with LZMAFile(BytesIO(truncated)) as f:
+ self.assertEqual(f.read(len(INPUT)), INPUT)
+ self.assertRaises(EOFError, f.read, 1)
+ # Incomplete 12-byte header.
+ for i in range(12):
+ with LZMAFile(BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
+ def test_read_bad_args(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ f.close()
+ self.assertRaises(ValueError, f.read)
+ with LZMAFile(BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read)
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertRaises(TypeError, f.read, None)
+
+ def test_read1(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), INPUT)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_0(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertEqual(f.read1(0), b"")
+
+ def test_read1_10(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ blocks = []
+ while True:
+ result = f.read1(10)
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), INPUT)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_multistream(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), INPUT * 5)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_bad_args(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ f.close()
+ self.assertRaises(ValueError, f.read1)
+ with LZMAFile(BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read1)
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertRaises(TypeError, f.read1, None)
+
+ def test_peek(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ result = f.peek()
+ self.assertGreater(len(result), 0)
+ self.assertTrue(INPUT.startswith(result))
+ self.assertEqual(f.read(), INPUT)
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ result = f.peek(10)
+ self.assertGreater(len(result), 0)
+ self.assertTrue(INPUT.startswith(result))
+ self.assertEqual(f.read(), INPUT)
+
+ def test_peek_bad_args(self):
+ with LZMAFile(BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.peek)
+
+ def test_iterator(self):
+ with BytesIO(INPUT) as f:
+ lines = f.readlines()
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertListEqual(list(iter(f)), lines)
+ with LZMAFile(BytesIO(COMPRESSED_ALONE)) as f:
+ self.assertListEqual(list(iter(f)), lines)
+ with LZMAFile(BytesIO(COMPRESSED_XZ), format=lzma.FORMAT_XZ) as f:
+ self.assertListEqual(list(iter(f)), lines)
+ with LZMAFile(BytesIO(COMPRESSED_ALONE), format=lzma.FORMAT_ALONE) as f:
+ self.assertListEqual(list(iter(f)), lines)
+ with LZMAFile(BytesIO(COMPRESSED_RAW_2),
+ format=lzma.FORMAT_RAW, filters=FILTERS_RAW_2) as f:
+ self.assertListEqual(list(iter(f)), lines)
+
+ def test_readline(self):
+ with BytesIO(INPUT) as f:
+ lines = f.readlines()
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ for line in lines:
+ self.assertEqual(f.readline(), line)
+
+ def test_readlines(self):
+ with BytesIO(INPUT) as f:
+ lines = f.readlines()
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertListEqual(f.readlines(), lines)
+
+ def test_write(self):
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w") as f:
+ f.write(INPUT)
+ expected = lzma.compress(INPUT)
+ self.assertEqual(dst.getvalue(), expected)
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w", format=lzma.FORMAT_XZ) as f:
+ f.write(INPUT)
+ expected = lzma.compress(INPUT, format=lzma.FORMAT_XZ)
+ self.assertEqual(dst.getvalue(), expected)
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w", format=lzma.FORMAT_ALONE) as f:
+ f.write(INPUT)
+ expected = lzma.compress(INPUT, format=lzma.FORMAT_ALONE)
+ self.assertEqual(dst.getvalue(), expected)
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w", format=lzma.FORMAT_RAW,
+ filters=FILTERS_RAW_2) as f:
+ f.write(INPUT)
+ expected = lzma.compress(INPUT, format=lzma.FORMAT_RAW,
+ filters=FILTERS_RAW_2)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_10(self):
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w") as f:
+ for start in range(0, len(INPUT), 10):
+ f.write(INPUT[start:start+10])
+ expected = lzma.compress(INPUT)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_append(self):
+ part1 = INPUT[:1024]
+ part2 = INPUT[1024:1536]
+ part3 = INPUT[1536:]
+ expected = b"".join(lzma.compress(x) for x in (part1, part2, part3))
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w") as f:
+ f.write(part1)
+ with LZMAFile(dst, "a") as f:
+ f.write(part2)
+ with LZMAFile(dst, "a") as f:
+ f.write(part3)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_to_file(self):
+ try:
+ with LZMAFile(TESTFN, "w") as f:
+ f.write(INPUT)
+ expected = lzma.compress(INPUT)
+ with open(TESTFN, "rb") as f:
+ self.assertEqual(f.read(), expected)
+ finally:
+ unlink(TESTFN)
+
+ def test_write_to_file_with_bytes_filename(self):
+ try:
+ bytes_filename = TESTFN.encode("ascii")
+ except UnicodeEncodeError:
+ self.skipTest("Temporary file name needs to be ASCII")
+ try:
+ with LZMAFile(bytes_filename, "w") as f:
+ f.write(INPUT)
+ expected = lzma.compress(INPUT)
+ with open(TESTFN, "rb") as f:
+ self.assertEqual(f.read(), expected)
+ finally:
+ unlink(TESTFN)
+
+ def test_write_append_to_file(self):
+ part1 = INPUT[:1024]
+ part2 = INPUT[1024:1536]
+ part3 = INPUT[1536:]
+ expected = b"".join(lzma.compress(x) for x in (part1, part2, part3))
+ try:
+ with LZMAFile(TESTFN, "w") as f:
+ f.write(part1)
+ with LZMAFile(TESTFN, "a") as f:
+ f.write(part2)
+ with LZMAFile(TESTFN, "a") as f:
+ f.write(part3)
+ with open(TESTFN, "rb") as f:
+ self.assertEqual(f.read(), expected)
+ finally:
+ unlink(TESTFN)
+
+ def test_write_bad_args(self):
+ f = LZMAFile(BytesIO(), "w")
+ f.close()
+ self.assertRaises(ValueError, f.write, b"foo")
+ with LZMAFile(BytesIO(COMPRESSED_XZ), "r") as f:
+ self.assertRaises(ValueError, f.write, b"bar")
+ with LZMAFile(BytesIO(), "w") as f:
+ self.assertRaises(TypeError, f.write, None)
+ self.assertRaises(TypeError, f.write, "text")
+ self.assertRaises(TypeError, f.write, 789)
+
+ def test_writelines(self):
+ with BytesIO(INPUT) as f:
+ lines = f.readlines()
+ with BytesIO() as dst:
+ with LZMAFile(dst, "w") as f:
+ f.writelines(lines)
+ expected = lzma.compress(INPUT)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_seek_forward(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.seek(555)
+ self.assertEqual(f.read(), INPUT[555:])
+
+ def test_seek_forward_across_streams(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ * 2)) as f:
+ f.seek(len(INPUT) + 123)
+ self.assertEqual(f.read(), INPUT[123:])
+
+ def test_seek_forward_relative_to_current(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.read(100)
+ f.seek(1236, 1)
+ self.assertEqual(f.read(), INPUT[1336:])
+
+ def test_seek_forward_relative_to_end(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.seek(-555, 2)
+ self.assertEqual(f.read(), INPUT[-555:])
+
+ def test_seek_backward(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.read(1001)
+ f.seek(211)
+ self.assertEqual(f.read(), INPUT[211:])
+
+ def test_seek_backward_across_streams(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ * 2)) as f:
+ f.read(len(INPUT) + 333)
+ f.seek(737)
+ self.assertEqual(f.read(), INPUT[737:] + INPUT)
+
+ def test_seek_backward_relative_to_end(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.seek(-150, 2)
+ self.assertEqual(f.read(), INPUT[-150:])
+
+ def test_seek_past_end(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.seek(len(INPUT) + 9001)
+ self.assertEqual(f.tell(), len(INPUT))
+ self.assertEqual(f.read(), b"")
+
+ def test_seek_past_start(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ f.seek(-88)
+ self.assertEqual(f.tell(), 0)
+ self.assertEqual(f.read(), INPUT)
+
+ def test_seek_bad_args(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ f.close()
+ self.assertRaises(ValueError, f.seek, 0)
+ with LZMAFile(BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.seek, 0)
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ self.assertRaises(ValueError, f.seek, 0, 3)
+ self.assertRaises(ValueError, f.seek, 9, ())
+ self.assertRaises(TypeError, f.seek, None)
+ self.assertRaises(TypeError, f.seek, b"derp")
+
+ def test_tell(self):
+ with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
+ pos = 0
+ while True:
+ self.assertEqual(f.tell(), pos)
+ result = f.read(183)
+ if not result:
+ break
+ pos += len(result)
+ self.assertEqual(f.tell(), len(INPUT))
+ with LZMAFile(BytesIO(), "w") as f:
+ for pos in range(0, len(INPUT), 144):
+ self.assertEqual(f.tell(), pos)
+ f.write(INPUT[pos:pos+144])
+ self.assertEqual(f.tell(), len(INPUT))
+
+ def test_tell_bad_args(self):
+ f = LZMAFile(BytesIO(COMPRESSED_XZ))
+ f.close()
+ self.assertRaises(ValueError, f.tell)
+
+
+class OpenTestCase(unittest.TestCase):
+
+ def test_binary_modes(self):
+ with lzma.open(BytesIO(COMPRESSED_XZ), "rb") as f:
+ self.assertEqual(f.read(), INPUT)
+ with BytesIO() as bio:
+ with lzma.open(bio, "wb") as f:
+ f.write(INPUT)
+ file_data = lzma.decompress(bio.getvalue())
+ self.assertEqual(file_data, INPUT)
+ with lzma.open(bio, "ab") as f:
+ f.write(INPUT)
+ file_data = lzma.decompress(bio.getvalue())
+ self.assertEqual(file_data, INPUT * 2)
+
+ def test_text_modes(self):
+ uncompressed = INPUT.decode("ascii")
+ uncompressed_raw = uncompressed.replace("\n", os.linesep)
+ with lzma.open(BytesIO(COMPRESSED_XZ), "rt") as f:
+ self.assertEqual(f.read(), uncompressed)
+ with BytesIO() as bio:
+ with lzma.open(bio, "wt") as f:
+ f.write(uncompressed)
+ file_data = lzma.decompress(bio.getvalue()).decode("ascii")
+ self.assertEqual(file_data, uncompressed_raw)
+ with lzma.open(bio, "at") as f:
+ f.write(uncompressed)
+ file_data = lzma.decompress(bio.getvalue()).decode("ascii")
+ self.assertEqual(file_data, uncompressed_raw * 2)
+
+ def test_filename(self):
+ with TempFile(TESTFN):
+ with lzma.open(TESTFN, "wb") as f:
+ f.write(INPUT)
+ with open(TESTFN, "rb") as f:
+ file_data = lzma.decompress(f.read())
+ self.assertEqual(file_data, INPUT)
+ with lzma.open(TESTFN, "rb") as f:
+ self.assertEqual(f.read(), INPUT)
+ with lzma.open(TESTFN, "ab") as f:
+ f.write(INPUT)
+ with lzma.open(TESTFN, "rb") as f:
+ self.assertEqual(f.read(), INPUT * 2)
+
+ def test_bad_params(self):
+ # Test invalid parameter combinations.
+ with self.assertRaises(ValueError):
+ lzma.open(TESTFN, "")
+ with self.assertRaises(ValueError):
+ lzma.open(TESTFN, "x")
+ with self.assertRaises(ValueError):
+ lzma.open(TESTFN, "rbt")
+ with self.assertRaises(ValueError):
+ lzma.open(TESTFN, "rb", encoding="utf-8")
+ with self.assertRaises(ValueError):
+ lzma.open(TESTFN, "rb", errors="ignore")
+ with self.assertRaises(ValueError):
+ lzma.open(TESTFN, "rb", newline="\n")
+
+ def test_format_and_filters(self):
+ # Test non-default format and filter chain.
+ options = {"format": lzma.FORMAT_RAW, "filters": FILTERS_RAW_1}
+ with lzma.open(BytesIO(COMPRESSED_RAW_1), "rb", **options) as f:
+ self.assertEqual(f.read(), INPUT)
+ with BytesIO() as bio:
+ with lzma.open(bio, "wb", **options) as f:
+ f.write(INPUT)
+ file_data = lzma.decompress(bio.getvalue(), **options)
+ self.assertEqual(file_data, INPUT)
+
+ def test_encoding(self):
+ # Test non-default encoding.
+ uncompressed = INPUT.decode("ascii")
+ uncompressed_raw = uncompressed.replace("\n", os.linesep)
+ with BytesIO() as bio:
+ with lzma.open(bio, "wt", encoding="utf-16-le") as f:
+ f.write(uncompressed)
+ file_data = lzma.decompress(bio.getvalue()).decode("utf-16-le")
+ self.assertEqual(file_data, uncompressed_raw)
+ bio.seek(0)
+ with lzma.open(bio, "rt", encoding="utf-16-le") as f:
+ self.assertEqual(f.read(), uncompressed)
+
+ def test_encoding_error_handler(self):
+ # Test wih non-default encoding error handler.
+ with BytesIO(lzma.compress(b"foo\xffbar")) as bio:
+ with lzma.open(bio, "rt", encoding="ascii", errors="ignore") as f:
+ self.assertEqual(f.read(), "foobar")
+
+ def test_newline(self):
+ # Test with explicit newline (universal newline mode disabled).
+ text = INPUT.decode("ascii")
+ with BytesIO() as bio:
+ with lzma.open(bio, "wt", newline="\n") as f:
+ f.write(text)
+ bio.seek(0)
+ with lzma.open(bio, "rt", newline="\r") as f:
+ self.assertEqual(f.readlines(), [text])
+
+
+class MiscellaneousTestCase(unittest.TestCase):
+
+ def test_is_check_supported(self):
+ # CHECK_NONE and CHECK_CRC32 should always be supported,
+ # regardless of the options liblzma was compiled with.
+ self.assertTrue(lzma.is_check_supported(lzma.CHECK_NONE))
+ self.assertTrue(lzma.is_check_supported(lzma.CHECK_CRC32))
+
+ # The .xz format spec cannot store check IDs above this value.
+ self.assertFalse(lzma.is_check_supported(lzma.CHECK_ID_MAX + 1))
+
+ # This value should not be a valid check ID.
+ self.assertFalse(lzma.is_check_supported(lzma.CHECK_UNKNOWN))
+
+ def test__encode_filter_properties(self):
+ with self.assertRaises(TypeError):
+ lzma._encode_filter_properties(b"not a dict")
+ with self.assertRaises(ValueError):
+ lzma._encode_filter_properties({"id": 0x100})
+ with self.assertRaises(ValueError):
+ lzma._encode_filter_properties({"id": lzma.FILTER_LZMA2, "junk": 12})
+ with self.assertRaises(lzma.LZMAError):
+ lzma._encode_filter_properties({"id": lzma.FILTER_DELTA,
+ "dist": 9001})
+
+ # Test with parameters used by zipfile module.
+ props = lzma._encode_filter_properties({
+ "id": lzma.FILTER_LZMA1,
+ "pb": 2,
+ "lp": 0,
+ "lc": 3,
+ "dict_size": 8 << 20,
+ })
+ self.assertEqual(props, b"]\x00\x00\x80\x00")
+
+ def test__decode_filter_properties(self):
+ with self.assertRaises(TypeError):
+ lzma._decode_filter_properties(lzma.FILTER_X86, {"should be": bytes})
+ with self.assertRaises(lzma.LZMAError):
+ lzma._decode_filter_properties(lzma.FILTER_DELTA, b"too long")
+
+ # Test with parameters used by zipfile module.
+ filterspec = lzma._decode_filter_properties(
+ lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00")
+ self.assertEqual(filterspec["id"], lzma.FILTER_LZMA1)
+ self.assertEqual(filterspec["pb"], 2)
+ self.assertEqual(filterspec["lp"], 0)
+ self.assertEqual(filterspec["lc"], 3)
+ self.assertEqual(filterspec["dict_size"], 8 << 20)
+
+ def test_filter_properties_roundtrip(self):
+ spec1 = lzma._decode_filter_properties(
+ lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00")
+ reencoded = lzma._encode_filter_properties(spec1)
+ spec2 = lzma._decode_filter_properties(lzma.FILTER_LZMA1, reencoded)
+ self.assertEqual(spec1, spec2)
+
+
+# Test data:
+
+INPUT = b"""
+LAERTES
+
+ O, fear me not.
+ I stay too long: but here my father comes.
+
+ Enter POLONIUS
+
+ A double blessing is a double grace,
+ Occasion smiles upon a second leave.
+
+LORD POLONIUS
+
+ Yet here, Laertes! aboard, aboard, for shame!
+ The wind sits in the shoulder of your sail,
+ And you are stay'd for. There; my blessing with thee!
+ And these few precepts in thy memory
+ See thou character. Give thy thoughts no tongue,
+ Nor any unproportioned thought his act.
+ Be thou familiar, but by no means vulgar.
+ Those friends thou hast, and their adoption tried,
+ Grapple them to thy soul with hoops of steel;
+ But do not dull thy palm with entertainment
+ Of each new-hatch'd, unfledged comrade. Beware
+ Of entrance to a quarrel, but being in,
+ Bear't that the opposed may beware of thee.
+ Give every man thy ear, but few thy voice;
+ Take each man's censure, but reserve thy judgment.
+ Costly thy habit as thy purse can buy,
+ But not express'd in fancy; rich, not gaudy;
+ For the apparel oft proclaims the man,
+ And they in France of the best rank and station
+ Are of a most select and generous chief in that.
+ Neither a borrower nor a lender be;
+ For loan oft loses both itself and friend,
+ And borrowing dulls the edge of husbandry.
+ This above all: to thine ownself be true,
+ And it must follow, as the night the day,
+ Thou canst not then be false to any man.
+ Farewell: my blessing season this in thee!
+
+LAERTES
+
+ Most humbly do I take my leave, my lord.
+
+LORD POLONIUS
+
+ The time invites you; go; your servants tend.
+
+LAERTES
+
+ Farewell, Ophelia; and remember well
+ What I have said to you.
+
+OPHELIA
+
+ 'Tis in my memory lock'd,
+ And you yourself shall keep the key of it.
+
+LAERTES
+
+ Farewell.
+"""
+
+COMPRESSED_XZ = (
+ b"\xfd7zXZ\x00\x00\x04\xe6\xd6\xb4F\x02\x00!\x01\x16\x00\x00\x00t/\xe5\xa3"
+ b"\xe0\x07\x80\x03\xdf]\x00\x05\x14\x07bX\x19\xcd\xddn\x98\x15\xe4\xb4\x9d"
+ b"o\x1d\xc4\xe5\n\x03\xcc2h\xc7\\\x86\xff\xf8\xe2\xfc\xe7\xd9\xfe6\xb8("
+ b"\xa8wd\xc2\"u.n\x1e\xc3\xf2\x8e\x8d\x8f\x02\x17/\xa6=\xf0\xa2\xdf/M\x89"
+ b"\xbe\xde\xa7\x1cz\x18-]\xd5\xef\x13\x8frZ\x15\x80\x8c\xf8\x8do\xfa\x12"
+ b"\x9b#z/\xef\xf0\xfaF\x01\x82\xa3M\x8e\xa1t\xca6 BF$\xe5Q\xa4\x98\xee\xde"
+ b"l\xe8\x7f\xf0\x9d,bn\x0b\x13\xd4\xa8\x81\xe4N\xc8\x86\x153\xf5x2\xa2O"
+ b"\x13@Q\xa1\x00/\xa5\xd0O\x97\xdco\xae\xf7z\xc4\xcdS\xb6t<\x16\xf2\x9cI#"
+ b"\x89ud\xc66Y\xd9\xee\xe6\xce\x12]\xe5\xf0\xaa\x96-Pe\xade:\x04\t\x1b\xf7"
+ b"\xdb7\n\x86\x1fp\xc8J\xba\xf4\xf0V\xa9\xdc\xf0\x02%G\xf9\xdf=?\x15\x1b"
+ b"\xe1(\xce\x82=\xd6I\xac3\x12\x0cR\xb7\xae\r\xb1i\x03\x95\x01\xbd\xbe\xfa"
+ b"\x02s\x01P\x9d\x96X\xb12j\xc8L\xa8\x84b\xf6\xc3\xd4c-H\x93oJl\xd0iQ\xe4k"
+ b"\x84\x0b\xc1\xb7\xbc\xb1\x17\x88\xb1\xca?@\xf6\x07\xea\xe6x\xf1H12P\x0f"
+ b"\x8a\xc9\xeauw\xe3\xbe\xaai\xa9W\xd0\x80\xcd#cb5\x99\xd8]\xa9d\x0c\xbd"
+ b"\xa2\xdcWl\xedUG\xbf\x89yF\xf77\x81v\xbd5\x98\xbeh8\x18W\x08\xf0\x1b\x99"
+ b"5:\x1a?rD\x96\xa1\x04\x0f\xae\xba\x85\xeb\x9d5@\xf5\x83\xd37\x83\x8ac"
+ b"\x06\xd4\x97i\xcdt\x16S\x82k\xf6K\x01vy\x88\x91\x9b6T\xdae\r\xfd]:k\xbal"
+ b"\xa9\xbba\xc34\xf9r\xeb}r\xdb\xc7\xdb*\x8f\x03z\xdc8h\xcc\xc9\xd3\xbcl"
+ b"\xa5-\xcb\xeaK\xa2\xc5\x15\xc0\xe3\xc1\x86Z\xfb\xebL\xe13\xcf\x9c\xe3"
+ b"\x1d\xc9\xed\xc2\x06\xcc\xce!\x92\xe5\xfe\x9c^\xa59w \x9bP\xa3PK\x08d"
+ b"\xf9\xe2Z}\xa7\xbf\xed\xeb%$\x0c\x82\xb8/\xb0\x01\xa9&,\xf7qh{Q\x96)\xf2"
+ b"q\x96\xc3\x80\xb4\x12\xb0\xba\xe6o\xf4!\xb4[\xd4\x8aw\x10\xf7t\x0c\xb3"
+ b"\xd9\xd5\xc3`^\x81\x11??\\\xa4\x99\x85R\xd4\x8e\x83\xc9\x1eX\xbfa\xf1"
+ b"\xac\xb0\xea\xea\xd7\xd0\xab\x18\xe2\xf2\xed\xe1\xb7\xc9\x18\xcbS\xe4>"
+ b"\xc9\x95H\xe8\xcb\t\r%\xeb\xc7$.o\xf1\xf3R\x17\x1db\xbb\xd8U\xa5^\xccS"
+ b"\x16\x01\x87\xf3/\x93\xd1\xf0v\xc0r\xd7\xcc\xa2Gkz\xca\x80\x0e\xfd\xd0"
+ b"\x8b\xbb\xd2Ix\xb3\x1ey\xca-0\xe3z^\xd6\xd6\x8f_\xf1\x9dP\x9fi\xa7\xd1"
+ b"\xe8\x90\x84\xdc\xbf\xcdky\x8e\xdc\x81\x7f\xa3\xb2+\xbf\x04\xef\xd8\\"
+ b"\xc4\xdf\xe1\xb0\x01\xe9\x93\xe3Y\xf1\x1dY\xe8h\x81\xcf\xf1w\xcc\xb4\xef"
+ b" \x8b|\x04\xea\x83ej\xbe\x1f\xd4z\x9c`\xd3\x1a\x92A\x06\xe5\x8f\xa9\x13"
+ b"\t\x9e=\xfa\x1c\xe5_\x9f%v\x1bo\x11ZO\xd8\xf4\t\xddM\x16-\x04\xfc\x18<\""
+ b"CM\xddg~b\xf6\xef\x8e\x0c\xd0\xde|\xa0'\x8a\x0c\xd6x\xae!J\xa6F\x88\x15u"
+ b"\x008\x17\xbc7y\xb3\xd8u\xac_\x85\x8d\xe7\xc1@\x9c\xecqc\xa3#\xad\xf1"
+ b"\x935\xb5)_\r\xec3]\x0fo]5\xd0my\x07\x9b\xee\x81\xb5\x0f\xcfK+\x00\xc0"
+ b"\xe4b\x10\xe4\x0c\x1a \x9b\xe0\x97t\xf6\xa1\x9e\x850\xba\x0c\x9a\x8d\xc8"
+ b"\x8f\x07\xd7\xae\xc8\xf9+i\xdc\xb9k\xb0>f\x19\xb8\r\xa8\xf8\x1f$\xa5{p"
+ b"\xc6\x880\xce\xdb\xcf\xca_\x86\xac\x88h6\x8bZ%'\xd0\n\xbf\x0f\x9c\"\xba"
+ b"\xe5\x86\x9f\x0f7X=mNX[\xcc\x19FU\xc9\x860\xbc\x90a+* \xae_$\x03\x1e\xd3"
+ b"\xcd_\xa0\x9c\xde\xaf46q\xa5\xc9\x92\xd7\xca\xe3`\x9d\x85}\xb4\xff\xb3"
+ b"\x83\xfb\xb6\xca\xae`\x0bw\x7f\xfc\xd8\xacVe\x19\xc8\x17\x0bZ\xad\x88"
+ b"\xeb#\x97\x03\x13\xb1d\x0f{\x0c\x04w\x07\r\x97\xbd\xd6\xc1\xc3B:\x95\x08"
+ b"^\x10V\xaeaH\x02\xd9\xe3\n\\\x01X\xf6\x9c\x8a\x06u#%\xbe*\xa1\x18v\x85"
+ b"\xec!\t4\x00\x00\x00\x00Vj?uLU\xf3\xa6\x00\x01\xfb\x07\x81\x0f\x00\x00tw"
+ b"\x99P\xb1\xc4g\xfb\x02\x00\x00\x00\x00\x04YZ"
+)
+
+COMPRESSED_ALONE = (
+ b"]\x00\x00\x80\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x05\x14\x07bX\x19"
+ b"\xcd\xddn\x98\x15\xe4\xb4\x9do\x1d\xc4\xe5\n\x03\xcc2h\xc7\\\x86\xff\xf8"
+ b"\xe2\xfc\xe7\xd9\xfe6\xb8(\xa8wd\xc2\"u.n\x1e\xc3\xf2\x8e\x8d\x8f\x02"
+ b"\x17/\xa6=\xf0\xa2\xdf/M\x89\xbe\xde\xa7\x1cz\x18-]\xd5\xef\x13\x8frZ"
+ b"\x15\x80\x8c\xf8\x8do\xfa\x12\x9b#z/\xef\xf0\xfaF\x01\x82\xa3M\x8e\xa1t"
+ b"\xca6 BF$\xe5Q\xa4\x98\xee\xdel\xe8\x7f\xf0\x9d,bn\x0b\x13\xd4\xa8\x81"
+ b"\xe4N\xc8\x86\x153\xf5x2\xa2O\x13@Q\xa1\x00/\xa5\xd0O\x97\xdco\xae\xf7z"
+ b"\xc4\xcdS\xb6t<\x16\xf2\x9cI#\x89ud\xc66Y\xd9\xee\xe6\xce\x12]\xe5\xf0"
+ b"\xaa\x96-Pe\xade:\x04\t\x1b\xf7\xdb7\n\x86\x1fp\xc8J\xba\xf4\xf0V\xa9"
+ b"\xdc\xf0\x02%G\xf9\xdf=?\x15\x1b\xe1(\xce\x82=\xd6I\xac3\x12\x0cR\xb7"
+ b"\xae\r\xb1i\x03\x95\x01\xbd\xbe\xfa\x02s\x01P\x9d\x96X\xb12j\xc8L\xa8"
+ b"\x84b\xf8\x1epl\xeajr\xd1=\t\x03\xdd\x13\x1b3!E\xf9vV\xdaF\xf3\xd7\xb4"
+ b"\x0c\xa9P~\xec\xdeE\xe37\xf6\x1d\xc6\xbb\xddc%\xb6\x0fI\x07\xf0;\xaf\xe7"
+ b"\xa0\x8b\xa7Z\x99(\xe9\xe2\xf0o\x18>`\xe1\xaa\xa8\xd9\xa1\xb2}\xe7\x8d"
+ b"\x834T\xb6\xef\xc1\xde\xe3\x98\xbcD\x03MA@\xd8\xed\xdc\xc8\x93\x03\x1a"
+ b"\x93\x0b\x7f\x94\x12\x0b\x02Sa\x18\xc9\xc5\x9bTJE}\xf6\xc8g\x17#ZV\x01"
+ b"\xc9\x9dc\x83\x0e>0\x16\x90S\xb8/\x03y_\x18\xfa(\xd7\x0br\xa2\xb0\xba?"
+ b"\x8c\xe6\x83@\x84\xdf\x02:\xc5z\x9e\xa6\x84\xc9\xf5BeyX\x83\x1a\xf1 :\t"
+ b"\xf7\x19\xfexD\\&G\xf3\x85Y\xa2J\xf9\x0bv{\x89\xf6\xe7)A\xaf\x04o\x00"
+ b"\x075\xd3\xe0\x7f\x97\x98F\x0f?v\x93\xedVtTf\xb5\x97\x83\xed\x19\xd7\x1a"
+ b"'k\xd7\xd9\xc5\\Y\xd1\xdc\x07\x15|w\xbc\xacd\x87\x08d\xec\xa7\xf6\x82"
+ b"\xfc\xb3\x93\xeb\xb9 \x8d\xbc ,\xb3X\xb0\xd2s\xd7\xd1\xffv\x05\xdf}\xa2"
+ b"\x96\xfb%\n\xdf\xa2\x7f\x08.\xa16\n\xe0\x19\x93\x7fh\n\x1c\x8c\x0f \x11"
+ b"\xc6Bl\x95\x19U}\xe4s\xb5\x10H\xea\x86pB\xe88\x95\xbe\x8cZ\xdb\xe4\x94A"
+ b"\x92\xb9;z\xaa\xa7{\x1c5!\xc0\xaf\xc1A\xf9\xda\xf0$\xb0\x02qg\xc8\xc7/|"
+ b"\xafr\x99^\x91\x88\xbf\x03\xd9=\xd7n\xda6{>8\n\xc7:\xa9'\xba.\x0b\xe2"
+ b"\xb5\x1d\x0e\n\x9a\x8e\x06\x8f:\xdd\x82'[\xc3\"wD$\xa7w\xecq\x8c,1\x93"
+ b"\xd0,\xae2w\x93\x12$Jd\x19mg\x02\x93\x9cA\x95\x9d&\xca8i\x9c\xb0;\xe7NQ"
+ b"\x1frh\x8beL;\xb0m\xee\x07Q\x9b\xc6\xd8\x03\xb5\xdeN\xd4\xfe\x98\xd0\xdc"
+ b"\x1a[\x04\xde\x1a\xf6\x91j\xf8EOli\x8eB^\x1d\x82\x07\xb2\xb5R]\xb7\xd7"
+ b"\xe9\xa6\xc3.\xfb\xf0-\xb4e\x9b\xde\x03\x88\xc6\xc1iN\x0e\x84wbQ\xdf~"
+ b"\xe9\xa4\x884\x96kM\xbc)T\xf3\x89\x97\x0f\x143\xe7)\xa0\xb3B\x00\xa8\xaf"
+ b"\x82^\xcb\xc7..\xdb\xc7\t\x9dH\xee5\xe9#\xe6NV\x94\xcb$Kk\xe3\x7f\r\xe3t"
+ b"\x12\xcf'\xefR\x8b\xf42\xcf-LH\xac\xe5\x1f0~?SO\xeb\xc1E\x1a\x1c]\xf2"
+ b"\xc4<\x11\x02\x10Z0a*?\xe4r\xff\xfb\xff\xf6\x14nG\xead^\xd6\xef8\xb6uEI"
+ b"\x99\nV\xe2\xb3\x95\x8e\x83\xf6i!\xb5&1F\xb1DP\xf4 SO3D!w\x99_G\x7f+\x90"
+ b".\xab\xbb]\x91>\xc9#h;\x0f5J\x91K\xf4^-[\x9e\x8a\\\x94\xca\xaf\xf6\x19"
+ b"\xd4\xa1\x9b\xc4\xb8p\xa1\xae\x15\xe9r\x84\xe0\xcar.l []\x8b\xaf+0\xf2g"
+ b"\x01aKY\xdfI\xcf,\n\xe8\xf0\xe7V\x80_#\xb2\xf2\xa9\x06\x8c>w\xe2W,\xf4"
+ b"\x8c\r\xf963\xf5J\xcc2\x05=kT\xeaUti\xe5_\xce\x1b\xfa\x8dl\x02h\xef\xa8"
+ b"\xfbf\x7f\xff\xf0\x19\xeax"
+)
+
+FILTERS_RAW_1 = [{"id": lzma.FILTER_LZMA2, "preset": 3}]
+COMPRESSED_RAW_1 = (
+ b"\xe0\x07\x80\x03\xfd]\x00\x05\x14\x07bX\x19\xcd\xddn\x96cyq\xa1\xdd\xee"
+ b"\xf8\xfam\xe3'\x88\xd3\xff\xe4\x9e \xceQ\x91\xa4\x14I\xf6\xb9\x9dVL8\x15"
+ b"_\x0e\x12\xc3\xeb\xbc\xa5\xcd\nW\x1d$=R;\x1d\xf8k8\t\xb1{\xd4\xc5+\x9d"
+ b"\x87c\xe5\xef\x98\xb4\xd7S3\xcd\xcc\xd2\xed\xa4\x0em\xe5\xf4\xdd\xd0b"
+ b"\xbe4*\xaa\x0b\xc5\x08\x10\x85+\x81.\x17\xaf9\xc9b\xeaZrA\xe20\x7fs\"r"
+ b"\xdaG\x81\xde\x90cu\xa5\xdb\xa9.A\x08l\xb0<\xf6\x03\xddOi\xd0\xc5\xb4"
+ b"\xec\xecg4t6\"\xa6\xb8o\xb5?\x18^}\xb6}\x03[:\xeb\x03\xa9\n[\x89l\x19g"
+ b"\x16\xc82\xed\x0b\xfb\x86n\xa2\x857@\x93\xcd6T\xc3u\xb0\t\xf9\x1b\x918"
+ b"\xfc[\x1b\x1e4\xb3\x14\x06PCV\xa8\"\xf5\x81x~\xe9\xb5N\x9cK\x9f\xc6\xc3%"
+ b"\xc8k:{6\xe7\xf7\xbd\x05\x02\xb4\xc4\xc3\xd3\xfd\xc3\xa8\\\xfc@\xb1F_"
+ b"\xc8\x90\xd9sU\x98\xad8\x05\x07\xde7J\x8bM\xd0\xb3;X\xec\x87\xef\xae\xb3"
+ b"eO,\xb1z,d\x11y\xeejlB\x02\x1d\xf28\x1f#\x896\xce\x0b\xf0\xf5\xa9PK\x0f"
+ b"\xb3\x13P\xd8\x88\xd2\xa1\x08\x04C?\xdb\x94_\x9a\"\xe9\xe3e\x1d\xde\x9b"
+ b"\xa1\xe8>H\x98\x10;\xc5\x03#\xb5\x9d4\x01\xe7\xc5\xba%v\xa49\x97A\xe0\""
+ b"\x8c\xc22\xe3i\xc1\x9d\xab3\xdf\xbe\xfdDm7\x1b\x9d\xab\xb5\x15o:J\x92"
+ b"\xdb\x816\x17\xc2O\x99\x1b\x0e\x8d\xf3\tQ\xed\x8e\x95S/\x16M\xb2S\x04"
+ b"\x0f\xc3J\xc6\xc7\xe4\xcb\xc5\xf4\xe7d\x14\xe4=^B\xfb\xd3E\xd3\x1e\xcd"
+ b"\x91\xa5\xd0G\x8f.\xf6\xf9\x0bb&\xd9\x9f\xc2\xfdj\xa2\x9e\xc4\\\x0e\x1dC"
+ b"v\xe8\xd2\x8a?^H\xec\xae\xeb>\xfe\xb8\xab\xd4IqY\x8c\xd4K7\x11\xf4D\xd0W"
+ b"\xa5\xbe\xeaO\xbf\xd0\x04\xfdl\x10\xae5\xd4U\x19\x06\xf9{\xaa\xe0\x81"
+ b"\x0f\xcf\xa3k{\x95\xbd\x19\xa2\xf8\xe4\xa3\x08O*\xf1\xf1B-\xc7(\x0eR\xfd"
+ b"@E\x9f\xd3\x1e:\xfdV\xb7\x04Y\x94\xeb]\x83\xc4\xa5\xd7\xc0gX\x98\xcf\x0f"
+ b"\xcd3\x00]n\x17\xec\xbd\xa3Y\x86\xc5\xf3u\xf6*\xbdT\xedA$A\xd9A\xe7\x98"
+ b"\xef\x14\x02\x9a\xfdiw\xec\xa0\x87\x11\xd9%\xc5\xeb\x8a=\xae\xc0\xc4\xc6"
+ b"D\x80\x8f\xa8\xd1\xbbq\xb2\xc0\xa0\xf5Cqp\xeeL\xe3\xe5\xdc \x84\"\xe9"
+ b"\x80t\x83\x05\xba\xf1\xc5~\x93\xc9\xf0\x01c\xceix\x9d\xed\xc5)l\x16)\xd1"
+ b"\x03@l\x04\x7f\x87\xa5yn\x1b\x01D\xaa:\xd2\x96\xb4\xb3?\xb0\xf9\xce\x07"
+ b"\xeb\x81\x00\xe4\xc3\xf5%_\xae\xd4\xf9\xeb\xe2\rh\xb2#\xd67Q\x16D\x82hn"
+ b"\xd1\xa3_?q\xf0\xe2\xac\xf317\x9e\xd0_\x83|\xf1\xca\xb7\x95S\xabW\x12"
+ b"\xff\xddt\xf69L\x01\xf2|\xdaW\xda\xees\x98L\x18\xb8_\xe8$\x82\xea\xd6"
+ b"\xd1F\xd4\x0b\xcdk\x01vf\x88h\xc3\xae\xb91\xc7Q\x9f\xa5G\xd9\xcc\x1f\xe3"
+ b"5\xb1\xdcy\x7fI\x8bcw\x8e\x10rIp\x02:\x19p_\xc8v\xcea\"\xc1\xd9\x91\x03"
+ b"\xbfe\xbe\xa6\xb3\xa8\x14\x18\xc3\xabH*m}\xc2\xc1\x9a}>l%\xce\x84\x99"
+ b"\xb3d\xaf\xd3\x82\x15\xdf\xc1\xfc5fOg\x9b\xfc\x8e^&\t@\xce\x9f\x06J\xb8"
+ b"\xb5\x86\x1d\xda{\x9f\xae\xb0\xff\x02\x81r\x92z\x8cM\xb7ho\xc9^\x9c\xb6"
+ b"\x9c\xae\xd1\xc9\xf4\xdfU7\xd6\\!\xea\x0b\x94k\xb9Ud~\x98\xe7\x86\x8az"
+ b"\x10;\xe3\x1d\xe5PG\xf8\xa4\x12\x05w\x98^\xc4\xb1\xbb\xfb\xcf\xe0\x7f"
+ b"\x033Sf\x0c \xb1\xf6@\x94\xe5\xa3\xb2\xa7\x10\x9a\xc0\x14\xc3s\xb5xRD"
+ b"\xf4`W\xd9\xe5\xd3\xcf\x91\rTZ-X\xbe\xbf\xb5\xe2\xee|\x1a\xbf\xfb\x08"
+ b"\x91\xe1\xfc\x9a\x18\xa3\x8b\xd6^\x89\xf5[\xef\x87\xd1\x06\x1c7\xd6\xa2"
+ b"\t\tQ5/@S\xc05\xd2VhAK\x03VC\r\x9b\x93\xd6M\xf1xO\xaaO\xed\xb9<\x0c\xdae"
+ b"*\xd0\x07Hk6\x9fG+\xa1)\xcd\x9cl\x87\xdb\xe1\xe7\xefK}\x875\xab\xa0\x19u"
+ b"\xf6*F\xb32\x00\x00\x00"
+)
+
+FILTERS_RAW_2 = [{"id": lzma.FILTER_DELTA, "dist": 2},
+ {"id": lzma.FILTER_LZMA2,
+ "preset": lzma.PRESET_DEFAULT | lzma.PRESET_EXTREME}]
+COMPRESSED_RAW_2 = (
+ b"\xe0\x07\x80\x05\x91]\x00\x05\x14\x06-\xd4\xa8d?\xef\xbe\xafH\xee\x042"
+ b"\xcb.\xb5g\x8f\xfb\x14\xab\xa5\x9f\x025z\xa4\xdd\xd8\t[}W\xf8\x0c\x1dmH"
+ b"\xfa\x05\xfcg\xba\xe5\x01Q\x0b\x83R\xb6A\x885\xc0\xba\xee\n\x1cv~\xde:o"
+ b"\x06:J\xa7\x11Cc\xea\xf7\xe5*o\xf7\x83\\l\xbdE\x19\x1f\r\xa8\x10\xb42"
+ b"\x0caU{\xd7\xb8w\xdc\xbe\x1b\xfc8\xb4\xcc\xd38\\\xf6\x13\xf6\xe7\x98\xfa"
+ b"\xc7[\x17_9\x86%\xa8\xf8\xaa\xb8\x8dfs#\x1e=\xed<\x92\x10\\t\xff\x86\xfb"
+ b"=\x9e7\x18\x1dft\\\xb5\x01\x95Q\xc5\x19\xb38\xe0\xd4\xaa\x07\xc3\x7f\xd8"
+ b"\xa2\x00>-\xd3\x8e\xa1#\xfa\x83ArAm\xdbJ~\x93\xa3B\x82\xe0\xc7\xcc(\x08`"
+ b"WK\xad\x1b\x94kaj\x04 \xde\xfc\xe1\xed\xb0\x82\x91\xefS\x84%\x86\xfbi"
+ b"\x99X\xf1B\xe7\x90;E\xfde\x98\xda\xca\xd6T\xb4bg\xa4\n\x9aj\xd1\x83\x9e]"
+ b"\"\x7fM\xb5\x0fr\xd2\\\xa5j~P\x10GH\xbfN*Z\x10.\x81\tpE\x8a\x08\xbe1\xbd"
+ b"\xcd\xa9\xe1\x8d\x1f\x04\xf9\x0eH\xb9\xae\xd6\xc3\xc1\xa5\xa9\x95P\xdc~"
+ b"\xff\x01\x930\xa9\x04\xf6\x03\xfe\xb5JK\xc3]\xdd9\xb1\xd3\xd7F\xf5\xd1"
+ b"\x1e\xa0\x1c_\xed[\x0c\xae\xd4\x8b\x946\xeb\xbf\xbb\xe3$kS{\xb5\x80,f:Sj"
+ b"\x0f\x08z\x1c\xf5\xe8\xe6\xae\x98\xb0Q~r\x0f\xb0\x05?\xb6\x90\x19\x02&"
+ b"\xcb\x80\t\xc4\xea\x9c|x\xce\x10\x9c\xc5|\xcbdhh+\x0c'\xc5\x81\xc33\xb5"
+ b"\x14q\xd6\xc5\xe3`Z#\xdc\x8a\xab\xdd\xea\x08\xc2I\xe7\x02l{\xec\x196\x06"
+ b"\x91\x8d\xdc\xd5\xb3x\xe1hz%\xd1\xf8\xa5\xdd\x98!\x8c\x1c\xc1\x17RUa\xbb"
+ b"\x95\x0f\xe4X\xea1\x0c\xf1=R\xbe\xc60\xe3\xa4\x9a\x90bd\x97$]B\x01\xdd"
+ b"\x1f\xe3h2c\x1e\xa0L`4\xc6x\xa3Z\x8a\r\x14]T^\xd8\x89\x1b\x92\r;\xedY"
+ b"\x0c\xef\x8d9z\xf3o\xb6)f\xa9]$n\rp\x93\xd0\x10\xa4\x08\xb8\xb2\x8b\xb6"
+ b"\x8f\x80\xae;\xdcQ\xf1\xfa\x9a\x06\x8e\xa5\x0e\x8cK\x9c @\xaa:UcX\n!\xc6"
+ b"\x02\x12\xcb\x1b\"=\x16.\x1f\x176\xf2g=\xe1Wn\xe9\xe1\xd4\xf1O\xad\x15"
+ b"\x86\xe9\xa3T\xaf\xa9\xd7D\xb5\xd1W3pnt\x11\xc7VOj\xb7M\xc4i\xa1\xf1$3"
+ b"\xbb\xdc\x8af\xb0\xc5Y\r\xd1\xfb\xf2\xe7K\xe6\xc5hwO\xfe\x8c2^&\x07\xd5"
+ b"\x1fV\x19\xfd\r\x14\xd2i=yZ\xe6o\xaf\xc6\xb6\x92\x9d\xc4\r\xb3\xafw\xac%"
+ b"\xcfc\x1a\xf1`]\xf2\x1a\x9e\x808\xedm\xedQ\xb2\xfe\xe4h`[q\xae\xe0\x0f"
+ b"\xba0g\xb6\"N\xc3\xfb\xcfR\x11\xc5\x18)(\xc40\\\xa3\x02\xd9G!\xce\x1b"
+ b"\xc1\x96x\xb5\xc8z\x1f\x01\xb4\xaf\xde\xc2\xcd\x07\xe7H\xb3y\xa8M\n\\A\t"
+ b"ar\xddM\x8b\x9a\xea\x84\x9b!\xf1\x8d\xb1\xf1~\x1e\r\xa5H\xba\xf1\x84o"
+ b"\xda\x87\x01h\xe9\xa2\xbe\xbeqN\x9d\x84\x0b!WG\xda\xa1\xa5A\xb7\xc7`j"
+ b"\x15\xf2\xe9\xdd?\x015B\xd2~E\x06\x11\xe0\x91!\x05^\x80\xdd\xa8y\x15}"
+ b"\xa1)\xb1)\x81\x18\xf4\xf4\xf8\xc0\xefD\xe3\xdb2f\x1e\x12\xabu\xc9\x97"
+ b"\xcd\x1e\xa7\x0c\x02x4_6\x03\xc4$t\xf39\x94\x1d=\xcb\xbfv\\\xf5\xa3\x1d"
+ b"\x9d8jk\x95\x13)ff\xf9n\xc4\xa9\xe3\x01\xb8\xda\xfb\xab\xdfM\x99\xfb\x05"
+ b"\xe0\xe9\xb0I\xf4E\xab\xe2\x15\xa3\x035\xe7\xdeT\xee\x82p\xb4\x88\xd3"
+ b"\x893\x9c/\xc0\xd6\x8fou;\xf6\x95PR\xa9\xb2\xc1\xefFj\xe2\xa7$\xf7h\xf1"
+ b"\xdfK(\xc9c\xba7\xe8\xe3)\xdd\xb2,\x83\xfb\x84\x18.y\x18Qi\x88\xf8`h-"
+ b"\xef\xd5\xed\x8c\t\xd8\xc3^\x0f\x00\xb7\xd0[!\xafM\x9b\xd7.\x07\xd8\xfb"
+ b"\xd9\xe2-S+\xaa8,\xa0\x03\x1b \xea\xa8\x00\xc3\xab~\xd0$e\xa5\x7f\xf7"
+ b"\x95P]\x12\x19i\xd9\x7fo\x0c\xd8g^\rE\xa5\x80\x18\xc5\x01\x80\xaek`\xff~"
+ b"\xb6y\xe7+\xe5\x11^D\xa7\x85\x18\"!\xd6\xd2\xa7\xf4\x1eT\xdb\x02\xe15"
+ b"\x02Y\xbc\x174Z\xe7\x9cH\x1c\xbf\x0f\xc6\xe9f]\xcf\x8cx\xbc\xe5\x15\x94"
+ b"\xfc3\xbc\xa7TUH\xf1\x84\x1b\xf7\xa9y\xc07\x84\xf8X\xd8\xef\xfc \x1c\xd8"
+ b"( /\xf2\xb7\xec\xc1\\\x8c\xf6\x95\xa1\x03J\x83vP8\xe1\xe3\xbb~\xc24kA"
+ b"\x98y\xa1\xf2P\xe9\x9d\xc9J\xf8N\x99\xb4\xceaO\xde\x16\x1e\xc2\x19\xa7"
+ b"\x03\xd2\xe0\x8f:\x15\xf3\x84\x9e\xee\xe6e\xb8\x02q\xc7AC\x1emw\xfd\t"
+ b"\x9a\x1eu\xc1\xa9\xcaCwUP\x00\xa5\xf78L4w!\x91L2 \x87\xd0\xf2\x06\x81j"
+ b"\x80;\x03V\x06\x87\x92\xcb\x90lv@E\x8d\x8d\xa5\xa6\xe7Z[\xdf\xd6E\x03`>"
+ b"\x8f\xde\xa1bZ\x84\xd0\xa9`\x05\x0e{\x80;\xe3\xbef\x8d\x1d\xebk1.\xe3"
+ b"\xe9N\x15\xf7\xd4(\xfa\xbb\x15\xbdu\xf7\x7f\x86\xae!\x03L\x1d\xb5\xc1"
+ b"\xb9\x11\xdb\xd0\x93\xe4\x02\xe1\xd2\xcbBjc_\xe8}d\xdb\xc3\xa0Y\xbe\xc9/"
+ b"\x95\x01\xa3,\xe6bl@\x01\xdbp\xc2\xce\x14\x168\xc2q\xe3uH\x89X\xa4\xa9"
+ b"\x19\x1d\xc1}\x7fOX\x19\x9f\xdd\xbe\x85\x83\xff\x96\x1ee\x82O`CF=K\xeb$I"
+ b"\x17_\xefX\x8bJ'v\xde\x1f+\xd9.v\xf8Tv\x17\xf2\x9f5\x19\xe1\xb9\x91\xa8S"
+ b"\x86\xbd\x1a\"(\xa5x\x8dC\x03X\x81\x91\xa8\x11\xc4pS\x13\xbc\xf2'J\xae!"
+ b"\xef\xef\x84G\t\x8d\xc4\x10\x132\x00oS\x9e\xe0\xe4d\x8f\xb8y\xac\xa6\x9f"
+ b",\xb8f\x87\r\xdf\x9eE\x0f\xe1\xd0\\L\x00\xb2\xe1h\x84\xef}\x98\xa8\x11"
+ b"\xccW#\\\x83\x7fo\xbbz\x8f\x00"
+)
+
+FILTERS_RAW_3 = [{"id": lzma.FILTER_IA64, "start_offset": 0x100},
+ {"id": lzma.FILTER_LZMA2}]
+COMPRESSED_RAW_3 = (
+ b"\xe0\x07\x80\x03\xdf]\x00\x05\x14\x07bX\x19\xcd\xddn\x98\x15\xe4\xb4\x9d"
+ b"o\x1d\xc4\xe5\n\x03\xcc2h\xc7\\\x86\xff\xf8\xe2\xfc\xe7\xd9\xfe6\xb8("
+ b"\xa8wd\xc2\"u.n\x1e\xc3\xf2\x8e\x8d\x8f\x02\x17/\xa6=\xf0\xa2\xdf/M\x89"
+ b"\xbe\xde\xa7\x1cz\x18-]\xd5\xef\x13\x8frZ\x15\x80\x8c\xf8\x8do\xfa\x12"
+ b"\x9b#z/\xef\xf0\xfaF\x01\x82\xa3M\x8e\xa1t\xca6 BF$\xe5Q\xa4\x98\xee\xde"
+ b"l\xe8\x7f\xf0\x9d,bn\x0b\x13\xd4\xa8\x81\xe4N\xc8\x86\x153\xf5x2\xa2O"
+ b"\x13@Q\xa1\x00/\xa5\xd0O\x97\xdco\xae\xf7z\xc4\xcdS\xb6t<\x16\xf2\x9cI#"
+ b"\x89ud\xc66Y\xd9\xee\xe6\xce\x12]\xe5\xf0\xaa\x96-Pe\xade:\x04\t\x1b\xf7"
+ b"\xdb7\n\x86\x1fp\xc8J\xba\xf4\xf0V\xa9\xdc\xf0\x02%G\xf9\xdf=?\x15\x1b"
+ b"\xe1(\xce\x82=\xd6I\xac3\x12\x0cR\xb7\xae\r\xb1i\x03\x95\x01\xbd\xbe\xfa"
+ b"\x02s\x01P\x9d\x96X\xb12j\xc8L\xa8\x84b\xf6\xc3\xd4c-H\x93oJl\xd0iQ\xe4k"
+ b"\x84\x0b\xc1\xb7\xbc\xb1\x17\x88\xb1\xca?@\xf6\x07\xea\xe6x\xf1H12P\x0f"
+ b"\x8a\xc9\xeauw\xe3\xbe\xaai\xa9W\xd0\x80\xcd#cb5\x99\xd8]\xa9d\x0c\xbd"
+ b"\xa2\xdcWl\xedUG\xbf\x89yF\xf77\x81v\xbd5\x98\xbeh8\x18W\x08\xf0\x1b\x99"
+ b"5:\x1a?rD\x96\xa1\x04\x0f\xae\xba\x85\xeb\x9d5@\xf5\x83\xd37\x83\x8ac"
+ b"\x06\xd4\x97i\xcdt\x16S\x82k\xf6K\x01vy\x88\x91\x9b6T\xdae\r\xfd]:k\xbal"
+ b"\xa9\xbba\xc34\xf9r\xeb}r\xdb\xc7\xdb*\x8f\x03z\xdc8h\xcc\xc9\xd3\xbcl"
+ b"\xa5-\xcb\xeaK\xa2\xc5\x15\xc0\xe3\xc1\x86Z\xfb\xebL\xe13\xcf\x9c\xe3"
+ b"\x1d\xc9\xed\xc2\x06\xcc\xce!\x92\xe5\xfe\x9c^\xa59w \x9bP\xa3PK\x08d"
+ b"\xf9\xe2Z}\xa7\xbf\xed\xeb%$\x0c\x82\xb8/\xb0\x01\xa9&,\xf7qh{Q\x96)\xf2"
+ b"q\x96\xc3\x80\xb4\x12\xb0\xba\xe6o\xf4!\xb4[\xd4\x8aw\x10\xf7t\x0c\xb3"
+ b"\xd9\xd5\xc3`^\x81\x11??\\\xa4\x99\x85R\xd4\x8e\x83\xc9\x1eX\xbfa\xf1"
+ b"\xac\xb0\xea\xea\xd7\xd0\xab\x18\xe2\xf2\xed\xe1\xb7\xc9\x18\xcbS\xe4>"
+ b"\xc9\x95H\xe8\xcb\t\r%\xeb\xc7$.o\xf1\xf3R\x17\x1db\xbb\xd8U\xa5^\xccS"
+ b"\x16\x01\x87\xf3/\x93\xd1\xf0v\xc0r\xd7\xcc\xa2Gkz\xca\x80\x0e\xfd\xd0"
+ b"\x8b\xbb\xd2Ix\xb3\x1ey\xca-0\xe3z^\xd6\xd6\x8f_\xf1\x9dP\x9fi\xa7\xd1"
+ b"\xe8\x90\x84\xdc\xbf\xcdky\x8e\xdc\x81\x7f\xa3\xb2+\xbf\x04\xef\xd8\\"
+ b"\xc4\xdf\xe1\xb0\x01\xe9\x93\xe3Y\xf1\x1dY\xe8h\x81\xcf\xf1w\xcc\xb4\xef"
+ b" \x8b|\x04\xea\x83ej\xbe\x1f\xd4z\x9c`\xd3\x1a\x92A\x06\xe5\x8f\xa9\x13"
+ b"\t\x9e=\xfa\x1c\xe5_\x9f%v\x1bo\x11ZO\xd8\xf4\t\xddM\x16-\x04\xfc\x18<\""
+ b"CM\xddg~b\xf6\xef\x8e\x0c\xd0\xde|\xa0'\x8a\x0c\xd6x\xae!J\xa6F\x88\x15u"
+ b"\x008\x17\xbc7y\xb3\xd8u\xac_\x85\x8d\xe7\xc1@\x9c\xecqc\xa3#\xad\xf1"
+ b"\x935\xb5)_\r\xec3]\x0fo]5\xd0my\x07\x9b\xee\x81\xb5\x0f\xcfK+\x00\xc0"
+ b"\xe4b\x10\xe4\x0c\x1a \x9b\xe0\x97t\xf6\xa1\x9e\x850\xba\x0c\x9a\x8d\xc8"
+ b"\x8f\x07\xd7\xae\xc8\xf9+i\xdc\xb9k\xb0>f\x19\xb8\r\xa8\xf8\x1f$\xa5{p"
+ b"\xc6\x880\xce\xdb\xcf\xca_\x86\xac\x88h6\x8bZ%'\xd0\n\xbf\x0f\x9c\"\xba"
+ b"\xe5\x86\x9f\x0f7X=mNX[\xcc\x19FU\xc9\x860\xbc\x90a+* \xae_$\x03\x1e\xd3"
+ b"\xcd_\xa0\x9c\xde\xaf46q\xa5\xc9\x92\xd7\xca\xe3`\x9d\x85}\xb4\xff\xb3"
+ b"\x83\xfb\xb6\xca\xae`\x0bw\x7f\xfc\xd8\xacVe\x19\xc8\x17\x0bZ\xad\x88"
+ b"\xeb#\x97\x03\x13\xb1d\x0f{\x0c\x04w\x07\r\x97\xbd\xd6\xc1\xc3B:\x95\x08"
+ b"^\x10V\xaeaH\x02\xd9\xe3\n\\\x01X\xf6\x9c\x8a\x06u#%\xbe*\xa1\x18v\x85"
+ b"\xec!\t4\x00\x00\x00"
+)
+
+FILTERS_RAW_4 = [{"id": lzma.FILTER_DELTA, "dist": 4},
+ {"id": lzma.FILTER_X86, "start_offset": 0x40},
+ {"id": lzma.FILTER_LZMA2, "preset": 4, "lc": 2}]
+COMPRESSED_RAW_4 = (
+ b"\xe0\x07\x80\x06\x0e\\\x00\x05\x14\x07bW\xaah\xdd\x10\xdc'\xd6\x90,\xc6v"
+ b"Jq \x14l\xb7\x83xB\x0b\x97f=&fx\xba\n>Tn\xbf\x8f\xfb\x1dF\xca\xc3v_\xca?"
+ b"\xfbV<\x92#\xd4w\xa6\x8a\xeb\xf6\x03\xc0\x01\x94\xd8\x9e\x13\x12\x98\xd1"
+ b"*\xfa]c\xe8\x1e~\xaf\xb5]Eg\xfb\x9e\x01\"8\xb2\x90\x06=~\xe9\x91W\xcd"
+ b"\xecD\x12\xc7\xfa\xe1\x91\x06\xc7\x99\xb9\xe3\x901\x87\x19u\x0f\x869\xff"
+ b"\xc1\xb0hw|\xb0\xdcl\xcck\xb16o7\x85\xee{Y_b\xbf\xbc$\xf3=\x8d\x8bw\xe5Z"
+ b"\x08@\xc4kmE\xad\xfb\xf6*\xd8\xad\xa1\xfb\xc5{\xdej,)\x1emB\x1f<\xaeca"
+ b"\x80(\xee\x07 \xdf\xe9\xf8\xeb\x0e-\x97\x86\x90c\xf9\xea'B\xf7`\xd7\xb0"
+ b"\x92\xbd\xa0\x82]\xbd\x0e\x0eB\x19\xdc\x96\xc6\x19\xd86D\xf0\xd5\x831"
+ b"\x03\xb7\x1c\xf7&5\x1a\x8f PZ&j\xf8\x98\x1bo\xcc\x86\x9bS\xd3\xa5\xcdu"
+ b"\xf9$\xcc\x97o\xe5V~\xfb\x97\xb5\x0b\x17\x9c\xfdxW\x10\xfep4\x80\xdaHDY"
+ b"\xfa)\xfet\xb5\"\xd4\xd3F\x81\xf4\x13\x1f\xec\xdf\xa5\x13\xfc\"\x91x\xb7"
+ b"\x99\xce\xc8\x92\n\xeb[\x10l*Y\xd8\xb1@\x06\xc8o\x8d7r\xebu\xfd5\x0e\x7f"
+ b"\xf1$U{\t}\x1fQ\xcfxN\x9d\x9fXX\xe9`\x83\xc1\x06\xf4\x87v-f\x11\xdb/\\"
+ b"\x06\xff\xd7)B\xf3g\x06\x88#2\x1eB244\x7f4q\t\xc893?mPX\x95\xa6a\xfb)d"
+ b"\x9b\xfc\x98\x9aj\x04\xae\x9b\x9d\x19w\xba\xf92\xfaA\x11\\\x17\x97C3\xa4"
+ b"\xbc!\x88\xcdo[\xec:\x030\x91.\x85\xe0@\\4\x16\x12\x9d\xcaJv\x97\xb04"
+ b"\xack\xcbkf\xa3ss\xfc\x16^\x8ce\x85a\xa5=&\xecr\xb3p\xd1E\xd5\x80y\xc7"
+ b"\xda\xf6\xfek\xbcT\xbfH\xee\x15o\xc5\x8c\x830\xec\x1d\x01\xae\x0c-e\\"
+ b"\x91\x90\x94\xb2\xf8\x88\x91\xe8\x0b\xae\xa7>\x98\xf6\x9ck\xd2\xc6\x08"
+ b"\xe6\xab\t\x98\xf2!\xa0\x8c^\xacqA\x99<\x1cEG\x97\xc8\xf1\xb6\xb9\x82"
+ b"\x8d\xf7\x08s\x98a\xff\xe3\xcc\x92\x0e\xd2\xb6U\xd7\xd9\x86\x7fa\xe5\x1c"
+ b"\x8dTG@\t\x1e\x0e7*\xfc\xde\xbc]6N\xf7\xf1\x84\x9e\x9f\xcf\xe9\x1e\xb5'"
+ b"\xf4<\xdf\x99sq\xd0\x9d\xbd\x99\x0b\xb4%p4\xbf{\xbb\x8a\xd2\x0b\xbc=M"
+ b"\x94H:\xf5\xa8\xd6\xa4\xc90\xc2D\xb9\xd3\xa8\xb0S\x87 `\xa2\xeb\xf3W\xce"
+ b" 7\xf9N#\r\xe6\xbe\t\x9d\xe7\x811\xf9\x10\xc1\xc2\x14\xf6\xfc\xcba\xb7"
+ b"\xb1\x7f\x95l\xe4\tjA\xec:\x10\xe5\xfe\xc2\\=D\xe2\x0c\x0b3]\xf7\xc1\xf7"
+ b"\xbceZ\xb1A\xea\x16\xe5\xfddgFQ\xed\xaf\x04\xa3\xd3\xf8\xa2q\x19B\xd4r"
+ b"\xc5\x0c\x9a\x14\x94\xea\x91\xc4o\xe4\xbb\xb4\x99\xf4@\xd1\xe6\x0c\xe3"
+ b"\xc6d\xa0Q\n\xf2/\xd8\xb8S5\x8a\x18:\xb5g\xac\x95D\xce\x17\x07\xd4z\xda"
+ b"\x90\xe65\x07\x19H!\t\xfdu\x16\x8e\x0eR\x19\xf4\x8cl\x0c\xf9Q\xf1\x80"
+ b"\xe3\xbf\xd7O\xf8\x8c\x18\x0b\x9c\xf1\x1fb\xe1\tR\xb2\xf1\xe1A\xea \xcf-"
+ b"IGE\xf1\x14\x98$\x83\x15\xc9\xd8j\xbf\x19\x0f\xd5\xd1\xaa\xb3\xf3\xa5I2s"
+ b"\x8d\x145\xca\xd5\xd93\x9c\xb8D0\xe6\xaa%\xd0\xc0P}JO^h\x8e\x08\xadlV."
+ b"\x18\x88\x13\x05o\xb0\x07\xeaw\xe0\xb6\xa4\xd5*\xe4r\xef\x07G+\xc1\xbei["
+ b"w\xe8\xab@_\xef\x15y\xe5\x12\xc9W\x1b.\xad\x85-\xc2\xf7\xe3mU6g\x8eSA"
+ b"\x01(\xd3\xdb\x16\x13=\xde\x92\xf9,D\xb8\x8a\xb2\xb4\xc9\xc3\xefnE\xe8\\"
+ b"\xa6\xe2Y\xd2\xcf\xcb\x8c\xb6\xd5\xe9\x1d\x1e\x9a\x8b~\xe2\xa6\rE\x84uV"
+ b"\xed\xc6\x99\xddm<\x10[\x0fu\x1f\xc1\x1d1\n\xcfw\xb2%!\xf0[\xce\x87\x83B"
+ b"\x08\xaa,\x08%d\xcef\x94\"\xd9g.\xc83\xcbXY+4\xec\x85qA\n\x1d=9\xf0*\xb1"
+ b"\x1f/\xf3s\xd61b\x7f@\xfb\x9d\xe3FQ\\\xbd\x82\x1e\x00\xf3\xce\xd3\xe1"
+ b"\xca,E\xfd7[\xab\xb6\xb7\xac!mA}\xbd\x9d3R5\x9cF\xabH\xeb\x92)cc\x13\xd0"
+ b"\xbd\xee\xe9n{\x1dIJB\xa5\xeb\x11\xe8`w&`\x8b}@Oxe\t\x8a\x07\x02\x95\xf2"
+ b"\xed\xda|\xb1e\xbe\xaa\xbbg\x19@\xe1Y\x878\x84\x0f\x8c\xe3\xc98\xf2\x9e"
+ b"\xd5N\xb5J\xef\xab!\xe2\x8dq\xe1\xe5q\xc5\xee\x11W\xb7\xe4k*\x027\xa0"
+ b"\xa3J\xf4\xd8m\xd0q\x94\xcb\x07\n:\xb6`.\xe4\x9c\x15+\xc0)\xde\x80X\xd4"
+ b"\xcfQm\x01\xc2cP\x1cA\x85'\xc9\xac\x8b\xe6\xb2)\xe6\x84t\x1c\x92\xe4Z"
+ b"\x1cR\xb0\x9e\x96\xd1\xfb\x1c\xa6\x8b\xcb`\x10\x12]\xf2gR\x9bFT\xe0\xc8H"
+ b"S\xfb\xac<\x04\xc7\xc1\xe8\xedP\xf4\x16\xdb\xc0\xd7e\xc2\x17J^\x1f\xab"
+ b"\xff[\x08\x19\xb4\xf5\xfb\x19\xb4\x04\xe5c~']\xcb\xc2A\xec\x90\xd0\xed"
+ b"\x06,\xc5K{\x86\x03\xb1\xcdMx\xdeQ\x8c3\xf9\x8a\xea=\x89\xaba\xd2\xc89a"
+ b"\xd72\xf0\xc3\x19\x8a\xdfs\xd4\xfd\xbb\x81b\xeaE\"\xd8\xf4d\x0cD\xf7IJ!"
+ b"\xe5d\xbbG\xe9\xcam\xaa\x0f_r\x95\x91NBq\xcaP\xce\xa7\xa9\xb5\x10\x94eP!"
+ b"|\x856\xcd\xbfIir\xb8e\x9bjP\x97q\xabwS7\x1a\x0ehM\xe7\xca\x86?\xdeP}y~"
+ b"\x0f\x95I\xfc\x13\xe1<Q\x1b\x868\x1d\x11\xdf\x94\xf4\x82>r\xa9k\x88\xcb"
+ b"\xfd\xc3v\xe2\xb9\x8a\x02\x8eq\x92I\xf8\xf6\xf1\x03s\x9b\xb8\xe3\"\xe3"
+ b"\xa9\xa5>D\xb8\x96;\xe7\x92\xd133\xe8\xdd'e\xc9.\xdc;\x17\x1f\xf5H\x13q"
+ b"\xa4W\x0c\xdb~\x98\x01\xeb\xdf\xe32\x13\x0f\xddx\n6\xa0\t\x10\xb6\xbb"
+ b"\xb0\xc3\x18\xb6;\x9fj[\xd9\xd5\xc9\x06\x8a\x87\xcd\xe5\xee\xfc\x9c-%@"
+ b"\xee\xe0\xeb\xd2\xe3\xe8\xfb\xc0\x122\\\xc7\xaf\xc2\xa1Oth\xb3\x8f\x82"
+ b"\xb3\x18\xa8\x07\xd5\xee_\xbe\xe0\x1cA\x1e_\r\x9a\xb0\x17W&\xa2D\x91\x94"
+ b"\x1a\xb2\xef\xf2\xdc\x85;X\xb0,\xeb>-7S\xe5\xca\x07)\x1fp\x7f\xcaQBL\xca"
+ b"\xf3\xb9d\xfc\xb5su\xb0\xc8\x95\x90\xeb*)\xa0v\xe4\x9a{FW\xf4l\xde\xcdj"
+ b"\x00"
+)
+
+
+def test_main():
+ run_unittest(
+ CompressorDecompressorTestCase,
+ CompressDecompressFunctionTestCase,
+ FileTestCase,
+ OpenTestCase,
+ MiscellaneousTestCase,
+ )
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_macpath.py b/Lib/test/test_macpath.py
index d732e14..ae4e564 100644
--- a/Lib/test/test_macpath.py
+++ b/Lib/test/test_macpath.py
@@ -115,13 +115,9 @@ class MacPathTestCase(unittest.TestCase):
self.assertEqual(normpath(b"a:b:"), b"a:b")
-class MacCommonTest(test_genericpath.CommonTest):
+class MacCommonTest(test_genericpath.CommonTest, unittest.TestCase):
pathmodule = macpath
-def test_main():
- support.run_unittest(MacPathTestCase, MacCommonTest)
-
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_mailbox.py b/Lib/test/test_mailbox.py
index 7a84292..39e8643 100644
--- a/Lib/test/test_mailbox.py
+++ b/Lib/test/test_mailbox.py
@@ -22,6 +22,10 @@ except ImportError:
class TestBase:
+ all_mailbox_types = (mailbox.Message, mailbox.MaildirMessage,
+ mailbox.mboxMessage, mailbox.MHMessage,
+ mailbox.BabylMessage, mailbox.MMDFMessage)
+
def _check_sample(self, msg):
# Inspect a mailbox.Message representation of the sample message
self.assertIsInstance(msg, email.message.Message)
@@ -91,14 +95,14 @@ class TestMailbox(TestBase):
""")
def test_add_invalid_8bit_bytes_header(self):
- key = self._box.add(self._nonascii_msg.encode('latin1'))
+ key = self._box.add(self._nonascii_msg.encode('latin-1'))
self.assertEqual(len(self._box), 1)
self.assertEqual(self._box.get_bytes(key),
- self._nonascii_msg.encode('latin1'))
+ self._nonascii_msg.encode('latin-1'))
def test_invalid_nonascii_header_as_string(self):
subj = self._nonascii_msg.splitlines()[1]
- key = self._box.add(subj.encode('latin1'))
+ key = self._box.add(subj.encode('latin-1'))
self.assertEqual(self._box.get_string(key),
'Subject: =?unknown-8bit?b?RmFsaW5hcHThciBo4Xpob3pzeuFsbO104XNz'
'YWwuIE3hciByZW5kZWx06Ww/?=\n\n')
@@ -930,8 +934,7 @@ class TestMaildir(TestMailbox, unittest.TestCase):
# the mtime and should cause a re-read. Note that "sleep
# emulation" is still in effect, as skewfactor is -3.
filename = os.path.join(self._path, 'cur', 'stray-file')
- f = open(filename, 'w')
- f.close()
+ support.create_empty_file(filename)
os.unlink(filename)
self._box._refresh()
self.assertTrue(refreshed())
@@ -1397,6 +1400,14 @@ class TestMessage(TestBase, unittest.TestCase):
# Initialize with invalid argument
self.assertRaises(TypeError, lambda: self._factory(object()))
+ def test_all_eMM_attribues_exist(self):
+ # Issue 12537
+ eMM = email.message_from_string(_sample_message)
+ msg = self._factory(_sample_message)
+ for attr in eMM.__dict__:
+ self.assertTrue(attr in msg.__dict__,
+ '{} attribute does not exist'.format(attr))
+
def test_become_message(self):
# Take on the state of another message
eMM = email.message_from_string(_sample_message)
@@ -1408,9 +1419,7 @@ class TestMessage(TestBase, unittest.TestCase):
# Copy self's format-specific data to other message formats.
# This test is superficial; better ones are in TestMessageConversion.
msg = self._factory()
- for class_ in (mailbox.Message, mailbox.MaildirMessage,
- mailbox.mboxMessage, mailbox.MHMessage,
- mailbox.BabylMessage, mailbox.MMDFMessage):
+ for class_ in self.all_mailbox_types:
other_msg = class_()
msg._explain_to(other_msg)
other_msg = email.message.Message()
@@ -1642,37 +1651,44 @@ class TestMessageConversion(TestBase, unittest.TestCase):
def test_plain_to_x(self):
# Convert Message to all formats
- for class_ in (mailbox.Message, mailbox.MaildirMessage,
- mailbox.mboxMessage, mailbox.MHMessage,
- mailbox.BabylMessage, mailbox.MMDFMessage):
+ for class_ in self.all_mailbox_types:
msg_plain = mailbox.Message(_sample_message)
msg = class_(msg_plain)
self._check_sample(msg)
def test_x_to_plain(self):
# Convert all formats to Message
- for class_ in (mailbox.Message, mailbox.MaildirMessage,
- mailbox.mboxMessage, mailbox.MHMessage,
- mailbox.BabylMessage, mailbox.MMDFMessage):
+ for class_ in self.all_mailbox_types:
msg = class_(_sample_message)
msg_plain = mailbox.Message(msg)
self._check_sample(msg_plain)
def test_x_from_bytes(self):
# Convert all formats to Message
- for class_ in (mailbox.Message, mailbox.MaildirMessage,
- mailbox.mboxMessage, mailbox.MHMessage,
- mailbox.BabylMessage, mailbox.MMDFMessage):
+ for class_ in self.all_mailbox_types:
msg = class_(_bytes_sample_message)
self._check_sample(msg)
def test_x_to_invalid(self):
# Convert all formats to an invalid format
- for class_ in (mailbox.Message, mailbox.MaildirMessage,
- mailbox.mboxMessage, mailbox.MHMessage,
- mailbox.BabylMessage, mailbox.MMDFMessage):
+ for class_ in self.all_mailbox_types:
self.assertRaises(TypeError, lambda: class_(False))
+ def test_type_specific_attributes_removed_on_conversion(self):
+ reference = {class_: class_(_sample_message).__dict__
+ for class_ in self.all_mailbox_types}
+ for class1 in self.all_mailbox_types:
+ for class2 in self.all_mailbox_types:
+ if class1 is class2:
+ continue
+ source = class1(_sample_message)
+ target = class2(source)
+ type_specific = [a for a in reference[class1]
+ if a not in reference[class2]]
+ for attr in type_specific:
+ self.assertNotIn(attr, target.__dict__,
+ "while converting {} to {}".format(class1, class2))
+
def test_maildir_to_maildir(self):
# Convert MaildirMessage to MaildirMessage
msg_maildir = mailbox.MaildirMessage(_sample_message)
diff --git a/Lib/test/test_mailcap.py b/Lib/test/test_mailcap.py
new file mode 100644
index 0000000..a4cd09c
--- /dev/null
+++ b/Lib/test/test_mailcap.py
@@ -0,0 +1,221 @@
+import mailcap
+import os
+import shutil
+import test.support
+import unittest
+
+# Location of mailcap file
+MAILCAPFILE = test.support.findfile("mailcap.txt")
+
+# Dict to act as mock mailcap entry for this test
+# The keys and values should match the contents of MAILCAPFILE
+MAILCAPDICT = {
+ 'application/x-movie':
+ [{'compose': 'moviemaker %s',
+ 'x11-bitmap': '"/usr/lib/Zmail/bitmaps/movie.xbm"',
+ 'description': '"Movie"',
+ 'view': 'movieplayer %s'}],
+ 'application/*':
+ [{'copiousoutput': '',
+ 'view': 'echo "This is \\"%t\\" but is 50 \\% Greek to me" \\; cat %s'}],
+ 'audio/basic':
+ [{'edit': 'audiocompose %s',
+ 'compose': 'audiocompose %s',
+ 'description': '"An audio fragment"',
+ 'view': 'showaudio %s'}],
+ 'video/mpeg':
+ [{'view': 'mpeg_play %s'}],
+ 'application/postscript':
+ [{'needsterminal': '', 'view': 'ps-to-terminal %s'},
+ {'compose': 'idraw %s', 'view': 'ps-to-terminal %s'}],
+ 'application/x-dvi':
+ [{'view': 'xdvi %s'}],
+ 'message/external-body':
+ [{'composetyped': 'extcompose %s',
+ 'description': '"A reference to data stored in an external location"',
+ 'needsterminal': '',
+ 'view': 'showexternal %s %{access-type} %{name} %{site} %{directory} %{mode} %{server}'}],
+ 'text/richtext':
+ [{'test': 'test "`echo %{charset} | tr \'[A-Z]\' \'[a-z]\'`" = iso-8859-8',
+ 'copiousoutput': '',
+ 'view': 'shownonascii iso-8859-8 -e richtext -p %s'}],
+ 'image/x-xwindowdump':
+ [{'view': 'display %s'}],
+ 'audio/*':
+ [{'view': '/usr/local/bin/showaudio %t'}],
+ 'video/*':
+ [{'view': 'animate %s'}],
+ 'application/frame':
+ [{'print': '"cat %s | lp"', 'view': 'showframe %s'}],
+ 'image/rgb':
+ [{'view': 'display %s'}]
+}
+
+
+class HelperFunctionTest(unittest.TestCase):
+
+ def test_listmailcapfiles(self):
+ # The return value for listmailcapfiles() will vary by system.
+ # So verify that listmailcapfiles() returns a list of strings that is of
+ # non-zero length.
+ mcfiles = mailcap.listmailcapfiles()
+ self.assertIsInstance(mcfiles, list)
+ for m in mcfiles:
+ self.assertIsInstance(m, str)
+ with test.support.EnvironmentVarGuard() as env:
+ # According to RFC 1524, if MAILCAPS env variable exists, use that
+ # and only that.
+ if "MAILCAPS" in env:
+ env_mailcaps = env["MAILCAPS"].split(os.pathsep)
+ else:
+ env_mailcaps = ["/testdir1/.mailcap", "/testdir2/mailcap"]
+ env["MAILCAPS"] = os.pathsep.join(env_mailcaps)
+ mcfiles = mailcap.listmailcapfiles()
+ self.assertEqual(env_mailcaps, mcfiles)
+
+ def test_readmailcapfile(self):
+ # Test readmailcapfile() using test file. It should match MAILCAPDICT.
+ with open(MAILCAPFILE, 'r') as mcf:
+ d = mailcap.readmailcapfile(mcf)
+ self.assertDictEqual(d, MAILCAPDICT)
+
+ def test_lookup(self):
+ # Test without key
+ expected = [{'view': 'mpeg_play %s'}, {'view': 'animate %s'}]
+ actual = mailcap.lookup(MAILCAPDICT, 'video/mpeg')
+ self.assertListEqual(expected, actual)
+
+ # Test with key
+ key = 'compose'
+ expected = [{'edit': 'audiocompose %s',
+ 'compose': 'audiocompose %s',
+ 'description': '"An audio fragment"',
+ 'view': 'showaudio %s'}]
+ actual = mailcap.lookup(MAILCAPDICT, 'audio/basic', key)
+ self.assertListEqual(expected, actual)
+
+ def test_subst(self):
+ plist = ['id=1', 'number=2', 'total=3']
+ # test case: ([field, MIMEtype, filename, plist=[]], <expected string>)
+ test_cases = [
+ (["", "audio/*", "foo.txt"], ""),
+ (["echo foo", "audio/*", "foo.txt"], "echo foo"),
+ (["echo %s", "audio/*", "foo.txt"], "echo foo.txt"),
+ (["echo %t", "audio/*", "foo.txt"], "echo audio/*"),
+ (["echo \%t", "audio/*", "foo.txt"], "echo %t"),
+ (["echo foo", "audio/*", "foo.txt", plist], "echo foo"),
+ (["echo %{total}", "audio/*", "foo.txt", plist], "echo 3")
+ ]
+ for tc in test_cases:
+ self.assertEqual(mailcap.subst(*tc[0]), tc[1])
+
+
+class GetcapsTest(unittest.TestCase):
+
+ def test_mock_getcaps(self):
+ # Test mailcap.getcaps() using mock mailcap file in this dir.
+ # Temporarily override any existing system mailcap file by pointing the
+ # MAILCAPS environment variable to our mock file.
+ with test.support.EnvironmentVarGuard() as env:
+ env["MAILCAPS"] = MAILCAPFILE
+ caps = mailcap.getcaps()
+ self.assertDictEqual(caps, MAILCAPDICT)
+
+ def test_system_mailcap(self):
+ # Test mailcap.getcaps() with mailcap file(s) on system, if any.
+ caps = mailcap.getcaps()
+ self.assertIsInstance(caps, dict)
+ mailcapfiles = mailcap.listmailcapfiles()
+ existingmcfiles = [mcf for mcf in mailcapfiles if os.path.exists(mcf)]
+ if existingmcfiles:
+ # At least 1 mailcap file exists, so test that.
+ for (k, v) in caps.items():
+ self.assertIsInstance(k, str)
+ self.assertIsInstance(v, list)
+ for e in v:
+ self.assertIsInstance(e, dict)
+ else:
+ # No mailcap files on system. getcaps() should return empty dict.
+ self.assertEqual({}, caps)
+
+
+class FindmatchTest(unittest.TestCase):
+
+ def test_findmatch(self):
+
+ # default findmatch arguments
+ c = MAILCAPDICT
+ fname = "foo.txt"
+ plist = ["access-type=default", "name=john", "site=python.org",
+ "directory=/tmp", "mode=foo", "server=bar"]
+ audio_basic_entry = {
+ 'edit': 'audiocompose %s',
+ 'compose': 'audiocompose %s',
+ 'description': '"An audio fragment"',
+ 'view': 'showaudio %s'
+ }
+ audio_entry = {"view": "/usr/local/bin/showaudio %t"}
+ video_entry = {'view': 'animate %s'}
+ message_entry = {
+ 'composetyped': 'extcompose %s',
+ 'description': '"A reference to data stored in an external location"', 'needsterminal': '',
+ 'view': 'showexternal %s %{access-type} %{name} %{site} %{directory} %{mode} %{server}'
+ }
+
+ # test case: (findmatch args, findmatch keyword args, expected output)
+ # positional args: caps, MIMEtype
+ # keyword args: key="view", filename="/dev/null", plist=[]
+ # output: (command line, mailcap entry)
+ cases = [
+ ([{}, "video/mpeg"], {}, (None, None)),
+ ([c, "foo/bar"], {}, (None, None)),
+ ([c, "video/mpeg"], {}, ('mpeg_play /dev/null', {'view': 'mpeg_play %s'})),
+ ([c, "audio/basic", "edit"], {}, ("audiocompose /dev/null", audio_basic_entry)),
+ ([c, "audio/basic", "compose"], {}, ("audiocompose /dev/null", audio_basic_entry)),
+ ([c, "audio/basic", "description"], {}, ('"An audio fragment"', audio_basic_entry)),
+ ([c, "audio/basic", "foobar"], {}, (None, None)),
+ ([c, "video/*"], {"filename": fname}, ("animate %s" % fname, video_entry)),
+ ([c, "audio/basic", "compose"],
+ {"filename": fname},
+ ("audiocompose %s" % fname, audio_basic_entry)),
+ ([c, "audio/basic"],
+ {"key": "description", "filename": fname},
+ ('"An audio fragment"', audio_basic_entry)),
+ ([c, "audio/*"],
+ {"filename": fname},
+ ("/usr/local/bin/showaudio audio/*", audio_entry)),
+ ([c, "message/external-body"],
+ {"plist": plist},
+ ("showexternal /dev/null default john python.org /tmp foo bar", message_entry))
+ ]
+ self._run_cases(cases)
+
+ @unittest.skipUnless(os.name == "posix", "Requires 'test' command on system")
+ def test_test(self):
+ # findmatch() will automatically check any "test" conditions and skip
+ # the entry if the check fails.
+ caps = {"test/pass": [{"test": "test 1 -eq 1"}],
+ "test/fail": [{"test": "test 1 -eq 0"}]}
+ # test case: (findmatch args, findmatch keyword args, expected output)
+ # positional args: caps, MIMEtype, key ("test")
+ # keyword args: N/A
+ # output: (command line, mailcap entry)
+ cases = [
+ # findmatch will return the mailcap entry for test/pass because it evaluates to true
+ ([caps, "test/pass", "test"], {}, ("test 1 -eq 1", {"test": "test 1 -eq 1"})),
+ # findmatch will return None because test/fail evaluates to false
+ ([caps, "test/fail", "test"], {}, (None, None))
+ ]
+ self._run_cases(cases)
+
+ def _run_cases(self, cases):
+ for c in cases:
+ self.assertEqual(mailcap.findmatch(*c[0], **c[1]), c[2])
+
+
+def test_main():
+ test.support.run_unittest(HelperFunctionTest, GetcapsTest, FindmatchTest)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py
index 02ef62a..025b53c 100644
--- a/Lib/test/test_marshal.py
+++ b/Lib/test/test_marshal.py
@@ -6,6 +6,7 @@ import marshal
import sys
import unittest
import os
+import types
class HelperMixin:
def helper(self, sample, *extra):
@@ -114,6 +115,22 @@ class CodeTestCase(unittest.TestCase):
codes = (ExceptionTestCase.test_exceptions.__code__,) * count
marshal.loads(marshal.dumps(codes))
+ def test_different_filenames(self):
+ co1 = compile("x", "f1", "exec")
+ co2 = compile("y", "f2", "exec")
+ co1, co2 = marshal.loads(marshal.dumps((co1, co2)))
+ self.assertEqual(co1.co_filename, "f1")
+ self.assertEqual(co2.co_filename, "f2")
+
+ @support.cpython_only
+ def test_same_filename_used(self):
+ s = """def f(): pass\ndef g(): pass"""
+ co = compile(s, "myfile", "exec")
+ co = marshal.loads(marshal.dumps(co))
+ for obj in co.co_consts:
+ if isinstance(obj, types.CodeType):
+ self.assertIs(co.co_filename, obj.co_filename)
+
class ContainerTestCase(unittest.TestCase, HelperMixin):
d = {'astring': 'foo@bar.baz.spam',
'afloat': 7283.43,
@@ -263,7 +280,6 @@ class BugsTestCase(unittest.TestCase):
self.assertRaises(TypeError, marshal.loads, unicode_string)
LARGE_SIZE = 2**31
-character_size = 4 if sys.maxunicode > 0xFFFF else 2
pointer_size = 8 if sys.maxsize > 0xFFFFFFFF else 4
class NullWriter:
@@ -279,7 +295,7 @@ class LargeValuesTestCase(unittest.TestCase):
def test_bytes(self, size):
self.check_unmarshallable(b'x' * size)
- @support.bigmemtest(size=LARGE_SIZE, memuse=character_size, dry_run=False)
+ @support.bigmemtest(size=LARGE_SIZE, memuse=1, dry_run=False)
def test_str(self, size):
self.check_unmarshallable('x' * size)
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index dddc889..715003a 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -2,11 +2,12 @@
# XXXX Should not do tests around zero only
from test.support import run_unittest, verbose, requires_IEEE_754
+from test import support
import unittest
import math
import os
+import platform
import sys
-import random
import struct
import sysconfig
@@ -457,12 +458,12 @@ class MathTests(unittest.TestCase):
def testFmod(self):
self.assertRaises(TypeError, math.fmod)
- self.ftest('fmod(10,1)', math.fmod(10,1), 0)
- self.ftest('fmod(10,0.5)', math.fmod(10,0.5), 0)
- self.ftest('fmod(10,1.5)', math.fmod(10,1.5), 1)
- self.ftest('fmod(-10,1)', math.fmod(-10,1), 0)
- self.ftest('fmod(-10,0.5)', math.fmod(-10,0.5), 0)
- self.ftest('fmod(-10,1.5)', math.fmod(-10,1.5), -1)
+ self.ftest('fmod(10, 1)', math.fmod(10, 1), 0.0)
+ self.ftest('fmod(10, 0.5)', math.fmod(10, 0.5), 0.0)
+ self.ftest('fmod(10, 1.5)', math.fmod(10, 1.5), 1.0)
+ self.ftest('fmod(-10, 1)', math.fmod(-10, 1), -0.0)
+ self.ftest('fmod(-10, 0.5)', math.fmod(-10, 0.5), -0.0)
+ self.ftest('fmod(-10, 1.5)', math.fmod(-10, 1.5), -1.0)
self.assertTrue(math.isnan(math.fmod(NAN, 1.)))
self.assertTrue(math.isnan(math.fmod(1., NAN)))
self.assertTrue(math.isnan(math.fmod(NAN, NAN)))
@@ -650,6 +651,33 @@ class MathTests(unittest.TestCase):
n= 2**90
self.assertAlmostEqual(math.log1p(n), math.log1p(float(n)))
+ @requires_IEEE_754
+ def testLog2(self):
+ self.assertRaises(TypeError, math.log2)
+
+ # Check some integer values
+ self.assertEqual(math.log2(1), 0.0)
+ self.assertEqual(math.log2(2), 1.0)
+ self.assertEqual(math.log2(4), 2.0)
+
+ # Large integer values
+ self.assertEqual(math.log2(2**1023), 1023.0)
+ self.assertEqual(math.log2(2**1024), 1024.0)
+ self.assertEqual(math.log2(2**2000), 2000.0)
+
+ self.assertRaises(ValueError, math.log2, -1.5)
+ self.assertRaises(ValueError, math.log2, NINF)
+ self.assertTrue(math.isnan(math.log2(NAN)))
+
+ @requires_IEEE_754
+ # log2() is not accurate enough on Mac OS X Tiger (10.4)
+ @support.requires_mac_ver(10, 5)
+ def testLog2Exact(self):
+ # Check that we get exact equality for log2 of powers of 2.
+ actual = [math.log2(math.ldexp(1.0, n)) for n in range(-1074, 1024)]
+ expected = [float(n) for n in range(-1074, 1024)]
+ self.assertEqual(actual, expected)
+
def testLog10(self):
self.assertRaises(TypeError, math.log10)
self.ftest('log10(0.1)', math.log10(0.1), -1)
@@ -1010,7 +1038,6 @@ class MathTests(unittest.TestCase):
@requires_IEEE_754
def test_mtestfile(self):
- ALLOWED_ERROR = 20 # permitted error, in ulps
fail_fmt = "{}:{}({!r}): expected {!r}, got {!r}"
failures = []
diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py
index 7c0c84f..e5db945 100644
--- a/Lib/test/test_memoryio.py
+++ b/Lib/test/test_memoryio.py
@@ -658,7 +658,7 @@ class CBytesIOTest(PyBytesIOTest):
@support.cpython_only
def test_sizeof(self):
- basesize = support.calcobjsize('P2PP2PP')
+ basesize = support.calcobjsize('P2nN2Pn')
check = self.check_sizeof
self.assertEqual(object.__sizeof__(io.BytesIO()), basesize)
check(io.BytesIO(), basesize )
diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py
index 0bfddd9..ee6b15a 100644
--- a/Lib/test/test_memoryview.py
+++ b/Lib/test/test_memoryview.py
@@ -1,6 +1,7 @@
"""Unit tests for the memoryview
-XXX We need more tests! Some tests are in test_bytes
+ Some tests are in test_bytes. Many tests that require _testbuffer.ndarray
+ are in test_buffer.
"""
import unittest
@@ -24,15 +25,14 @@ class AbstractMemoryTests:
return filter(None, [self.ro_type, self.rw_type])
def check_getitem_with_type(self, tp):
- item = self.getitem_type
b = tp(self._source)
oldrefcount = sys.getrefcount(b)
m = self._view(b)
- self.assertEqual(m[0], item(b"a"))
- self.assertIsInstance(m[0], bytes)
- self.assertEqual(m[5], item(b"f"))
- self.assertEqual(m[-1], item(b"f"))
- self.assertEqual(m[-6], item(b"a"))
+ self.assertEqual(m[0], ord(b"a"))
+ self.assertIsInstance(m[0], int)
+ self.assertEqual(m[5], ord(b"f"))
+ self.assertEqual(m[-1], ord(b"f"))
+ self.assertEqual(m[-6], ord(b"a"))
# Bounds checking
self.assertRaises(IndexError, lambda: m[6])
self.assertRaises(IndexError, lambda: m[-7])
@@ -76,7 +76,9 @@ class AbstractMemoryTests:
b = self.rw_type(self._source)
oldrefcount = sys.getrefcount(b)
m = self._view(b)
- m[0] = tp(b"0")
+ m[0] = ord(b'1')
+ self._check_contents(tp, b, b"1bcdef")
+ m[0:1] = tp(b"0")
self._check_contents(tp, b, b"0bcdef")
m[1:3] = tp(b"12")
self._check_contents(tp, b, b"012def")
@@ -102,10 +104,17 @@ class AbstractMemoryTests:
# Wrong index/slice types
self.assertRaises(TypeError, setitem, 0.0, b"a")
self.assertRaises(TypeError, setitem, (0,), b"a")
+ self.assertRaises(TypeError, setitem, (slice(0,1,1), 0), b"a")
+ self.assertRaises(TypeError, setitem, (0, slice(0,1,1)), b"a")
+ self.assertRaises(TypeError, setitem, (0,), b"a")
self.assertRaises(TypeError, setitem, "a", b"a")
+ # Not implemented: multidimensional slices
+ slices = (slice(0,1,1), slice(0,1,2))
+ self.assertRaises(NotImplementedError, setitem, slices, b"a")
# Trying to resize the memory object
- self.assertRaises(ValueError, setitem, 0, b"")
- self.assertRaises(ValueError, setitem, 0, b"ab")
+ exc = ValueError if m.format == 'c' else TypeError
+ self.assertRaises(exc, setitem, 0, b"")
+ self.assertRaises(exc, setitem, 0, b"ab")
self.assertRaises(ValueError, setitem, slice(1,1), b"a")
self.assertRaises(ValueError, setitem, slice(0,2), b"a")
@@ -175,7 +184,7 @@ class AbstractMemoryTests:
self.assertEqual(m.shape, (6,))
self.assertEqual(len(m), 6)
self.assertEqual(m.strides, (self.itemsize,))
- self.assertEqual(m.suboffsets, None)
+ self.assertEqual(m.suboffsets, ())
return m
def test_attributes_readonly(self):
@@ -209,12 +218,16 @@ class AbstractMemoryTests:
# If tp is a factory rather than a plain type, skip
continue
+ class MyView():
+ def __init__(self, base):
+ self.m = memoryview(base)
class MySource(tp):
pass
class MyObject:
pass
- # Create a reference cycle through a memoryview object
+ # Create a reference cycle through a memoryview object.
+ # This exercises mbuf_clear().
b = MySource(tp(b'abc'))
m = self._view(b)
o = MyObject()
@@ -226,6 +239,17 @@ class AbstractMemoryTests:
gc.collect()
self.assertTrue(wr() is None, wr())
+ # This exercises memory_clear().
+ m = MyView(tp(b'abc'))
+ o = MyObject()
+ m.x = m
+ m.o = o
+ wr = weakref.ref(o)
+ m = o = None
+ # The cycle must be broken
+ gc.collect()
+ self.assertTrue(wr() is None, wr())
+
def _check_released(self, m, tp):
check = self.assertRaisesRegex(ValueError, "released")
with check: bytes(m)
@@ -283,6 +307,51 @@ class AbstractMemoryTests:
i = io.BytesIO(b'ZZZZ')
self.assertRaises(TypeError, i.readinto, m)
+ def test_getbuf_fail(self):
+ self.assertRaises(TypeError, self._view, {})
+
+ def test_hash(self):
+ # Memoryviews of readonly (hashable) types are hashable, and they
+ # hash as hash(obj.tobytes()).
+ tp = self.ro_type
+ if tp is None:
+ self.skipTest("no read-only type to test")
+ b = tp(self._source)
+ m = self._view(b)
+ self.assertEqual(hash(m), hash(b"abcdef"))
+ # Releasing the memoryview keeps the stored hash value (as with weakrefs)
+ m.release()
+ self.assertEqual(hash(m), hash(b"abcdef"))
+ # Hashing a memoryview for the first time after it is released
+ # results in an error (as with weakrefs).
+ m = self._view(b)
+ m.release()
+ self.assertRaises(ValueError, hash, m)
+
+ def test_hash_writable(self):
+ # Memoryviews of writable types are unhashable
+ tp = self.rw_type
+ if tp is None:
+ self.skipTest("no writable type to test")
+ b = tp(self._source)
+ m = self._view(b)
+ self.assertRaises(ValueError, hash, m)
+
+ def test_weakref(self):
+ # Check memoryviews are weakrefable
+ for tp in self._types:
+ b = tp(self._source)
+ m = self._view(b)
+ L = []
+ def callback(wr, b=b):
+ L.append(b)
+ wr = weakref.ref(m, callback)
+ self.assertIs(wr(), m)
+ del m
+ test.support.gc_collect()
+ self.assertIs(wr(), None)
+ self.assertIs(L[0], b)
+
# Variations on source objects for the buffer: bytes-like objects, then arrays
# with itemsize > 1.
# NOTE: support for multi-dimensional objects is unimplemented.
diff --git a/Lib/test/test_metaclass.py b/Lib/test/test_metaclass.py
index 219ab99..e6fe20a 100644
--- a/Lib/test/test_metaclass.py
+++ b/Lib/test/test_metaclass.py
@@ -159,6 +159,7 @@ Use a __prepare__ method that returns an instrumented dict.
... bar = 123
...
d['__module__'] = 'test.test_metaclass'
+ d['__qualname__'] = 'C'
d['foo'] = 4
d['foo'] = 42
d['bar'] = 123
@@ -177,12 +178,12 @@ Use a metaclass that doesn't derive from type.
... b = 24
...
meta: C ()
- ns: [('__module__', 'test.test_metaclass'), ('a', 42), ('b', 24)]
+ ns: [('__module__', 'test.test_metaclass'), ('__qualname__', 'C'), ('a', 42), ('b', 24)]
kw: []
>>> type(C) is dict
True
>>> print(sorted(C.items()))
- [('__module__', 'test.test_metaclass'), ('a', 42), ('b', 24)]
+ [('__module__', 'test.test_metaclass'), ('__qualname__', 'C'), ('a', 42), ('b', 24)]
>>>
And again, with a __prepare__ attribute.
@@ -199,11 +200,12 @@ And again, with a __prepare__ attribute.
...
prepare: C () [('other', 'booh')]
d['__module__'] = 'test.test_metaclass'
+ d['__qualname__'] = 'C'
d['a'] = 1
d['a'] = 2
d['b'] = 3
meta: C ()
- ns: [('__module__', 'test.test_metaclass'), ('a', 2), ('b', 3)]
+ ns: [('__module__', 'test.test_metaclass'), ('__qualname__', 'C'), ('a', 2), ('b', 3)]
kw: [('other', 'booh')]
>>>
@@ -246,7 +248,13 @@ Test failures in looking up the __prepare__ method work.
"""
-__test__ = {'doctests' : doctests}
+import sys
+
+# Trace function introduces __locals__ which causes various tests to fail.
+if hasattr(sys, 'gettrace') and sys.gettrace():
+ __test__ = {}
+else:
+ __test__ = {'doctests' : doctests}
def test_main(verbose=False):
from test import support
diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py
index 80812c8..5867b2d 100644
--- a/Lib/test/test_minidom.py
+++ b/Lib/test/test_minidom.py
@@ -4,9 +4,7 @@ import pickle
from test.support import verbose, run_unittest, findfile
import unittest
-import xml.dom
import xml.dom.minidom
-import xml.parsers.expat
from xml.dom.minidom import parse, Node, Document, parseString
from xml.dom.minidom import getDOMImplementation
@@ -14,7 +12,6 @@ from xml.dom.minidom import getDOMImplementation
tstfile = findfile("test.xml", subdir="xmltestdata")
-
# The tests of DocumentType importing use these helpers to construct
# the documents to work with, since not all DOM builders actually
# create the DocumentType nodes.
@@ -50,7 +47,7 @@ class MinidomTest(unittest.TestCase):
def checkWholeText(self, node, s):
t = node.wholeText
- self.confirm(t == s, "looking for %s, found %s" % (repr(s), repr(t)))
+ self.confirm(t == s, "looking for %r, found %r" % (s, t))
def testParseFromFile(self):
with open(tstfile) as file:
@@ -282,6 +279,7 @@ class MinidomTest(unittest.TestCase):
child.setAttribute("def", "ghi")
self.confirm(len(child.attributes) == 1)
+ self.assertRaises(xml.dom.NotFoundErr, child.removeAttribute, "foo")
child.removeAttribute("def")
self.confirm(len(child.attributes) == 0)
dom.unlink()
@@ -293,6 +291,8 @@ class MinidomTest(unittest.TestCase):
child.setAttributeNS("http://www.w3.org", "xmlns:python",
"http://www.python.org")
child.setAttributeNS("http://www.python.org", "python:abcattr", "foo")
+ self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNS,
+ "foo", "http://www.python.org")
self.confirm(len(child.attributes) == 2)
child.removeAttributeNS("http://www.python.org", "abcattr")
self.confirm(len(child.attributes) == 1)
@@ -304,11 +304,23 @@ class MinidomTest(unittest.TestCase):
child.setAttribute("spam", "jam")
self.confirm(len(child.attributes) == 1)
node = child.getAttributeNode("spam")
+ self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNode,
+ None)
child.removeAttributeNode(node)
self.confirm(len(child.attributes) == 0
and child.getAttributeNode("spam") is None)
+ dom2 = Document()
+ child2 = dom2.appendChild(dom.createElement("foo"))
+ self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNode,
+ node)
dom.unlink()
+ def testHasAttribute(self):
+ dom = Document()
+ child = dom.appendChild(dom.createElement("foo"))
+ child.setAttribute("spam", "jam")
+ self.confirm(child.hasAttribute("spam"))
+
def testChangeAttr(self):
dom = parseString("<abc/>")
el = dom.documentElement
@@ -592,7 +604,13 @@ class MinidomTest(unittest.TestCase):
def testFirstChild(self): pass
- def testHasChildNodes(self): pass
+ def testHasChildNodes(self):
+ dom = parseString("<doc><foo/></doc>")
+ doc = dom.documentElement
+ self.assertTrue(dom.hasChildNodes())
+ dom2 = parseString("<doc/>")
+ doc2 = dom2.documentElement
+ self.assertFalse(doc2.hasChildNodes())
def _testCloneElementCopiesAttributes(self, e1, e2, test):
attrs1 = e1.attributes
@@ -1041,41 +1059,6 @@ class MinidomTest(unittest.TestCase):
"test NodeList.item()")
doc.unlink()
- def testSAX2DOM(self):
- from xml.dom import pulldom
-
- sax2dom = pulldom.SAX2DOM()
- sax2dom.startDocument()
- sax2dom.startElement("doc", {})
- sax2dom.characters("text")
- sax2dom.startElement("subelm", {})
- sax2dom.characters("text")
- sax2dom.endElement("subelm")
- sax2dom.characters("text")
- sax2dom.endElement("doc")
- sax2dom.endDocument()
-
- doc = sax2dom.document
- root = doc.documentElement
- (text1, elm1, text2) = root.childNodes
- text3 = elm1.childNodes[0]
-
- self.confirm(text1.previousSibling is None and
- text1.nextSibling is elm1 and
- elm1.previousSibling is text1 and
- elm1.nextSibling is text2 and
- text2.previousSibling is elm1 and
- text2.nextSibling is None and
- text3.previousSibling is None and
- text3.nextSibling is None, "testSAX2DOM - siblings")
-
- self.confirm(root.parentNode is doc and
- text1.parentNode is root and
- elm1.parentNode is root and
- text2.parentNode is root and
- text3.parentNode is elm1, "testSAX2DOM - parents")
- doc.unlink()
-
def testEncodings(self):
doc = parseString('<foo>&#x20ac;</foo>')
self.assertEqual(doc.toxml(),
@@ -1084,6 +1067,11 @@ class MinidomTest(unittest.TestCase):
b'<?xml version="1.0" encoding="utf-8"?><foo>\xe2\x82\xac</foo>')
self.assertEqual(doc.toxml('iso-8859-15'),
b'<?xml version="1.0" encoding="iso-8859-15"?><foo>\xa4</foo>')
+ self.assertEqual(doc.toxml('us-ascii'),
+ b'<?xml version="1.0" encoding="us-ascii"?><foo>&#8364;</foo>')
+ self.assertEqual(doc.toxml('utf-16'),
+ '<?xml version="1.0" encoding="utf-16"?>'
+ '<foo>\u20ac</foo>'.encode('utf-16'))
# Verify that character decoding errors raise exceptions instead
# of crashing
@@ -1522,12 +1510,21 @@ class MinidomTest(unittest.TestCase):
doc.appendChild(doc.createComment("foo--bar"))
self.assertRaises(ValueError, doc.toxml)
+
def testEmptyXMLNSValue(self):
doc = parseString("<element xmlns=''>\n"
"<foo/>\n</element>")
doc2 = parseString(doc.toxml())
self.confirm(doc2.namespaceURI == xml.dom.EMPTY_NAMESPACE)
+ def testDocRemoveChild(self):
+ doc = parse(tstfile)
+ title_tag = doc.documentElement.getElementsByTagName("TITLE")[0]
+ self.assertRaises( xml.dom.NotFoundErr, doc.removeChild, title_tag)
+ num_children_before = len(doc.childNodes)
+ doc.removeChild(doc.childNodes[0])
+ num_children_after = len(doc.childNodes)
+ self.assertTrue(num_children_after == num_children_before - 1)
def test_main():
run_unittest(MinidomTest)
diff --git a/Lib/test/test_mmap.py b/Lib/test/test_mmap.py
index 4407c5c..d066368 100644
--- a/Lib/test/test_mmap.py
+++ b/Lib/test/test_mmap.py
@@ -417,6 +417,35 @@ class MmapTests(unittest.TestCase):
m[x] = b
self.assertEqual(m[x], b)
+ def test_read_all(self):
+ m = mmap.mmap(-1, 16)
+ self.addCleanup(m.close)
+
+ # With no parameters, or None or a negative argument, reads all
+ m.write(bytes(range(16)))
+ m.seek(0)
+ self.assertEqual(m.read(), bytes(range(16)))
+ m.seek(8)
+ self.assertEqual(m.read(), bytes(range(8, 16)))
+ m.seek(16)
+ self.assertEqual(m.read(), b'')
+ m.seek(3)
+ self.assertEqual(m.read(None), bytes(range(3, 16)))
+ m.seek(4)
+ self.assertEqual(m.read(-1), bytes(range(4, 16)))
+ m.seek(5)
+ self.assertEqual(m.read(-2), bytes(range(5, 16)))
+ m.seek(9)
+ self.assertEqual(m.read(-42), bytes(range(9, 16)))
+
+ def test_read_invalid_arg(self):
+ m = mmap.mmap(-1, 16)
+ self.addCleanup(m.close)
+
+ self.assertRaises(TypeError, m.read, 'foo')
+ self.assertRaises(TypeError, m.read, 5.5)
+ self.assertRaises(TypeError, m.read, [1, 2, 3])
+
def test_extended_getslice(self):
# Test extended slicing by comparing with list slicing.
s = bytes(reversed(range(256)))
@@ -543,8 +572,7 @@ class MmapTests(unittest.TestCase):
f.close()
def test_error(self):
- self.assertTrue(issubclass(mmap.error, EnvironmentError))
- self.assertIn("mmap.error", str(mmap.error))
+ self.assertIs(mmap.error, OSError)
def test_io_methods(self):
data = b"0123456789"
diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py
index 15836ca..e5a2525 100644
--- a/Lib/test/test_module.py
+++ b/Lib/test/test_module.py
@@ -5,6 +5,15 @@ from test.support import run_unittest, gc_collect
import sys
ModuleType = type(sys)
+class FullLoader:
+ @classmethod
+ def module_repr(cls, m):
+ return "<module '{}' (crafted)>".format(m.__name__)
+
+class BareLoader:
+ pass
+
+
class ModuleTests(unittest.TestCase):
def test_uninitialized(self):
# An uninitialized module has no __dict__ or __name__,
@@ -70,16 +79,100 @@ class ModuleTests(unittest.TestCase):
m = ModuleType("foo")
m.destroyed = destroyed
s = """class A:
- def __del__(self, destroyed=destroyed):
- destroyed.append(1)
-a = A()"""
+ def __init__(self, l):
+ self.l = l
+ def __del__(self):
+ self.l.append(1)
+a = A(destroyed)"""
exec(s, m.__dict__)
del m
gc_collect()
self.assertEqual(destroyed, [1])
+ def test_module_repr_minimal(self):
+ # reprs when modules have no __file__, __name__, or __loader__
+ m = ModuleType('foo')
+ del m.__name__
+ self.assertEqual(repr(m), "<module '?'>")
+
+ def test_module_repr_with_name(self):
+ m = ModuleType('foo')
+ self.assertEqual(repr(m), "<module 'foo'>")
+
+ def test_module_repr_with_name_and_filename(self):
+ m = ModuleType('foo')
+ m.__file__ = '/tmp/foo.py'
+ self.assertEqual(repr(m), "<module 'foo' from '/tmp/foo.py'>")
+
+ def test_module_repr_with_filename_only(self):
+ m = ModuleType('foo')
+ del m.__name__
+ m.__file__ = '/tmp/foo.py'
+ self.assertEqual(repr(m), "<module '?' from '/tmp/foo.py'>")
+
+ def test_module_repr_with_bare_loader_but_no_name(self):
+ m = ModuleType('foo')
+ del m.__name__
+ # Yes, a class not an instance.
+ m.__loader__ = BareLoader
+ self.assertEqual(
+ repr(m), "<module '?' (<class 'test.test_module.BareLoader'>)>")
+
+ def test_module_repr_with_full_loader_but_no_name(self):
+ # m.__loader__.module_repr() will fail because the module has no
+ # m.__name__. This exception will get suppressed and instead the
+ # loader's repr will be used.
+ m = ModuleType('foo')
+ del m.__name__
+ # Yes, a class not an instance.
+ m.__loader__ = FullLoader
+ self.assertEqual(
+ repr(m), "<module '?' (<class 'test.test_module.FullLoader'>)>")
+
+ def test_module_repr_with_bare_loader(self):
+ m = ModuleType('foo')
+ # Yes, a class not an instance.
+ m.__loader__ = BareLoader
+ self.assertEqual(
+ repr(m), "<module 'foo' (<class 'test.test_module.BareLoader'>)>")
+
+ def test_module_repr_with_full_loader(self):
+ m = ModuleType('foo')
+ # Yes, a class not an instance.
+ m.__loader__ = FullLoader
+ self.assertEqual(
+ repr(m), "<module 'foo' (crafted)>")
+
+ def test_module_repr_with_bare_loader_and_filename(self):
+ # Because the loader has no module_repr(), use the file name.
+ m = ModuleType('foo')
+ # Yes, a class not an instance.
+ m.__loader__ = BareLoader
+ m.__file__ = '/tmp/foo.py'
+ self.assertEqual(repr(m), "<module 'foo' from '/tmp/foo.py'>")
+
+ def test_module_repr_with_full_loader_and_filename(self):
+ # Even though the module has an __file__, use __loader__.module_repr()
+ m = ModuleType('foo')
+ # Yes, a class not an instance.
+ m.__loader__ = FullLoader
+ m.__file__ = '/tmp/foo.py'
+ self.assertEqual(repr(m), "<module 'foo' (crafted)>")
+
+ def test_module_repr_builtin(self):
+ self.assertEqual(repr(sys), "<module 'sys' (built-in)>")
+
+ def test_module_repr_source(self):
+ r = repr(unittest)
+ self.assertEqual(r[:25], "<module 'unittest' from '")
+ self.assertEqual(r[-13:], "__init__.py'>")
+
+ # frozen and namespace module reprs are tested in importlib.
+
+
def test_main():
run_unittest(ModuleTests)
+
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_modulefinder.py b/Lib/test/test_modulefinder.py
index a184217..53ea232 100644
--- a/Lib/test/test_modulefinder.py
+++ b/Lib/test/test_modulefinder.py
@@ -1,7 +1,7 @@
-import __future__
import os
+import errno
+import shutil
import unittest
-import distutils.dir_util
import tempfile
from test import support
@@ -9,7 +9,7 @@ from test import support
import modulefinder
TEST_DIR = tempfile.mkdtemp()
-TEST_PATH = [TEST_DIR, os.path.dirname(__future__.__file__)]
+TEST_PATH = [TEST_DIR, os.path.dirname(tempfile.__file__)]
# Each test description is a list of 5 items:
#
@@ -196,12 +196,29 @@ a/module.py
from . import bar
"""]
+relative_import_test_4 = [
+ "a.module",
+ ["a", "a.module"],
+ [],
+ [],
+ """\
+a/__init__.py
+ def foo(): pass
+a/module.py
+ from . import *
+"""]
+
+
def open_file(path):
- ##print "#", os.path.abspath(path)
dirname = os.path.dirname(path)
- distutils.dir_util.mkpath(dirname)
+ try:
+ os.makedirs(dirname)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
return open(path, "w")
+
def create_package(source):
ofi = None
try:
@@ -216,6 +233,7 @@ def create_package(source):
if ofi:
ofi.close()
+
class ModuleFinderTest(unittest.TestCase):
def _do_test(self, info, report=False):
import_this, modules, missing, maybe_missing, source = info
@@ -234,19 +252,17 @@ class ModuleFinderTest(unittest.TestCase):
## import traceback; traceback.print_exc()
## sys.path = opath
## return
- modules = set(modules)
- found = set(mf.modules.keys())
- more = list(found - modules)
- less = list(modules - found)
+ modules = sorted(set(modules))
+ found = sorted(mf.modules)
# check if we found what we expected, not more, not less
- self.assertEqual((more, less), ([], []))
+ self.assertEqual(found, modules)
# check for missing and maybe missing modules
bad, maybe = mf.any_missing_maybe()
self.assertEqual(bad, missing)
self.assertEqual(maybe, maybe_missing)
finally:
- distutils.dir_util.remove_tree(TEST_DIR)
+ shutil.rmtree(TEST_DIR)
def test_package(self):
self._do_test(package_test)
@@ -254,25 +270,26 @@ class ModuleFinderTest(unittest.TestCase):
def test_maybe(self):
self._do_test(maybe_test)
- if getattr(__future__, "absolute_import", None):
+ def test_maybe_new(self):
+ self._do_test(maybe_test_new)
+
+ def test_absolute_imports(self):
+ self._do_test(absolute_import_test)
- def test_maybe_new(self):
- self._do_test(maybe_test_new)
+ def test_relative_imports(self):
+ self._do_test(relative_import_test)
- def test_absolute_imports(self):
- self._do_test(absolute_import_test)
+ def test_relative_imports_2(self):
+ self._do_test(relative_import_test_2)
- def test_relative_imports(self):
- self._do_test(relative_import_test)
+ def test_relative_imports_3(self):
+ self._do_test(relative_import_test_3)
- def test_relative_imports_2(self):
- self._do_test(relative_import_test_2)
+ def test_relative_imports_4(self):
+ self._do_test(relative_import_test_4)
- def test_relative_imports_3(self):
- self._do_test(relative_import_test_3)
def test_main():
- distutils.log.set_threshold(distutils.log.WARN)
support.run_unittest(ModuleFinderTest)
if __name__ == "__main__":
diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py
index 86c68dc..feb7bd5 100644
--- a/Lib/test/test_multibytecodec.py
+++ b/Lib/test/test_multibytecodec.py
@@ -247,20 +247,16 @@ class Test_ISO2022(unittest.TestCase):
self.assertFalse(any(x > 0x80 for x in e))
def test_bug1572832(self):
- if sys.maxunicode >= 0x10000:
- myunichr = chr
- else:
- myunichr = lambda x: chr(0xD7C0+(x>>10)) + chr(0xDC00+(x&0x3FF))
-
for x in range(0x10000, 0x110000):
# Any ISO 2022 codec will cause the segfault
- myunichr(x).encode('iso_2022_jp', 'ignore')
+ chr(x).encode('iso_2022_jp', 'ignore')
class TestStateful(unittest.TestCase):
text = '\u4E16\u4E16'
encoding = 'iso-2022-jp'
expected = b'\x1b$B@$@$'
- expected_reset = b'\x1b$B@$@$\x1b(B'
+ reset = b'\x1b(B'
+ expected_reset = expected + reset
def test_encode(self):
self.assertEqual(self.text.encode(self.encoding), self.expected_reset)
@@ -271,6 +267,8 @@ class TestStateful(unittest.TestCase):
encoder.encode(char)
for char in self.text)
self.assertEqual(output, self.expected)
+ self.assertEqual(encoder.encode('', final=True), self.reset)
+ self.assertEqual(encoder.encode('', final=True), b'')
def test_incrementalencoder_final(self):
encoder = codecs.getincrementalencoder(self.encoding)()
@@ -279,12 +277,14 @@ class TestStateful(unittest.TestCase):
encoder.encode(char, index == last_index)
for index, char in enumerate(self.text))
self.assertEqual(output, self.expected_reset)
+ self.assertEqual(encoder.encode('', final=True), b'')
class TestHZStateful(TestStateful):
text = '\u804a\u804a'
encoding = 'hz'
expected = b'~{ADAD'
- expected_reset = b'~{ADAD~}'
+ reset = b'~}'
+ expected_reset = expected + reset
def test_main():
support.run_unittest(__name__)
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py
index fa4865b..c38b314 100644
--- a/Lib/test/test_multiprocessing.py
+++ b/Lib/test/test_multiprocessing.py
@@ -8,6 +8,7 @@ import unittest
import queue as pyqueue
import time
import io
+import itertools
import sys
import os
import gc
@@ -17,6 +18,7 @@ import array
import socket
import random
import logging
+import struct
import test.support
import test.script_helper
@@ -84,6 +86,13 @@ HAVE_GETVALUE = not getattr(_multiprocessing,
WIN32 = (sys.platform == "win32")
+from multiprocessing.connection import wait
+
+def wait_for_handle(handle, timeout):
+ if timeout is not None and timeout < 0.0:
+ timeout = None
+ return wait([handle], timeout)
+
try:
MAXFD = os.sysconf("SC_OPEN_MAX")
except:
@@ -197,6 +206,18 @@ class _TestProcess(BaseTestCase):
self.assertEqual(current.ident, os.getpid())
self.assertEqual(current.exitcode, None)
+ def test_daemon_argument(self):
+ if self.TYPE == "threads":
+ return
+
+ # By default uses the current process's daemon flag.
+ proc0 = self.Process(target=self._test)
+ self.assertEqual(proc0.daemon, self.current_process().daemon)
+ proc1 = self.Process(target=self._test, daemon=True)
+ self.assertTrue(proc1.daemon)
+ proc2 = self.Process(target=self._test, daemon=False)
+ self.assertFalse(proc2.daemon)
+
@classmethod
def _test(cls, q, *args, **kwds):
current = cls.current_process()
@@ -262,9 +283,18 @@ class _TestProcess(BaseTestCase):
self.assertIn(p, self.active_children())
self.assertEqual(p.exitcode, None)
+ join = TimingWrapper(p.join)
+
+ self.assertEqual(join(0), None)
+ self.assertTimingAlmostEqual(join.elapsed, 0.0)
+ self.assertEqual(p.is_alive(), True)
+
+ self.assertEqual(join(-1), None)
+ self.assertTimingAlmostEqual(join.elapsed, 0.0)
+ self.assertEqual(p.is_alive(), True)
+
p.terminate()
- join = TimingWrapper(p.join)
self.assertEqual(join(), None)
self.assertTimingAlmostEqual(join.elapsed, 0.0)
@@ -329,6 +359,26 @@ class _TestProcess(BaseTestCase):
]
self.assertEqual(result, expected)
+ @classmethod
+ def _test_sentinel(cls, event):
+ event.wait(10.0)
+
+ def test_sentinel(self):
+ if self.TYPE == "threads":
+ return
+ event = self.Event()
+ p = self.Process(target=self._test_sentinel, args=(event,))
+ with self.assertRaises(ValueError):
+ p.sentinel
+ p.start()
+ self.addCleanup(p.join)
+ sentinel = p.sentinel
+ self.assertIsInstance(sentinel, int)
+ self.assertFalse(wait_for_handle(sentinel, timeout=0.0))
+ event.set()
+ p.join()
+ self.assertTrue(wait_for_handle(sentinel, timeout=DELTA))
+
#
#
#
@@ -868,6 +918,104 @@ class _TestCondition(BaseTestCase):
self.assertEqual(res, False)
self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1)
+ @classmethod
+ def _test_waitfor_f(cls, cond, state):
+ with cond:
+ state.value = 0
+ cond.notify()
+ result = cond.wait_for(lambda : state.value==4)
+ if not result or state.value != 4:
+ sys.exit(1)
+
+ @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes')
+ def test_waitfor(self):
+ # based on test in test/lock_tests.py
+ cond = self.Condition()
+ state = self.Value('i', -1)
+
+ p = self.Process(target=self._test_waitfor_f, args=(cond, state))
+ p.daemon = True
+ p.start()
+
+ with cond:
+ result = cond.wait_for(lambda : state.value==0)
+ self.assertTrue(result)
+ self.assertEqual(state.value, 0)
+
+ for i in range(4):
+ time.sleep(0.01)
+ with cond:
+ state.value += 1
+ cond.notify()
+
+ p.join(5)
+ self.assertFalse(p.is_alive())
+ self.assertEqual(p.exitcode, 0)
+
+ @classmethod
+ def _test_waitfor_timeout_f(cls, cond, state, success, sem):
+ sem.release()
+ with cond:
+ expected = 0.1
+ dt = time.time()
+ result = cond.wait_for(lambda : state.value==4, timeout=expected)
+ dt = time.time() - dt
+ # borrow logic in assertTimeout() from test/lock_tests.py
+ if not result and expected * 0.6 < dt < expected * 10.0:
+ success.value = True
+
+ @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes')
+ def test_waitfor_timeout(self):
+ # based on test in test/lock_tests.py
+ cond = self.Condition()
+ state = self.Value('i', 0)
+ success = self.Value('i', False)
+ sem = self.Semaphore(0)
+
+ p = self.Process(target=self._test_waitfor_timeout_f,
+ args=(cond, state, success, sem))
+ p.daemon = True
+ p.start()
+ self.assertTrue(sem.acquire(timeout=10))
+
+ # Only increment 3 times, so state == 4 is never reached.
+ for i in range(3):
+ time.sleep(0.01)
+ with cond:
+ state.value += 1
+ cond.notify()
+
+ p.join(5)
+ self.assertTrue(success.value)
+
+ @classmethod
+ def _test_wait_result(cls, c, pid):
+ with c:
+ c.notify()
+ time.sleep(1)
+ if pid is not None:
+ os.kill(pid, signal.SIGINT)
+
+ def test_wait_result(self):
+ if isinstance(self, ProcessesMixin) and sys.platform != 'win32':
+ pid = os.getpid()
+ else:
+ pid = None
+
+ c = self.Condition()
+ with c:
+ self.assertFalse(c.wait(0))
+ self.assertFalse(c.wait(0.1))
+
+ p = self.Process(target=self._test_wait_result, args=(c, pid))
+ p.start()
+
+ self.assertTrue(c.wait(10))
+ if pid is not None:
+ self.assertRaises(KeyboardInterrupt, c.wait, 10)
+
+ p.join()
+
class _TestEvent(BaseTestCase):
@@ -911,6 +1059,340 @@ class _TestEvent(BaseTestCase):
self.assertEqual(wait(), True)
#
+# Tests for Barrier - adapted from tests in test/lock_tests.py
+#
+
+# Many of the tests for threading.Barrier use a list as an atomic
+# counter: a value is appended to increment the counter, and the
+# length of the list gives the value. We use the class DummyList
+# for the same purpose.
+
+class _DummyList(object):
+
+ def __init__(self):
+ wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i'))
+ lock = multiprocessing.Lock()
+ self.__setstate__((wrapper, lock))
+ self._lengthbuf[0] = 0
+
+ def __setstate__(self, state):
+ (self._wrapper, self._lock) = state
+ self._lengthbuf = self._wrapper.create_memoryview().cast('i')
+
+ def __getstate__(self):
+ return (self._wrapper, self._lock)
+
+ def append(self, _):
+ with self._lock:
+ self._lengthbuf[0] += 1
+
+ def __len__(self):
+ with self._lock:
+ return self._lengthbuf[0]
+
+def _wait():
+ # A crude wait/yield function not relying on synchronization primitives.
+ time.sleep(0.01)
+
+
+class Bunch(object):
+ """
+ A bunch of threads.
+ """
+ def __init__(self, namespace, f, args, n, wait_before_exit=False):
+ """
+ Construct a bunch of `n` threads running the same function `f`.
+ If `wait_before_exit` is True, the threads won't terminate until
+ do_finish() is called.
+ """
+ self.f = f
+ self.args = args
+ self.n = n
+ self.started = namespace.DummyList()
+ self.finished = namespace.DummyList()
+ self._can_exit = namespace.Event()
+ if not wait_before_exit:
+ self._can_exit.set()
+ for i in range(n):
+ p = namespace.Process(target=self.task)
+ p.daemon = True
+ p.start()
+
+ def task(self):
+ pid = os.getpid()
+ self.started.append(pid)
+ try:
+ self.f(*self.args)
+ finally:
+ self.finished.append(pid)
+ self._can_exit.wait(30)
+ assert self._can_exit.is_set()
+
+ def wait_for_started(self):
+ while len(self.started) < self.n:
+ _wait()
+
+ def wait_for_finished(self):
+ while len(self.finished) < self.n:
+ _wait()
+
+ def do_finish(self):
+ self._can_exit.set()
+
+
+class AppendTrue(object):
+ def __init__(self, obj):
+ self.obj = obj
+ def __call__(self):
+ self.obj.append(True)
+
+
+class _TestBarrier(BaseTestCase):
+ """
+ Tests for Barrier objects.
+ """
+ N = 5
+ defaultTimeout = 30.0 # XXX Slow Windows buildbots need generous timeout
+
+ def setUp(self):
+ self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout)
+
+ def tearDown(self):
+ self.barrier.abort()
+ self.barrier = None
+
+ def DummyList(self):
+ if self.TYPE == 'threads':
+ return []
+ elif self.TYPE == 'manager':
+ return self.manager.list()
+ else:
+ return _DummyList()
+
+ def run_threads(self, f, args):
+ b = Bunch(self, f, args, self.N-1)
+ f(*args)
+ b.wait_for_finished()
+
+ @classmethod
+ def multipass(cls, barrier, results, n):
+ m = barrier.parties
+ assert m == cls.N
+ for i in range(n):
+ results[0].append(True)
+ assert len(results[1]) == i * m
+ barrier.wait()
+ results[1].append(True)
+ assert len(results[0]) == (i + 1) * m
+ barrier.wait()
+ try:
+ assert barrier.n_waiting == 0
+ except NotImplementedError:
+ pass
+ assert not barrier.broken
+
+ def test_barrier(self, passes=1):
+ """
+ Test that a barrier is passed in lockstep
+ """
+ results = [self.DummyList(), self.DummyList()]
+ self.run_threads(self.multipass, (self.barrier, results, passes))
+
+ def test_barrier_10(self):
+ """
+ Test that a barrier works for 10 consecutive runs
+ """
+ return self.test_barrier(10)
+
+ @classmethod
+ def _test_wait_return_f(cls, barrier, queue):
+ res = barrier.wait()
+ queue.put(res)
+
+ def test_wait_return(self):
+ """
+ test the return value from barrier.wait
+ """
+ queue = self.Queue()
+ self.run_threads(self._test_wait_return_f, (self.barrier, queue))
+ results = [queue.get() for i in range(self.N)]
+ self.assertEqual(results.count(0), 1)
+
+ @classmethod
+ def _test_action_f(cls, barrier, results):
+ barrier.wait()
+ if len(results) != 1:
+ raise RuntimeError
+
+ def test_action(self):
+ """
+ Test the 'action' callback
+ """
+ results = self.DummyList()
+ barrier = self.Barrier(self.N, action=AppendTrue(results))
+ self.run_threads(self._test_action_f, (barrier, results))
+ self.assertEqual(len(results), 1)
+
+ @classmethod
+ def _test_abort_f(cls, barrier, results1, results2):
+ try:
+ i = barrier.wait()
+ if i == cls.N//2:
+ raise RuntimeError
+ barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ barrier.abort()
+
+ def test_abort(self):
+ """
+ Test that an abort will put the barrier in a broken state
+ """
+ results1 = self.DummyList()
+ results2 = self.DummyList()
+ self.run_threads(self._test_abort_f,
+ (self.barrier, results1, results2))
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(self.barrier.broken)
+
+ @classmethod
+ def _test_reset_f(cls, barrier, results1, results2, results3):
+ i = barrier.wait()
+ if i == cls.N//2:
+ # Wait until the other threads are all in the barrier.
+ while barrier.n_waiting < cls.N-1:
+ time.sleep(0.001)
+ barrier.reset()
+ else:
+ try:
+ barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ # Now, pass the barrier again
+ barrier.wait()
+ results3.append(True)
+
+ def test_reset(self):
+ """
+ Test that a 'reset' on a barrier frees the waiting threads
+ """
+ results1 = self.DummyList()
+ results2 = self.DummyList()
+ results3 = self.DummyList()
+ self.run_threads(self._test_reset_f,
+ (self.barrier, results1, results2, results3))
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+ @classmethod
+ def _test_abort_and_reset_f(cls, barrier, barrier2,
+ results1, results2, results3):
+ try:
+ i = barrier.wait()
+ if i == cls.N//2:
+ raise RuntimeError
+ barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ barrier.abort()
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ if barrier2.wait() == cls.N//2:
+ barrier.reset()
+ barrier2.wait()
+ barrier.wait()
+ results3.append(True)
+
+ def test_abort_and_reset(self):
+ """
+ Test that a barrier can be reset after being broken.
+ """
+ results1 = self.DummyList()
+ results2 = self.DummyList()
+ results3 = self.DummyList()
+ barrier2 = self.Barrier(self.N)
+
+ self.run_threads(self._test_abort_and_reset_f,
+ (self.barrier, barrier2, results1, results2, results3))
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+ @classmethod
+ def _test_timeout_f(cls, barrier, results):
+ i = barrier.wait()
+ if i == cls.N//2:
+ # One thread is late!
+ time.sleep(1.0)
+ try:
+ barrier.wait(0.5)
+ except threading.BrokenBarrierError:
+ results.append(True)
+
+ def test_timeout(self):
+ """
+ Test wait(timeout)
+ """
+ results = self.DummyList()
+ self.run_threads(self._test_timeout_f, (self.barrier, results))
+ self.assertEqual(len(results), self.barrier.parties)
+
+ @classmethod
+ def _test_default_timeout_f(cls, barrier, results):
+ i = barrier.wait(cls.defaultTimeout)
+ if i == cls.N//2:
+ # One thread is later than the default timeout
+ time.sleep(1.0)
+ try:
+ barrier.wait()
+ except threading.BrokenBarrierError:
+ results.append(True)
+
+ def test_default_timeout(self):
+ """
+ Test the barrier's default timeout
+ """
+ barrier = self.Barrier(self.N, timeout=0.5)
+ results = self.DummyList()
+ self.run_threads(self._test_default_timeout_f, (barrier, results))
+ self.assertEqual(len(results), barrier.parties)
+
+ def test_single_thread(self):
+ b = self.Barrier(1)
+ b.wait()
+ b.wait()
+
+ @classmethod
+ def _test_thousand_f(cls, barrier, passes, conn, lock):
+ for i in range(passes):
+ barrier.wait()
+ with lock:
+ conn.send(i)
+
+ def test_thousand(self):
+ if self.TYPE == 'manager':
+ return
+ passes = 1000
+ lock = self.Lock()
+ conn, child_conn = self.Pipe(False)
+ for j in range(self.N):
+ p = self.Process(target=self._test_thousand_f,
+ args=(self.barrier, passes, child_conn, lock))
+ p.start()
+
+ for i in range(passes):
+ for j in range(self.N):
+ self.assertEqual(conn.recv(), i)
+
+#
#
#
@@ -1130,6 +1612,9 @@ def sqr(x, wait=0.0):
time.sleep(wait)
return x*x
+def mul(x, y):
+ return x*y
+
class _TestPool(BaseTestCase):
def test_apply(self):
@@ -1143,6 +1628,37 @@ class _TestPool(BaseTestCase):
self.assertEqual(pmap(sqr, list(range(100)), chunksize=20),
list(map(sqr, list(range(100)))))
+ def test_starmap(self):
+ psmap = self.pool.starmap
+ tuples = list(zip(range(10), range(9,-1, -1)))
+ self.assertEqual(psmap(mul, tuples),
+ list(itertools.starmap(mul, tuples)))
+ tuples = list(zip(range(100), range(99,-1, -1)))
+ self.assertEqual(psmap(mul, tuples, chunksize=20),
+ list(itertools.starmap(mul, tuples)))
+
+ def test_starmap_async(self):
+ tuples = list(zip(range(100), range(99,-1, -1)))
+ self.assertEqual(self.pool.starmap_async(mul, tuples).get(),
+ list(itertools.starmap(mul, tuples)))
+
+ def test_map_async(self):
+ self.assertEqual(self.pool.map_async(sqr, list(range(10))).get(),
+ list(map(sqr, list(range(10)))))
+
+ def test_map_async_callbacks(self):
+ call_args = self.manager.list() if self.TYPE == 'manager' else []
+ self.pool.map_async(int, ['1'],
+ callback=call_args.append,
+ error_callback=call_args.append).wait()
+ self.assertEqual(1, len(call_args))
+ self.assertEqual([1], call_args[0])
+ self.pool.map_async(int, ['a'],
+ callback=call_args.append,
+ error_callback=call_args.append).wait()
+ self.assertEqual(2, len(call_args))
+ self.assertIsInstance(call_args[1], ValueError)
+
def test_map_chunksize(self):
try:
self.pool.map_async(sqr, [], chunksize=1).get(timeout=TIMEOUT1)
@@ -1221,6 +1737,16 @@ class _TestPool(BaseTestCase):
p.close()
p.join()
+ def test_context(self):
+ if self.TYPE == 'processes':
+ L = list(range(10))
+ expected = [sqr(i) for i in L]
+ with multiprocessing.Pool(2) as p:
+ r = p.map_async(sqr, L)
+ self.assertEqual(r.get(), expected)
+ print(p._state)
+ self.assertRaises(ValueError, p.map_async, sqr, L)
+
def raising():
raise KeyError("key")
@@ -1322,6 +1848,11 @@ class _TestZZZNumberOfObjects(BaseTestCase):
# run after all the other tests for the manager. It tests that
# there have been no "reference leaks" for the manager's shared
# objects. Note the comment in _TestPool.test_terminate().
+
+ # If some other test using ManagerMixin.manager fails, then the
+ # raised exception may keep alive a frame which holds a reference
+ # to a managed object. This will cause test_number_of_objects to
+ # also fail.
ALLOWED_TYPES = ('manager',)
def test_number_of_objects(self):
@@ -1376,7 +1907,27 @@ class _TestMyManager(BaseTestCase):
def test_mymanager(self):
manager = MyManager()
manager.start()
+ self.common(manager)
+ manager.shutdown()
+
+ # If the manager process exited cleanly then the exitcode
+ # will be zero. Otherwise (after a short timeout)
+ # terminate() is used, resulting in an exitcode of -SIGTERM.
+ self.assertEqual(manager._process.exitcode, 0)
+
+ def test_mymanager_context(self):
+ with MyManager() as manager:
+ self.common(manager)
+ self.assertEqual(manager._process.exitcode, 0)
+
+ def test_mymanager_context_prestarted(self):
+ manager = MyManager()
+ manager.start()
+ with manager:
+ self.common(manager)
+ self.assertEqual(manager._process.exitcode, 0)
+ def common(self, manager):
foo = manager.Foo()
bar = manager.Bar()
baz = manager.baz()
@@ -1399,7 +1950,6 @@ class _TestMyManager(BaseTestCase):
self.assertEqual(list(baz), [i*i for i in range(10)])
- manager.shutdown()
#
# Test of connecting to a remote server and using xmlrpclib for serialization
@@ -1570,6 +2120,9 @@ class _TestConnection(BaseTestCase):
self.assertEqual(poll(), False)
self.assertTimingAlmostEqual(poll.elapsed, 0)
+ self.assertEqual(poll(-1), False)
+ self.assertTimingAlmostEqual(poll.elapsed, 0)
+
self.assertEqual(poll(TIMEOUT1), False)
self.assertTimingAlmostEqual(poll.elapsed, TIMEOUT1)
@@ -1756,6 +2309,43 @@ class _TestConnection(BaseTestCase):
self.assertRaises(RuntimeError, reduction.recv_handle, conn)
p.join()
+ def test_context(self):
+ a, b = self.Pipe()
+
+ with a, b:
+ a.send(1729)
+ self.assertEqual(b.recv(), 1729)
+ if self.TYPE == 'processes':
+ self.assertFalse(a.closed)
+ self.assertFalse(b.closed)
+
+ if self.TYPE == 'processes':
+ self.assertTrue(a.closed)
+ self.assertTrue(b.closed)
+ self.assertRaises(IOError, a.recv)
+ self.assertRaises(IOError, b.recv)
+
+class _TestListener(BaseTestCase):
+
+ ALLOWED_TYPES = ('processes',)
+
+ def test_multiple_bind(self):
+ for family in self.connection.families:
+ l = self.connection.Listener(family=family)
+ self.addCleanup(l.close)
+ self.assertRaises(OSError, self.connection.Listener,
+ l.address, family)
+
+ def test_context(self):
+ with self.connection.Listener() as l:
+ with self.connection.Client(l.address) as c:
+ with l.accept() as d:
+ c.send(1729)
+ self.assertEqual(d.recv(), 1729)
+
+ if self.TYPE == 'processes':
+ self.assertRaises(IOError, l.accept)
+
class _TestListenerClient(BaseTestCase):
ALLOWED_TYPES = ('processes', 'threads')
@@ -1793,52 +2383,146 @@ class _TestListenerClient(BaseTestCase):
p.join()
l.close()
+ def test_issue16955(self):
+ for fam in self.connection.families:
+ l = self.connection.Listener(family=fam)
+ c = self.connection.Client(l.address)
+ a = l.accept()
+ a.send_bytes(b"hello")
+ self.assertTrue(c.poll(1))
+ a.close()
+ c.close()
+ l.close()
+
+class _TestPoll(unittest.TestCase):
+
+ ALLOWED_TYPES = ('processes', 'threads')
+
+ def test_empty_string(self):
+ a, b = self.Pipe()
+ self.assertEqual(a.poll(), False)
+ b.send_bytes(b'')
+ self.assertEqual(a.poll(), True)
+ self.assertEqual(a.poll(), True)
+
+ @classmethod
+ def _child_strings(cls, conn, strings):
+ for s in strings:
+ time.sleep(0.1)
+ conn.send_bytes(s)
+ conn.close()
+
+ def test_strings(self):
+ strings = (b'hello', b'', b'a', b'b', b'', b'bye', b'', b'lop')
+ a, b = self.Pipe()
+ p = self.Process(target=self._child_strings, args=(b, strings))
+ p.start()
+
+ for s in strings:
+ for i in range(200):
+ if a.poll(0.01):
+ break
+ x = a.recv_bytes()
+ self.assertEqual(s, x)
+
+ p.join()
+
+ @classmethod
+ def _child_boundaries(cls, r):
+ # Polling may "pull" a message in to the child process, but we
+ # don't want it to pull only part of a message, as that would
+ # corrupt the pipe for any other processes which might later
+ # read from it.
+ r.poll(5)
+
+ def test_boundaries(self):
+ r, w = self.Pipe(False)
+ p = self.Process(target=self._child_boundaries, args=(r,))
+ p.start()
+ time.sleep(2)
+ L = [b"first", b"second"]
+ for obj in L:
+ w.send_bytes(obj)
+ w.close()
+ p.join()
+ self.assertIn(r.recv_bytes(), L)
+
+ @classmethod
+ def _child_dont_merge(cls, b):
+ b.send_bytes(b'a')
+ b.send_bytes(b'b')
+ b.send_bytes(b'cd')
+
+ def test_dont_merge(self):
+ a, b = self.Pipe()
+ self.assertEqual(a.poll(0.0), False)
+ self.assertEqual(a.poll(0.1), False)
+
+ p = self.Process(target=self._child_dont_merge, args=(b,))
+ p.start()
+
+ self.assertEqual(a.recv_bytes(), b'a')
+ self.assertEqual(a.poll(1.0), True)
+ self.assertEqual(a.poll(1.0), True)
+ self.assertEqual(a.recv_bytes(), b'b')
+ self.assertEqual(a.poll(1.0), True)
+ self.assertEqual(a.poll(1.0), True)
+ self.assertEqual(a.poll(0.0), True)
+ self.assertEqual(a.recv_bytes(), b'cd')
+
+ p.join()
+
#
# Test of sending connection and socket objects between processes
#
-"""
+
+@unittest.skipUnless(HAS_REDUCTION, "test needs multiprocessing.reduction")
class _TestPicklingConnections(BaseTestCase):
ALLOWED_TYPES = ('processes',)
- def _listener(self, conn, families):
+ @classmethod
+ def tearDownClass(cls):
+ from multiprocessing.reduction import resource_sharer
+ resource_sharer.stop(timeout=5)
+
+ @classmethod
+ def _listener(cls, conn, families):
for fam in families:
- l = self.connection.Listener(family=fam)
+ l = cls.connection.Listener(family=fam)
conn.send(l.address)
new_conn = l.accept()
conn.send(new_conn)
+ new_conn.close()
+ l.close()
- if self.TYPE == 'processes':
- l = socket.socket()
- l.bind(('localhost', 0))
- conn.send(l.getsockname())
- l.listen(1)
- new_conn, addr = l.accept()
- conn.send(new_conn)
+ l = socket.socket()
+ l.bind(('localhost', 0))
+ l.listen(1)
+ conn.send(l.getsockname())
+ new_conn, addr = l.accept()
+ conn.send(new_conn)
+ new_conn.close()
+ l.close()
conn.recv()
- def _remote(self, conn):
+ @classmethod
+ def _remote(cls, conn):
for (address, msg) in iter(conn.recv, None):
- client = self.connection.Client(address)
+ client = cls.connection.Client(address)
client.send(msg.upper())
client.close()
- if self.TYPE == 'processes':
- address, msg = conn.recv()
- client = socket.socket()
- client.connect(address)
- client.sendall(msg.upper())
- client.close()
+ address, msg = conn.recv()
+ client = socket.socket()
+ client.connect(address)
+ client.sendall(msg.upper())
+ client.close()
conn.close()
def test_pickling(self):
- try:
- multiprocessing.allow_connection_pickling()
- except ImportError:
- return
-
families = self.connection.families
lconn, lconn0 = self.Pipe()
@@ -1862,16 +2546,19 @@ class _TestPicklingConnections(BaseTestCase):
rconn.send(None)
- if self.TYPE == 'processes':
- msg = latin('This connection uses a normal socket')
- address = lconn.recv()
- rconn.send((address, msg))
- if hasattr(socket, 'fromfd'):
- new_conn = lconn.recv()
- self.assertEqual(new_conn.recv(100), msg.upper())
- else:
- # XXX On Windows with Py2.6 need to backport fromfd()
- discard = lconn.recv_bytes()
+ msg = latin('This connection uses a normal socket')
+ address = lconn.recv()
+ rconn.send((address, msg))
+ new_conn = lconn.recv()
+ buf = []
+ while True:
+ s = new_conn.recv(100)
+ if not s:
+ break
+ buf.append(s)
+ buf = b''.join(buf)
+ self.assertEqual(buf, msg.upper())
+ new_conn.close()
lconn.send(None)
@@ -1880,7 +2567,46 @@ class _TestPicklingConnections(BaseTestCase):
lp.join()
rp.join()
-"""
+
+ @classmethod
+ def child_access(cls, conn):
+ w = conn.recv()
+ w.send('all is well')
+ w.close()
+
+ r = conn.recv()
+ msg = r.recv()
+ conn.send(msg*2)
+
+ conn.close()
+
+ def test_access(self):
+ # On Windows, if we do not specify a destination pid when
+ # using DupHandle then we need to be careful to use the
+ # correct access flags for DuplicateHandle(), or else
+ # DupHandle.detach() will raise PermissionError. For example,
+ # for a read only pipe handle we should use
+ # access=FILE_GENERIC_READ. (Unfortunately
+ # DUPLICATE_SAME_ACCESS does not work.)
+ conn, child_conn = self.Pipe()
+ p = self.Process(target=self.child_access, args=(child_conn,))
+ p.daemon = True
+ p.start()
+ child_conn.close()
+
+ r, w = self.Pipe(duplex=False)
+ conn.send(w)
+ w.close()
+ self.assertEqual(r.recv(), 'all is well')
+ r.close()
+
+ r, w = self.Pipe(duplex=False)
+ conn.send(r)
+ r.close()
+ w.send('foobar')
+ w.close()
+ self.assertEqual(conn.recv(), 'foobar'*2)
+
#
#
#
@@ -2175,9 +2901,15 @@ class TestInvalidHandle(unittest.TestCase):
@unittest.skipIf(WIN32, "skipped on Windows")
def test_invalid_handles(self):
- conn = _multiprocessing.Connection(44977608)
- self.assertRaises(IOError, conn.poll)
- self.assertRaises(IOError, _multiprocessing.Connection, -1)
+ conn = multiprocessing.connection.Connection(44977608)
+ try:
+ self.assertRaises((ValueError, IOError), conn.poll)
+ finally:
+ # Hack private attribute _handle to avoid printing an error
+ # in conn.__del__
+ conn._handle = None
+ self.assertRaises((ValueError, IOError),
+ multiprocessing.connection.Connection, -1)
#
# Functions used to create test cases from the base ones in this module
@@ -2196,10 +2928,12 @@ def create_test_cases(Mixin, type):
result = {}
glob = globals()
Type = type.capitalize()
+ ALL_TYPES = {'processes', 'threads', 'manager'}
for name in list(glob.keys()):
if name.startswith('_Test'):
base = glob[name]
+ assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES)
if type in base.ALLOWED_TYPES:
newname = 'With' + Type + name[1:]
class Temp(base, unittest.TestCase, Mixin):
@@ -2218,7 +2952,7 @@ class ProcessesMixin(object):
Process = multiprocessing.Process
locals().update(get_attributes(multiprocessing, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
- 'Condition', 'Event', 'Value', 'Array', 'RawValue',
+ 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue',
'RawArray', 'current_process', 'active_children', 'Pipe',
'connection', 'JoinableQueue', 'Pool'
)))
@@ -2233,7 +2967,7 @@ class ManagerMixin(object):
manager = object.__new__(multiprocessing.managers.SyncManager)
locals().update(get_attributes(manager, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
- 'Condition', 'Event', 'Value', 'Array', 'list', 'dict',
+ 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict',
'Namespace', 'JoinableQueue', 'Pool'
)))
@@ -2246,7 +2980,7 @@ class ThreadsMixin(object):
Process = multiprocessing.dummy.Process
locals().update(get_attributes(multiprocessing.dummy, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
- 'Condition', 'Event', 'Value', 'Array', 'current_process',
+ 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process',
'active_children', 'Pipe', 'connection', 'dict', 'list',
'Namespace', 'JoinableQueue', 'Pool'
)))
@@ -2298,6 +3032,7 @@ class TestInitializers(unittest.TestCase):
def tearDown(self):
self.mgr.shutdown()
+ self.mgr.join()
def test_manager_initializer(self):
m = multiprocessing.managers.SyncManager()
@@ -2305,6 +3040,7 @@ class TestInitializers(unittest.TestCase):
m.start(initializer, (self.ns,))
self.assertEqual(self.ns.test, 1)
m.shutdown()
+ m.join()
def test_pool_initializer(self):
self.assertRaises(TypeError, multiprocessing.Pool, initializer=1)
@@ -2337,6 +3073,8 @@ def _afunc(x):
def pool_in_process():
pool = multiprocessing.Pool(processes=4)
x = pool.map(_afunc, [1, 2, 3, 4, 5, 6, 7])
+ pool.close()
+ pool.join()
class _file_like(object):
def __init__(self, delegate):
@@ -2381,6 +3119,181 @@ class TestStdinBadfiledescriptor(unittest.TestCase):
assert sio.getvalue() == 'foo'
+class TestWait(unittest.TestCase):
+
+ @classmethod
+ def _child_test_wait(cls, w, slow):
+ for i in range(10):
+ if slow:
+ time.sleep(random.random()*0.1)
+ w.send((i, os.getpid()))
+ w.close()
+
+ def test_wait(self, slow=False):
+ from multiprocessing.connection import wait
+ readers = []
+ procs = []
+ messages = []
+
+ for i in range(4):
+ r, w = multiprocessing.Pipe(duplex=False)
+ p = multiprocessing.Process(target=self._child_test_wait, args=(w, slow))
+ p.daemon = True
+ p.start()
+ w.close()
+ readers.append(r)
+ procs.append(p)
+ self.addCleanup(p.join)
+
+ while readers:
+ for r in wait(readers):
+ try:
+ msg = r.recv()
+ except EOFError:
+ readers.remove(r)
+ r.close()
+ else:
+ messages.append(msg)
+
+ messages.sort()
+ expected = sorted((i, p.pid) for i in range(10) for p in procs)
+ self.assertEqual(messages, expected)
+
+ @classmethod
+ def _child_test_wait_socket(cls, address, slow):
+ s = socket.socket()
+ s.connect(address)
+ for i in range(10):
+ if slow:
+ time.sleep(random.random()*0.1)
+ s.sendall(('%s\n' % i).encode('ascii'))
+ s.close()
+
+ def test_wait_socket(self, slow=False):
+ from multiprocessing.connection import wait
+ l = socket.socket()
+ l.bind(('', 0))
+ l.listen(4)
+ addr = ('localhost', l.getsockname()[1])
+ readers = []
+ procs = []
+ dic = {}
+
+ for i in range(4):
+ p = multiprocessing.Process(target=self._child_test_wait_socket,
+ args=(addr, slow))
+ p.daemon = True
+ p.start()
+ procs.append(p)
+ self.addCleanup(p.join)
+
+ for i in range(4):
+ r, _ = l.accept()
+ readers.append(r)
+ dic[r] = []
+ l.close()
+
+ while readers:
+ for r in wait(readers):
+ msg = r.recv(32)
+ if not msg:
+ readers.remove(r)
+ r.close()
+ else:
+ dic[r].append(msg)
+
+ expected = ''.join('%s\n' % i for i in range(10)).encode('ascii')
+ for v in dic.values():
+ self.assertEqual(b''.join(v), expected)
+
+ def test_wait_slow(self):
+ self.test_wait(True)
+
+ def test_wait_socket_slow(self):
+ self.test_wait_socket(True)
+
+ def test_wait_timeout(self):
+ from multiprocessing.connection import wait
+
+ expected = 5
+ a, b = multiprocessing.Pipe()
+
+ start = time.time()
+ res = wait([a, b], expected)
+ delta = time.time() - start
+
+ self.assertEqual(res, [])
+ self.assertLess(delta, expected * 2)
+ self.assertGreater(delta, expected * 0.5)
+
+ b.send(None)
+
+ start = time.time()
+ res = wait([a, b], 20)
+ delta = time.time() - start
+
+ self.assertEqual(res, [a])
+ self.assertLess(delta, 0.4)
+
+ @classmethod
+ def signal_and_sleep(cls, sem, period):
+ sem.release()
+ time.sleep(period)
+
+ def test_wait_integer(self):
+ from multiprocessing.connection import wait
+
+ expected = 3
+ sorted_ = lambda l: sorted(l, key=lambda x: id(x))
+ sem = multiprocessing.Semaphore(0)
+ a, b = multiprocessing.Pipe()
+ p = multiprocessing.Process(target=self.signal_and_sleep,
+ args=(sem, expected))
+
+ p.start()
+ self.assertIsInstance(p.sentinel, int)
+ self.assertTrue(sem.acquire(timeout=20))
+
+ start = time.time()
+ res = wait([a, p.sentinel, b], expected + 20)
+ delta = time.time() - start
+
+ self.assertEqual(res, [p.sentinel])
+ self.assertLess(delta, expected + 2)
+ self.assertGreater(delta, expected - 2)
+
+ a.send(None)
+
+ start = time.time()
+ res = wait([a, p.sentinel, b], 20)
+ delta = time.time() - start
+
+ self.assertEqual(sorted_(res), sorted_([p.sentinel, b]))
+ self.assertLess(delta, 0.4)
+
+ b.send(None)
+
+ start = time.time()
+ res = wait([a, p.sentinel, b], 20)
+ delta = time.time() - start
+
+ self.assertEqual(sorted_(res), sorted_([a, p.sentinel, b]))
+ self.assertLess(delta, 0.4)
+
+ p.terminate()
+ p.join()
+
+ def test_neg_timeout(self):
+ from multiprocessing.connection import wait
+ a, b = multiprocessing.Pipe()
+ t = time.time()
+ res = wait([a], timeout=-1)
+ t = time.time() - t
+ self.assertEqual(res, [])
+ self.assertLess(t, 1)
+ a.close()
+ b.close()
+
#
# Issue 14151: Test invalid family on invalid environment
#
@@ -2398,6 +3311,38 @@ class TestInvalidFamily(unittest.TestCase):
multiprocessing.connection.Listener('/var/test.pipe')
#
+# Issue 12098: check sys.flags of child matches that for parent
+#
+
+class TestFlags(unittest.TestCase):
+ @classmethod
+ def run_in_grandchild(cls, conn):
+ conn.send(tuple(sys.flags))
+
+ @classmethod
+ def run_in_child(cls):
+ import json
+ r, w = multiprocessing.Pipe(duplex=False)
+ p = multiprocessing.Process(target=cls.run_in_grandchild, args=(w,))
+ p.start()
+ grandchild_flags = r.recv()
+ p.join()
+ r.close()
+ w.close()
+ flags = (tuple(sys.flags), grandchild_flags)
+ print(json.dumps(flags))
+
+ def test_flags(self):
+ import json, subprocess
+ # start child process using unusual flags
+ prog = ('from test.test_multiprocessing import TestFlags; ' +
+ 'TestFlags.run_in_child()')
+ data = subprocess.check_output(
+ [sys.executable, '-E', '-S', '-O', '-c', prog])
+ child_flags, grandchild_flags = json.loads(data.decode('ascii'))
+ self.assertEqual(child_flags, grandchild_flags)
+
+#
# Test interaction with socket timeouts - see Issue #6056
#
@@ -2452,8 +3397,8 @@ class TestNoForkBomb(unittest.TestCase):
#
testcases_other = [OtherTest, TestInvalidHandle, TestInitializers,
- TestStdinBadfiledescriptor, TestInvalidFamily,
- TestTimeouts, TestNoForkBomb]
+ TestStdinBadfiledescriptor, TestWait, TestInvalidFamily,
+ TestFlags, TestTimeouts, TestNoForkBomb]
#
#
@@ -2490,14 +3435,18 @@ def test_main(run=None):
loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase
suite = unittest.TestSuite(loadTestsFromTestCase(tc) for tc in testcases)
- run(suite)
-
- ThreadsMixin.pool.terminate()
- ProcessesMixin.pool.terminate()
- ManagerMixin.pool.terminate()
- ManagerMixin.manager.shutdown()
-
- del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool
+ try:
+ run(suite)
+ finally:
+ ThreadsMixin.pool.terminate()
+ ProcessesMixin.pool.terminate()
+ ManagerMixin.pool.terminate()
+ ManagerMixin.pool.join()
+ ManagerMixin.manager.shutdown()
+ ManagerMixin.manager.join()
+ ThreadsMixin.pool.join()
+ ProcessesMixin.pool.join()
+ del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool
def main():
test_main(unittest.TextTestRunner(verbosity=2).run)
diff --git a/Lib/test/test_mutants.py b/Lib/test/test_mutants.py
deleted file mode 100644
index b43fa47..0000000
--- a/Lib/test/test_mutants.py
+++ /dev/null
@@ -1,291 +0,0 @@
-from test.support import verbose, TESTFN
-import random
-import os
-
-# From SF bug #422121: Insecurities in dict comparison.
-
-# Safety of code doing comparisons has been an historical Python weak spot.
-# The problem is that comparison of structures written in C *naturally*
-# wants to hold on to things like the size of the container, or "the
-# biggest" containee so far, across a traversal of the container; but
-# code to do containee comparisons can call back into Python and mutate
-# the container in arbitrary ways while the C loop is in midstream. If the
-# C code isn't extremely paranoid about digging things out of memory on
-# each trip, and artificially boosting refcounts for the duration, anything
-# from infinite loops to OS crashes can result (yes, I use Windows <wink>).
-#
-# The other problem is that code designed to provoke a weakness is usually
-# white-box code, and so catches only the particular vulnerabilities the
-# author knew to protect against. For example, Python's list.sort() code
-# went thru many iterations as one "new" vulnerability after another was
-# discovered.
-#
-# So the dict comparison test here uses a black-box approach instead,
-# generating dicts of various sizes at random, and performing random
-# mutations on them at random times. This proved very effective,
-# triggering at least six distinct failure modes the first 20 times I
-# ran it. Indeed, at the start, the driver never got beyond 6 iterations
-# before the test died.
-
-# The dicts are global to make it easy to mutate tham from within functions.
-dict1 = {}
-dict2 = {}
-
-# The current set of keys in dict1 and dict2. These are materialized as
-# lists to make it easy to pick a dict key at random.
-dict1keys = []
-dict2keys = []
-
-# Global flag telling maybe_mutate() whether to *consider* mutating.
-mutate = 0
-
-# If global mutate is true, consider mutating a dict. May or may not
-# mutate a dict even if mutate is true. If it does decide to mutate a
-# dict, it picks one of {dict1, dict2} at random, and deletes a random
-# entry from it; or, more rarely, adds a random element.
-
-def maybe_mutate():
- global mutate
- if not mutate:
- return
- if random.random() < 0.5:
- return
-
- if random.random() < 0.5:
- target, keys = dict1, dict1keys
- else:
- target, keys = dict2, dict2keys
-
- if random.random() < 0.2:
- # Insert a new key.
- mutate = 0 # disable mutation until key inserted
- while 1:
- newkey = Horrid(random.randrange(100))
- if newkey not in target:
- break
- target[newkey] = Horrid(random.randrange(100))
- keys.append(newkey)
- mutate = 1
-
- elif keys:
- # Delete a key at random.
- mutate = 0 # disable mutation until key deleted
- i = random.randrange(len(keys))
- key = keys[i]
- del target[key]
- del keys[i]
- mutate = 1
-
-# A horrid class that triggers random mutations of dict1 and dict2 when
-# instances are compared.
-
-class Horrid:
- def __init__(self, i):
- # Comparison outcomes are determined by the value of i.
- self.i = i
-
- # An artificial hashcode is selected at random so that we don't
- # have any systematic relationship between comparison outcomes
- # (based on self.i and other.i) and relative position within the
- # hash vector (based on hashcode).
- # XXX This is no longer effective.
- ##self.hashcode = random.randrange(1000000000)
-
- def __hash__(self):
- return 42
- return self.hashcode
-
- def __eq__(self, other):
- maybe_mutate() # The point of the test.
- return self.i == other.i
-
- def __ne__(self, other):
- raise RuntimeError("I didn't expect some kind of Spanish inquisition!")
-
- __lt__ = __le__ = __gt__ = __ge__ = __ne__
-
- def __repr__(self):
- return "Horrid(%d)" % self.i
-
-# Fill dict d with numentries (Horrid(i), Horrid(j)) key-value pairs,
-# where i and j are selected at random from the candidates list.
-# Return d.keys() after filling.
-
-def fill_dict(d, candidates, numentries):
- d.clear()
- for i in range(numentries):
- d[Horrid(random.choice(candidates))] = \
- Horrid(random.choice(candidates))
- return list(d.keys())
-
-# Test one pair of randomly generated dicts, each with n entries.
-# Note that dict comparison is trivial if they don't have the same number
-# of entires (then the "shorter" dict is instantly considered to be the
-# smaller one, without even looking at the entries).
-
-def test_one(n):
- global mutate, dict1, dict2, dict1keys, dict2keys
-
- # Fill the dicts without mutating them.
- mutate = 0
- dict1keys = fill_dict(dict1, range(n), n)
- dict2keys = fill_dict(dict2, range(n), n)
-
- # Enable mutation, then compare the dicts so long as they have the
- # same size.
- mutate = 1
- if verbose:
- print("trying w/ lengths", len(dict1), len(dict2), end=' ')
- while dict1 and len(dict1) == len(dict2):
- if verbose:
- print(".", end=' ')
- c = dict1 == dict2
- if verbose:
- print()
-
-# Run test_one n times. At the start (before the bugs were fixed), 20
-# consecutive runs of this test each blew up on or before the sixth time
-# test_one was run. So n doesn't have to be large to get an interesting
-# test.
-# OTOH, calling with large n is also interesting, to ensure that the fixed
-# code doesn't hold on to refcounts *too* long (in which case memory would
-# leak).
-
-def test(n):
- for i in range(n):
- test_one(random.randrange(1, 100))
-
-# See last comment block for clues about good values for n.
-test(100)
-
-##########################################################################
-# Another segfault bug, distilled by Michael Hudson from a c.l.py post.
-
-class Child:
- def __init__(self, parent):
- self.__dict__['parent'] = parent
- def __getattr__(self, attr):
- self.parent.a = 1
- self.parent.b = 1
- self.parent.c = 1
- self.parent.d = 1
- self.parent.e = 1
- self.parent.f = 1
- self.parent.g = 1
- self.parent.h = 1
- self.parent.i = 1
- return getattr(self.parent, attr)
-
-class Parent:
- def __init__(self):
- self.a = Child(self)
-
-# Hard to say what this will print! May vary from time to time. But
-# we're specifically trying to test the tp_print slot here, and this is
-# the clearest way to do it. We print the result to a temp file so that
-# the expected-output file doesn't need to change.
-
-f = open(TESTFN, "w")
-print(Parent().__dict__, file=f)
-f.close()
-os.unlink(TESTFN)
-
-##########################################################################
-# And another core-dumper from Michael Hudson.
-
-dict = {}
-
-# Force dict to malloc its table.
-for i in range(1, 10):
- dict[i] = i
-
-f = open(TESTFN, "w")
-
-class Machiavelli:
- def __repr__(self):
- dict.clear()
-
- # Michael sez: "doesn't crash without this. don't know why."
- # Tim sez: "luck of the draw; crashes with or without for me."
- print(file=f)
-
- return repr("machiavelli")
-
- def __hash__(self):
- return 0
-
-dict[Machiavelli()] = Machiavelli()
-
-print(str(dict), file=f)
-f.close()
-os.unlink(TESTFN)
-del f, dict
-
-
-##########################################################################
-# And another core-dumper from Michael Hudson.
-
-dict = {}
-
-# let's force dict to malloc its table
-for i in range(1, 10):
- dict[i] = i
-
-class Machiavelli2:
- def __eq__(self, other):
- dict.clear()
- return 1
-
- def __hash__(self):
- return 0
-
-dict[Machiavelli2()] = Machiavelli2()
-
-try:
- dict[Machiavelli2()]
-except KeyError:
- pass
-
-del dict
-
-##########################################################################
-# And another core-dumper from Michael Hudson.
-
-dict = {}
-
-# let's force dict to malloc its table
-for i in range(1, 10):
- dict[i] = i
-
-class Machiavelli3:
- def __init__(self, id):
- self.id = id
-
- def __eq__(self, other):
- if self.id == other.id:
- dict.clear()
- return 1
- else:
- return 0
-
- def __repr__(self):
- return "%s(%s)"%(self.__class__.__name__, self.id)
-
- def __hash__(self):
- return 0
-
-dict[Machiavelli3(1)] = Machiavelli3(0)
-dict[Machiavelli3(2)] = Machiavelli3(0)
-
-f = open(TESTFN, "w")
-try:
- try:
- print(dict[Machiavelli3(2)], file=f)
- except KeyError:
- pass
-finally:
- f.close()
- os.unlink(TESTFN)
-
-del dict
-del dict1, dict2, dict1keys, dict2keys
diff --git a/Lib/test/test_namespace_pkgs.py b/Lib/test/test_namespace_pkgs.py
new file mode 100644
index 0000000..7067b12
--- /dev/null
+++ b/Lib/test/test_namespace_pkgs.py
@@ -0,0 +1,294 @@
+import sys
+import contextlib
+import unittest
+import os
+
+from test.test_importlib import util
+from test.support import run_unittest
+
+# needed tests:
+#
+# need to test when nested, so that the top-level path isn't sys.path
+# need to test dynamic path detection, both at top-level and nested
+# with dynamic path, check when a loader is returned on path reload (that is,
+# trying to switch from a namespace package to a regular package)
+
+
+@contextlib.contextmanager
+def sys_modules_context():
+ """
+ Make sure sys.modules is the same object and has the same content
+ when exiting the context as when entering.
+
+ Similar to importlib.test.util.uncache, but doesn't require explicit
+ names.
+ """
+ sys_modules_saved = sys.modules
+ sys_modules_copy = sys.modules.copy()
+ try:
+ yield
+ finally:
+ sys.modules = sys_modules_saved
+ sys.modules.clear()
+ sys.modules.update(sys_modules_copy)
+
+
+@contextlib.contextmanager
+def namespace_tree_context(**kwargs):
+ """
+ Save import state and sys.modules cache and restore it on exit.
+ Typical usage:
+
+ >>> with namespace_tree_context(path=['/tmp/xxyy/portion1',
+ ... '/tmp/xxyy/portion2']):
+ ... pass
+ """
+ # use default meta_path and path_hooks unless specified otherwise
+ kwargs.setdefault('meta_path', sys.meta_path)
+ kwargs.setdefault('path_hooks', sys.path_hooks)
+ import_context = util.import_state(**kwargs)
+ with import_context, sys_modules_context():
+ yield
+
+class NamespacePackageTest(unittest.TestCase):
+ """
+ Subclasses should define self.root and self.paths (under that root)
+ to be added to sys.path.
+ """
+ root = os.path.join(os.path.dirname(__file__), 'namespace_pkgs')
+
+ def setUp(self):
+ self.resolved_paths = [
+ os.path.join(self.root, path) for path in self.paths
+ ]
+ self.ctx = namespace_tree_context(path=self.resolved_paths)
+ self.ctx.__enter__()
+
+ def tearDown(self):
+ # TODO: will we ever want to pass exc_info to __exit__?
+ self.ctx.__exit__(None, None, None)
+
+class SingleNamespacePackage(NamespacePackageTest):
+ paths = ['portion1']
+
+ def test_simple_package(self):
+ import foo.one
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+
+ def test_cant_import_other(self):
+ with self.assertRaises(ImportError):
+ import foo.two
+
+ def test_module_repr(self):
+ import foo.one
+ self.assertEqual(repr(foo), "<module 'foo' (namespace)>")
+
+
+class DynamicPatheNamespacePackage(NamespacePackageTest):
+ paths = ['portion1']
+
+ def test_dynamic_path(self):
+ # Make sure only 'foo.one' can be imported
+ import foo.one
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+
+ with self.assertRaises(ImportError):
+ import foo.two
+
+ # Now modify sys.path
+ sys.path.append(os.path.join(self.root, 'portion2'))
+
+ # And make sure foo.two is now importable
+ import foo.two
+ self.assertEqual(foo.two.attr, 'portion2 foo two')
+
+
+class CombinedNamespacePackages(NamespacePackageTest):
+ paths = ['both_portions']
+
+ def test_imports(self):
+ import foo.one
+ import foo.two
+ self.assertEqual(foo.one.attr, 'both_portions foo one')
+ self.assertEqual(foo.two.attr, 'both_portions foo two')
+
+
+class SeparatedNamespacePackages(NamespacePackageTest):
+ paths = ['portion1', 'portion2']
+
+ def test_imports(self):
+ import foo.one
+ import foo.two
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+ self.assertEqual(foo.two.attr, 'portion2 foo two')
+
+
+class SeparatedOverlappingNamespacePackages(NamespacePackageTest):
+ paths = ['portion1', 'both_portions']
+
+ def test_first_path_wins(self):
+ import foo.one
+ import foo.two
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+ self.assertEqual(foo.two.attr, 'both_portions foo two')
+
+ def test_first_path_wins_again(self):
+ sys.path.reverse()
+ import foo.one
+ import foo.two
+ self.assertEqual(foo.one.attr, 'both_portions foo one')
+ self.assertEqual(foo.two.attr, 'both_portions foo two')
+
+ def test_first_path_wins_importing_second_first(self):
+ import foo.two
+ import foo.one
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+ self.assertEqual(foo.two.attr, 'both_portions foo two')
+
+
+class SingleZipNamespacePackage(NamespacePackageTest):
+ paths = ['top_level_portion1.zip']
+
+ def test_simple_package(self):
+ import foo.one
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+
+ def test_cant_import_other(self):
+ with self.assertRaises(ImportError):
+ import foo.two
+
+
+class SeparatedZipNamespacePackages(NamespacePackageTest):
+ paths = ['top_level_portion1.zip', 'portion2']
+
+ def test_imports(self):
+ import foo.one
+ import foo.two
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+ self.assertEqual(foo.two.attr, 'portion2 foo two')
+ self.assertIn('top_level_portion1.zip', foo.one.__file__)
+ self.assertNotIn('.zip', foo.two.__file__)
+
+
+class SingleNestedZipNamespacePackage(NamespacePackageTest):
+ paths = ['nested_portion1.zip/nested_portion1']
+
+ def test_simple_package(self):
+ import foo.one
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+
+ def test_cant_import_other(self):
+ with self.assertRaises(ImportError):
+ import foo.two
+
+
+class SeparatedNestedZipNamespacePackages(NamespacePackageTest):
+ paths = ['nested_portion1.zip/nested_portion1', 'portion2']
+
+ def test_imports(self):
+ import foo.one
+ import foo.two
+ self.assertEqual(foo.one.attr, 'portion1 foo one')
+ self.assertEqual(foo.two.attr, 'portion2 foo two')
+ fn = os.path.join('nested_portion1.zip', 'nested_portion1')
+ self.assertIn(fn, foo.one.__file__)
+ self.assertNotIn('.zip', foo.two.__file__)
+
+
+class LegacySupport(NamespacePackageTest):
+ paths = ['not_a_namespace_pkg', 'portion1', 'portion2', 'both_portions']
+
+ def test_non_namespace_package_takes_precedence(self):
+ import foo.one
+ with self.assertRaises(ImportError):
+ import foo.two
+ self.assertIn('__init__', foo.__file__)
+ self.assertNotIn('namespace', str(foo.__loader__).lower())
+
+
+class DynamicPathCalculation(NamespacePackageTest):
+ paths = ['project1', 'project2']
+
+ def test_project3_fails(self):
+ import parent.child.one
+ self.assertEqual(len(parent.__path__), 2)
+ self.assertEqual(len(parent.child.__path__), 2)
+ import parent.child.two
+ self.assertEqual(len(parent.__path__), 2)
+ self.assertEqual(len(parent.child.__path__), 2)
+
+ self.assertEqual(parent.child.one.attr, 'parent child one')
+ self.assertEqual(parent.child.two.attr, 'parent child two')
+
+ with self.assertRaises(ImportError):
+ import parent.child.three
+
+ self.assertEqual(len(parent.__path__), 2)
+ self.assertEqual(len(parent.child.__path__), 2)
+
+ def test_project3_succeeds(self):
+ import parent.child.one
+ self.assertEqual(len(parent.__path__), 2)
+ self.assertEqual(len(parent.child.__path__), 2)
+ import parent.child.two
+ self.assertEqual(len(parent.__path__), 2)
+ self.assertEqual(len(parent.child.__path__), 2)
+
+ self.assertEqual(parent.child.one.attr, 'parent child one')
+ self.assertEqual(parent.child.two.attr, 'parent child two')
+
+ with self.assertRaises(ImportError):
+ import parent.child.three
+
+ # now add project3
+ sys.path.append(os.path.join(self.root, 'project3'))
+ import parent.child.three
+
+ # the paths dynamically get longer, to include the new directories
+ self.assertEqual(len(parent.__path__), 3)
+ self.assertEqual(len(parent.child.__path__), 3)
+
+ self.assertEqual(parent.child.three.attr, 'parent child three')
+
+
+class ZipWithMissingDirectory(NamespacePackageTest):
+ paths = ['missing_directory.zip']
+
+ @unittest.expectedFailure
+ def test_missing_directory(self):
+ # This will fail because missing_directory.zip contains:
+ # Length Date Time Name
+ # --------- ---------- ----- ----
+ # 29 2012-05-03 18:13 foo/one.py
+ # 0 2012-05-03 20:57 bar/
+ # 38 2012-05-03 20:57 bar/two.py
+ # --------- -------
+ # 67 3 files
+
+ # Because there is no 'foo/', the zipimporter currently doesn't
+ # know that foo is a namespace package
+
+ import foo.one
+
+ def test_present_directory(self):
+ # This succeeds because there is a "bar/" in the zip file
+ import bar.two
+ self.assertEqual(bar.two.attr, 'missing_directory foo two')
+
+
+class ModuleAndNamespacePackageInSameDir(NamespacePackageTest):
+ paths = ['module_and_namespace_package']
+
+ def test_module_before_namespace_package(self):
+ # Make sure we find the module in preference to the
+ # namespace package.
+ import a_test
+ self.assertEqual(a_test.attr, 'in module')
+
+
+def test_main():
+ run_unittest(*NamespacePackageTest.__subclasses__())
+
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py
index 7335b23..19fb10a 100644
--- a/Lib/test/test_nntplib.py
+++ b/Lib/test/test_nntplib.py
@@ -1,4 +1,5 @@
import io
+import socket
import datetime
import textwrap
import unittest
@@ -257,6 +258,26 @@ class NetworkedNNTPTestsMixin:
# value
setattr(cls, name, wrap_meth(meth))
+ def test_with_statement(self):
+ def is_connected():
+ if not hasattr(server, 'file'):
+ return False
+ try:
+ server.help()
+ except (socket.error, EOFError):
+ return False
+ return True
+
+ with self.NNTP_CLASS(self.NNTP_HOST, timeout=TIMEOUT, usenetrc=False) as server:
+ self.assertTrue(is_connected())
+ self.assertTrue(server.help())
+ self.assertFalse(is_connected())
+
+ with self.NNTP_CLASS(self.NNTP_HOST, timeout=TIMEOUT, usenetrc=False) as server:
+ server.quit()
+ self.assertFalse(is_connected())
+
+
NetworkedNNTPTestsMixin.wrap_methods()
@@ -894,7 +915,7 @@ class NNTPv1v2TestsMixin:
def _check_article_body(self, lines):
self.assertEqual(len(lines), 4)
- self.assertEqual(lines[-1].decode('utf8'), "-- Signed by André.")
+ self.assertEqual(lines[-1].decode('utf-8'), "-- Signed by André.")
self.assertEqual(lines[-2], b"")
self.assertEqual(lines[-3], b".Here is a dot-starting line.")
self.assertEqual(lines[-4], b"This is just a test article.")
@@ -1133,12 +1154,12 @@ class NNTPv1v2TestsMixin:
self.assertEqual(resp, success_resp)
# With an iterable of terminated lines
def iterlines(b):
- return iter(b.splitlines(True))
+ return iter(b.splitlines(keepends=True))
resp = self._check_post_ihave_sub(func, *args, file_factory=iterlines)
self.assertEqual(resp, success_resp)
# With an iterable of non-terminated lines
def iterlines(b):
- return iter(b.splitlines(False))
+ return iter(b.splitlines(keepends=False))
resp = self._check_post_ihave_sub(func, *args, file_factory=iterlines)
self.assertEqual(resp, success_resp)
diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py
index 6464950..f809876 100644
--- a/Lib/test/test_ntpath.py
+++ b/Lib/test/test_ntpath.py
@@ -1,10 +1,11 @@
import ntpath
import os
import sys
+import unittest
+import warnings
from test.support import TestFailed
from test import support, test_genericpath
from tempfile import TemporaryFile
-import unittest
def tester(fn, wantResult):
@@ -21,7 +22,9 @@ def tester(fn, wantResult):
fn = fn.replace('["', '[b"')
fn = fn.replace(", '", ", b'")
fn = fn.replace(', "', ', b"')
- gotResult = eval(fn)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ gotResult = eval(fn)
if isinstance(wantResult, str):
wantResult = wantResult.encode('ascii')
elif isinstance(wantResult, tuple):
@@ -254,14 +257,10 @@ class TestNtpath(unittest.TestCase):
ntpath.sameopenfile(-1, -1)
-class NtCommonTest(test_genericpath.CommonTest):
+class NtCommonTest(test_genericpath.CommonTest, unittest.TestCase):
pathmodule = ntpath
attributes = ['relpath', 'splitunc']
-def test_main():
- support.run_unittest(TestNtpath, NtCommonTest)
-
-
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_numeric_tower.py b/Lib/test/test_numeric_tower.py
index bef3d4c..3423d4e 100644
--- a/Lib/test/test_numeric_tower.py
+++ b/Lib/test/test_numeric_tower.py
@@ -150,7 +150,7 @@ class ComparisonTest(unittest.TestCase):
# int, float, Fraction, Decimal
test_values = [
float('-inf'),
- D('-1e999999999'),
+ D('-1e425000000'),
-1e308,
F(-22, 7),
-3.14,
diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py
index 7b95612..78de278 100644
--- a/Lib/test/test_optparse.py
+++ b/Lib/test/test_optparse.py
@@ -318,6 +318,22 @@ class TestOptionChecks(BaseTest):
["-b"], {'action': 'store',
'callback_kwargs': 'foo'})
+ def test_no_single_dash(self):
+ self.assertOptionError(
+ "invalid long option string '-debug': "
+ "must start with --, followed by non-dash",
+ ["-debug"])
+
+ self.assertOptionError(
+ "option -d: invalid long option string '-debug': must start with"
+ " --, followed by non-dash",
+ ["-d", "-debug"])
+
+ self.assertOptionError(
+ "invalid long option string '-debug': "
+ "must start with --, followed by non-dash",
+ ["-debug", "--debug"])
+
class TestOptionParser(BaseTest):
def setUp(self):
self.parser = OptionParser()
@@ -631,7 +647,7 @@ class TestStandard(BaseTest):
option_list=options)
def test_required_value(self):
- self.assertParseFail(["-a"], "-a option requires an argument")
+ self.assertParseFail(["-a"], "-a option requires 1 argument")
def test_invalid_integer(self):
self.assertParseFail(["-b", "5x"],
@@ -1023,7 +1039,7 @@ class TestExtendAddTypes(BaseTest):
TYPE_CHECKER["file"] = check_file
def test_filetype_ok(self):
- open(support.TESTFN, "w").close()
+ support.create_empty_file(support.TESTFN)
self.assertParseOK(["--file", support.TESTFN, "-afoo"],
{'file': support.TESTFN, 'a': 'foo'},
[])
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index 720e78b..184c9ae 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -14,20 +14,43 @@ import shutil
from test import support
import contextlib
import mmap
+import platform
+import re
import uuid
+import asyncore
+import asynchat
+import socket
+import itertools
import stat
+import locale
+import codecs
+try:
+ import threading
+except ImportError:
+ threading = None
from test.script_helper import assert_python_ok
+with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ os.stat_float_times(True)
+st = os.stat(__file__)
+stat_supports_subsecond = (
+ # check if float and int timestamps are different
+ (st.st_atime != st[7])
+ or (st.st_mtime != st[8])
+ or (st.st_ctime != st[9]))
+
# Detect whether we're on a Linux system that uses the (now outdated
# and unmaintained) linuxthreads threading library. There's an issue
# when combining linuxthreads with a failed execv call: see
# http://bugs.python.org/issue4970.
-if (hasattr(os, "confstr_names") and
- "CS_GNU_LIBPTHREAD_VERSION" in os.confstr_names):
- libpthread = os.confstr("CS_GNU_LIBPTHREAD_VERSION")
- USING_LINUXTHREADS= libpthread.startswith("linuxthreads")
+if hasattr(sys, 'thread_info') and sys.thread_info.version:
+ USING_LINUXTHREADS = sys.thread_info.version.startswith("linuxthreads")
else:
- USING_LINUXTHREADS= False
+ USING_LINUXTHREADS = False
+
+# Issue #14110: Some tests fail on FreeBSD if the user is in the wheel group.
+HAVE_WHEEL_GROUP = sys.platform.startswith('freebsd') and os.getgid() == 0
# Tests creating TESTFN
class FileTests(unittest.TestCase):
@@ -124,6 +147,18 @@ class FileTests(unittest.TestCase):
self.fdopen_helper('r')
self.fdopen_helper('r', 100)
+ def test_replace(self):
+ TESTFN2 = support.TESTFN + ".2"
+ with open(support.TESTFN, 'w') as f:
+ f.write("1")
+ with open(TESTFN2, 'w') as f:
+ f.write("2")
+ self.addCleanup(os.unlink, TESTFN2)
+ os.replace(support.TESTFN, TESTFN2)
+ self.assertRaises(FileNotFoundError, os.stat, support.TESTFN)
+ with open(TESTFN2, 'r') as f:
+ self.assertEqual(f.read(), "1")
+
# Test attributes on return values from os.*stat* family.
class StatAttributeTests(unittest.TestCase):
@@ -142,7 +177,6 @@ class StatAttributeTests(unittest.TestCase):
if not hasattr(os, "stat"):
return
- import stat
result = os.stat(fname)
# Make sure direct access works
@@ -162,6 +196,13 @@ class StatAttributeTests(unittest.TestCase):
result[getattr(stat, name)])
self.assertIn(attr, members)
+ # Make sure that the st_?time and st_?time_ns fields roughly agree
+ # (they should always agree up to around tens-of-microseconds)
+ for name in 'st_atime st_mtime st_ctime'.split():
+ floaty = int(getattr(result, name) * 100000)
+ nanosecondy = getattr(result, name + "_ns") // 10000
+ self.assertAlmostEqual(floaty, nanosecondy, delta=2)
+
try:
result[200]
self.fail("No exception raised")
@@ -208,7 +249,9 @@ class StatAttributeTests(unittest.TestCase):
fname = self.fname.encode(sys.getfilesystemencoding())
except UnicodeEncodeError:
self.skipTest("cannot encode %a for the filesystem" % self.fname)
- self.check_stat_attributes(fname)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ self.check_stat_attributes(fname)
def test_statvfs_attributes(self):
if not hasattr(os, "statvfs"):
@@ -265,6 +308,143 @@ class StatAttributeTests(unittest.TestCase):
st2 = os.stat(support.TESTFN)
self.assertEqual(st2.st_mtime, int(st.st_mtime-delta))
+ def _test_utime(self, filename, attr, utime, delta):
+ # Issue #13327 removed the requirement to pass None as the
+ # second argument. Check that the previous methods of passing
+ # a time tuple or None work in addition to no argument.
+ st0 = os.stat(filename)
+ # Doesn't set anything new, but sets the time tuple way
+ utime(filename, (attr(st0, "st_atime"), attr(st0, "st_mtime")))
+ # Setting the time to the time you just read, then reading again,
+ # should always return exactly the same times.
+ st1 = os.stat(filename)
+ self.assertEqual(attr(st0, "st_mtime"), attr(st1, "st_mtime"))
+ self.assertEqual(attr(st0, "st_atime"), attr(st1, "st_atime"))
+ # Set to the current time in the old explicit way.
+ os.utime(filename, None)
+ st2 = os.stat(support.TESTFN)
+ # Set to the current time in the new way
+ os.utime(filename)
+ st3 = os.stat(filename)
+ self.assertAlmostEqual(attr(st2, "st_mtime"), attr(st3, "st_mtime"), delta=delta)
+
+ def test_utime(self):
+ def utime(file, times):
+ return os.utime(file, times)
+ self._test_utime(self.fname, getattr, utime, 10)
+ self._test_utime(support.TESTFN, getattr, utime, 10)
+
+
+ def _test_utime_ns(self, set_times_ns, test_dir=True):
+ def getattr_ns(o, attr):
+ return getattr(o, attr + "_ns")
+ ten_s = 10 * 1000 * 1000 * 1000
+ self._test_utime(self.fname, getattr_ns, set_times_ns, ten_s)
+ if test_dir:
+ self._test_utime(support.TESTFN, getattr_ns, set_times_ns, ten_s)
+
+ def test_utime_ns(self):
+ def utime_ns(file, times):
+ return os.utime(file, ns=times)
+ self._test_utime_ns(utime_ns)
+
+ requires_utime_dir_fd = unittest.skipUnless(
+ os.utime in os.supports_dir_fd,
+ "dir_fd support for utime required for this test.")
+ requires_utime_fd = unittest.skipUnless(
+ os.utime in os.supports_fd,
+ "fd support for utime required for this test.")
+ requires_utime_nofollow_symlinks = unittest.skipUnless(
+ os.utime in os.supports_follow_symlinks,
+ "follow_symlinks support for utime required for this test.")
+
+ @requires_utime_nofollow_symlinks
+ def test_lutimes_ns(self):
+ def lutimes_ns(file, times):
+ return os.utime(file, ns=times, follow_symlinks=False)
+ self._test_utime_ns(lutimes_ns)
+
+ @requires_utime_fd
+ def test_futimes_ns(self):
+ def futimes_ns(file, times):
+ with open(file, "wb") as f:
+ os.utime(f.fileno(), ns=times)
+ self._test_utime_ns(futimes_ns, test_dir=False)
+
+ def _utime_invalid_arguments(self, name, arg):
+ with self.assertRaises(ValueError):
+ getattr(os, name)(arg, (5, 5), ns=(5, 5))
+
+ def test_utime_invalid_arguments(self):
+ self._utime_invalid_arguments('utime', self.fname)
+
+
+ @unittest.skipUnless(stat_supports_subsecond,
+ "os.stat() doesn't has a subsecond resolution")
+ def _test_utime_subsecond(self, set_time_func):
+ asec, amsec = 1, 901
+ atime = asec + amsec * 1e-3
+ msec, mmsec = 2, 901
+ mtime = msec + mmsec * 1e-3
+ filename = self.fname
+ os.utime(filename, (0, 0))
+ set_time_func(filename, atime, mtime)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ os.stat_float_times(True)
+ st = os.stat(filename)
+ self.assertAlmostEqual(st.st_atime, atime, places=3)
+ self.assertAlmostEqual(st.st_mtime, mtime, places=3)
+
+ def test_utime_subsecond(self):
+ def set_time(filename, atime, mtime):
+ os.utime(filename, (atime, mtime))
+ self._test_utime_subsecond(set_time)
+
+ @requires_utime_fd
+ def test_futimes_subsecond(self):
+ def set_time(filename, atime, mtime):
+ with open(filename, "wb") as f:
+ os.utime(f.fileno(), times=(atime, mtime))
+ self._test_utime_subsecond(set_time)
+
+ @requires_utime_fd
+ def test_futimens_subsecond(self):
+ def set_time(filename, atime, mtime):
+ with open(filename, "wb") as f:
+ os.utime(f.fileno(), times=(atime, mtime))
+ self._test_utime_subsecond(set_time)
+
+ @requires_utime_dir_fd
+ def test_futimesat_subsecond(self):
+ def set_time(filename, atime, mtime):
+ dirname = os.path.dirname(filename)
+ dirfd = os.open(dirname, os.O_RDONLY)
+ try:
+ os.utime(os.path.basename(filename), dir_fd=dirfd,
+ times=(atime, mtime))
+ finally:
+ os.close(dirfd)
+ self._test_utime_subsecond(set_time)
+
+ @requires_utime_nofollow_symlinks
+ def test_lutimes_subsecond(self):
+ def set_time(filename, atime, mtime):
+ os.utime(filename, (atime, mtime), follow_symlinks=False)
+ self._test_utime_subsecond(set_time)
+
+ @requires_utime_dir_fd
+ def test_utimensat_subsecond(self):
+ def set_time(filename, atime, mtime):
+ dirname = os.path.dirname(filename)
+ dirfd = os.open(dirname, os.O_RDONLY)
+ try:
+ os.utime(os.path.basename(filename), dir_fd=dirfd,
+ times=(atime, mtime))
+ finally:
+ os.close(dirfd)
+ self._test_utime_subsecond(set_time)
+
# Restrict test to Win32, since there is no guarantee other
# systems support centiseconds
if sys.platform == 'win32':
@@ -296,6 +476,19 @@ class StatAttributeTests(unittest.TestCase):
return
self.fail("Could not stat pagefile.sys")
+ @unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
+ def test_15261(self):
+ # Verify that stat'ing a closed fd does not cause crash
+ r, w = os.pipe()
+ try:
+ os.stat(r) # should not raise error
+ finally:
+ os.close(r)
+ os.close(w)
+ with self.assertRaises(OSError) as ctx:
+ os.stat(r)
+ self.assertEqual(ctx.exception.errno, errno.EBADF)
+
from test import mapping_tests
class EnvironTests(mapping_tests.BasicTestMappingProtocol):
@@ -324,23 +517,23 @@ class EnvironTests(mapping_tests.BasicTestMappingProtocol):
return os.environ
# Bug 1110478
+ @unittest.skipUnless(os.path.exists('/bin/sh'), 'requires /bin/sh')
def test_update2(self):
os.environ.clear()
- if os.path.exists("/bin/sh"):
- os.environ.update(HELLO="World")
- with os.popen("/bin/sh -c 'echo $HELLO'") as popen:
- value = popen.read().strip()
- self.assertEqual(value, "World")
+ os.environ.update(HELLO="World")
+ with os.popen("/bin/sh -c 'echo $HELLO'") as popen:
+ value = popen.read().strip()
+ self.assertEqual(value, "World")
+ @unittest.skipUnless(os.path.exists('/bin/sh'), 'requires /bin/sh')
def test_os_popen_iter(self):
- if os.path.exists("/bin/sh"):
- with os.popen(
- "/bin/sh -c 'echo \"line1\nline2\nline3\"'") as popen:
- it = iter(popen)
- self.assertEqual(next(it), "line1\n")
- self.assertEqual(next(it), "line2\n")
- self.assertEqual(next(it), "line3\n")
- self.assertRaises(StopIteration, next, it)
+ with os.popen(
+ "/bin/sh -c 'echo \"line1\nline2\nline3\"'") as popen:
+ it = iter(popen)
+ self.assertEqual(next(it), "line1\n")
+ self.assertEqual(next(it), "line2\n")
+ self.assertEqual(next(it), "line3\n")
+ self.assertRaises(StopIteration, next, it)
# Verify environ keys and values from the OS are of the
# correct str type.
@@ -427,8 +620,8 @@ class EnvironTests(mapping_tests.BasicTestMappingProtocol):
# On FreeBSD < 7 and OS X < 10.6, unsetenv() doesn't return a value (issue
# #13415).
- @unittest.skipIf(sys.platform.startswith(('freebsd', 'darwin')),
- "due to known OS bug: see issue #13415")
+ @support.requires_freebsd_version(7)
+ @support.requires_mac_ver(10, 6)
def test_unset_error(self):
if sys.platform == "win32":
# an environment variable is limited to 32,767 characters
@@ -442,7 +635,7 @@ class EnvironTests(mapping_tests.BasicTestMappingProtocol):
class WalkTests(unittest.TestCase):
"""Tests for os.walk()."""
- def test_traversal(self):
+ def setUp(self):
import os
from os.path import join
@@ -456,6 +649,7 @@ class WalkTests(unittest.TestCase):
# SUB2/ a file kid and a dirsymlink kid
# tmp3
# link/ a symlink to TESTFN.2
+ # broken_link
# TEST2/
# tmp4 a lone file
walk_path = join(support.TESTFN, "TEST1")
@@ -468,6 +662,8 @@ class WalkTests(unittest.TestCase):
link_path = join(sub2_path, "link")
t2_path = join(support.TESTFN, "TEST2")
tmp4_path = join(support.TESTFN, "TEST2", "tmp4")
+ link_path = join(sub2_path, "link")
+ broken_link_path = join(sub2_path, "broken_link")
# Create stuff.
os.makedirs(sub11_path)
@@ -484,7 +680,8 @@ class WalkTests(unittest.TestCase):
else:
symlink_to_dir = os.symlink
symlink_to_dir(os.path.abspath(t2_path), link_path)
- sub2_tree = (sub2_path, ["link"], ["tmp3"])
+ symlink_to_dir('broken', broken_link_path)
+ sub2_tree = (sub2_path, ["link"], ["broken_link", "tmp3"])
else:
sub2_tree = (sub2_path, [], ["tmp3"])
@@ -496,6 +693,7 @@ class WalkTests(unittest.TestCase):
# flipped: TESTFN, SUB2, SUB1, SUB11
flipped = all[0][1][0] != "SUB1"
all[0][1].sort()
+ all[3 - 2 * flipped][-1].sort()
self.assertEqual(all[0], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[1 + flipped], (sub1_path, ["SUB11"], ["tmp2"]))
self.assertEqual(all[2 + flipped], (sub11_path, [], []))
@@ -511,6 +709,7 @@ class WalkTests(unittest.TestCase):
dirs.remove('SUB1')
self.assertEqual(len(all), 2)
self.assertEqual(all[0], (walk_path, ["SUB2"], ["tmp1"]))
+ all[1][-1].sort()
self.assertEqual(all[1], sub2_tree)
# Walk bottom-up.
@@ -521,6 +720,7 @@ class WalkTests(unittest.TestCase):
# flipped: SUB2, SUB11, SUB1, TESTFN
flipped = all[3][1][0] != "SUB1"
all[3][1].sort()
+ all[2 - 2 * flipped][-1].sort()
self.assertEqual(all[3], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[flipped], (sub11_path, [], []))
self.assertEqual(all[flipped + 1], (sub1_path, ["SUB11"], ["tmp2"]))
@@ -552,6 +752,82 @@ class WalkTests(unittest.TestCase):
os.remove(dirname)
os.rmdir(support.TESTFN)
+
+@unittest.skipUnless(hasattr(os, 'fwalk'), "Test needs os.fwalk()")
+class FwalkTests(WalkTests):
+ """Tests for os.fwalk()."""
+
+ def _compare_to_walk(self, walk_kwargs, fwalk_kwargs):
+ """
+ compare with walk() results.
+ """
+ walk_kwargs = walk_kwargs.copy()
+ fwalk_kwargs = fwalk_kwargs.copy()
+ for topdown, follow_symlinks in itertools.product((True, False), repeat=2):
+ walk_kwargs.update(topdown=topdown, followlinks=follow_symlinks)
+ fwalk_kwargs.update(topdown=topdown, follow_symlinks=follow_symlinks)
+
+ expected = {}
+ for root, dirs, files in os.walk(**walk_kwargs):
+ expected[root] = (set(dirs), set(files))
+
+ for root, dirs, files, rootfd in os.fwalk(**fwalk_kwargs):
+ self.assertIn(root, expected)
+ self.assertEqual(expected[root], (set(dirs), set(files)))
+
+ def test_compare_to_walk(self):
+ kwargs = {'top': support.TESTFN}
+ self._compare_to_walk(kwargs, kwargs)
+
+ def test_dir_fd(self):
+ try:
+ fd = os.open(".", os.O_RDONLY)
+ walk_kwargs = {'top': support.TESTFN}
+ fwalk_kwargs = walk_kwargs.copy()
+ fwalk_kwargs['dir_fd'] = fd
+ self._compare_to_walk(walk_kwargs, fwalk_kwargs)
+ finally:
+ os.close(fd)
+
+ def test_yields_correct_dir_fd(self):
+ # check returned file descriptors
+ for topdown, follow_symlinks in itertools.product((True, False), repeat=2):
+ args = support.TESTFN, topdown, None
+ for root, dirs, files, rootfd in os.fwalk(*args, follow_symlinks=follow_symlinks):
+ # check that the FD is valid
+ os.fstat(rootfd)
+ # redundant check
+ os.stat(rootfd)
+ # check that listdir() returns consistent information
+ self.assertEqual(set(os.listdir(rootfd)), set(dirs) | set(files))
+
+ def test_fd_leak(self):
+ # Since we're opening a lot of FDs, we must be careful to avoid leaks:
+ # we both check that calling fwalk() a large number of times doesn't
+ # yield EMFILE, and that the minimum allocated FD hasn't changed.
+ minfd = os.dup(1)
+ os.close(minfd)
+ for i in range(256):
+ for x in os.fwalk(support.TESTFN):
+ pass
+ newfd = os.dup(1)
+ self.addCleanup(os.close, newfd)
+ self.assertEqual(newfd, minfd)
+
+ def tearDown(self):
+ # cleanup
+ for root, dirs, files, rootfd in os.fwalk(support.TESTFN, topdown=False):
+ for name in files:
+ os.unlink(name, dir_fd=rootfd)
+ for name in dirs:
+ st = os.stat(name, dir_fd=rootfd, follow_symlinks=False)
+ if stat.S_ISDIR(st.st_mode):
+ os.rmdir(name, dir_fd=rootfd)
+ else:
+ os.unlink(name, dir_fd=rootfd)
+ os.rmdir(support.TESTFN)
+
+
class MakedirTests(unittest.TestCase):
def setUp(self):
os.mkdir(support.TESTFN)
@@ -575,14 +851,12 @@ class MakedirTests(unittest.TestCase):
path = os.path.join(support.TESTFN, 'dir1')
mode = 0o777
old_mask = os.umask(0o022)
- try:
- os.makedirs(path, mode)
- self.assertRaises(OSError, os.makedirs, path, mode)
- self.assertRaises(OSError, os.makedirs, path, mode, exist_ok=False)
- self.assertRaises(OSError, os.makedirs, path, 0o776, exist_ok=True)
- os.makedirs(path, mode=mode, exist_ok=True)
- finally:
- os.umask(old_mask)
+ os.makedirs(path, mode)
+ self.assertRaises(OSError, os.makedirs, path, mode)
+ self.assertRaises(OSError, os.makedirs, path, mode, exist_ok=False)
+ self.assertRaises(OSError, os.makedirs, path, 0o776, exist_ok=True)
+ os.makedirs(path, mode=mode, exist_ok=True)
+ os.umask(old_mask)
def test_exist_ok_s_isgid_directory(self):
path = os.path.join(support.TESTFN, 'dir1')
@@ -594,7 +868,7 @@ class MakedirTests(unittest.TestCase):
os.lstat(support.TESTFN).st_mode)
try:
os.chmod(support.TESTFN, existing_testfn_mode | S_ISGID)
- except OSError:
+ except PermissionError:
raise unittest.SkipTest('Cannot set S_ISGID for dir.')
if (os.lstat(support.TESTFN).st_mode & S_ISGID != S_ISGID):
raise unittest.SkipTest('No support for S_ISGID dir mode.')
@@ -894,10 +1168,12 @@ class TestInvalidFD(unittest.TestCase):
def test_fpathconf(self):
if hasattr(os, "fpathconf"):
+ self.check(os.pathconf, "PC_NAME_MAX")
self.check(os.fpathconf, "PC_NAME_MAX")
def test_ftruncate(self):
if hasattr(os, "ftruncate"):
+ self.check(os.truncate, 0)
self.check(os.ftruncate, 0)
def test_lseek(self):
@@ -931,7 +1207,9 @@ class LinkTests(unittest.TestCase):
with open(file1, "w") as f1:
f1.write("test")
- os.link(file1, file2)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ os.link(file1, file2)
with open(file1, "r") as f1, open(file2, "r") as f2:
self.assertTrue(os.path.sameopenfile(f1.fileno(), f2.fileno()))
@@ -965,7 +1243,7 @@ if sys.platform != 'win32':
if hasattr(os, 'setgid'):
def test_setgid(self):
- if os.getuid() != 0:
+ if os.getuid() != 0 and not HAVE_WHEEL_GROUP:
self.assertRaises(os.error, os.setgid, 0)
self.assertRaises(OverflowError, os.setgid, 1<<32)
@@ -977,7 +1255,7 @@ if sys.platform != 'win32':
if hasattr(os, 'setegid'):
def test_setegid(self):
- if os.getuid() != 0:
+ if os.getuid() != 0 and not HAVE_WHEEL_GROUP:
self.assertRaises(os.error, os.setegid, 0)
self.assertRaises(OverflowError, os.setegid, 1<<32)
@@ -997,7 +1275,7 @@ if sys.platform != 'win32':
if hasattr(os, 'setregid'):
def test_setregid(self):
- if os.getuid() != 0:
+ if os.getuid() != 0 and not HAVE_WHEEL_GROUP:
self.assertRaises(os.error, os.setregid, 0, 0)
self.assertRaises(OverflowError, os.setregid, 1<<32, 0)
self.assertRaises(OverflowError, os.setregid, 0, 1<<32)
@@ -1038,8 +1316,7 @@ if sys.platform != 'win32':
os.mkdir(self.dir)
try:
for fn in bytesfn:
- f = open(os.path.join(self.bdir, fn), "w")
- f.close()
+ support.create_empty_file(os.path.join(self.bdir, fn))
fn = os.fsdecode(fn)
if fn in self.unicodefn:
raise ValueError("duplicate filename")
@@ -1055,6 +1332,13 @@ if sys.platform != 'win32':
expected = self.unicodefn
found = set(os.listdir(self.dir))
self.assertEqual(found, expected)
+ # test listdir without arguments
+ current_directory = os.getcwd()
+ try:
+ os.chdir(os.sep)
+ self.assertEqual(set(os.listdir()), set(os.listdir(os.sep)))
+ finally:
+ os.chdir(current_directory)
def test_open(self):
for fn in self.unicodefn:
@@ -1267,8 +1551,10 @@ class Win32SymlinkTests(unittest.TestCase):
self.assertNotEqual(os.lstat(link), os.stat(link))
bytes_link = os.fsencode(link)
- self.assertEqual(os.stat(bytes_link), os.stat(target))
- self.assertNotEqual(os.lstat(bytes_link), os.stat(bytes_link))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ self.assertEqual(os.stat(bytes_link), os.stat(target))
+ self.assertNotEqual(os.lstat(bytes_link), os.stat(bytes_link))
def test_12084(self):
level1 = os.path.abspath(support.TESTFN)
@@ -1327,6 +1613,22 @@ class FSEncodingTests(unittest.TestCase):
self.assertEqual(os.fsdecode(bytesfn), fn)
+
+class DeviceEncodingTests(unittest.TestCase):
+
+ def test_bad_fd(self):
+ # Return None when an fd doesn't actually exist.
+ self.assertIsNone(os.device_encoding(123456))
+
+ @unittest.skipUnless(os.isatty(0) and (sys.platform.startswith('win') or
+ (hasattr(locale, 'nl_langinfo') and hasattr(locale, 'CODESET'))),
+ 'test requires a tty and either Windows or nl_langinfo(CODESET)')
+ def test_device_encoding(self):
+ encoding = os.device_encoding(0)
+ self.assertIsNotNone(encoding)
+ self.assertTrue(codecs.lookup(encoding))
+
+
class PidTests(unittest.TestCase):
@unittest.skipUnless(hasattr(os, 'getppid'), "test needs os.getppid")
def test_getppid(self):
@@ -1348,12 +1650,471 @@ class LoginTests(unittest.TestCase):
self.assertNotEqual(len(user_name), 0)
+@unittest.skipUnless(hasattr(os, 'getpriority') and hasattr(os, 'setpriority'),
+ "needs os.getpriority and os.setpriority")
+class ProgramPriorityTests(unittest.TestCase):
+ """Tests for os.getpriority() and os.setpriority()."""
+
+ def test_set_get_priority(self):
+
+ base = os.getpriority(os.PRIO_PROCESS, os.getpid())
+ os.setpriority(os.PRIO_PROCESS, os.getpid(), base + 1)
+ try:
+ new_prio = os.getpriority(os.PRIO_PROCESS, os.getpid())
+ if base >= 19 and new_prio <= 19:
+ raise unittest.SkipTest(
+ "unable to reliably test setpriority at current nice level of %s" % base)
+ else:
+ self.assertEqual(new_prio, base + 1)
+ finally:
+ try:
+ os.setpriority(os.PRIO_PROCESS, os.getpid(), base)
+ except OSError as err:
+ if err.errno != errno.EACCES:
+ raise
+
+
+if threading is not None:
+ class SendfileTestServer(asyncore.dispatcher, threading.Thread):
+
+ class Handler(asynchat.async_chat):
+
+ def __init__(self, conn):
+ asynchat.async_chat.__init__(self, conn)
+ self.in_buffer = []
+ self.closed = False
+ self.push(b"220 ready\r\n")
+
+ def handle_read(self):
+ data = self.recv(4096)
+ self.in_buffer.append(data)
+
+ def get_data(self):
+ return b''.join(self.in_buffer)
+
+ def handle_close(self):
+ self.close()
+ self.closed = True
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, address):
+ threading.Thread.__init__(self)
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.bind(address)
+ self.listen(5)
+ self.host, self.port = self.socket.getsockname()[:2]
+ self.handler_instance = None
+ self._active = False
+ self._active_lock = threading.Lock()
+
+ # --- public API
+
+ @property
+ def running(self):
+ return self._active
+
+ def start(self):
+ assert not self.running
+ self.__flag = threading.Event()
+ threading.Thread.start(self)
+ self.__flag.wait()
+
+ def stop(self):
+ assert self.running
+ self._active = False
+ self.join()
+
+ def wait(self):
+ # wait for handler connection to be closed, then stop the server
+ while not getattr(self.handler_instance, "closed", False):
+ time.sleep(0.001)
+ self.stop()
+
+ # --- internals
+
+ def run(self):
+ self._active = True
+ self.__flag.set()
+ while self._active and asyncore.socket_map:
+ self._active_lock.acquire()
+ asyncore.loop(timeout=0.001, count=1)
+ self._active_lock.release()
+ asyncore.close_all()
+
+ def handle_accept(self):
+ conn, addr = self.accept()
+ self.handler_instance = self.Handler(conn)
+
+ def handle_connect(self):
+ self.close()
+ handle_read = handle_connect
+
+ def writable(self):
+ return 0
+
+ def handle_error(self):
+ raise
+
+
+@unittest.skipUnless(threading is not None, "test needs threading module")
+@unittest.skipUnless(hasattr(os, 'sendfile'), "test needs os.sendfile()")
+class TestSendfile(unittest.TestCase):
+
+ DATA = b"12345abcde" * 16 * 1024 # 160 KB
+ SUPPORT_HEADERS_TRAILERS = not sys.platform.startswith("linux") and \
+ not sys.platform.startswith("solaris") and \
+ not sys.platform.startswith("sunos")
+
+ @classmethod
+ def setUpClass(cls):
+ with open(support.TESTFN, "wb") as f:
+ f.write(cls.DATA)
+
+ @classmethod
+ def tearDownClass(cls):
+ support.unlink(support.TESTFN)
+
+ def setUp(self):
+ self.server = SendfileTestServer((support.HOST, 0))
+ self.server.start()
+ self.client = socket.socket()
+ self.client.connect((self.server.host, self.server.port))
+ self.client.settimeout(1)
+ # synchronize by waiting for "220 ready" response
+ self.client.recv(1024)
+ self.sockno = self.client.fileno()
+ self.file = open(support.TESTFN, 'rb')
+ self.fileno = self.file.fileno()
+
+ def tearDown(self):
+ self.file.close()
+ self.client.close()
+ if self.server.running:
+ self.server.stop()
+
+ def sendfile_wrapper(self, sock, file, offset, nbytes, headers=[], trailers=[]):
+ """A higher level wrapper representing how an application is
+ supposed to use sendfile().
+ """
+ while 1:
+ try:
+ if self.SUPPORT_HEADERS_TRAILERS:
+ return os.sendfile(sock, file, offset, nbytes, headers,
+ trailers)
+ else:
+ return os.sendfile(sock, file, offset, nbytes)
+ except OSError as err:
+ if err.errno == errno.ECONNRESET:
+ # disconnected
+ raise
+ elif err.errno in (errno.EAGAIN, errno.EBUSY):
+ # we have to retry send data
+ continue
+ else:
+ raise
+
+ def test_send_whole_file(self):
+ # normal send
+ total_sent = 0
+ offset = 0
+ nbytes = 4096
+ while total_sent < len(self.DATA):
+ sent = self.sendfile_wrapper(self.sockno, self.fileno, offset, nbytes)
+ if sent == 0:
+ break
+ offset += sent
+ total_sent += sent
+ self.assertTrue(sent <= nbytes)
+ self.assertEqual(offset, total_sent)
+
+ self.assertEqual(total_sent, len(self.DATA))
+ self.client.shutdown(socket.SHUT_RDWR)
+ self.client.close()
+ self.server.wait()
+ data = self.server.handler_instance.get_data()
+ self.assertEqual(len(data), len(self.DATA))
+ self.assertEqual(data, self.DATA)
+
+ def test_send_at_certain_offset(self):
+ # start sending a file at a certain offset
+ total_sent = 0
+ offset = len(self.DATA) // 2
+ must_send = len(self.DATA) - offset
+ nbytes = 4096
+ while total_sent < must_send:
+ sent = self.sendfile_wrapper(self.sockno, self.fileno, offset, nbytes)
+ if sent == 0:
+ break
+ offset += sent
+ total_sent += sent
+ self.assertTrue(sent <= nbytes)
+
+ self.client.shutdown(socket.SHUT_RDWR)
+ self.client.close()
+ self.server.wait()
+ data = self.server.handler_instance.get_data()
+ expected = self.DATA[len(self.DATA) // 2:]
+ self.assertEqual(total_sent, len(expected))
+ self.assertEqual(len(data), len(expected))
+ self.assertEqual(data, expected)
+
+ def test_offset_overflow(self):
+ # specify an offset > file size
+ offset = len(self.DATA) + 4096
+ try:
+ sent = os.sendfile(self.sockno, self.fileno, offset, 4096)
+ except OSError as e:
+ # Solaris can raise EINVAL if offset >= file length, ignore.
+ if e.errno != errno.EINVAL:
+ raise
+ else:
+ self.assertEqual(sent, 0)
+ self.client.shutdown(socket.SHUT_RDWR)
+ self.client.close()
+ self.server.wait()
+ data = self.server.handler_instance.get_data()
+ self.assertEqual(data, b'')
+
+ def test_invalid_offset(self):
+ with self.assertRaises(OSError) as cm:
+ os.sendfile(self.sockno, self.fileno, -1, 4096)
+ self.assertEqual(cm.exception.errno, errno.EINVAL)
+
+ # --- headers / trailers tests
+
+ if SUPPORT_HEADERS_TRAILERS:
+
+ def test_headers(self):
+ total_sent = 0
+ sent = os.sendfile(self.sockno, self.fileno, 0, 4096,
+ headers=[b"x" * 512])
+ total_sent += sent
+ offset = 4096
+ nbytes = 4096
+ while 1:
+ sent = self.sendfile_wrapper(self.sockno, self.fileno,
+ offset, nbytes)
+ if sent == 0:
+ break
+ total_sent += sent
+ offset += sent
+
+ expected_data = b"x" * 512 + self.DATA
+ self.assertEqual(total_sent, len(expected_data))
+ self.client.close()
+ self.server.wait()
+ data = self.server.handler_instance.get_data()
+ self.assertEqual(hash(data), hash(expected_data))
+
+ def test_trailers(self):
+ TESTFN2 = support.TESTFN + "2"
+ with open(TESTFN2, 'wb') as f:
+ f.write(b"abcde")
+ with open(TESTFN2, 'rb')as f:
+ self.addCleanup(os.remove, TESTFN2)
+ os.sendfile(self.sockno, f.fileno(), 0, 4096,
+ trailers=[b"12345"])
+ self.client.close()
+ self.server.wait()
+ data = self.server.handler_instance.get_data()
+ self.assertEqual(data, b"abcde12345")
+
+ if hasattr(os, "SF_NODISKIO"):
+ def test_flags(self):
+ try:
+ os.sendfile(self.sockno, self.fileno, 0, 4096,
+ flags=os.SF_NODISKIO)
+ except OSError as err:
+ if err.errno not in (errno.EBUSY, errno.EAGAIN):
+ raise
+
+
+def supports_extended_attributes():
+ if not hasattr(os, "setxattr"):
+ return False
+ try:
+ with open(support.TESTFN, "wb") as fp:
+ try:
+ os.setxattr(fp.fileno(), b"user.test", b"")
+ except OSError:
+ return False
+ finally:
+ support.unlink(support.TESTFN)
+ # Kernels < 2.6.39 don't respect setxattr flags.
+ kernel_version = platform.release()
+ m = re.match("2.6.(\d{1,2})", kernel_version)
+ return m is None or int(m.group(1)) >= 39
+
+
+@unittest.skipUnless(supports_extended_attributes(),
+ "no non-broken extended attribute support")
+class ExtendedAttributeTests(unittest.TestCase):
+
+ def tearDown(self):
+ support.unlink(support.TESTFN)
+
+ def _check_xattrs_str(self, s, getxattr, setxattr, removexattr, listxattr, **kwargs):
+ fn = support.TESTFN
+ open(fn, "wb").close()
+ with self.assertRaises(OSError) as cm:
+ getxattr(fn, s("user.test"), **kwargs)
+ self.assertEqual(cm.exception.errno, errno.ENODATA)
+ init_xattr = listxattr(fn)
+ self.assertIsInstance(init_xattr, list)
+ setxattr(fn, s("user.test"), b"", **kwargs)
+ xattr = set(init_xattr)
+ xattr.add("user.test")
+ self.assertEqual(set(listxattr(fn)), xattr)
+ self.assertEqual(getxattr(fn, b"user.test", **kwargs), b"")
+ setxattr(fn, s("user.test"), b"hello", os.XATTR_REPLACE, **kwargs)
+ self.assertEqual(getxattr(fn, b"user.test", **kwargs), b"hello")
+ with self.assertRaises(OSError) as cm:
+ setxattr(fn, s("user.test"), b"bye", os.XATTR_CREATE, **kwargs)
+ self.assertEqual(cm.exception.errno, errno.EEXIST)
+ with self.assertRaises(OSError) as cm:
+ setxattr(fn, s("user.test2"), b"bye", os.XATTR_REPLACE, **kwargs)
+ self.assertEqual(cm.exception.errno, errno.ENODATA)
+ setxattr(fn, s("user.test2"), b"foo", os.XATTR_CREATE, **kwargs)
+ xattr.add("user.test2")
+ self.assertEqual(set(listxattr(fn)), xattr)
+ removexattr(fn, s("user.test"), **kwargs)
+ with self.assertRaises(OSError) as cm:
+ getxattr(fn, s("user.test"), **kwargs)
+ self.assertEqual(cm.exception.errno, errno.ENODATA)
+ xattr.remove("user.test")
+ self.assertEqual(set(listxattr(fn)), xattr)
+ self.assertEqual(getxattr(fn, s("user.test2"), **kwargs), b"foo")
+ setxattr(fn, s("user.test"), b"a"*1024, **kwargs)
+ self.assertEqual(getxattr(fn, s("user.test"), **kwargs), b"a"*1024)
+ removexattr(fn, s("user.test"), **kwargs)
+ many = sorted("user.test{}".format(i) for i in range(100))
+ for thing in many:
+ setxattr(fn, thing, b"x", **kwargs)
+ self.assertEqual(set(listxattr(fn)), set(init_xattr) | set(many))
+
+ def _check_xattrs(self, *args, **kwargs):
+ def make_bytes(s):
+ return bytes(s, "ascii")
+ self._check_xattrs_str(str, *args, **kwargs)
+ support.unlink(support.TESTFN)
+ self._check_xattrs_str(make_bytes, *args, **kwargs)
+
+ def test_simple(self):
+ self._check_xattrs(os.getxattr, os.setxattr, os.removexattr,
+ os.listxattr)
+
+ def test_lpath(self):
+ self._check_xattrs(os.getxattr, os.setxattr, os.removexattr,
+ os.listxattr, follow_symlinks=False)
+
+ def test_fds(self):
+ def getxattr(path, *args):
+ with open(path, "rb") as fp:
+ return os.getxattr(fp.fileno(), *args)
+ def setxattr(path, *args):
+ with open(path, "wb") as fp:
+ os.setxattr(fp.fileno(), *args)
+ def removexattr(path, *args):
+ with open(path, "wb") as fp:
+ os.removexattr(fp.fileno(), *args)
+ def listxattr(path, *args):
+ with open(path, "rb") as fp:
+ return os.listxattr(fp.fileno(), *args)
+ self._check_xattrs(getxattr, setxattr, removexattr, listxattr)
+
+
+@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests")
+class Win32DeprecatedBytesAPI(unittest.TestCase):
+ def test_deprecated(self):
+ import nt
+ filename = os.fsencode(support.TESTFN)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", DeprecationWarning)
+ for func, *args in (
+ (nt._getfullpathname, filename),
+ (nt._isdir, filename),
+ (os.access, filename, os.R_OK),
+ (os.chdir, filename),
+ (os.chmod, filename, 0o777),
+ (os.getcwdb,),
+ (os.link, filename, filename),
+ (os.listdir, filename),
+ (os.lstat, filename),
+ (os.mkdir, filename),
+ (os.open, filename, os.O_RDONLY),
+ (os.rename, filename, filename),
+ (os.rmdir, filename),
+ (os.startfile, filename),
+ (os.stat, filename),
+ (os.unlink, filename),
+ (os.utime, filename),
+ ):
+ self.assertRaises(DeprecationWarning, func, *args)
+
+ @support.skip_unless_symlink
+ def test_symlink(self):
+ filename = os.fsencode(support.TESTFN)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", DeprecationWarning)
+ self.assertRaises(DeprecationWarning,
+ os.symlink, filename, filename)
+
+
+@unittest.skipUnless(hasattr(os, 'get_terminal_size'), "requires os.get_terminal_size")
+class TermsizeTests(unittest.TestCase):
+ def test_does_not_crash(self):
+ """Check if get_terminal_size() returns a meaningful value.
+
+ There's no easy portable way to actually check the size of the
+ terminal, so let's check if it returns something sensible instead.
+ """
+ try:
+ size = os.get_terminal_size()
+ except OSError as e:
+ if sys.platform == "win32" or e.errno in (errno.EINVAL, errno.ENOTTY):
+ # Under win32 a generic OSError can be thrown if the
+ # handle cannot be retrieved
+ self.skipTest("failed to query terminal size")
+ raise
+
+ self.assertGreaterEqual(size.columns, 0)
+ self.assertGreaterEqual(size.lines, 0)
+
+ def test_stty_match(self):
+ """Check if stty returns the same results
+
+ stty actually tests stdin, so get_terminal_size is invoked on
+ stdin explicitly. If stty succeeded, then get_terminal_size()
+ should work too.
+ """
+ try:
+ size = subprocess.check_output(['stty', 'size']).decode().split()
+ except (FileNotFoundError, subprocess.CalledProcessError):
+ self.skipTest("stty invocation failed")
+ expected = (int(size[1]), int(size[0])) # reversed order
+
+ try:
+ actual = os.get_terminal_size(sys.__stdin__.fileno())
+ except OSError as e:
+ if sys.platform == "win32" or e.errno in (errno.EINVAL, errno.ENOTTY):
+ # Under win32 a generic OSError can be thrown if the
+ # handle cannot be retrieved
+ self.skipTest("failed to query terminal size")
+ raise
+ self.assertEqual(expected, actual)
+
+
+@support.reap_threads
def test_main():
support.run_unittest(
FileTests,
StatAttributeTests,
EnvironTests,
WalkTests,
+ FwalkTests,
MakedirTests,
DevNullTests,
URandomTests,
@@ -1365,9 +2126,15 @@ def test_main():
Win32KillTests,
Win32SymlinkTests,
FSEncodingTests,
+ DeviceEncodingTests,
PidTests,
LoginTests,
LinkTests,
+ TestSendfile,
+ ProgramPriorityTests,
+ ExtendedAttributeTests,
+ Win32DeprecatedBytesAPI,
+ TermsizeTests,
RemoveDirsTests,
)
diff --git a/Lib/test/test_ossaudiodev.py b/Lib/test/test_ossaudiodev.py
index 9cb89d6..3908a05 100644
--- a/Lib/test/test_ossaudiodev.py
+++ b/Lib/test/test_ossaudiodev.py
@@ -170,6 +170,22 @@ class OSSAudioDevTests(unittest.TestCase):
pass
self.assertTrue(dsp.closed)
+ def test_on_closed(self):
+ dsp = ossaudiodev.open('w')
+ dsp.close()
+ self.assertRaises(ValueError, dsp.fileno)
+ self.assertRaises(ValueError, dsp.read, 1)
+ self.assertRaises(ValueError, dsp.write, b'x')
+ self.assertRaises(ValueError, dsp.writeall, b'x')
+ self.assertRaises(ValueError, dsp.bufsize)
+ self.assertRaises(ValueError, dsp.obufcount)
+ self.assertRaises(ValueError, dsp.obufcount)
+ self.assertRaises(ValueError, dsp.obuffree)
+ self.assertRaises(ValueError, dsp.getptr)
+
+ mixer = ossaudiodev.openmixer()
+ mixer.close()
+ self.assertRaises(ValueError, mixer.fileno)
def test_main():
try:
diff --git a/Lib/test/test_osx_env.py b/Lib/test/test_osx_env.py
index 8b3df37..24ec2b4 100644
--- a/Lib/test/test_osx_env.py
+++ b/Lib/test/test_osx_env.py
@@ -5,6 +5,7 @@ Test suite for OS X interpreter environment variables.
from test.support import EnvironmentVarGuard, run_unittest
import subprocess
import sys
+import sysconfig
import unittest
class OSXEnvironmentVariableTestCase(unittest.TestCase):
@@ -27,8 +28,6 @@ class OSXEnvironmentVariableTestCase(unittest.TestCase):
self._check_sys('PYTHONEXECUTABLE', '==', 'sys.executable')
def test_main():
- from distutils import sysconfig
-
if sys.platform == 'darwin' and sysconfig.get_config_var('WITH_NEXT_FRAMEWORK'):
run_unittest(OSXEnvironmentVariableTestCase)
diff --git a/Lib/test/test_parser.py b/Lib/test/test_parser.py
index 70eb9c0..c93a2ca 100644
--- a/Lib/test/test_parser.py
+++ b/Lib/test/test_parser.py
@@ -51,6 +51,10 @@ class RoundtripLegalSyntaxTestCase(unittest.TestCase):
self.check_suite("def f(): (yield 1)*2")
self.check_suite("def f(): return; yield 1")
self.check_suite("def f(): yield 1; return")
+ self.check_suite("def f(): yield from 1")
+ self.check_suite("def f(): x = yield from 1")
+ self.check_suite("def f(): f((yield from 1))")
+ self.check_suite("def f(): yield 1; return 1")
self.check_suite("def f():\n"
" for x in range(30):\n"
" yield x\n")
@@ -717,6 +721,12 @@ class STObjectTestCase(unittest.TestCase):
# XXX tests for pickling and unpickling of ST objects should go here
+class OtherParserCase(unittest.TestCase):
+
+ def test_two_args_to_expr(self):
+ # See bug #12264
+ with self.assertRaises(TypeError):
+ parser.expr("a", "b")
def test_main():
support.run_unittest(
@@ -725,6 +735,7 @@ def test_main():
CompileTestCase,
ParserStackLimitTestCase,
STObjectTestCase,
+ OtherParserCase,
)
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
index 8355e7d..4fb07a3 100644
--- a/Lib/test/test_pdb.py
+++ b/Lib/test/test_pdb.py
@@ -21,9 +21,12 @@ class PdbTestInput(object):
def __enter__(self):
self.real_stdin = sys.stdin
sys.stdin = _FakeInput(self.input)
+ self.orig_trace = sys.gettrace() if hasattr(sys, 'gettrace') else None
def __exit__(self, *exc):
sys.stdin = self.real_stdin
+ if self.orig_trace:
+ sys.settrace(self.orig_trace)
def test_pdb_displayhook():
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index e268ae2..1cacdea 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -3,13 +3,16 @@ import re
import sys
from io import StringIO
import unittest
+from math import copysign
def disassemble(func):
f = StringIO()
tmp = sys.stdout
sys.stdout = f
- dis.dis(func)
- sys.stdout = tmp
+ try:
+ dis.dis(func)
+ finally:
+ sys.stdout = tmp
result = f.getvalue()
f.close()
return result
@@ -17,6 +20,7 @@ def disassemble(func):
def dis_single(line):
return disassemble(compile(line, '', 'single'))
+
class TestTranforms(unittest.TestCase):
def test_unot(self):
@@ -99,6 +103,12 @@ class TestTranforms(unittest.TestCase):
self.assertIn(elem, asm)
self.assertNotIn('BUILD_TUPLE', asm)
+ # Long tuples should be folded too.
+ asm = dis_single(repr(tuple(range(10000))))
+ # One LOAD_CONST for the tuple, one for the None return value
+ self.assertEqual(asm.count('LOAD_CONST'), 2)
+ self.assertNotIn('BUILD_TUPLE', asm)
+
# Bug 1053819: Tuple of constants misidentified when presented with:
# . . . opcode_with_arg 100 unary_opcode BUILD_TUPLE 1 . . .
# The following would segfault upon compilation
@@ -196,27 +206,28 @@ class TestTranforms(unittest.TestCase):
self.assertIn('(1000)', asm)
def test_binary_subscr_on_unicode(self):
- # unicode strings don't get optimized
+ # valid code get optimized
asm = dis_single('"foo"[0]')
- self.assertNotIn("('f')", asm)
- self.assertIn('BINARY_SUBSCR', asm)
+ self.assertIn("('f')", asm)
+ self.assertNotIn('BINARY_SUBSCR', asm)
asm = dis_single('"\u0061\uffff"[1]')
- self.assertNotIn("('\\uffff')", asm)
- self.assertIn('BINARY_SUBSCR', asm)
+ self.assertIn("('\\uffff')", asm)
+ self.assertNotIn('BINARY_SUBSCR', asm)
+ asm = dis_single('"\U00012345abcdef"[3]')
+ self.assertIn("('c')", asm)
+ self.assertNotIn('BINARY_SUBSCR', asm)
+ # invalid code doesn't get optimized
# out of range
asm = dis_single('"fuu"[10]')
self.assertIn('BINARY_SUBSCR', asm)
- # non-BMP char (see #5057)
- asm = dis_single('"\U00012345"[0]')
- self.assertIn('BINARY_SUBSCR', asm)
- asm = dis_single('"\U00012345abcdef"[3]')
- self.assertIn('BINARY_SUBSCR', asm)
-
def test_folding_of_unaryops_on_constants(self):
for line, elem in (
('-0.5', '(-0.5)'), # unary negative
+ ('-0.0', '(-0.0)'), # -0.0
+ ('-(1.0-1.0)','(-0.0)'), # -0.0 after folding
+ ('-0', '(0)'), # -0
('~-2', '(1)'), # unary invert
('+1', '(1)'), # unary positive
):
@@ -224,6 +235,13 @@ class TestTranforms(unittest.TestCase):
self.assertIn(elem, asm, asm)
self.assertNotIn('UNARY_', asm)
+ # Check that -0.0 works after marshaling
+ def negzero():
+ return -(1.0-1.0)
+
+ self.assertNotIn('UNARY_', disassemble(negzero))
+ self.assertTrue(copysign(1.0, negzero()) < 0)
+
# Verify that unfoldables are skipped
for line, elem in (
('-"abc"', "('abc')"), # unary negative
@@ -286,6 +304,25 @@ class TestTranforms(unittest.TestCase):
asm = disassemble(f)
self.assertNotIn('BINARY_ADD', asm)
+ def test_constant_folding(self):
+ # Issue #11244: aggressive constant folding.
+ exprs = [
+ "3 * -5",
+ "-3 * 5",
+ "2 * (3 * 4)",
+ "(2 * 3) * 4",
+ "(-1, 2, 3)",
+ "(1, -2, 3)",
+ "(1, 2, -3)",
+ "(1, 2, -3) * 6",
+ "lambda x: x in {(3 * -5) + (-1 - 6), (1, -2, 3) * 2, None}",
+ ]
+ for e in exprs:
+ asm = dis_single(e)
+ self.assertNotIn('UNARY_', asm, e)
+ self.assertNotIn('BINARY_', asm, e)
+ self.assertNotIn('BUILD_', asm, e)
+
class TestBuglets(unittest.TestCase):
def test_bug_11510(self):
diff --git a/Lib/test/test_pep277.py b/Lib/test/test_pep277.py
index 6d891e5..4b16cbb 100644
--- a/Lib/test/test_pep277.py
+++ b/Lib/test/test_pep277.py
@@ -1,6 +1,9 @@
# Test the Unicode versions of normal file functions
# open, os.open, os.stat. os.listdir, os.rename, os.remove, os.mkdir, os.chdir, os.rmdir
-import sys, os, unittest
+import os
+import sys
+import unittest
+import warnings
from unicodedata import normalize
from test import support
@@ -38,8 +41,8 @@ if sys.platform != 'darwin':
'17_\u2001\u2001\u2001A',
'18_\u2003\u2003\u2003A', # == NFC('\u2001\u2001\u2001A')
'19_\u0020\u0020\u0020A', # '\u0020' == ' ' == NFKC('\u2000') ==
- # NFKC('\u2001') == NFKC('\u2003')
-])
+ # NFKC('\u2001') == NFKC('\u2003')
+ ])
# Is it Unicode-friendly?
@@ -71,7 +74,7 @@ class UnicodeFileTests(unittest.TestCase):
def setUp(self):
try:
os.mkdir(support.TESTFN)
- except OSError:
+ except FileExistsError:
pass
files = set()
for name in self.files:
@@ -90,15 +93,17 @@ class UnicodeFileTests(unittest.TestCase):
return normalize(self.normal_form, s)
return s
- def _apply_failure(self, fn, filename, expected_exception,
- check_fn_in_exception = True):
+ def _apply_failure(self, fn, filename,
+ expected_exception=FileNotFoundError,
+ check_filename=True):
with self.assertRaises(expected_exception) as c:
fn(filename)
exc_filename = c.exception.filename
- # the "filename" exception attribute may be encoded
- if isinstance(exc_filename, bytes):
- filename = filename.encode(sys.getfilesystemencoding())
- if check_fn_in_exception:
+ # listdir may append a wildcard to the filename
+ if fn is os.listdir and sys.platform == 'win32':
+ exc_filename, _, wildcard = exc_filename.rpartition(os.sep)
+ self.assertEqual(wildcard, '*.*')
+ if check_filename:
self.assertEqual(exc_filename, filename, "Function '%s(%a) failed "
"with bad filename in the exception: %a" %
(fn.__name__, filename, exc_filename))
@@ -107,13 +112,18 @@ class UnicodeFileTests(unittest.TestCase):
# Pass non-existing Unicode filenames all over the place.
for name in self.files:
name = "not_" + name
- self._apply_failure(open, name, IOError)
- self._apply_failure(os.stat, name, OSError)
- self._apply_failure(os.chdir, name, OSError)
- self._apply_failure(os.rmdir, name, OSError)
- self._apply_failure(os.remove, name, OSError)
- # listdir may append a wildcard to the filename, so dont check
- self._apply_failure(os.listdir, name, OSError, False)
+ self._apply_failure(open, name)
+ self._apply_failure(os.stat, name)
+ self._apply_failure(os.chdir, name)
+ self._apply_failure(os.rmdir, name)
+ self._apply_failure(os.remove, name)
+ self._apply_failure(os.listdir, name)
+
+ if sys.platform == 'win32':
+ # Windows is lunatic. Issue #13366.
+ _listdir_failure = NotADirectoryError, FileNotFoundError
+ else:
+ _listdir_failure = NotADirectoryError
def test_open(self):
for name in self.files:
@@ -121,12 +131,13 @@ class UnicodeFileTests(unittest.TestCase):
f.write((name+'\n').encode("utf-8"))
f.close()
os.stat(name)
+ self._apply_failure(os.listdir, name, self._listdir_failure)
# Skip the test on darwin, because darwin does normalize the filename to
# NFD (a variant of Unicode NFD form). Normalize the filename to NFC, NFKC,
# NFKD in Python is useless, because darwin will normalize it later and so
# open(), os.stat(), etc. don't raise any exception.
- @unittest.skipIf(sys.platform == 'darwin', 'irrevelant test on Mac OS X')
+ @unittest.skipIf(sys.platform == 'darwin', 'irrelevant test on Mac OS X')
def test_normalize(self):
files = set(self.files)
others = set()
@@ -134,21 +145,22 @@ class UnicodeFileTests(unittest.TestCase):
others |= set(normalize(nf, file) for file in files)
others -= files
for name in others:
- self._apply_failure(open, name, IOError)
- self._apply_failure(os.stat, name, OSError)
- self._apply_failure(os.chdir, name, OSError)
- self._apply_failure(os.rmdir, name, OSError)
- self._apply_failure(os.remove, name, OSError)
- # listdir may append a wildcard to the filename, so dont check
- self._apply_failure(os.listdir, name, OSError, False)
+ self._apply_failure(open, name)
+ self._apply_failure(os.stat, name)
+ self._apply_failure(os.chdir, name)
+ self._apply_failure(os.rmdir, name)
+ self._apply_failure(os.remove, name)
+ self._apply_failure(os.listdir, name)
# Skip the test on darwin, because darwin uses a normalization different
# than Python NFD normalization: filenames are different even if we use
# Python NFD normalization.
- @unittest.skipIf(sys.platform == 'darwin', 'irrevelant test on Mac OS X')
+ @unittest.skipIf(sys.platform == 'darwin', 'irrelevant test on Mac OS X')
def test_listdir(self):
sf0 = set(self.files)
- f1 = os.listdir(support.TESTFN.encode(sys.getfilesystemencoding()))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ f1 = os.listdir(support.TESTFN.encode(sys.getfilesystemencoding()))
f2 = os.listdir(support.TESTFN)
sf2 = set(os.path.join(support.TESTFN, f) for f in f2)
self.assertEqual(sf0, sf2, "%a != %a" % (sf0, sf2))
diff --git a/Lib/test/test_pep292.py b/Lib/test/test_pep292.py
index 119c7ea..6da8d2e 100644
--- a/Lib/test/test_pep292.py
+++ b/Lib/test/test_pep292.py
@@ -42,19 +42,6 @@ class TestTemplate(unittest.TestCase):
s = Template('$who likes $$')
eq(s.substitute(dict(who='tim', what='ham')), 'tim likes $')
- def test_invalid(self):
- class MyPattern(Template):
- pattern = r"""
- (?:
- (?P<invalid>) |
- (?P<escaped>%(delim)s) |
- @(?P<named>%(id)s) |
- @{(?P<braced>%(id)s)}
- )
- """
- s = MyPattern('$')
- self.assertRaises(ValueError, s.substitute, dict())
-
def test_percents(self):
eq = self.assertEqual
s = Template('%(foo)s $foo ${foo}')
@@ -172,6 +159,26 @@ class TestTemplate(unittest.TestCase):
val = t.safe_substitute({'location': 'Cleveland'})
self.assertEqual(val, 'PyCon in Cleveland')
+ def test_invalid_with_no_lines(self):
+ # The error formatting for invalid templates
+ # has a special case for no data that the default
+ # pattern can't trigger (always has at least '$')
+ # So we craft a pattern that is always invalid
+ # with no leading data.
+ class MyTemplate(Template):
+ pattern = r"""
+ (?P<invalid>) |
+ unreachable(
+ (?P<named>) |
+ (?P<braced>) |
+ (?P<escaped>)
+ )
+ """
+ s = MyTemplate('')
+ with self.assertRaises(ValueError) as err:
+ s.substitute({})
+ self.assertIn('line 1, col 1', str(err.exception))
+
def test_unicode_values(self):
s = Template('$who likes $what')
d = dict(who='t\xffm', what='f\xfe\fed')
diff --git a/Lib/test/test_pep3131.py b/Lib/test/test_pep3131.py
index df0f64d..2e6b90a 100644
--- a/Lib/test/test_pep3131.py
+++ b/Lib/test/test_pep3131.py
@@ -17,12 +17,7 @@ class PEP3131Test(unittest.TestCase):
def test_non_bmp_normalized(self):
𝔘𝔫𝔦𝔠𝔬𝔡𝔢 = 1
- # On wide builds, this is normalized, but on narrow ones it is not. See
- # #12746.
- try:
- self.assertIn("𝔘𝔫𝔦𝔠𝔬𝔡𝔢", dir())
- except AssertionError:
- raise unittest.case._ExpectedFailure(sys.exc_info())
+ self.assertIn("Unicode", dir())
def test_invalid(self):
try:
diff --git a/Lib/test/test_pep3151.py b/Lib/test/test_pep3151.py
new file mode 100644
index 0000000..2792c10
--- /dev/null
+++ b/Lib/test/test_pep3151.py
@@ -0,0 +1,211 @@
+import builtins
+import os
+import select
+import socket
+import sys
+import unittest
+import errno
+from errno import EEXIST
+
+from test import support
+
+class SubOSError(OSError):
+ pass
+
+class SubOSErrorWithInit(OSError):
+ def __init__(self, message, bar):
+ self.bar = bar
+ super().__init__(message)
+
+class SubOSErrorWithNew(OSError):
+ def __new__(cls, message, baz):
+ self = super().__new__(cls, message)
+ self.baz = baz
+ return self
+
+class SubOSErrorCombinedInitFirst(SubOSErrorWithInit, SubOSErrorWithNew):
+ pass
+
+class SubOSErrorCombinedNewFirst(SubOSErrorWithNew, SubOSErrorWithInit):
+ pass
+
+class SubOSErrorWithStandaloneInit(OSError):
+ def __init__(self):
+ pass
+
+
+class HierarchyTest(unittest.TestCase):
+
+ def test_builtin_errors(self):
+ self.assertEqual(OSError.__name__, 'OSError')
+ self.assertIs(IOError, OSError)
+ self.assertIs(EnvironmentError, OSError)
+
+ def test_socket_errors(self):
+ self.assertIs(socket.error, IOError)
+ self.assertIs(socket.gaierror.__base__, OSError)
+ self.assertIs(socket.herror.__base__, OSError)
+ self.assertIs(socket.timeout.__base__, OSError)
+
+ def test_select_error(self):
+ self.assertIs(select.error, OSError)
+
+ # mmap.error is tested in test_mmap
+
+ _pep_map = """
+ +-- BlockingIOError EAGAIN, EALREADY, EWOULDBLOCK, EINPROGRESS
+ +-- ChildProcessError ECHILD
+ +-- ConnectionError
+ +-- BrokenPipeError EPIPE, ESHUTDOWN
+ +-- ConnectionAbortedError ECONNABORTED
+ +-- ConnectionRefusedError ECONNREFUSED
+ +-- ConnectionResetError ECONNRESET
+ +-- FileExistsError EEXIST
+ +-- FileNotFoundError ENOENT
+ +-- InterruptedError EINTR
+ +-- IsADirectoryError EISDIR
+ +-- NotADirectoryError ENOTDIR
+ +-- PermissionError EACCES, EPERM
+ +-- ProcessLookupError ESRCH
+ +-- TimeoutError ETIMEDOUT
+ """
+ def _make_map(s):
+ _map = {}
+ for line in s.splitlines():
+ line = line.strip('+- ')
+ if not line:
+ continue
+ excname, _, errnames = line.partition(' ')
+ for errname in filter(None, errnames.strip().split(', ')):
+ _map[getattr(errno, errname)] = getattr(builtins, excname)
+ return _map
+ _map = _make_map(_pep_map)
+
+ def test_errno_mapping(self):
+ # The OSError constructor maps errnos to subclasses
+ # A sample test for the basic functionality
+ e = OSError(EEXIST, "Bad file descriptor")
+ self.assertIs(type(e), FileExistsError)
+ # Exhaustive testing
+ for errcode, exc in self._map.items():
+ e = OSError(errcode, "Some message")
+ self.assertIs(type(e), exc)
+ othercodes = set(errno.errorcode) - set(self._map)
+ for errcode in othercodes:
+ e = OSError(errcode, "Some message")
+ self.assertIs(type(e), OSError)
+
+ def test_try_except(self):
+ filename = "some_hopefully_non_existing_file"
+
+ # This checks that try .. except checks the concrete exception
+ # (FileNotFoundError) and not the base type specified when
+ # PyErr_SetFromErrnoWithFilenameObject was called.
+ # (it is therefore deliberate that it doesn't use assertRaises)
+ try:
+ open(filename)
+ except FileNotFoundError:
+ pass
+ else:
+ self.fail("should have raised a FileNotFoundError")
+
+ # Another test for PyErr_SetExcFromWindowsErrWithFilenameObject()
+ self.assertFalse(os.path.exists(filename))
+ try:
+ os.unlink(filename)
+ except FileNotFoundError:
+ pass
+ else:
+ self.fail("should have raised a FileNotFoundError")
+
+
+class AttributesTest(unittest.TestCase):
+
+ def test_windows_error(self):
+ if os.name == "nt":
+ self.assertIn('winerror', dir(OSError))
+ else:
+ self.assertNotIn('winerror', dir(OSError))
+
+ def test_posix_error(self):
+ e = OSError(EEXIST, "File already exists", "foo.txt")
+ self.assertEqual(e.errno, EEXIST)
+ self.assertEqual(e.args[0], EEXIST)
+ self.assertEqual(e.strerror, "File already exists")
+ self.assertEqual(e.filename, "foo.txt")
+ if os.name == "nt":
+ self.assertEqual(e.winerror, None)
+
+ @unittest.skipUnless(os.name == "nt", "Windows-specific test")
+ def test_errno_translation(self):
+ # ERROR_ALREADY_EXISTS (183) -> EEXIST
+ e = OSError(0, "File already exists", "foo.txt", 183)
+ self.assertEqual(e.winerror, 183)
+ self.assertEqual(e.errno, EEXIST)
+ self.assertEqual(e.args[0], EEXIST)
+ self.assertEqual(e.strerror, "File already exists")
+ self.assertEqual(e.filename, "foo.txt")
+
+ def test_blockingioerror(self):
+ args = ("a", "b", "c", "d", "e")
+ for n in range(6):
+ e = BlockingIOError(*args[:n])
+ with self.assertRaises(AttributeError):
+ e.characters_written
+ e = BlockingIOError("a", "b", 3)
+ self.assertEqual(e.characters_written, 3)
+ e.characters_written = 5
+ self.assertEqual(e.characters_written, 5)
+
+ # XXX VMSError not tested
+
+
+class ExplicitSubclassingTest(unittest.TestCase):
+
+ def test_errno_mapping(self):
+ # When constructing an OSError subclass, errno mapping isn't done
+ e = SubOSError(EEXIST, "Bad file descriptor")
+ self.assertIs(type(e), SubOSError)
+
+ def test_init_overriden(self):
+ e = SubOSErrorWithInit("some message", "baz")
+ self.assertEqual(e.bar, "baz")
+ self.assertEqual(e.args, ("some message",))
+
+ def test_init_kwdargs(self):
+ e = SubOSErrorWithInit("some message", bar="baz")
+ self.assertEqual(e.bar, "baz")
+ self.assertEqual(e.args, ("some message",))
+
+ def test_new_overriden(self):
+ e = SubOSErrorWithNew("some message", "baz")
+ self.assertEqual(e.baz, "baz")
+ self.assertEqual(e.args, ("some message",))
+
+ def test_new_kwdargs(self):
+ e = SubOSErrorWithNew("some message", baz="baz")
+ self.assertEqual(e.baz, "baz")
+ self.assertEqual(e.args, ("some message",))
+
+ def test_init_new_overriden(self):
+ e = SubOSErrorCombinedInitFirst("some message", "baz")
+ self.assertEqual(e.bar, "baz")
+ self.assertEqual(e.baz, "baz")
+ self.assertEqual(e.args, ("some message",))
+ e = SubOSErrorCombinedNewFirst("some message", "baz")
+ self.assertEqual(e.bar, "baz")
+ self.assertEqual(e.baz, "baz")
+ self.assertEqual(e.args, ("some message",))
+
+ def test_init_standalone(self):
+ # __init__ doesn't propagate to OSError.__init__ (see issue #15229)
+ e = SubOSErrorWithStandaloneInit()
+ self.assertEqual(e.args, ())
+ self.assertEqual(str(e), '')
+
+
+def test_main():
+ support.run_unittest(__name__)
+
+if __name__=="__main__":
+ test_main()
diff --git a/Lib/test/test_pep380.py b/Lib/test/test_pep380.py
new file mode 100644
index 0000000..f3eed43
--- /dev/null
+++ b/Lib/test/test_pep380.py
@@ -0,0 +1,965 @@
+# -*- coding: utf-8 -*-
+
+"""
+Test suite for PEP 380 implementation
+
+adapted from original tests written by Greg Ewing
+see <http://www.cosc.canterbury.ac.nz/greg.ewing/python/yield-from/YieldFrom-Python3.1.2-rev5.zip>
+"""
+
+import unittest
+import io
+import sys
+import inspect
+import parser
+
+from test.support import captured_stderr
+
+class TestPEP380Operation(unittest.TestCase):
+ """
+ Test semantics.
+ """
+
+ def test_delegation_of_initial_next_to_subgenerator(self):
+ """
+ Test delegation of initial next() call to subgenerator
+ """
+ trace = []
+ def g1():
+ trace.append("Starting g1")
+ yield from g2()
+ trace.append("Finishing g1")
+ def g2():
+ trace.append("Starting g2")
+ yield 42
+ trace.append("Finishing g2")
+ for x in g1():
+ trace.append("Yielded %s" % (x,))
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Starting g2",
+ "Yielded 42",
+ "Finishing g2",
+ "Finishing g1",
+ ])
+
+ def test_raising_exception_in_initial_next_call(self):
+ """
+ Test raising exception in initial next() call
+ """
+ trace = []
+ def g1():
+ try:
+ trace.append("Starting g1")
+ yield from g2()
+ finally:
+ trace.append("Finishing g1")
+ def g2():
+ try:
+ trace.append("Starting g2")
+ raise ValueError("spanish inquisition occurred")
+ finally:
+ trace.append("Finishing g2")
+ try:
+ for x in g1():
+ trace.append("Yielded %s" % (x,))
+ except ValueError as e:
+ self.assertEqual(e.args[0], "spanish inquisition occurred")
+ else:
+ self.fail("subgenerator failed to raise ValueError")
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Starting g2",
+ "Finishing g2",
+ "Finishing g1",
+ ])
+
+ def test_delegation_of_next_call_to_subgenerator(self):
+ """
+ Test delegation of next() call to subgenerator
+ """
+ trace = []
+ def g1():
+ trace.append("Starting g1")
+ yield "g1 ham"
+ yield from g2()
+ yield "g1 eggs"
+ trace.append("Finishing g1")
+ def g2():
+ trace.append("Starting g2")
+ yield "g2 spam"
+ yield "g2 more spam"
+ trace.append("Finishing g2")
+ for x in g1():
+ trace.append("Yielded %s" % (x,))
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Yielded g2 more spam",
+ "Finishing g2",
+ "Yielded g1 eggs",
+ "Finishing g1",
+ ])
+
+ def test_raising_exception_in_delegated_next_call(self):
+ """
+ Test raising exception in delegated next() call
+ """
+ trace = []
+ def g1():
+ try:
+ trace.append("Starting g1")
+ yield "g1 ham"
+ yield from g2()
+ yield "g1 eggs"
+ finally:
+ trace.append("Finishing g1")
+ def g2():
+ try:
+ trace.append("Starting g2")
+ yield "g2 spam"
+ raise ValueError("hovercraft is full of eels")
+ yield "g2 more spam"
+ finally:
+ trace.append("Finishing g2")
+ try:
+ for x in g1():
+ trace.append("Yielded %s" % (x,))
+ except ValueError as e:
+ self.assertEqual(e.args[0], "hovercraft is full of eels")
+ else:
+ self.fail("subgenerator failed to raise ValueError")
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Finishing g2",
+ "Finishing g1",
+ ])
+
+ def test_delegation_of_send(self):
+ """
+ Test delegation of send()
+ """
+ trace = []
+ def g1():
+ trace.append("Starting g1")
+ x = yield "g1 ham"
+ trace.append("g1 received %s" % (x,))
+ yield from g2()
+ x = yield "g1 eggs"
+ trace.append("g1 received %s" % (x,))
+ trace.append("Finishing g1")
+ def g2():
+ trace.append("Starting g2")
+ x = yield "g2 spam"
+ trace.append("g2 received %s" % (x,))
+ x = yield "g2 more spam"
+ trace.append("g2 received %s" % (x,))
+ trace.append("Finishing g2")
+ g = g1()
+ y = next(g)
+ x = 1
+ try:
+ while 1:
+ y = g.send(x)
+ trace.append("Yielded %s" % (y,))
+ x += 1
+ except StopIteration:
+ pass
+ self.assertEqual(trace,[
+ "Starting g1",
+ "g1 received 1",
+ "Starting g2",
+ "Yielded g2 spam",
+ "g2 received 2",
+ "Yielded g2 more spam",
+ "g2 received 3",
+ "Finishing g2",
+ "Yielded g1 eggs",
+ "g1 received 4",
+ "Finishing g1",
+ ])
+
+ def test_handling_exception_while_delegating_send(self):
+ """
+ Test handling exception while delegating 'send'
+ """
+ trace = []
+ def g1():
+ trace.append("Starting g1")
+ x = yield "g1 ham"
+ trace.append("g1 received %s" % (x,))
+ yield from g2()
+ x = yield "g1 eggs"
+ trace.append("g1 received %s" % (x,))
+ trace.append("Finishing g1")
+ def g2():
+ trace.append("Starting g2")
+ x = yield "g2 spam"
+ trace.append("g2 received %s" % (x,))
+ raise ValueError("hovercraft is full of eels")
+ x = yield "g2 more spam"
+ trace.append("g2 received %s" % (x,))
+ trace.append("Finishing g2")
+ def run():
+ g = g1()
+ y = next(g)
+ x = 1
+ try:
+ while 1:
+ y = g.send(x)
+ trace.append("Yielded %s" % (y,))
+ x += 1
+ except StopIteration:
+ trace.append("StopIteration")
+ self.assertRaises(ValueError,run)
+ self.assertEqual(trace,[
+ "Starting g1",
+ "g1 received 1",
+ "Starting g2",
+ "Yielded g2 spam",
+ "g2 received 2",
+ ])
+
+ def test_delegating_close(self):
+ """
+ Test delegating 'close'
+ """
+ trace = []
+ def g1():
+ try:
+ trace.append("Starting g1")
+ yield "g1 ham"
+ yield from g2()
+ yield "g1 eggs"
+ finally:
+ trace.append("Finishing g1")
+ def g2():
+ try:
+ trace.append("Starting g2")
+ yield "g2 spam"
+ yield "g2 more spam"
+ finally:
+ trace.append("Finishing g2")
+ g = g1()
+ for i in range(2):
+ x = next(g)
+ trace.append("Yielded %s" % (x,))
+ g.close()
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Finishing g2",
+ "Finishing g1"
+ ])
+
+ def test_handing_exception_while_delegating_close(self):
+ """
+ Test handling exception while delegating 'close'
+ """
+ trace = []
+ def g1():
+ try:
+ trace.append("Starting g1")
+ yield "g1 ham"
+ yield from g2()
+ yield "g1 eggs"
+ finally:
+ trace.append("Finishing g1")
+ def g2():
+ try:
+ trace.append("Starting g2")
+ yield "g2 spam"
+ yield "g2 more spam"
+ finally:
+ trace.append("Finishing g2")
+ raise ValueError("nybbles have exploded with delight")
+ try:
+ g = g1()
+ for i in range(2):
+ x = next(g)
+ trace.append("Yielded %s" % (x,))
+ g.close()
+ except ValueError as e:
+ self.assertEqual(e.args[0], "nybbles have exploded with delight")
+ self.assertIsInstance(e.__context__, GeneratorExit)
+ else:
+ self.fail("subgenerator failed to raise ValueError")
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Finishing g2",
+ "Finishing g1",
+ ])
+
+ def test_delegating_throw(self):
+ """
+ Test delegating 'throw'
+ """
+ trace = []
+ def g1():
+ try:
+ trace.append("Starting g1")
+ yield "g1 ham"
+ yield from g2()
+ yield "g1 eggs"
+ finally:
+ trace.append("Finishing g1")
+ def g2():
+ try:
+ trace.append("Starting g2")
+ yield "g2 spam"
+ yield "g2 more spam"
+ finally:
+ trace.append("Finishing g2")
+ try:
+ g = g1()
+ for i in range(2):
+ x = next(g)
+ trace.append("Yielded %s" % (x,))
+ e = ValueError("tomato ejected")
+ g.throw(e)
+ except ValueError as e:
+ self.assertEqual(e.args[0], "tomato ejected")
+ else:
+ self.fail("subgenerator failed to raise ValueError")
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Finishing g2",
+ "Finishing g1",
+ ])
+
+ def test_value_attribute_of_StopIteration_exception(self):
+ """
+ Test 'value' attribute of StopIteration exception
+ """
+ trace = []
+ def pex(e):
+ trace.append("%s: %s" % (e.__class__.__name__, e))
+ trace.append("value = %s" % (e.value,))
+ e = StopIteration()
+ pex(e)
+ e = StopIteration("spam")
+ pex(e)
+ e.value = "eggs"
+ pex(e)
+ self.assertEqual(trace,[
+ "StopIteration: ",
+ "value = None",
+ "StopIteration: spam",
+ "value = spam",
+ "StopIteration: spam",
+ "value = eggs",
+ ])
+
+
+ def test_exception_value_crash(self):
+ # There used to be a refcount error when the return value
+ # stored in the StopIteration has a refcount of 1.
+ def g1():
+ yield from g2()
+ def g2():
+ yield "g2"
+ return [42]
+ self.assertEqual(list(g1()), ["g2"])
+
+
+ def test_generator_return_value(self):
+ """
+ Test generator return value
+ """
+ trace = []
+ def g1():
+ trace.append("Starting g1")
+ yield "g1 ham"
+ ret = yield from g2()
+ trace.append("g2 returned %s" % (ret,))
+ ret = yield from g2(42)
+ trace.append("g2 returned %s" % (ret,))
+ yield "g1 eggs"
+ trace.append("Finishing g1")
+ def g2(v = None):
+ trace.append("Starting g2")
+ yield "g2 spam"
+ yield "g2 more spam"
+ trace.append("Finishing g2")
+ if v:
+ return v
+ for x in g1():
+ trace.append("Yielded %s" % (x,))
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Yielded g2 more spam",
+ "Finishing g2",
+ "g2 returned None",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Yielded g2 more spam",
+ "Finishing g2",
+ "g2 returned 42",
+ "Yielded g1 eggs",
+ "Finishing g1",
+ ])
+
+ def test_delegation_of_next_to_non_generator(self):
+ """
+ Test delegation of next() to non-generator
+ """
+ trace = []
+ def g():
+ yield from range(3)
+ for x in g():
+ trace.append("Yielded %s" % (x,))
+ self.assertEqual(trace,[
+ "Yielded 0",
+ "Yielded 1",
+ "Yielded 2",
+ ])
+
+
+ def test_conversion_of_sendNone_to_next(self):
+ """
+ Test conversion of send(None) to next()
+ """
+ trace = []
+ def g():
+ yield from range(3)
+ gi = g()
+ for x in range(3):
+ y = gi.send(None)
+ trace.append("Yielded: %s" % (y,))
+ self.assertEqual(trace,[
+ "Yielded: 0",
+ "Yielded: 1",
+ "Yielded: 2",
+ ])
+
+ def test_delegation_of_close_to_non_generator(self):
+ """
+ Test delegation of close() to non-generator
+ """
+ trace = []
+ def g():
+ try:
+ trace.append("starting g")
+ yield from range(3)
+ trace.append("g should not be here")
+ finally:
+ trace.append("finishing g")
+ gi = g()
+ next(gi)
+ with captured_stderr() as output:
+ gi.close()
+ self.assertEqual(output.getvalue(), '')
+ self.assertEqual(trace,[
+ "starting g",
+ "finishing g",
+ ])
+
+ def test_delegating_throw_to_non_generator(self):
+ """
+ Test delegating 'throw' to non-generator
+ """
+ trace = []
+ def g():
+ try:
+ trace.append("Starting g")
+ yield from range(10)
+ finally:
+ trace.append("Finishing g")
+ try:
+ gi = g()
+ for i in range(5):
+ x = next(gi)
+ trace.append("Yielded %s" % (x,))
+ e = ValueError("tomato ejected")
+ gi.throw(e)
+ except ValueError as e:
+ self.assertEqual(e.args[0],"tomato ejected")
+ else:
+ self.fail("subgenerator failed to raise ValueError")
+ self.assertEqual(trace,[
+ "Starting g",
+ "Yielded 0",
+ "Yielded 1",
+ "Yielded 2",
+ "Yielded 3",
+ "Yielded 4",
+ "Finishing g",
+ ])
+
+ def test_attempting_to_send_to_non_generator(self):
+ """
+ Test attempting to send to non-generator
+ """
+ trace = []
+ def g():
+ try:
+ trace.append("starting g")
+ yield from range(3)
+ trace.append("g should not be here")
+ finally:
+ trace.append("finishing g")
+ try:
+ gi = g()
+ next(gi)
+ for x in range(3):
+ y = gi.send(42)
+ trace.append("Should not have yielded: %s" % (y,))
+ except AttributeError as e:
+ self.assertIn("send", e.args[0])
+ else:
+ self.fail("was able to send into non-generator")
+ self.assertEqual(trace,[
+ "starting g",
+ "finishing g",
+ ])
+
+ def test_broken_getattr_handling(self):
+ """
+ Test subiterator with a broken getattr implementation
+ """
+ class Broken:
+ def __iter__(self):
+ return self
+ def __next__(self):
+ return 1
+ def __getattr__(self, attr):
+ 1/0
+
+ def g():
+ yield from Broken()
+
+ with self.assertRaises(ZeroDivisionError):
+ gi = g()
+ self.assertEqual(next(gi), 1)
+ gi.send(1)
+
+ with self.assertRaises(ZeroDivisionError):
+ gi = g()
+ self.assertEqual(next(gi), 1)
+ gi.throw(AttributeError)
+
+ with captured_stderr() as output:
+ gi = g()
+ self.assertEqual(next(gi), 1)
+ gi.close()
+ self.assertIn('ZeroDivisionError', output.getvalue())
+
+ def test_exception_in_initial_next_call(self):
+ """
+ Test exception in initial next() call
+ """
+ trace = []
+ def g1():
+ trace.append("g1 about to yield from g2")
+ yield from g2()
+ trace.append("g1 should not be here")
+ def g2():
+ yield 1/0
+ def run():
+ gi = g1()
+ next(gi)
+ self.assertRaises(ZeroDivisionError,run)
+ self.assertEqual(trace,[
+ "g1 about to yield from g2"
+ ])
+
+ def test_attempted_yield_from_loop(self):
+ """
+ Test attempted yield-from loop
+ """
+ trace = []
+ def g1():
+ trace.append("g1: starting")
+ yield "y1"
+ trace.append("g1: about to yield from g2")
+ yield from g2()
+ trace.append("g1 should not be here")
+
+ def g2():
+ trace.append("g2: starting")
+ yield "y2"
+ trace.append("g2: about to yield from g1")
+ yield from gi
+ trace.append("g2 should not be here")
+ try:
+ gi = g1()
+ for y in gi:
+ trace.append("Yielded: %s" % (y,))
+ except ValueError as e:
+ self.assertEqual(e.args[0],"generator already executing")
+ else:
+ self.fail("subgenerator didn't raise ValueError")
+ self.assertEqual(trace,[
+ "g1: starting",
+ "Yielded: y1",
+ "g1: about to yield from g2",
+ "g2: starting",
+ "Yielded: y2",
+ "g2: about to yield from g1",
+ ])
+
+ def test_returning_value_from_delegated_throw(self):
+ """
+ Test returning value from delegated 'throw'
+ """
+ trace = []
+ def g1():
+ try:
+ trace.append("Starting g1")
+ yield "g1 ham"
+ yield from g2()
+ yield "g1 eggs"
+ finally:
+ trace.append("Finishing g1")
+ def g2():
+ try:
+ trace.append("Starting g2")
+ yield "g2 spam"
+ yield "g2 more spam"
+ except LunchError:
+ trace.append("Caught LunchError in g2")
+ yield "g2 lunch saved"
+ yield "g2 yet more spam"
+ class LunchError(Exception):
+ pass
+ g = g1()
+ for i in range(2):
+ x = next(g)
+ trace.append("Yielded %s" % (x,))
+ e = LunchError("tomato ejected")
+ g.throw(e)
+ for x in g:
+ trace.append("Yielded %s" % (x,))
+ self.assertEqual(trace,[
+ "Starting g1",
+ "Yielded g1 ham",
+ "Starting g2",
+ "Yielded g2 spam",
+ "Caught LunchError in g2",
+ "Yielded g2 yet more spam",
+ "Yielded g1 eggs",
+ "Finishing g1",
+ ])
+
+ def test_next_and_return_with_value(self):
+ """
+ Test next and return with value
+ """
+ trace = []
+ def f(r):
+ gi = g(r)
+ next(gi)
+ try:
+ trace.append("f resuming g")
+ next(gi)
+ trace.append("f SHOULD NOT BE HERE")
+ except StopIteration as e:
+ trace.append("f caught %s" % (repr(e),))
+ def g(r):
+ trace.append("g starting")
+ yield
+ trace.append("g returning %s" % (r,))
+ return r
+ f(None)
+ f(42)
+ self.assertEqual(trace,[
+ "g starting",
+ "f resuming g",
+ "g returning None",
+ "f caught StopIteration()",
+ "g starting",
+ "f resuming g",
+ "g returning 42",
+ "f caught StopIteration(42,)",
+ ])
+
+ def test_send_and_return_with_value(self):
+ """
+ Test send and return with value
+ """
+ trace = []
+ def f(r):
+ gi = g(r)
+ next(gi)
+ try:
+ trace.append("f sending spam to g")
+ gi.send("spam")
+ trace.append("f SHOULD NOT BE HERE")
+ except StopIteration as e:
+ trace.append("f caught %r" % (e,))
+ def g(r):
+ trace.append("g starting")
+ x = yield
+ trace.append("g received %s" % (x,))
+ trace.append("g returning %s" % (r,))
+ return r
+ f(None)
+ f(42)
+ self.assertEqual(trace,[
+ "g starting",
+ "f sending spam to g",
+ "g received spam",
+ "g returning None",
+ "f caught StopIteration()",
+ "g starting",
+ "f sending spam to g",
+ "g received spam",
+ "g returning 42",
+ "f caught StopIteration(42,)",
+ ])
+
+ def test_catching_exception_from_subgen_and_returning(self):
+ """
+ Test catching an exception thrown into a
+ subgenerator and returning a value
+ """
+ trace = []
+ def inner():
+ try:
+ yield 1
+ except ValueError:
+ trace.append("inner caught ValueError")
+ return 2
+
+ def outer():
+ v = yield from inner()
+ trace.append("inner returned %r to outer" % v)
+ yield v
+ g = outer()
+ trace.append(next(g))
+ trace.append(g.throw(ValueError))
+ self.assertEqual(trace,[
+ 1,
+ "inner caught ValueError",
+ "inner returned 2 to outer",
+ 2,
+ ])
+
+ def test_throwing_GeneratorExit_into_subgen_that_returns(self):
+ """
+ Test throwing GeneratorExit into a subgenerator that
+ catches it and returns normally.
+ """
+ trace = []
+ def f():
+ try:
+ trace.append("Enter f")
+ yield
+ trace.append("Exit f")
+ except GeneratorExit:
+ return
+ def g():
+ trace.append("Enter g")
+ yield from f()
+ trace.append("Exit g")
+ try:
+ gi = g()
+ next(gi)
+ gi.throw(GeneratorExit)
+ except GeneratorExit:
+ pass
+ else:
+ self.fail("subgenerator failed to raise GeneratorExit")
+ self.assertEqual(trace,[
+ "Enter g",
+ "Enter f",
+ ])
+
+ def test_throwing_GeneratorExit_into_subgenerator_that_yields(self):
+ """
+ Test throwing GeneratorExit into a subgenerator that
+ catches it and yields.
+ """
+ trace = []
+ def f():
+ try:
+ trace.append("Enter f")
+ yield
+ trace.append("Exit f")
+ except GeneratorExit:
+ yield
+ def g():
+ trace.append("Enter g")
+ yield from f()
+ trace.append("Exit g")
+ try:
+ gi = g()
+ next(gi)
+ gi.throw(GeneratorExit)
+ except RuntimeError as e:
+ self.assertEqual(e.args[0], "generator ignored GeneratorExit")
+ else:
+ self.fail("subgenerator failed to raise GeneratorExit")
+ self.assertEqual(trace,[
+ "Enter g",
+ "Enter f",
+ ])
+
+ def test_throwing_GeneratorExit_into_subgen_that_raises(self):
+ """
+ Test throwing GeneratorExit into a subgenerator that
+ catches it and raises a different exception.
+ """
+ trace = []
+ def f():
+ try:
+ trace.append("Enter f")
+ yield
+ trace.append("Exit f")
+ except GeneratorExit:
+ raise ValueError("Vorpal bunny encountered")
+ def g():
+ trace.append("Enter g")
+ yield from f()
+ trace.append("Exit g")
+ try:
+ gi = g()
+ next(gi)
+ gi.throw(GeneratorExit)
+ except ValueError as e:
+ self.assertEqual(e.args[0], "Vorpal bunny encountered")
+ self.assertIsInstance(e.__context__, GeneratorExit)
+ else:
+ self.fail("subgenerator failed to raise ValueError")
+ self.assertEqual(trace,[
+ "Enter g",
+ "Enter f",
+ ])
+
+ def test_yield_from_empty(self):
+ def g():
+ yield from ()
+ self.assertRaises(StopIteration, next, g())
+
+ def test_delegating_generators_claim_to_be_running(self):
+ # Check with basic iteration
+ def one():
+ yield 0
+ yield from two()
+ yield 3
+ def two():
+ yield 1
+ try:
+ yield from g1
+ except ValueError:
+ pass
+ yield 2
+ g1 = one()
+ self.assertEqual(list(g1), [0, 1, 2, 3])
+ # Check with send
+ g1 = one()
+ res = [next(g1)]
+ try:
+ while True:
+ res.append(g1.send(42))
+ except StopIteration:
+ pass
+ self.assertEqual(res, [0, 1, 2, 3])
+ # Check with throw
+ class MyErr(Exception):
+ pass
+ def one():
+ try:
+ yield 0
+ except MyErr:
+ pass
+ yield from two()
+ try:
+ yield 3
+ except MyErr:
+ pass
+ def two():
+ try:
+ yield 1
+ except MyErr:
+ pass
+ try:
+ yield from g1
+ except ValueError:
+ pass
+ try:
+ yield 2
+ except MyErr:
+ pass
+ g1 = one()
+ res = [next(g1)]
+ try:
+ while True:
+ res.append(g1.throw(MyErr))
+ except StopIteration:
+ pass
+ # Check with close
+ class MyIt(object):
+ def __iter__(self):
+ return self
+ def __next__(self):
+ return 42
+ def close(self_):
+ self.assertTrue(g1.gi_running)
+ self.assertRaises(ValueError, next, g1)
+ def one():
+ yield from MyIt()
+ g1 = one()
+ next(g1)
+ g1.close()
+
+ def test_delegator_is_visible_to_debugger(self):
+ def call_stack():
+ return [f[3] for f in inspect.stack()]
+
+ def gen():
+ yield call_stack()
+ yield call_stack()
+ yield call_stack()
+
+ def spam(g):
+ yield from g
+
+ def eggs(g):
+ yield from g
+
+ for stack in spam(gen()):
+ self.assertTrue('spam' in stack)
+
+ for stack in spam(eggs(gen())):
+ self.assertTrue('spam' in stack and 'eggs' in stack)
+
+ def test_custom_iterator_return(self):
+ # See issue #15568
+ class MyIter:
+ def __iter__(self):
+ return self
+ def __next__(self):
+ raise StopIteration(42)
+ def gen():
+ nonlocal ret
+ ret = yield from MyIter()
+ ret = None
+ list(gen())
+ self.assertEqual(ret, 42)
+
+
+def test_main():
+ from test import support
+ test_classes = [TestPEP380Operation]
+ support.run_unittest(*test_classes)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index 9da2cae..f52d4bd 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -1,5 +1,6 @@
import pickle
import io
+import collections
from test import support
@@ -7,6 +8,7 @@ from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests
+from test.pickletester import AbstractDispatchTableTests
from test.pickletester import BigmemPickleTests
try:
@@ -80,6 +82,18 @@ class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
unpickler_class = pickle._Unpickler
+class PyDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle._Pickler
+ def get_dispatch_table(self):
+ return pickle.dispatch_table.copy()
+
+
+class PyChainDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle._Pickler
+ def get_dispatch_table(self):
+ return collections.ChainMap({}, pickle.dispatch_table)
+
+
if has_c_implementation:
class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler
@@ -101,14 +115,26 @@ if has_c_implementation:
pickler_class = _pickle.Pickler
unpickler_class = _pickle.Unpickler
+ class CDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle.Pickler
+ def get_dispatch_table(self):
+ return pickle.dispatch_table.copy()
+
+ class CChainDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle.Pickler
+ def get_dispatch_table(self):
+ return collections.ChainMap({}, pickle.dispatch_table)
+
def test_main():
- tests = [PickleTests, PyPicklerTests, PyPersPicklerTests]
+ tests = [PickleTests, PyPicklerTests, PyPersPicklerTests,
+ PyDispatchTableTests, PyChainDispatchTableTests]
if has_c_implementation:
tests.extend([CPicklerTests, CPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests,
+ CDispatchTableTests, CChainDispatchTableTests,
InMemoryPickleTests])
support.run_unittest(*tests)
support.run_doctest(pickle)
diff --git a/Lib/test/test_pipes.py b/Lib/test/test_pipes.py
index d5b886f..6a7b45f 100644
--- a/Lib/test/test_pipes.py
+++ b/Lib/test/test_pipes.py
@@ -79,21 +79,6 @@ class SimplePipeTests(unittest.TestCase):
with open(TESTFN) as f:
self.assertEqual(f.read(), d)
- def testQuoting(self):
- safeunquoted = string.ascii_letters + string.digits + '@%_-+=:,./'
- unicode_sample = '\xe9\xe0\xdf' # e + acute accent, a + grave, sharp s
- unsafe = '"`$\\!' + unicode_sample
-
- self.assertEqual(pipes.quote(''), "''")
- self.assertEqual(pipes.quote(safeunquoted), safeunquoted)
- self.assertEqual(pipes.quote('test file name'), "'test file name'")
- for u in unsafe:
- self.assertEqual(pipes.quote('test%sname' % u),
- "'test%sname'" % u)
- for u in unsafe:
- self.assertEqual(pipes.quote("test%s'name'" % u),
- "'test%s'\"'\"'name'\"'\"''" % u)
-
def testRepr(self):
t = pipes.Template()
self.assertEqual(repr(t), "<Template instance, steps=[]>")
diff --git a/Lib/test/test_pkg.py b/Lib/test/test_pkg.py
index a4ddb15..355efe7 100644
--- a/Lib/test/test_pkg.py
+++ b/Lib/test/test_pkg.py
@@ -23,6 +23,8 @@ def cleanout(root):
def fixdir(lst):
if "__builtins__" in lst:
lst.remove("__builtins__")
+ if "__initializing__" in lst:
+ lst.remove("__initializing__")
return lst
@@ -196,14 +198,15 @@ class TestPkg(unittest.TestCase):
import t5
self.assertEqual(fixdir(dir(t5)),
- ['__cached__', '__doc__', '__file__', '__name__',
- '__package__', '__path__', 'foo', 'string', 't5'])
+ ['__cached__', '__doc__', '__file__', '__loader__',
+ '__name__', '__package__', '__path__', 'foo',
+ 'string', 't5'])
self.assertEqual(fixdir(dir(t5.foo)),
- ['__cached__', '__doc__', '__file__', '__name__',
- '__package__', 'string'])
+ ['__cached__', '__doc__', '__file__', '__loader__',
+ '__name__', '__package__', 'string'])
self.assertEqual(fixdir(dir(t5.string)),
- ['__cached__', '__doc__', '__file__', '__name__',
- '__package__', 'spam'])
+ ['__cached__', '__doc__', '__file__', '__loader__',
+ '__name__', '__package__', 'spam'])
def test_6(self):
hier = [
@@ -219,14 +222,14 @@ class TestPkg(unittest.TestCase):
import t6
self.assertEqual(fixdir(dir(t6)),
['__all__', '__cached__', '__doc__', '__file__',
- '__name__', '__package__', '__path__'])
+ '__loader__', '__name__', '__package__', '__path__'])
s = """
import t6
from t6 import *
self.assertEqual(fixdir(dir(t6)),
['__all__', '__cached__', '__doc__', '__file__',
- '__name__', '__package__', '__path__',
- 'eggs', 'ham', 'spam'])
+ '__loader__', '__name__', '__package__',
+ '__path__', 'eggs', 'ham', 'spam'])
self.assertEqual(dir(), ['eggs', 'ham', 'self', 'spam', 't6'])
"""
self.run_code(s)
@@ -252,19 +255,19 @@ class TestPkg(unittest.TestCase):
t7, sub, subsub = None, None, None
import t7 as tas
self.assertEqual(fixdir(dir(tas)),
- ['__cached__', '__doc__', '__file__', '__name__',
- '__package__', '__path__'])
+ ['__cached__', '__doc__', '__file__', '__loader__',
+ '__name__', '__package__', '__path__'])
self.assertFalse(t7)
from t7 import sub as subpar
self.assertEqual(fixdir(dir(subpar)),
- ['__cached__', '__doc__', '__file__', '__name__',
- '__package__', '__path__'])
+ ['__cached__', '__doc__', '__file__', '__loader__',
+ '__name__', '__package__', '__path__'])
self.assertFalse(t7)
self.assertFalse(sub)
from t7.sub import subsub as subsubsub
self.assertEqual(fixdir(dir(subsubsub)),
- ['__cached__', '__doc__', '__file__', '__name__',
- '__package__', '__path__', 'spam'])
+ ['__cached__', '__doc__', '__file__', '__loader__',
+ '__name__', '__package__', '__path__', 'spam'])
self.assertFalse(t7)
self.assertFalse(sub)
self.assertFalse(subsub)
diff --git a/Lib/test/test_pkgimport.py b/Lib/test/test_pkgimport.py
index c37e936..a8426b5 100644
--- a/Lib/test/test_pkgimport.py
+++ b/Lib/test/test_pkgimport.py
@@ -7,7 +7,7 @@ import tempfile
import unittest
from imp import cache_from_source
-from test.support import run_unittest
+from test.support import run_unittest, create_empty_file
class TestImport(unittest.TestCase):
@@ -29,7 +29,7 @@ class TestImport(unittest.TestCase):
self.package_dir = os.path.join(self.test_dir,
self.package_name)
os.mkdir(self.package_dir)
- open(os.path.join(self.package_dir, '__init__.py'), 'w').close()
+ create_empty_file(os.path.join(self.package_dir, '__init__.py'))
self.module_path = os.path.join(self.package_dir, 'foo.py')
def tearDown(self):
diff --git a/Lib/test/test_pkgutil.py b/Lib/test/test_pkgutil.py
index 3b72f06..88e2d33 100644
--- a/Lib/test/test_pkgutil.py
+++ b/Lib/test/test_pkgutil.py
@@ -1,4 +1,4 @@
-from test.support import run_unittest, unload
+from test.support import run_unittest, unload, check_warnings
import unittest
import sys
import imp
@@ -9,7 +9,11 @@ import tempfile
import shutil
import zipfile
-
+# Note: pkgutil.walk_packages is currently tested in test_runpy. This is
+# a hack to get a major issue resolved for 3.3b2. Longer term, it should
+# be moved back here, perhaps by factoring out the helper code for
+# creating interesting package layouts to a separate module.
+# Issue #15348 declares this is indeed a dodgy hack ;)
class PkgutilTests(unittest.TestCase):
@@ -138,10 +142,11 @@ class PkgutilPEP302Tests(unittest.TestCase):
del sys.modules['foo']
+# These tests, especially the setup and cleanup, are hideous. They
+# need to be cleaned up once issue 14715 is addressed.
class ExtendPathTests(unittest.TestCase):
def create_init(self, pkgname):
dirname = tempfile.mkdtemp()
- self.addCleanup(shutil.rmtree, dirname)
sys.path.insert(0, dirname)
pkgdir = os.path.join(dirname, pkgname)
@@ -156,22 +161,40 @@ class ExtendPathTests(unittest.TestCase):
with open(module_name, 'w') as fl:
print('value={}'.format(value), file=fl)
- def setUp(self):
- # Create 2 directories on sys.path
- self.pkgname = 'foo'
- self.dirname_0 = self.create_init(self.pkgname)
- self.dirname_1 = self.create_init(self.pkgname)
+ def test_simple(self):
+ pkgname = 'foo'
+ dirname_0 = self.create_init(pkgname)
+ dirname_1 = self.create_init(pkgname)
+ self.create_submodule(dirname_0, pkgname, 'bar', 0)
+ self.create_submodule(dirname_1, pkgname, 'baz', 1)
+ import foo.bar
+ import foo.baz
+ # Ensure we read the expected values
+ self.assertEqual(foo.bar.value, 0)
+ self.assertEqual(foo.baz.value, 1)
- def tearDown(self):
+ # Ensure the path is set up correctly
+ self.assertEqual(sorted(foo.__path__),
+ sorted([os.path.join(dirname_0, pkgname),
+ os.path.join(dirname_1, pkgname)]))
+
+ # Cleanup
+ shutil.rmtree(dirname_0)
+ shutil.rmtree(dirname_1)
del sys.path[0]
del sys.path[0]
del sys.modules['foo']
del sys.modules['foo.bar']
del sys.modules['foo.baz']
- def test_simple(self):
- self.create_submodule(self.dirname_0, self.pkgname, 'bar', 0)
- self.create_submodule(self.dirname_1, self.pkgname, 'baz', 1)
+ def test_mixed_namespace(self):
+ pkgname = 'foo'
+ dirname_0 = self.create_init(pkgname)
+ dirname_1 = self.create_init(pkgname)
+ self.create_submodule(dirname_0, pkgname, 'bar', 0)
+ # Turn this into a PEP 420 namespace package
+ os.unlink(os.path.join(dirname_0, pkgname, '__init__.py'))
+ self.create_submodule(dirname_1, pkgname, 'baz', 1)
import foo.bar
import foo.baz
# Ensure we read the expected values
@@ -180,8 +203,17 @@ class ExtendPathTests(unittest.TestCase):
# Ensure the path is set up correctly
self.assertEqual(sorted(foo.__path__),
- sorted([os.path.join(self.dirname_0, self.pkgname),
- os.path.join(self.dirname_1, self.pkgname)]))
+ sorted([os.path.join(dirname_0, pkgname),
+ os.path.join(dirname_1, pkgname)]))
+
+ # Cleanup
+ shutil.rmtree(dirname_0)
+ shutil.rmtree(dirname_1)
+ del sys.path[0]
+ del sys.path[0]
+ del sys.modules['foo']
+ del sys.modules['foo.bar']
+ del sys.modules['foo.baz']
# XXX: test .pkg files
@@ -227,12 +259,52 @@ class NestedNamespacePackageTest(unittest.TestCase):
self.assertEqual(d, 2)
+class ImportlibMigrationTests(unittest.TestCase):
+ # With full PEP 302 support in the standard import machinery, the
+ # PEP 302 emulation in this module is in the process of being
+ # deprecated in favour of importlib proper
+
+ def check_deprecated(self):
+ return check_warnings(
+ ("This emulation is deprecated, use 'importlib' instead",
+ DeprecationWarning))
+
+ def test_importer_deprecated(self):
+ with self.check_deprecated():
+ x = pkgutil.ImpImporter("")
+
+ def test_loader_deprecated(self):
+ with self.check_deprecated():
+ x = pkgutil.ImpLoader("", "", "", "")
+
+ def test_get_loader_avoids_emulation(self):
+ with check_warnings() as w:
+ self.assertIsNotNone(pkgutil.get_loader("sys"))
+ self.assertIsNotNone(pkgutil.get_loader("os"))
+ self.assertIsNotNone(pkgutil.get_loader("test.support"))
+ self.assertEqual(len(w.warnings), 0)
+
+ def test_get_importer_avoids_emulation(self):
+ # We use an illegal path so *none* of the path hooks should fire
+ with check_warnings() as w:
+ self.assertIsNone(pkgutil.get_importer("*??"))
+ self.assertEqual(len(w.warnings), 0)
+
+ def test_iter_importers_avoids_emulation(self):
+ with check_warnings() as w:
+ for importer in pkgutil.iter_importers(): pass
+ self.assertEqual(len(w.warnings), 0)
+
+
def test_main():
run_unittest(PkgutilTests, PkgutilPEP302Tests, ExtendPathTests,
- NestedNamespacePackageTest)
+ NestedNamespacePackageTest, ImportlibMigrationTests)
# this is necessary if test is run repeated (like when finding leaks)
import zipimport
+ import importlib
zipimport._zip_directory_cache.clear()
+ importlib.invalidate_caches()
+
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py
index 8751aa8..6abf443 100644
--- a/Lib/test/test_platform.py
+++ b/Lib/test/test_platform.py
@@ -1,8 +1,9 @@
-import sys
import os
-import unittest
import platform
import subprocess
+import sys
+import unittest
+import warnings
from test import support
@@ -56,13 +57,11 @@ class PlatformTest(unittest.TestCase):
def setUp(self):
self.save_version = sys.version
- self.save_subversion = sys.subversion
self.save_mercurial = sys._mercurial
self.save_platform = sys.platform
def tearDown(self):
sys.version = self.save_version
- sys.subversion = self.save_subversion
sys._mercurial = self.save_mercurial
sys.platform = self.save_platform
@@ -77,7 +76,7 @@ class PlatformTest(unittest.TestCase):
('IronPython', '1.0.0', '', '', '', '', '.NET 2.0.50727.42')),
):
# branch and revision are not "parsed", but fetched
- # from sys.subversion. Ignore them
+ # from sys._mercurial. Ignore them
(name, version, branch, revision, buildno, builddate, compiler) \
= platform._sys_version(input)
self.assertEqual(
@@ -113,8 +112,6 @@ class PlatformTest(unittest.TestCase):
if subversion is None:
if hasattr(sys, "_mercurial"):
del sys._mercurial
- if hasattr(sys, "subversion"):
- del sys.subversion
else:
sys._mercurial = subversion
if sys_platform is not None:
@@ -136,6 +133,12 @@ class PlatformTest(unittest.TestCase):
def test_uname(self):
res = platform.uname()
self.assertTrue(any(res))
+ self.assertEqual(res[0], res.system)
+ self.assertEqual(res[1], res.node)
+ self.assertEqual(res[2], res.release)
+ self.assertEqual(res[3], res.version)
+ self.assertEqual(res[4], res.machine)
+ self.assertEqual(res[5], res.processor)
@unittest.skipUnless(sys.platform.startswith('win'), "windows only test")
def test_uname_win32_ARCHITEW6432(self):
@@ -169,7 +172,7 @@ class PlatformTest(unittest.TestCase):
def test_mac_ver(self):
res = platform.mac_ver()
- if platform.uname()[0] == 'Darwin':
+ if platform.uname().system == 'Darwin':
# We're on a MacOSX system, check that
# the right version information is returned
fd = os.popen('sw_vers', 'r')
@@ -247,6 +250,38 @@ class PlatformTest(unittest.TestCase):
):
self.assertEqual(platform._parse_release_file(input), output)
+ def test_popen(self):
+ mswindows = (sys.platform == "win32")
+
+ if mswindows:
+ command = '"{}" -c "print(\'Hello\')"'.format(sys.executable)
+ else:
+ command = "'{}' -c 'print(\"Hello\")'".format(sys.executable)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ with platform.popen(command) as stdout:
+ hello = stdout.read().strip()
+ stdout.close()
+ self.assertEqual(hello, "Hello")
+
+ data = 'plop'
+ if mswindows:
+ command = '"{}" -c "import sys; data=sys.stdin.read(); exit(len(data))"'
+ else:
+ command = "'{}' -c 'import sys; data=sys.stdin.read(); exit(len(data))'"
+ command = command.format(sys.executable)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ with platform.popen(command, 'w') as stdin:
+ stdout = stdin.write(data)
+ ret = stdin.close()
+ self.assertIsNotNone(ret)
+ if os.name == 'nt':
+ returncode = ret
+ else:
+ returncode = ret >> 8
+ self.assertEqual(returncode, len(data))
+
def test_main():
support.run_unittest(
diff --git a/Lib/test/test_plistlib.py b/Lib/test/test_plistlib.py
index 5f980d0..a9e343e 100644
--- a/Lib/test/test_plistlib.py
+++ b/Lib/test/test_plistlib.py
@@ -55,6 +55,10 @@ TESTDATA = b"""<?xml version="1.0" encoding="UTF-8"?>
</array>
<key>aString</key>
<string>Doodah</string>
+ <key>anEmptyDict</key>
+ <dict/>
+ <key>anEmptyList</key>
+ <array/>
<key>anInt</key>
<integer>728</integer>
<key>nestedData</key>
@@ -112,6 +116,8 @@ class TestPlistlib(unittest.TestCase):
someMoreData = plistlib.Data(b"<lots of binary gunk>\0\1\2\3" * 10),
nestedData = [plistlib.Data(b"<lots of binary gunk>\0\1\2\3" * 10)],
aDate = datetime.datetime(2004, 10, 26, 10, 33, 33),
+ anEmptyDict = dict(),
+ anEmptyList = list()
)
pl['\xc5benraa'] = "That was a unicode key."
return pl
diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py
index e3901b8..c0929a0 100644
--- a/Lib/test/test_poplib.py
+++ b/Lib/test/test_poplib.py
@@ -108,6 +108,10 @@ class DummyPOP3Handler(asynchat.async_chat):
def cmd_apop(self, arg):
self.push('+OK done nothing.')
+ def cmd_quit(self, arg):
+ self.push('+OK closing.')
+ self.close_when_done()
+
class DummyPOP3Server(asyncore.dispatcher, threading.Thread):
@@ -165,10 +169,10 @@ class TestPOP3Class(TestCase):
def setUp(self):
self.server = DummyPOP3Server((HOST, PORT))
self.server.start()
- self.client = poplib.POP3(self.server.host, self.server.port)
+ self.client = poplib.POP3(self.server.host, self.server.port, timeout=3)
def tearDown(self):
- self.client.quit()
+ self.client.close()
self.server.stop()
def test_getwelcome(self):
@@ -228,6 +232,12 @@ class TestPOP3Class(TestCase):
self.client.uidl()
self.client.uidl('foo')
+ def test_quit(self):
+ resp = self.client.quit()
+ self.assertTrue(resp)
+ self.assertIsNone(self.client.sock)
+ self.assertIsNone(self.client.file)
+
SUPPORTS_SSL = False
if hasattr(poplib, 'POP3_SSL'):
@@ -274,6 +284,7 @@ if hasattr(poplib, 'POP3_SSL'):
else:
DummyPOP3Handler.handle_read(self)
+
class TestPOP3_SSLClass(TestPOP3Class):
# repeat previous tests by using poplib.POP3_SSL
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index e06ffa5..cfc188a 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -9,6 +9,7 @@ import errno
import sys
import time
import os
+import fcntl
import platform
import pwd
import shutil
@@ -16,6 +17,7 @@ import stat
import tempfile
import unittest
import warnings
+import _testcapi
_DUMMY_SYMLINK = os.path.join(tempfile.gettempdir(),
support.TESTFN + '-dummy-symlink')
@@ -43,7 +45,7 @@ class PosixTester(unittest.TestCase):
NO_ARG_FUNCTIONS = [ "ctermid", "getcwd", "getcwdb", "uname",
"times", "getloadavg",
"getegid", "geteuid", "getgid", "getgroups",
- "getpid", "getpgrp", "getppid", "getuid",
+ "getpid", "getpgrp", "getppid", "getuid", "sync",
]
for name in NO_ARG_FUNCTIONS:
@@ -128,6 +130,7 @@ class PosixTester(unittest.TestCase):
fp = open(support.TESTFN)
try:
self.assertTrue(posix.fstatvfs(fp.fileno()))
+ self.assertTrue(posix.statvfs(fp.fileno()))
finally:
fp.close()
@@ -142,6 +145,151 @@ class PosixTester(unittest.TestCase):
finally:
fp.close()
+ @unittest.skipUnless(hasattr(posix, 'truncate'), "test needs posix.truncate()")
+ def test_truncate(self):
+ with open(support.TESTFN, 'w') as fp:
+ fp.write('test')
+ fp.flush()
+ posix.truncate(support.TESTFN, 0)
+
+ @unittest.skipUnless(getattr(os, 'execve', None) in os.supports_fd, "test needs execve() to support the fd parameter")
+ @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
+ @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()")
+ def test_fexecve(self):
+ fp = os.open(sys.executable, os.O_RDONLY)
+ try:
+ pid = os.fork()
+ if pid == 0:
+ os.chdir(os.path.split(sys.executable)[0])
+ posix.execve(fp, [sys.executable, '-c', 'pass'], os.environ)
+ else:
+ self.assertEqual(os.waitpid(pid, 0), (pid, 0))
+ finally:
+ os.close(fp)
+
+ @unittest.skipUnless(hasattr(posix, 'waitid'), "test needs posix.waitid()")
+ @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
+ def test_waitid(self):
+ pid = os.fork()
+ if pid == 0:
+ os.chdir(os.path.split(sys.executable)[0])
+ posix.execve(sys.executable, [sys.executable, '-c', 'pass'], os.environ)
+ else:
+ res = posix.waitid(posix.P_PID, pid, posix.WEXITED)
+ self.assertEqual(pid, res.si_pid)
+
+ @unittest.skipUnless(hasattr(posix, 'lockf'), "test needs posix.lockf()")
+ def test_lockf(self):
+ fd = os.open(support.TESTFN, os.O_WRONLY | os.O_CREAT)
+ try:
+ os.write(fd, b'test')
+ os.lseek(fd, 0, os.SEEK_SET)
+ posix.lockf(fd, posix.F_LOCK, 4)
+ # section is locked
+ posix.lockf(fd, posix.F_ULOCK, 4)
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(hasattr(posix, 'pread'), "test needs posix.pread()")
+ def test_pread(self):
+ fd = os.open(support.TESTFN, os.O_RDWR | os.O_CREAT)
+ try:
+ os.write(fd, b'test')
+ os.lseek(fd, 0, os.SEEK_SET)
+ self.assertEqual(b'es', posix.pread(fd, 2, 1))
+ # the first pread() shouldn't disturb the file offset
+ self.assertEqual(b'te', posix.read(fd, 2))
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(hasattr(posix, 'pwrite'), "test needs posix.pwrite()")
+ def test_pwrite(self):
+ fd = os.open(support.TESTFN, os.O_RDWR | os.O_CREAT)
+ try:
+ os.write(fd, b'test')
+ os.lseek(fd, 0, os.SEEK_SET)
+ posix.pwrite(fd, b'xx', 1)
+ self.assertEqual(b'txxt', posix.read(fd, 4))
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(hasattr(posix, 'posix_fallocate'),
+ "test needs posix.posix_fallocate()")
+ def test_posix_fallocate(self):
+ fd = os.open(support.TESTFN, os.O_WRONLY | os.O_CREAT)
+ try:
+ posix.posix_fallocate(fd, 0, 10)
+ except OSError as inst:
+ # issue10812, ZFS doesn't appear to support posix_fallocate,
+ # so skip Solaris-based since they are likely to have ZFS.
+ if inst.errno != errno.EINVAL or not sys.platform.startswith("sunos"):
+ raise
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(hasattr(posix, 'posix_fadvise'),
+ "test needs posix.posix_fadvise()")
+ def test_posix_fadvise(self):
+ fd = os.open(support.TESTFN, os.O_RDONLY)
+ try:
+ posix.posix_fadvise(fd, 0, 0, posix.POSIX_FADV_WILLNEED)
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(os.utime in os.supports_fd, "test needs fd support in os.utime")
+ def test_utime_with_fd(self):
+ now = time.time()
+ fd = os.open(support.TESTFN, os.O_RDONLY)
+ try:
+ posix.utime(fd)
+ posix.utime(fd, None)
+ self.assertRaises(TypeError, posix.utime, fd, (None, None))
+ self.assertRaises(TypeError, posix.utime, fd, (now, None))
+ self.assertRaises(TypeError, posix.utime, fd, (None, now))
+ posix.utime(fd, (int(now), int(now)))
+ posix.utime(fd, (now, now))
+ self.assertRaises(ValueError, posix.utime, fd, (now, now), ns=(now, now))
+ self.assertRaises(ValueError, posix.utime, fd, (now, 0), ns=(None, None))
+ self.assertRaises(ValueError, posix.utime, fd, (None, None), ns=(now, 0))
+ posix.utime(fd, (int(now), int((now - int(now)) * 1e9)))
+ posix.utime(fd, ns=(int(now), int((now - int(now)) * 1e9)))
+
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(os.utime in os.supports_follow_symlinks, "test needs follow_symlinks support in os.utime")
+ def test_utime_nofollow_symlinks(self):
+ now = time.time()
+ posix.utime(support.TESTFN, None, follow_symlinks=False)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (None, None), follow_symlinks=False)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (now, None), follow_symlinks=False)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (None, now), follow_symlinks=False)
+ posix.utime(support.TESTFN, (int(now), int(now)), follow_symlinks=False)
+ posix.utime(support.TESTFN, (now, now), follow_symlinks=False)
+ posix.utime(support.TESTFN, follow_symlinks=False)
+
+ @unittest.skipUnless(hasattr(posix, 'writev'), "test needs posix.writev()")
+ def test_writev(self):
+ fd = os.open(support.TESTFN, os.O_RDWR | os.O_CREAT)
+ try:
+ os.writev(fd, (b'test1', b'tt2', b't3'))
+ os.lseek(fd, 0, os.SEEK_SET)
+ self.assertEqual(b'test1tt2t3', posix.read(fd, 10))
+ finally:
+ os.close(fd)
+
+ @unittest.skipUnless(hasattr(posix, 'readv'), "test needs posix.readv()")
+ def test_readv(self):
+ fd = os.open(support.TESTFN, os.O_RDWR | os.O_CREAT)
+ try:
+ os.write(fd, b'test1tt2t3')
+ os.lseek(fd, 0, os.SEEK_SET)
+ buf = [bytearray(i) for i in [5, 3, 2]]
+ self.assertEqual(posix.readv(fd, buf), 10)
+ self.assertEqual([b'test1', b'tt2', b't3'], [bytes(i) for i in buf])
+ finally:
+ os.close(fd)
+
def test_dup(self):
if hasattr(posix, 'dup'):
fp = open(support.TESTFN)
@@ -167,6 +315,13 @@ class PosixTester(unittest.TestCase):
fp1.close()
fp2.close()
+ @unittest.skipUnless(hasattr(os, 'O_CLOEXEC'), "needs os.O_CLOEXEC")
+ @support.requires_linux_version(2, 6, 23)
+ def test_oscloexec(self):
+ fd = os.open(support.TESTFN, os.O_RDONLY|os.O_CLOEXEC)
+ self.addCleanup(os.close, fd)
+ self.assertTrue(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
+
def test_osexlock(self):
if hasattr(posix, "O_EXLOCK"):
fd = os.open(support.TESTFN,
@@ -203,12 +358,29 @@ class PosixTester(unittest.TestCase):
fp = open(support.TESTFN)
try:
self.assertTrue(posix.fstat(fp.fileno()))
+ self.assertTrue(posix.stat(fp.fileno()))
+
+ self.assertRaisesRegex(TypeError,
+ 'should be string, bytes or integer, not',
+ posix.stat, float(fp.fileno()))
finally:
fp.close()
def test_stat(self):
if hasattr(posix, 'stat'):
self.assertTrue(posix.stat(support.TESTFN))
+ self.assertTrue(posix.stat(os.fsencode(support.TESTFN)))
+ self.assertTrue(posix.stat(bytearray(os.fsencode(support.TESTFN))))
+
+ self.assertRaisesRegex(TypeError,
+ 'can\'t specify None for path argument',
+ posix.stat, None)
+ self.assertRaisesRegex(TypeError,
+ 'should be string, bytes or integer, not',
+ posix.stat, list(support.TESTFN))
+ self.assertRaisesRegex(TypeError,
+ 'should be string, bytes or integer, not',
+ posix.stat, list(os.fsencode(support.TESTFN)))
@unittest.skipUnless(hasattr(posix, 'mkfifo'), "don't have mkfifo()")
def test_mkfifo(self):
@@ -297,7 +469,7 @@ class PosixTester(unittest.TestCase):
self.assertRaises(OSError, posix.chown, support.TESTFN, -1, -1)
# re-create the file
- open(support.TESTFN, 'w').close()
+ support.create_empty_file(support.TESTFN)
self._test_all_chown_common(posix.chown, support.TESTFN,
getattr(posix, 'stat', None))
@@ -328,13 +500,32 @@ class PosixTester(unittest.TestCase):
self.assertRaises(OSError, posix.chdir, support.TESTFN)
def test_listdir(self):
- if hasattr(posix, 'listdir'):
- self.assertTrue(support.TESTFN in posix.listdir(os.curdir))
+ self.assertTrue(support.TESTFN in posix.listdir(os.curdir))
def test_listdir_default(self):
- # When listdir is called without argument, it's the same as listdir(os.curdir)
- if hasattr(posix, 'listdir'):
- self.assertTrue(support.TESTFN in posix.listdir())
+ # When listdir is called without argument,
+ # it's the same as listdir(os.curdir).
+ self.assertTrue(support.TESTFN in posix.listdir())
+
+ def test_listdir_bytes(self):
+ # When listdir is called with a bytes object,
+ # the returned strings are of type bytes.
+ self.assertTrue(os.fsencode(support.TESTFN) in posix.listdir(b'.'))
+
+ @unittest.skipUnless(posix.listdir in os.supports_fd,
+ "test needs fd support for posix.listdir()")
+ def test_listdir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ self.addCleanup(posix.close, f)
+ self.assertEqual(
+ sorted(posix.listdir('.')),
+ sorted(posix.listdir(f))
+ )
+ # Check that the fd offset was reset (issue #13739)
+ self.assertEqual(
+ sorted(posix.listdir('.')),
+ sorted(posix.listdir(f))
+ )
def test_access(self):
if hasattr(posix, 'access'):
@@ -356,6 +547,36 @@ class PosixTester(unittest.TestCase):
os.close(reader)
os.close(writer)
+ @unittest.skipUnless(hasattr(os, 'pipe2'), "test needs os.pipe2()")
+ @support.requires_linux_version(2, 6, 27)
+ def test_pipe2(self):
+ self.assertRaises(TypeError, os.pipe2, 'DEADBEEF')
+ self.assertRaises(TypeError, os.pipe2, 0, 0)
+
+ # try calling with flags = 0, like os.pipe()
+ r, w = os.pipe2(0)
+ os.close(r)
+ os.close(w)
+
+ # test flags
+ r, w = os.pipe2(os.O_CLOEXEC|os.O_NONBLOCK)
+ self.addCleanup(os.close, r)
+ self.addCleanup(os.close, w)
+ self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
+ self.assertTrue(fcntl.fcntl(w, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
+ # try reading from an empty pipe: this should fail, not block
+ self.assertRaises(OSError, os.read, r, 1)
+ # try a write big enough to fill-up the pipe: this should either
+ # fail or perform a partial write, not block
+ try:
+ os.write(w, b'x' * support.PIPE_MAX_SIZE)
+ except OSError:
+ pass
+
+ # Issue 15989
+ self.assertRaises(OverflowError, os.pipe2, _testcapi.INT_MAX + 1)
+ self.assertRaises(OverflowError, os.pipe2, _testcapi.UINT_MAX + 1)
+
def test_utime(self):
if hasattr(posix, 'utime'):
now = time.time()
@@ -366,13 +587,14 @@ class PosixTester(unittest.TestCase):
posix.utime(support.TESTFN, (int(now), int(now)))
posix.utime(support.TESTFN, (now, now))
- def _test_chflags_regular_file(self, chflags_func, target_file):
+ def _test_chflags_regular_file(self, chflags_func, target_file, **kwargs):
st = os.stat(target_file)
self.assertTrue(hasattr(st, 'st_flags'))
# ZFS returns EOPNOTSUPP when attempting to set flag UF_IMMUTABLE.
+ flags = st.st_flags | stat.UF_IMMUTABLE
try:
- chflags_func(target_file, st.st_flags | stat.UF_IMMUTABLE)
+ chflags_func(target_file, flags, **kwargs)
except OSError as err:
if err.errno != errno.EOPNOTSUPP:
raise
@@ -396,6 +618,7 @@ class PosixTester(unittest.TestCase):
@unittest.skipUnless(hasattr(posix, 'lchflags'), 'test needs os.lchflags()')
def test_lchflags_regular_file(self):
self._test_chflags_regular_file(posix.lchflags, support.TESTFN)
+ self._test_chflags_regular_file(posix.chflags, support.TESTFN, follow_symlinks=False)
@unittest.skipUnless(hasattr(posix, 'lchflags'), 'test needs os.lchflags()')
def test_lchflags_symlink(self):
@@ -407,25 +630,28 @@ class PosixTester(unittest.TestCase):
self.teardown_files.append(_DUMMY_SYMLINK)
dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)
- # ZFS returns EOPNOTSUPP when attempting to set flag UF_IMMUTABLE.
- try:
- posix.lchflags(_DUMMY_SYMLINK,
- dummy_symlink_st.st_flags | stat.UF_IMMUTABLE)
- except OSError as err:
- if err.errno != errno.EOPNOTSUPP:
- raise
- msg = 'chflag UF_IMMUTABLE not supported by underlying fs'
- self.skipTest(msg)
+ def chflags_nofollow(path, flags):
+ return posix.chflags(path, flags, follow_symlinks=False)
- try:
- new_testfn_st = os.stat(support.TESTFN)
- new_dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)
+ for fn in (posix.lchflags, chflags_nofollow):
+ # ZFS returns EOPNOTSUPP when attempting to set flag UF_IMMUTABLE.
+ flags = dummy_symlink_st.st_flags | stat.UF_IMMUTABLE
+ try:
+ fn(_DUMMY_SYMLINK, flags)
+ except OSError as err:
+ if err.errno != errno.EOPNOTSUPP:
+ raise
+ msg = 'chflag UF_IMMUTABLE not supported by underlying fs'
+ self.skipTest(msg)
+ try:
+ new_testfn_st = os.stat(support.TESTFN)
+ new_dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)
- self.assertEqual(testfn_st.st_flags, new_testfn_st.st_flags)
- self.assertEqual(dummy_symlink_st.st_flags | stat.UF_IMMUTABLE,
- new_dummy_symlink_st.st_flags)
- finally:
- posix.lchflags(_DUMMY_SYMLINK, dummy_symlink_st.st_flags)
+ self.assertEqual(testfn_st.st_flags, new_testfn_st.st_flags)
+ self.assertEqual(dummy_symlink_st.st_flags | stat.UF_IMMUTABLE,
+ new_dummy_symlink_st.st_flags)
+ finally:
+ fn(_DUMMY_SYMLINK, dummy_symlink_st.st_flags)
def test_environ(self):
if os.name == "nt":
@@ -473,13 +699,22 @@ class PosixTester(unittest.TestCase):
os.chdir(curdir)
support.rmtree(base_path)
+ @unittest.skipUnless(hasattr(posix, 'getgrouplist'), "test needs posix.getgrouplist()")
+ @unittest.skipUnless(hasattr(pwd, 'getpwuid'), "test needs pwd.getpwuid()")
+ @unittest.skipUnless(hasattr(os, 'getuid'), "test needs os.getuid()")
+ def test_getgrouplist(self):
+ user = pwd.getpwuid(os.getuid())[0]
+ group = pwd.getpwuid(os.getuid())[3]
+ self.assertIn(group, posix.getgrouplist(user, group))
+
+
@unittest.skipUnless(hasattr(os, 'getegid'), "test needs os.getegid()")
def test_getgroups(self):
with os.popen('id -G') as idg:
groups = idg.read().strip()
ret = idg.close()
- if ret != None or not groups:
+ if ret is not None or not groups:
raise unittest.SkipTest("need working 'id -G'")
# Issues 16698: OS X ABIs prior to 10.6 have limits on getgroups()
@@ -497,6 +732,348 @@ class PosixTester(unittest.TestCase):
set([int(x) for x in groups.split()]),
set(posix.getgroups() + [posix.getegid()]))
+ # tests for the posix *at functions follow
+
+ @unittest.skipUnless(os.access in os.supports_dir_fd, "test needs dir_fd support for os.access()")
+ def test_access_dir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ self.assertTrue(posix.access(support.TESTFN, os.R_OK, dir_fd=f))
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.chmod in os.supports_dir_fd, "test needs dir_fd support in os.chmod()")
+ def test_chmod_dir_fd(self):
+ os.chmod(support.TESTFN, stat.S_IRUSR)
+
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.chmod(support.TESTFN, stat.S_IRUSR | stat.S_IWUSR, dir_fd=f)
+
+ s = posix.stat(support.TESTFN)
+ self.assertEqual(s[0] & stat.S_IRWXU, stat.S_IRUSR | stat.S_IWUSR)
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.chown in os.supports_dir_fd, "test needs dir_fd support in os.chown()")
+ def test_chown_dir_fd(self):
+ support.unlink(support.TESTFN)
+ support.create_empty_file(support.TESTFN)
+
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.chown(support.TESTFN, os.getuid(), os.getgid(), dir_fd=f)
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.stat in os.supports_dir_fd, "test needs dir_fd support in os.stat()")
+ def test_stat_dir_fd(self):
+ support.unlink(support.TESTFN)
+ with open(support.TESTFN, 'w') as outfile:
+ outfile.write("testline\n")
+
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ s1 = posix.stat(support.TESTFN)
+ s2 = posix.stat(support.TESTFN, dir_fd=f)
+ self.assertEqual(s1, s2)
+ s2 = posix.stat(support.TESTFN, dir_fd=None)
+ self.assertEqual(s1, s2)
+ self.assertRaisesRegex(TypeError, 'should be integer, not',
+ posix.stat, support.TESTFN, dir_fd=posix.getcwd())
+ self.assertRaisesRegex(TypeError, 'should be integer, not',
+ posix.stat, support.TESTFN, dir_fd=float(f))
+ self.assertRaises(OverflowError,
+ posix.stat, support.TESTFN, dir_fd=10**20)
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.utime in os.supports_dir_fd, "test needs dir_fd support in os.utime()")
+ def test_utime_dir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ now = time.time()
+ posix.utime(support.TESTFN, None, dir_fd=f)
+ posix.utime(support.TESTFN, dir_fd=f)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, now, dir_fd=f)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (None, None), dir_fd=f)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (now, None), dir_fd=f)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (None, now), dir_fd=f)
+ self.assertRaises(TypeError, posix.utime, support.TESTFN, (now, "x"), dir_fd=f)
+ posix.utime(support.TESTFN, (int(now), int(now)), dir_fd=f)
+ posix.utime(support.TESTFN, (now, now), dir_fd=f)
+ posix.utime(support.TESTFN,
+ (int(now), int((now - int(now)) * 1e9)), dir_fd=f)
+ posix.utime(support.TESTFN, dir_fd=f,
+ times=(int(now), int((now - int(now)) * 1e9)))
+
+ # try dir_fd and follow_symlinks together
+ if os.utime in os.supports_follow_symlinks:
+ try:
+ posix.utime(support.TESTFN, follow_symlinks=False, dir_fd=f)
+ except ValueError:
+ # whoops! using both together not supported on this platform.
+ pass
+
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.link in os.supports_dir_fd, "test needs dir_fd support in os.link()")
+ def test_link_dir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.link(support.TESTFN, support.TESTFN + 'link', src_dir_fd=f, dst_dir_fd=f)
+ # should have same inodes
+ self.assertEqual(posix.stat(support.TESTFN)[1],
+ posix.stat(support.TESTFN + 'link')[1])
+ finally:
+ posix.close(f)
+ support.unlink(support.TESTFN + 'link')
+
+ @unittest.skipUnless(os.mkdir in os.supports_dir_fd, "test needs dir_fd support in os.mkdir()")
+ def test_mkdir_dir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.mkdir(support.TESTFN + 'dir', dir_fd=f)
+ posix.stat(support.TESTFN + 'dir') # should not raise exception
+ finally:
+ posix.close(f)
+ support.rmtree(support.TESTFN + 'dir')
+
+ @unittest.skipUnless((os.mknod in os.supports_dir_fd) and hasattr(stat, 'S_IFIFO'),
+ "test requires both stat.S_IFIFO and dir_fd support for os.mknod()")
+ def test_mknod_dir_fd(self):
+ # Test using mknodat() to create a FIFO (the only use specified
+ # by POSIX).
+ support.unlink(support.TESTFN)
+ mode = stat.S_IFIFO | stat.S_IRUSR | stat.S_IWUSR
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.mknod(support.TESTFN, mode, 0, dir_fd=f)
+ except OSError as e:
+ # Some old systems don't allow unprivileged users to use
+ # mknod(), or only support creating device nodes.
+ self.assertIn(e.errno, (errno.EPERM, errno.EINVAL))
+ else:
+ self.assertTrue(stat.S_ISFIFO(posix.stat(support.TESTFN).st_mode))
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.open in os.supports_dir_fd, "test needs dir_fd support in os.open()")
+ def test_open_dir_fd(self):
+ support.unlink(support.TESTFN)
+ with open(support.TESTFN, 'w') as outfile:
+ outfile.write("testline\n")
+ a = posix.open(posix.getcwd(), posix.O_RDONLY)
+ b = posix.open(support.TESTFN, posix.O_RDONLY, dir_fd=a)
+ try:
+ res = posix.read(b, 9).decode(encoding="utf-8")
+ self.assertEqual("testline\n", res)
+ finally:
+ posix.close(a)
+ posix.close(b)
+
+ @unittest.skipUnless(os.readlink in os.supports_dir_fd, "test needs dir_fd support in os.readlink()")
+ def test_readlink_dir_fd(self):
+ os.symlink(support.TESTFN, support.TESTFN + 'link')
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ self.assertEqual(posix.readlink(support.TESTFN + 'link'),
+ posix.readlink(support.TESTFN + 'link', dir_fd=f))
+ finally:
+ support.unlink(support.TESTFN + 'link')
+ posix.close(f)
+
+ @unittest.skipUnless(os.rename in os.supports_dir_fd, "test needs dir_fd support in os.rename()")
+ def test_rename_dir_fd(self):
+ support.unlink(support.TESTFN)
+ support.create_empty_file(support.TESTFN + 'ren')
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.rename(support.TESTFN + 'ren', support.TESTFN, src_dir_fd=f, dst_dir_fd=f)
+ except:
+ posix.rename(support.TESTFN + 'ren', support.TESTFN)
+ raise
+ else:
+ posix.stat(support.TESTFN) # should not raise exception
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.symlink in os.supports_dir_fd, "test needs dir_fd support in os.symlink()")
+ def test_symlink_dir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.symlink(support.TESTFN, support.TESTFN + 'link', dir_fd=f)
+ self.assertEqual(posix.readlink(support.TESTFN + 'link'), support.TESTFN)
+ finally:
+ posix.close(f)
+ support.unlink(support.TESTFN + 'link')
+
+ @unittest.skipUnless(os.unlink in os.supports_dir_fd, "test needs dir_fd support in os.unlink()")
+ def test_unlink_dir_fd(self):
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ support.create_empty_file(support.TESTFN + 'del')
+ posix.stat(support.TESTFN + 'del') # should not raise exception
+ try:
+ posix.unlink(support.TESTFN + 'del', dir_fd=f)
+ except:
+ support.unlink(support.TESTFN + 'del')
+ raise
+ else:
+ self.assertRaises(OSError, posix.stat, support.TESTFN + 'link')
+ finally:
+ posix.close(f)
+
+ @unittest.skipUnless(os.mkfifo in os.supports_dir_fd, "test needs dir_fd support in os.mkfifo()")
+ def test_mkfifo_dir_fd(self):
+ support.unlink(support.TESTFN)
+ f = posix.open(posix.getcwd(), posix.O_RDONLY)
+ try:
+ posix.mkfifo(support.TESTFN, stat.S_IRUSR | stat.S_IWUSR, dir_fd=f)
+ self.assertTrue(stat.S_ISFIFO(posix.stat(support.TESTFN).st_mode))
+ finally:
+ posix.close(f)
+
+ requires_sched_h = unittest.skipUnless(hasattr(posix, 'sched_yield'),
+ "don't have scheduling support")
+ requires_sched_affinity = unittest.skipUnless(hasattr(posix, 'sched_setaffinity'),
+ "don't have sched affinity support")
+
+ @requires_sched_h
+ def test_sched_yield(self):
+ # This has no error conditions (at least on Linux).
+ posix.sched_yield()
+
+ @requires_sched_h
+ @unittest.skipUnless(hasattr(posix, 'sched_get_priority_max'),
+ "requires sched_get_priority_max()")
+ def test_sched_priority(self):
+ # Round-robin usually has interesting priorities.
+ pol = posix.SCHED_RR
+ lo = posix.sched_get_priority_min(pol)
+ hi = posix.sched_get_priority_max(pol)
+ self.assertIsInstance(lo, int)
+ self.assertIsInstance(hi, int)
+ self.assertGreaterEqual(hi, lo)
+ # OSX evidently just returns 15 without checking the argument.
+ if sys.platform != "darwin":
+ self.assertRaises(OSError, posix.sched_get_priority_min, -23)
+ self.assertRaises(OSError, posix.sched_get_priority_max, -23)
+
+ @unittest.skipUnless(hasattr(posix, 'sched_setscheduler'), "can't change scheduler")
+ def test_get_and_set_scheduler_and_param(self):
+ possible_schedulers = [sched for name, sched in posix.__dict__.items()
+ if name.startswith("SCHED_")]
+ mine = posix.sched_getscheduler(0)
+ self.assertIn(mine, possible_schedulers)
+ try:
+ parent = posix.sched_getscheduler(os.getppid())
+ except OSError as e:
+ if e.errno != errno.EPERM:
+ raise
+ else:
+ self.assertIn(parent, possible_schedulers)
+ self.assertRaises(OSError, posix.sched_getscheduler, -1)
+ self.assertRaises(OSError, posix.sched_getparam, -1)
+ param = posix.sched_getparam(0)
+ self.assertIsInstance(param.sched_priority, int)
+
+ # POSIX states that calling sched_setparam() or sched_setscheduler() on
+ # a process with a scheduling policy other than SCHED_FIFO or SCHED_RR
+ # is implementation-defined: NetBSD and FreeBSD can return EINVAL.
+ if not sys.platform.startswith(('freebsd', 'netbsd')):
+ try:
+ posix.sched_setscheduler(0, mine, param)
+ posix.sched_setparam(0, param)
+ except OSError as e:
+ if e.errno != errno.EPERM:
+ raise
+ self.assertRaises(OSError, posix.sched_setparam, -1, param)
+
+ self.assertRaises(OSError, posix.sched_setscheduler, -1, mine, param)
+ self.assertRaises(TypeError, posix.sched_setscheduler, 0, mine, None)
+ self.assertRaises(TypeError, posix.sched_setparam, 0, 43)
+ param = posix.sched_param(None)
+ self.assertRaises(TypeError, posix.sched_setparam, 0, param)
+ large = 214748364700
+ param = posix.sched_param(large)
+ self.assertRaises(OverflowError, posix.sched_setparam, 0, param)
+ param = posix.sched_param(sched_priority=-large)
+ self.assertRaises(OverflowError, posix.sched_setparam, 0, param)
+
+ @unittest.skipUnless(hasattr(posix, "sched_rr_get_interval"), "no function")
+ def test_sched_rr_get_interval(self):
+ try:
+ interval = posix.sched_rr_get_interval(0)
+ except OSError as e:
+ # This likely means that sched_rr_get_interval is only valid for
+ # processes with the SCHED_RR scheduler in effect.
+ if e.errno != errno.EINVAL:
+ raise
+ self.skipTest("only works on SCHED_RR processes")
+ self.assertIsInstance(interval, float)
+ # Reasonable constraints, I think.
+ self.assertGreaterEqual(interval, 0.)
+ self.assertLess(interval, 1.)
+
+ @requires_sched_affinity
+ def test_sched_getaffinity(self):
+ mask = posix.sched_getaffinity(0)
+ self.assertIsInstance(mask, set)
+ self.assertGreaterEqual(len(mask), 1)
+ self.assertRaises(OSError, posix.sched_getaffinity, -1)
+ for cpu in mask:
+ self.assertIsInstance(cpu, int)
+ self.assertGreaterEqual(cpu, 0)
+ self.assertLess(cpu, 1 << 32)
+
+ @requires_sched_affinity
+ def test_sched_setaffinity(self):
+ mask = posix.sched_getaffinity(0)
+ if len(mask) > 1:
+ # Empty masks are forbidden
+ mask.pop()
+ posix.sched_setaffinity(0, mask)
+ self.assertEqual(posix.sched_getaffinity(0), mask)
+ self.assertRaises(OSError, posix.sched_setaffinity, 0, [])
+ self.assertRaises(ValueError, posix.sched_setaffinity, 0, [-10])
+ self.assertRaises(OverflowError, posix.sched_setaffinity, 0, [1<<128])
+ self.assertRaises(OSError, posix.sched_setaffinity, -1, mask)
+
+ def test_rtld_constants(self):
+ # check presence of major RTLD_* constants
+ posix.RTLD_LAZY
+ posix.RTLD_NOW
+ posix.RTLD_GLOBAL
+ posix.RTLD_LOCAL
+
+ @unittest.skipUnless(hasattr(os, 'SEEK_HOLE'),
+ "test needs an OS that reports file holes")
+ def test_fs_holes(self):
+ # Even if the filesystem doesn't report holes,
+ # if the OS supports it the SEEK_* constants
+ # will be defined and will have a consistent
+ # behaviour:
+ # os.SEEK_DATA = current position
+ # os.SEEK_HOLE = end of file position
+ with open(support.TESTFN, 'r+b') as fp:
+ fp.write(b"hello")
+ fp.flush()
+ size = fp.tell()
+ fno = fp.fileno()
+ try :
+ for i in range(size):
+ self.assertEqual(i, os.lseek(fno, i, os.SEEK_DATA))
+ self.assertLessEqual(size, os.lseek(fno, i, os.SEEK_HOLE))
+ self.assertRaises(OSError, os.lseek, fno, size, os.SEEK_DATA)
+ self.assertRaises(OSError, os.lseek, fno, size, os.SEEK_HOLE)
+ except OSError :
+ # Some OSs claim to support SEEK_HOLE/SEEK_DATA
+ # but it is not true.
+ # For instance:
+ # http://lists.freebsd.org/pipermail/freebsd-amd64/2012-January/014332.html
+ raise unittest.SkipTest("OSError raised!")
+
class PosixGroupsTester(unittest.TestCase):
def setUp(self):
@@ -532,9 +1109,11 @@ class PosixGroupsTester(unittest.TestCase):
posix.setgroups(groups)
self.assertListEqual(groups, posix.getgroups())
-
def test_main():
- support.run_unittest(PosixTester, PosixGroupsTester)
+ try:
+ support.run_unittest(PosixTester, PosixGroupsTester)
+ finally:
+ support.reap_children()
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py
index 5d0f5b7..f545f08 100644
--- a/Lib/test/test_posixpath.py
+++ b/Lib/test/test_posixpath.py
@@ -1,11 +1,11 @@
-import unittest
-from test import support, test_genericpath
-
import itertools
-import posixpath
import os
+import posixpath
import sys
+import unittest
+import warnings
from posixpath import realpath, abspath, dirname, basename
+from test import support, test_genericpath
try:
import posix
@@ -245,7 +245,9 @@ class PosixPathTest(unittest.TestCase):
def test_ismount(self):
self.assertIs(posixpath.ismount("/"), True)
- self.assertIs(posixpath.ismount(b"/"), True)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ self.assertIs(posixpath.ismount(b"/"), True)
def test_ismount_non_existent(self):
# Non-existent mountpoint.
@@ -597,14 +599,10 @@ class PosixPathTest(unittest.TestCase):
self.assertTrue(posixpath.sameopenfile(a.fileno(), b.fileno()))
-class PosixCommonTest(test_genericpath.CommonTest):
+class PosixCommonTest(test_genericpath.CommonTest, unittest.TestCase):
pathmodule = posixpath
attributes = ['relpath', 'samefile', 'sameopenfile', 'samestat']
-def test_main():
- support.run_unittest(PosixPathTest, PosixCommonTest)
-
-
if __name__=="__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
index 0c93d0e..d492d75 100644
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -219,6 +219,8 @@ class QueryTestCase(unittest.TestCase):
others.should.not.be: like.this}"""
self.assertEqual(DottedPrettyPrinter().pformat(o), exp)
+ @unittest.expectedFailure
+ #See http://bugs.python.org/issue13907
@test.support.cpython_only
def test_set_reprs(self):
# This test creates a complex arrangement of frozensets and
@@ -241,10 +243,12 @@ class QueryTestCase(unittest.TestCase):
# Consequently, this test is fragile and
# implementation-dependent. Small changes to Python's sort
# algorithm cause the test to fail when it should pass.
+ # XXX Or changes to the dictionary implmentation...
self.assertEqual(pprint.pformat(set()), 'set()')
self.assertEqual(pprint.pformat(set(range(3))), '{0, 1, 2}')
self.assertEqual(pprint.pformat(frozenset()), 'frozenset()')
+
self.assertEqual(pprint.pformat(frozenset(range(3))), 'frozenset({0, 1, 2})')
cube_repr_tgt = """\
{frozenset(): frozenset({frozenset({2}), frozenset({0}), frozenset({1})}),
diff --git a/Lib/test/test_print.py b/Lib/test/test_print.py
index 8d37bbc..9d6dbea 100644
--- a/Lib/test/test_print.py
+++ b/Lib/test/test_print.py
@@ -111,6 +111,32 @@ class TestPrint(unittest.TestCase):
self.assertRaises(TypeError, print, '', end=3)
self.assertRaises(AttributeError, print, '', file='')
+ def test_print_flush(self):
+ # operation of the flush flag
+ class filelike():
+ def __init__(self):
+ self.written = ''
+ self.flushed = 0
+ def write(self, str):
+ self.written += str
+ def flush(self):
+ self.flushed += 1
+
+ f = filelike()
+ print(1, file=f, end='', flush=True)
+ print(2, file=f, end='', flush=True)
+ print(3, file=f, flush=False)
+ self.assertEqual(f.written, '123\n')
+ self.assertEqual(f.flushed, 2)
+
+ # ensure exceptions from flush are passed through
+ class noflush():
+ def write(self, str):
+ pass
+ def flush(self):
+ raise RuntimeError
+ self.assertRaises(RuntimeError, print, 1, file=noflush(), flush=True)
+
def test_main():
support.run_unittest(TestPrint)
diff --git a/Lib/test/test_property.py b/Lib/test/test_property.py
index b58270f..96eeb88 100644
--- a/Lib/test/test_property.py
+++ b/Lib/test/test_property.py
@@ -128,6 +128,29 @@ class PropertyTests(unittest.TestCase):
self.assertEqual(newgetter.spam, 8)
self.assertEqual(newgetter.__class__.spam.__doc__, "new docstring")
+ def test_property___isabstractmethod__descriptor(self):
+ for val in (True, False, [], [1], '', '1'):
+ class C(object):
+ def foo(self):
+ pass
+ foo.__isabstractmethod__ = val
+ foo = property(foo)
+ self.assertIs(C.foo.__isabstractmethod__, bool(val))
+
+ # check that the property's __isabstractmethod__ descriptor does the
+ # right thing when presented with a value that fails truth testing:
+ class NotBool(object):
+ def __nonzero__(self):
+ raise ValueError()
+ __len__ = __nonzero__
+ with self.assertRaises(ValueError):
+ class C(object):
+ def foo(self):
+ pass
+ foo.__isabstractmethod__ = NotBool()
+ foo = property(foo)
+ C.foo.__isabstractmethod__
+
# Issue 5890: subclasses of property do not preserve method __doc__ strings
class PropertySub(property):
diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py
index 4d471d5..29297f8 100644
--- a/Lib/test/test_pty.py
+++ b/Lib/test/test_pty.py
@@ -205,6 +205,7 @@ class SmallPtyTests(unittest.TestCase):
self.orig_stdout_fileno = pty.STDOUT_FILENO
self.orig_pty_select = pty.select
self.fds = [] # A list of file descriptors to close.
+ self.files = []
self.select_rfds_lengths = []
self.select_rfds_results = []
@@ -212,6 +213,11 @@ class SmallPtyTests(unittest.TestCase):
pty.STDIN_FILENO = self.orig_stdin_fileno
pty.STDOUT_FILENO = self.orig_stdout_fileno
pty.select = self.orig_pty_select
+ for file in self.files:
+ try:
+ file.close()
+ except OSError:
+ pass
for fd in self.fds:
try:
os.close(fd)
@@ -223,6 +229,11 @@ class SmallPtyTests(unittest.TestCase):
self.fds.extend(pipe_fds)
return pipe_fds
+ def _socketpair(self):
+ socketpair = socket.socketpair()
+ self.files.extend(socketpair)
+ return socketpair
+
def _mock_select(self, rfds, wfds, xfds):
# This will raise IndexError when no more expected calls exist.
self.assertEqual(self.select_rfds_lengths.pop(0), len(rfds))
@@ -234,9 +245,7 @@ class SmallPtyTests(unittest.TestCase):
pty.STDOUT_FILENO = mock_stdout_fd
mock_stdin_fd, write_to_stdin_fd = self._pipe()
pty.STDIN_FILENO = mock_stdin_fd
- socketpair = socket.socketpair()
- for s in socketpair:
- self.addCleanup(s.close)
+ socketpair = self._socketpair()
masters = [s.fileno() for s in socketpair]
# Feed data. Smaller than PIPEBUF. These writes will not block.
@@ -264,9 +273,7 @@ class SmallPtyTests(unittest.TestCase):
pty.STDOUT_FILENO = mock_stdout_fd
mock_stdin_fd, write_to_stdin_fd = self._pipe()
pty.STDIN_FILENO = mock_stdin_fd
- socketpair = socket.socketpair()
- for s in socketpair:
- self.addCleanup(s.close)
+ socketpair = self._socketpair()
masters = [s.fileno() for s in socketpair]
os.close(masters[1])
diff --git a/Lib/test/test_pulldom.py b/Lib/test/test_pulldom.py
new file mode 100644
index 0000000..b81a595
--- /dev/null
+++ b/Lib/test/test_pulldom.py
@@ -0,0 +1,347 @@
+import io
+import unittest
+import sys
+import xml.sax
+
+from xml.sax.xmlreader import AttributesImpl
+from xml.dom import pulldom
+
+from test.support import run_unittest, findfile
+
+
+tstfile = findfile("test.xml", subdir="xmltestdata")
+
+# A handy XML snippet, containing attributes, a namespace prefix, and a
+# self-closing tag:
+SMALL_SAMPLE = """<?xml version="1.0"?>
+<html xmlns="http://www.w3.org/1999/xhtml" xmlns:xdc="http://www.xml.com/books">
+<!-- A comment -->
+<title>Introduction to XSL</title>
+<hr/>
+<p><xdc:author xdc:attrib="prefixed attribute" attrib="other attrib">A. Namespace</xdc:author></p>
+</html>"""
+
+
+class PullDOMTestCase(unittest.TestCase):
+
+ def test_parse(self):
+ """Minimal test of DOMEventStream.parse()"""
+
+ # This just tests that parsing from a stream works. Actual parser
+ # semantics are tested using parseString with a more focused XML
+ # fragment.
+
+ # Test with a filename:
+ handler = pulldom.parse(tstfile)
+ self.addCleanup(handler.stream.close)
+ list(handler)
+
+ # Test with a file object:
+ with open(tstfile, "rb") as fin:
+ list(pulldom.parse(fin))
+
+ def test_parse_semantics(self):
+ """Test DOMEventStream parsing semantics."""
+
+ items = pulldom.parseString(SMALL_SAMPLE)
+ evt, node = next(items)
+ # Just check the node is a Document:
+ self.assertTrue(hasattr(node, "createElement"))
+ self.assertEqual(pulldom.START_DOCUMENT, evt)
+ evt, node = next(items)
+ self.assertEqual(pulldom.START_ELEMENT, evt)
+ self.assertEqual("html", node.tagName)
+ self.assertEqual(2, len(node.attributes))
+ self.assertEqual(node.attributes.getNamedItem("xmlns:xdc").value,
+ "http://www.xml.com/books")
+ evt, node = next(items)
+ self.assertEqual(pulldom.CHARACTERS, evt) # Line break
+ evt, node = next(items)
+ # XXX - A comment should be reported here!
+ # self.assertEqual(pulldom.COMMENT, evt)
+ # Line break after swallowed comment:
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ evt, node = next(items)
+ self.assertEqual("title", node.tagName)
+ title_node = node
+ evt, node = next(items)
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ self.assertEqual("Introduction to XSL", node.data)
+ evt, node = next(items)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ self.assertEqual("title", node.tagName)
+ self.assertTrue(title_node is node)
+ evt, node = next(items)
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ evt, node = next(items)
+ self.assertEqual(pulldom.START_ELEMENT, evt)
+ self.assertEqual("hr", node.tagName)
+ evt, node = next(items)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ self.assertEqual("hr", node.tagName)
+ evt, node = next(items)
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ evt, node = next(items)
+ self.assertEqual(pulldom.START_ELEMENT, evt)
+ self.assertEqual("p", node.tagName)
+ evt, node = next(items)
+ self.assertEqual(pulldom.START_ELEMENT, evt)
+ self.assertEqual("xdc:author", node.tagName)
+ evt, node = next(items)
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ evt, node = next(items)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ self.assertEqual("xdc:author", node.tagName)
+ evt, node = next(items)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ evt, node = next(items)
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ evt, node = next(items)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ # XXX No END_DOCUMENT item is ever obtained:
+ #evt, node = next(items)
+ #self.assertEqual(pulldom.END_DOCUMENT, evt)
+
+ def test_expandItem(self):
+ """Ensure expandItem works as expected."""
+ items = pulldom.parseString(SMALL_SAMPLE)
+ # Loop through the nodes until we get to a "title" start tag:
+ for evt, item in items:
+ if evt == pulldom.START_ELEMENT and item.tagName == "title":
+ items.expandNode(item)
+ self.assertEqual(1, len(item.childNodes))
+ break
+ else:
+ self.fail("No \"title\" element detected in SMALL_SAMPLE!")
+ # Loop until we get to the next start-element:
+ for evt, node in items:
+ if evt == pulldom.START_ELEMENT:
+ break
+ self.assertEqual("hr", node.tagName,
+ "expandNode did not leave DOMEventStream in the correct state.")
+ # Attempt to expand a standalone element:
+ items.expandNode(node)
+ self.assertEqual(next(items)[0], pulldom.CHARACTERS)
+ evt, node = next(items)
+ self.assertEqual(node.tagName, "p")
+ items.expandNode(node)
+ next(items) # Skip character data
+ evt, node = next(items)
+ self.assertEqual(node.tagName, "html")
+ with self.assertRaises(StopIteration):
+ next(items)
+ items.clear()
+ self.assertIsNone(items.parser)
+ self.assertIsNone(items.stream)
+
+ @unittest.expectedFailure
+ def test_comment(self):
+ """PullDOM does not receive "comment" events."""
+ items = pulldom.parseString(SMALL_SAMPLE)
+ for evt, _ in items:
+ if evt == pulldom.COMMENT:
+ break
+ else:
+ self.fail("No comment was encountered")
+
+ @unittest.expectedFailure
+ def test_end_document(self):
+ """PullDOM does not receive "end-document" events."""
+ items = pulldom.parseString(SMALL_SAMPLE)
+ # Read all of the nodes up to and including </html>:
+ for evt, node in items:
+ if evt == pulldom.END_ELEMENT and node.tagName == "html":
+ break
+ try:
+ # Assert that the next node is END_DOCUMENT:
+ evt, node = next(items)
+ self.assertEqual(pulldom.END_DOCUMENT, evt)
+ except StopIteration:
+ self.fail(
+ "Ran out of events, but should have received END_DOCUMENT")
+
+
+class ThoroughTestCase(unittest.TestCase):
+ """Test the hard-to-reach parts of pulldom."""
+
+ def test_thorough_parse(self):
+ """Test some of the hard-to-reach parts of PullDOM."""
+ self._test_thorough(pulldom.parse(None, parser=SAXExerciser()))
+
+ @unittest.expectedFailure
+ def test_sax2dom_fail(self):
+ """SAX2DOM can"t handle a PI before the root element."""
+ pd = SAX2DOMTestHelper(None, SAXExerciser(), 12)
+ self._test_thorough(pd)
+
+ def test_thorough_sax2dom(self):
+ """Test some of the hard-to-reach parts of SAX2DOM."""
+ pd = SAX2DOMTestHelper(None, SAX2DOMExerciser(), 12)
+ self._test_thorough(pd, False)
+
+ def _test_thorough(self, pd, before_root=True):
+ """Test some of the hard-to-reach parts of the parser, using a mock
+ parser."""
+
+ evt, node = next(pd)
+ self.assertEqual(pulldom.START_DOCUMENT, evt)
+ # Just check the node is a Document:
+ self.assertTrue(hasattr(node, "createElement"))
+
+ if before_root:
+ evt, node = next(pd)
+ self.assertEqual(pulldom.COMMENT, evt)
+ self.assertEqual("a comment", node.data)
+ evt, node = next(pd)
+ self.assertEqual(pulldom.PROCESSING_INSTRUCTION, evt)
+ self.assertEqual("target", node.target)
+ self.assertEqual("data", node.data)
+
+ evt, node = next(pd)
+ self.assertEqual(pulldom.START_ELEMENT, evt)
+ self.assertEqual("html", node.tagName)
+
+ evt, node = next(pd)
+ self.assertEqual(pulldom.COMMENT, evt)
+ self.assertEqual("a comment", node.data)
+ evt, node = next(pd)
+ self.assertEqual(pulldom.PROCESSING_INSTRUCTION, evt)
+ self.assertEqual("target", node.target)
+ self.assertEqual("data", node.data)
+
+ evt, node = next(pd)
+ self.assertEqual(pulldom.START_ELEMENT, evt)
+ self.assertEqual("p", node.tagName)
+
+ evt, node = next(pd)
+ self.assertEqual(pulldom.CHARACTERS, evt)
+ self.assertEqual("text", node.data)
+ evt, node = next(pd)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ self.assertEqual("p", node.tagName)
+ evt, node = next(pd)
+ self.assertEqual(pulldom.END_ELEMENT, evt)
+ self.assertEqual("html", node.tagName)
+ evt, node = next(pd)
+ self.assertEqual(pulldom.END_DOCUMENT, evt)
+
+
+class SAXExerciser(object):
+ """A fake sax parser that calls some of the harder-to-reach sax methods to
+ ensure it emits the correct events"""
+
+ def setContentHandler(self, handler):
+ self._handler = handler
+
+ def parse(self, _):
+ h = self._handler
+ h.startDocument()
+
+ # The next two items ensure that items preceding the first
+ # start_element are properly stored and emitted:
+ h.comment("a comment")
+ h.processingInstruction("target", "data")
+
+ h.startElement("html", AttributesImpl({}))
+
+ h.comment("a comment")
+ h.processingInstruction("target", "data")
+
+ h.startElement("p", AttributesImpl({"class": "paraclass"}))
+ h.characters("text")
+ h.endElement("p")
+ h.endElement("html")
+ h.endDocument()
+
+ def stub(self, *args, **kwargs):
+ """Stub method. Does nothing."""
+ pass
+ setProperty = stub
+ setFeature = stub
+
+
+class SAX2DOMExerciser(SAXExerciser):
+ """The same as SAXExerciser, but without the processing instruction and
+ comment before the root element, because S2D can"t handle it"""
+
+ def parse(self, _):
+ h = self._handler
+ h.startDocument()
+ h.startElement("html", AttributesImpl({}))
+ h.comment("a comment")
+ h.processingInstruction("target", "data")
+ h.startElement("p", AttributesImpl({"class": "paraclass"}))
+ h.characters("text")
+ h.endElement("p")
+ h.endElement("html")
+ h.endDocument()
+
+
+class SAX2DOMTestHelper(pulldom.DOMEventStream):
+ """Allows us to drive SAX2DOM from a DOMEventStream."""
+
+ def reset(self):
+ self.pulldom = pulldom.SAX2DOM()
+ # This content handler relies on namespace support
+ self.parser.setFeature(xml.sax.handler.feature_namespaces, 1)
+ self.parser.setContentHandler(self.pulldom)
+
+
+class SAX2DOMTestCase(unittest.TestCase):
+
+ def confirm(self, test, testname="Test"):
+ self.assertTrue(test, testname)
+
+ def test_basic(self):
+ """Ensure SAX2DOM can parse from a stream."""
+ with io.StringIO(SMALL_SAMPLE) as fin:
+ sd = SAX2DOMTestHelper(fin, xml.sax.make_parser(),
+ len(SMALL_SAMPLE))
+ for evt, node in sd:
+ if evt == pulldom.START_ELEMENT and node.tagName == "html":
+ break
+ # Because the buffer is the same length as the XML, all the
+ # nodes should have been parsed and added:
+ self.assertGreater(len(node.childNodes), 0)
+
+ def testSAX2DOM(self):
+ """Ensure SAX2DOM expands nodes as expected."""
+ sax2dom = pulldom.SAX2DOM()
+ sax2dom.startDocument()
+ sax2dom.startElement("doc", {})
+ sax2dom.characters("text")
+ sax2dom.startElement("subelm", {})
+ sax2dom.characters("text")
+ sax2dom.endElement("subelm")
+ sax2dom.characters("text")
+ sax2dom.endElement("doc")
+ sax2dom.endDocument()
+
+ doc = sax2dom.document
+ root = doc.documentElement
+ (text1, elm1, text2) = root.childNodes
+ text3 = elm1.childNodes[0]
+
+ self.assertIsNone(text1.previousSibling)
+ self.assertIs(text1.nextSibling, elm1)
+ self.assertIs(elm1.previousSibling, text1)
+ self.assertIs(elm1.nextSibling, text2)
+ self.assertIs(text2.previousSibling, elm1)
+ self.assertIsNone(text2.nextSibling)
+ self.assertIsNone(text3.previousSibling)
+ self.assertIsNone(text3.nextSibling)
+
+ self.assertIs(root.parentNode, doc)
+ self.assertIs(text1.parentNode, root)
+ self.assertIs(elm1.parentNode, root)
+ self.assertIs(text2.parentNode, root)
+ self.assertIs(text3.parentNode, elm1)
+ doc.unlink()
+
+
+def test_main():
+ run_unittest(PullDOMTestCase, ThoroughTestCase, SAX2DOMTestCase)
+
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py
index c7318ff..d98a526 100644
--- a/Lib/test/test_pydoc.py
+++ b/Lib/test/test_pydoc.py
@@ -207,7 +207,7 @@ expected_html_data_docstrings = tuple(s.replace(' ', '&nbsp;')
missing_pattern = "no Python documentation found for '%s'"
# output pattern for module with bad imports
-badimport_pattern = "problem in %s - ImportError: No module named %s"
+badimport_pattern = "problem in %s - ImportError: No module named %r"
def run_pydoc(module_name, *args, **env):
"""
@@ -245,8 +245,8 @@ def get_pydoc_text(module):
def print_diffs(text1, text2):
"Prints unified diffs for two texts"
# XXX now obsolete, use unittest built-in support
- lines1 = text1.splitlines(True)
- lines2 = text2.splitlines(True)
+ lines1 = text1.splitlines(keepends=True)
+ lines2 = text2.splitlines(keepends=True)
diffs = difflib.unified_diff(lines1, lines2, n=0, fromfile='expected',
tofile='got')
print('\n' + ''.join(diffs))
@@ -263,6 +263,8 @@ class PydocDocTest(unittest.TestCase):
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
def test_html_doc(self):
result, doc_loc = get_pydoc_html(pydoc_mod)
mod_file = inspect.getabsfile(pydoc_mod)
@@ -280,6 +282,8 @@ class PydocDocTest(unittest.TestCase):
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
def test_text_doc(self):
result, doc_loc = get_pydoc_text(pydoc_mod)
expected_text = expected_text_pattern % (
@@ -334,6 +338,8 @@ class PydocDocTest(unittest.TestCase):
@unittest.skipIf(sys.flags.optimize >= 2,
'Docstrings are omitted with -O2 and above')
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
def test_help_output_redirect(self):
# issue 940286, if output is set in Helper, then all output from
# Helper.help should be redirected
@@ -403,11 +409,10 @@ class PydocImportTest(unittest.TestCase):
modname = 'testmod_xyzzy'
testpairs = (
('i_am_not_here', 'i_am_not_here'),
- ('test.i_am_not_here_either', 'i_am_not_here_either'),
- ('test.i_am_not_here.neither_am_i', 'i_am_not_here.neither_am_i'),
- ('i_am_not_here.{}'.format(modname),
- 'i_am_not_here.{}'.format(modname)),
- ('test.{}'.format(modname), modname),
+ ('test.i_am_not_here_either', 'test.i_am_not_here_either'),
+ ('test.i_am_not_here.neither_am_i', 'test.i_am_not_here'),
+ ('i_am_not_here.{}'.format(modname), 'i_am_not_here'),
+ ('test.{}'.format(modname), 'test.{}'.format(modname)),
)
sourcefn = os.path.join(TESTFN, modname) + os.extsep + "py"
diff --git a/Lib/test/test_raise.py b/Lib/test/test_raise.py
index e02c1af..be5c1c6 100644
--- a/Lib/test/test_raise.py
+++ b/Lib/test/test_raise.py
@@ -4,6 +4,7 @@
"""Tests for the raise statement."""
from test import support
+import re
import sys
import types
import unittest
@@ -77,6 +78,16 @@ class TestRaise(unittest.TestCase):
nested_reraise()
self.assertRaises(TypeError, reraise)
+ def test_raise_from_None(self):
+ try:
+ try:
+ raise TypeError("foo")
+ except:
+ raise ValueError() from None
+ except ValueError as e:
+ self.assertTrue(isinstance(e.__context__, TypeError))
+ self.assertIsNone(e.__cause__)
+
def test_with_reraise1(self):
def reraise():
try:
@@ -130,8 +141,35 @@ class TestRaise(unittest.TestCase):
with self.assertRaises(TypeError):
raise MyException
+ def test_assert_with_tuple_arg(self):
+ try:
+ assert False, (3,)
+ except AssertionError as e:
+ self.assertEqual(str(e), "(3,)")
+
+
class TestCause(unittest.TestCase):
+
+ def testCauseSyntax(self):
+ try:
+ try:
+ try:
+ raise TypeError
+ except Exception:
+ raise ValueError from None
+ except ValueError as exc:
+ self.assertIsNone(exc.__cause__)
+ self.assertTrue(exc.__suppress_context__)
+ exc.__suppress_context__ = False
+ raise exc
+ except ValueError as exc:
+ e = exc
+
+ self.assertIsNone(e.__cause__)
+ self.assertFalse(e.__suppress_context__)
+ self.assertIsInstance(e.__context__, TypeError)
+
def test_invalid_cause(self):
try:
raise IndexError from 5
@@ -171,6 +209,7 @@ class TestCause(unittest.TestCase):
class TestTraceback(unittest.TestCase):
+
def test_sets_traceback(self):
try:
raise IndexError()
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 97d8098..c3ab7d2 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -34,8 +34,12 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(randseq, self.randomlist(N))
def test_seedargs(self):
+ # Seed value with a negative hash.
+ class MySeed(object):
+ def __hash__(self):
+ return -1729
for arg in [None, 0, 0, 1, 1, -1, -1, 10**20, -(10**20),
- 3.14, 1+2j, 'a', tuple('abc')]:
+ 3.14, 1+2j, 'a', tuple('abc'), MySeed()]:
self.gen.seed(arg)
for arg in [list(range(3)), dict(one=1)]:
self.assertRaises(TypeError, self.gen.seed, arg)
diff --git a/Lib/test/test_range.py b/Lib/test/test_range.py
index 4dab448..2a13bfe 100644
--- a/Lib/test/test_range.py
+++ b/Lib/test/test_range.py
@@ -350,13 +350,35 @@ class RangeTest(unittest.TestCase):
def test_pickling(self):
testcases = [(13,), (0, 11), (-22, 10), (20, 3, -1),
- (13, 21, 3), (-2, 2, 2)]
+ (13, 21, 3), (-2, 2, 2), (2**65, 2**65+2)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
for t in testcases:
r = range(*t)
self.assertEqual(list(pickle.loads(pickle.dumps(r, proto))),
list(r))
+ def test_iterator_pickling(self):
+ testcases = [(13,), (0, 11), (-22, 10), (20, 3, -1),
+ (13, 21, 3), (-2, 2, 2), (2**65, 2**65+2)]
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ for t in testcases:
+ it = itorg = iter(range(*t))
+ data = list(range(*t))
+
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(list(it), data)
+
+ it = pickle.loads(d)
+ try:
+ next(it)
+ except StopIteration:
+ continue
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(list(it), data[1:])
+
def test_odd_bug(self):
# This used to raise a "SystemError: NULL result without error"
# because the range validation step was eating the exception
@@ -516,6 +538,87 @@ class RangeTest(unittest.TestCase):
for k in values - {0}:
r[i:j:k]
+ def test_comparison(self):
+ test_ranges = [range(0), range(0, -1), range(1, 1, 3),
+ range(1), range(5, 6), range(5, 6, 2),
+ range(5, 7, 2), range(2), range(0, 4, 2),
+ range(0, 5, 2), range(0, 6, 2)]
+ test_tuples = list(map(tuple, test_ranges))
+
+ # Check that equality of ranges matches equality of the corresponding
+ # tuples for each pair from the test lists above.
+ ranges_eq = [a == b for a in test_ranges for b in test_ranges]
+ tuples_eq = [a == b for a in test_tuples for b in test_tuples]
+ self.assertEqual(ranges_eq, tuples_eq)
+
+ # Check that != correctly gives the logical negation of ==
+ ranges_ne = [a != b for a in test_ranges for b in test_ranges]
+ self.assertEqual(ranges_ne, [not x for x in ranges_eq])
+
+ # Equal ranges should have equal hashes.
+ for a in test_ranges:
+ for b in test_ranges:
+ if a == b:
+ self.assertEqual(hash(a), hash(b))
+
+ # Ranges are unequal to other types (even sequence types)
+ self.assertIs(range(0) == (), False)
+ self.assertIs(() == range(0), False)
+ self.assertIs(range(2) == [0, 1], False)
+
+ # Huge integers aren't a problem.
+ self.assertEqual(range(0, 2**100 - 1, 2),
+ range(0, 2**100, 2))
+ self.assertEqual(hash(range(0, 2**100 - 1, 2)),
+ hash(range(0, 2**100, 2)))
+ self.assertNotEqual(range(0, 2**100, 2),
+ range(0, 2**100 + 1, 2))
+ self.assertEqual(range(2**200, 2**201 - 2**99, 2**100),
+ range(2**200, 2**201, 2**100))
+ self.assertEqual(hash(range(2**200, 2**201 - 2**99, 2**100)),
+ hash(range(2**200, 2**201, 2**100)))
+ self.assertNotEqual(range(2**200, 2**201, 2**100),
+ range(2**200, 2**201 + 1, 2**100))
+
+ # Order comparisons are not implemented for ranges.
+ with self.assertRaises(TypeError):
+ range(0) < range(0)
+ with self.assertRaises(TypeError):
+ range(0) > range(0)
+ with self.assertRaises(TypeError):
+ range(0) <= range(0)
+ with self.assertRaises(TypeError):
+ range(0) >= range(0)
+
+
+ def test_attributes(self):
+ # test the start, stop and step attributes of range objects
+ self.assert_attrs(range(0), 0, 0, 1)
+ self.assert_attrs(range(10), 0, 10, 1)
+ self.assert_attrs(range(-10), 0, -10, 1)
+ self.assert_attrs(range(0, 10, 1), 0, 10, 1)
+ self.assert_attrs(range(0, 10, 3), 0, 10, 3)
+ self.assert_attrs(range(10, 0, -1), 10, 0, -1)
+ self.assert_attrs(range(10, 0, -3), 10, 0, -3)
+
+ def assert_attrs(self, rangeobj, start, stop, step):
+ self.assertEqual(rangeobj.start, start)
+ self.assertEqual(rangeobj.stop, stop)
+ self.assertEqual(rangeobj.step, step)
+
+ with self.assertRaises(AttributeError):
+ rangeobj.start = 0
+ with self.assertRaises(AttributeError):
+ rangeobj.stop = 10
+ with self.assertRaises(AttributeError):
+ rangeobj.step = 1
+
+ with self.assertRaises(AttributeError):
+ del rangeobj.start
+ with self.assertRaises(AttributeError):
+ del rangeobj.stop
+ with self.assertRaises(AttributeError):
+ del rangeobj.step
def test_main():
test.support.run_unittest(RangeTest)
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index 0c95f4e..f96c3f9 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -8,9 +8,6 @@ import string
import traceback
from weakref import proxy
-from test.test_bigmem import character_size
-
-
# Misc tests from Tim Peters' re.doc
# WARNING: Don't change details in these tests if you don't know
@@ -496,7 +493,7 @@ class ReTests(unittest.TestCase):
self.assertEqual(m.span(), span)
def test_re_escape(self):
- alnum_chars = string.ascii_letters + string.digits
+ alnum_chars = string.ascii_letters + string.digits + '_'
p = ''.join(chr(i) for i in range(256))
for c in p:
if c in alnum_chars:
@@ -509,7 +506,7 @@ class ReTests(unittest.TestCase):
self.assertMatch(re.escape(p), p)
def test_re_escape_byte(self):
- alnum_chars = (string.ascii_letters + string.digits).encode('ascii')
+ alnum_chars = (string.ascii_letters + string.digits + '_').encode('ascii')
p = bytes(range(256))
for i in p:
b = bytes([i])
@@ -556,24 +553,92 @@ class ReTests(unittest.TestCase):
self.assertNotEqual(re.compile('^pattern$', flag), None)
def test_sre_character_literals(self):
- for i in [0, 8, 16, 32, 64, 127, 128, 255]:
- self.assertNotEqual(re.match(r"\%03o" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"\%03o0" % i, chr(i)+"0"), None)
- self.assertNotEqual(re.match(r"\%03o8" % i, chr(i)+"8"), None)
- self.assertNotEqual(re.match(r"\x%02x" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"\x%02x0" % i, chr(i)+"0"), None)
- self.assertNotEqual(re.match(r"\x%02xz" % i, chr(i)+"z"), None)
- self.assertRaises(re.error, re.match, "\911", "")
+ for i in [0, 8, 16, 32, 64, 127, 128, 255, 256, 0xFFFF, 0x10000, 0x10FFFF]:
+ if i < 256:
+ self.assertIsNotNone(re.match(r"\%03o" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"\%03o0" % i, chr(i)+"0"))
+ self.assertIsNotNone(re.match(r"\%03o8" % i, chr(i)+"8"))
+ self.assertIsNotNone(re.match(r"\x%02x" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"\x%02x0" % i, chr(i)+"0"))
+ self.assertIsNotNone(re.match(r"\x%02xz" % i, chr(i)+"z"))
+ if i < 0x10000:
+ self.assertIsNotNone(re.match(r"\u%04x" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"\u%04x0" % i, chr(i)+"0"))
+ self.assertIsNotNone(re.match(r"\u%04xz" % i, chr(i)+"z"))
+ self.assertIsNotNone(re.match(r"\U%08x" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"\U%08x0" % i, chr(i)+"0"))
+ self.assertIsNotNone(re.match(r"\U%08xz" % i, chr(i)+"z"))
+ self.assertIsNotNone(re.match(r"\0", "\000"))
+ self.assertIsNotNone(re.match(r"\08", "\0008"))
+ self.assertIsNotNone(re.match(r"\01", "\001"))
+ self.assertIsNotNone(re.match(r"\018", "\0018"))
+ self.assertIsNotNone(re.match(r"\567", chr(0o167)))
+ self.assertRaises(re.error, re.match, r"\911", "")
+ self.assertRaises(re.error, re.match, r"\x1", "")
+ self.assertRaises(re.error, re.match, r"\x1z", "")
+ self.assertRaises(re.error, re.match, r"\u123", "")
+ self.assertRaises(re.error, re.match, r"\u123z", "")
+ self.assertRaises(re.error, re.match, r"\U0001234", "")
+ self.assertRaises(re.error, re.match, r"\U0001234z", "")
+ self.assertRaises(re.error, re.match, r"\U00110000", "")
def test_sre_character_class_literals(self):
+ for i in [0, 8, 16, 32, 64, 127, 128, 255, 256, 0xFFFF, 0x10000, 0x10FFFF]:
+ if i < 256:
+ self.assertIsNotNone(re.match(r"[\%o]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\%o8]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\%03o]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\%03o0]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\%03o8]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\x%02x]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\x%02x0]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\x%02xz]" % i, chr(i)))
+ if i < 0x10000:
+ self.assertIsNotNone(re.match(r"[\u%04x]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\u%04x0]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\u%04xz]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\U%08x]" % i, chr(i)))
+ self.assertIsNotNone(re.match(r"[\U%08x0]" % i, chr(i)+"0"))
+ self.assertIsNotNone(re.match(r"[\U%08xz]" % i, chr(i)+"z"))
+ self.assertRaises(re.error, re.match, r"[\911]", "")
+ self.assertRaises(re.error, re.match, r"[\x1z]", "")
+ self.assertRaises(re.error, re.match, r"[\u123z]", "")
+ self.assertRaises(re.error, re.match, r"[\U0001234z]", "")
+ self.assertRaises(re.error, re.match, r"[\U00110000]", "")
+
+ def test_sre_byte_literals(self):
+ for i in [0, 8, 16, 32, 64, 127, 128, 255]:
+ self.assertIsNotNone(re.match((r"\%03o" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"\%03o0" % i).encode(), bytes([i])+b"0"))
+ self.assertIsNotNone(re.match((r"\%03o8" % i).encode(), bytes([i])+b"8"))
+ self.assertIsNotNone(re.match((r"\x%02x" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"\x%02x0" % i).encode(), bytes([i])+b"0"))
+ self.assertIsNotNone(re.match((r"\x%02xz" % i).encode(), bytes([i])+b"z"))
+ self.assertIsNotNone(re.match(br"\u", b'u'))
+ self.assertIsNotNone(re.match(br"\U", b'U'))
+ self.assertIsNotNone(re.match(br"\0", b"\000"))
+ self.assertIsNotNone(re.match(br"\08", b"\0008"))
+ self.assertIsNotNone(re.match(br"\01", b"\001"))
+ self.assertIsNotNone(re.match(br"\018", b"\0018"))
+ self.assertIsNotNone(re.match(br"\567", bytes([0o167])))
+ self.assertRaises(re.error, re.match, br"\911", b"")
+ self.assertRaises(re.error, re.match, br"\x1", b"")
+ self.assertRaises(re.error, re.match, br"\x1z", b"")
+
+ def test_sre_byte_class_literals(self):
for i in [0, 8, 16, 32, 64, 127, 128, 255]:
- self.assertNotEqual(re.match(r"[\%03o]" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"[\%03o0]" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"[\%03o8]" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"[\x%02x]" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"[\x%02x0]" % i, chr(i)), None)
- self.assertNotEqual(re.match(r"[\x%02xz]" % i, chr(i)), None)
- self.assertRaises(re.error, re.match, "[\911]", "")
+ self.assertIsNotNone(re.match((r"[\%o]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\%o8]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\%03o]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\%03o0]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\%03o8]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\x%02x]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\x%02x0]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match((r"[\x%02xz]" % i).encode(), bytes([i])))
+ self.assertIsNotNone(re.match(br"[\u]", b'u'))
+ self.assertIsNotNone(re.match(br"[\U]", b'U'))
+ self.assertRaises(re.error, re.match, br"[\911]", "")
+ self.assertRaises(re.error, re.match, br"[\x1z]", "")
def test_bug_113254(self):
self.assertEqual(re.match(r'(a)|(b)', 'b').start(1), -1)
@@ -691,6 +756,26 @@ class ReTests(unittest.TestCase):
self.assertEqual([item.group(0) for item in iter],
[":", "::", ":::"])
+ pat = re.compile(r":+")
+ iter = pat.finditer("a:b::c:::d", 1, 10)
+ self.assertEqual([item.group(0) for item in iter],
+ [":", "::", ":::"])
+
+ pat = re.compile(r":+")
+ iter = pat.finditer("a:b::c:::d", pos=1, endpos=10)
+ self.assertEqual([item.group(0) for item in iter],
+ [":", "::", ":::"])
+
+ pat = re.compile(r":+")
+ iter = pat.finditer("a:b::c:::d", endpos=10, pos=1)
+ self.assertEqual([item.group(0) for item in iter],
+ [":", "::", ":::"])
+
+ pat = re.compile(r":+")
+ iter = pat.finditer("a:b::c:::d", pos=3, endpos=8)
+ self.assertEqual([item.group(0) for item in iter],
+ ["::", "::"])
+
def test_bug_926075(self):
self.assertTrue(re.compile('bug_926075') is not
re.compile(b'bug_926075'))
@@ -857,6 +942,13 @@ class ReTests(unittest.TestCase):
self.assertRaises(OverflowError, _sre.compile, "abc", 0, [long_overflow])
self.assertRaises(TypeError, _sre.compile, {}, 0, [])
+ def test_search_dot_unicode(self):
+ self.assertIsNotNone(re.search("123.*-", '123abc-'))
+ self.assertIsNotNone(re.search("123.*-", '123\xe9-'))
+ self.assertIsNotNone(re.search("123.*-", '123\u20ac-'))
+ self.assertIsNotNone(re.search("123.*-", '123\U0010ffff-'))
+ self.assertIsNotNone(re.search("123.*-", '123\xe9\u20ac\U0010ffff-'))
+
def test_compile(self):
# Test return value when given string and pattern as parameter
pattern = re.compile('random pattern')
@@ -873,7 +965,7 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.findall(r'[\A\B\b\C\Z]', 'AB\bCZ'),
['A', 'B', '\b', 'C', 'Z'])
- @bigmemtest(size=_2G, memuse=character_size)
+ @bigmemtest(size=_2G, memuse=1)
def test_large_search(self, size):
# Issue #10182: indices were 32-bit-truncated.
s = 'a' * size
@@ -884,7 +976,7 @@ class ReTests(unittest.TestCase):
# The huge memuse is because of re.sub() using a list and a join()
# to create the replacement result.
- @bigmemtest(size=_2G, memuse=16 + 2 * character_size)
+ @bigmemtest(size=_2G, memuse=16 + 2)
def test_large_subn(self, size):
# Issue #10182: indices were 32-bit-truncated.
s = 'a' * size
@@ -892,6 +984,11 @@ class ReTests(unittest.TestCase):
self.assertEqual(r, s)
self.assertEqual(n, size + 1)
+ def test_bug_16688(self):
+ # Issue 16688: Backreferences make case-insensitive regex fail on
+ # non-ASCII strings.
+ self.assertEqual(re.findall(r"(?i)(a)\1", "aa \u0100"), ['a'])
+ self.assertEqual(re.match(r"(?s).{1,3}", "\u0100\u0100").span(), (0, 2))
def test_repeat_minmax_overflow(self):
# Issue #13169
diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py
index b0dc4d7..589ecdd 100644
--- a/Lib/test/test_reprlib.py
+++ b/Lib/test/test_reprlib.py
@@ -3,12 +3,14 @@
Nick Mathewson
"""
+import imp
import sys
import os
import shutil
+import importlib
import unittest
-from test.support import run_unittest
+from test.support import run_unittest, create_empty_file, verbose
from reprlib import repr as r # Don't shadow builtin repr
from reprlib import Repr
from reprlib import recursive_repr
@@ -129,8 +131,8 @@ class ReprTests(unittest.TestCase):
self.assertIn(s.find("..."), [12, 13])
def test_lambda(self):
- self.assertTrue(repr(lambda x: x).startswith(
- "<function <lambda"))
+ r = repr(lambda x: x)
+ self.assertTrue(r.startswith("<function ReprTests.test_lambda.<locals>.<lambda"), r)
# XXX anonymous functions? see func_repr
def test_builtin_function(self):
@@ -193,26 +195,29 @@ class ReprTests(unittest.TestCase):
r(y)
r(z)
-def touch(path, text=''):
- fp = open(path, 'w')
- fp.write(text)
- fp.close()
+def write_file(path, text):
+ with open(path, 'w', encoding='ASCII') as fp:
+ fp.write(text)
class LongReprTest(unittest.TestCase):
+ longname = 'areallylongpackageandmodulenametotestreprtruncation'
+
def setUp(self):
- longname = 'areallylongpackageandmodulenametotestreprtruncation'
- self.pkgname = os.path.join(longname)
- self.subpkgname = os.path.join(longname, longname)
+ self.pkgname = os.path.join(self.longname)
+ self.subpkgname = os.path.join(self.longname, self.longname)
# Make the package and subpackage
shutil.rmtree(self.pkgname, ignore_errors=True)
os.mkdir(self.pkgname)
- touch(os.path.join(self.pkgname, '__init__.py'))
+ create_empty_file(os.path.join(self.pkgname, '__init__.py'))
shutil.rmtree(self.subpkgname, ignore_errors=True)
os.mkdir(self.subpkgname)
- touch(os.path.join(self.subpkgname, '__init__.py'))
+ create_empty_file(os.path.join(self.subpkgname, '__init__.py'))
# Remember where we are
self.here = os.getcwd()
sys.path.insert(0, self.here)
+ # When regrtest is run with its -j option, this command alone is not
+ # enough.
+ importlib.invalidate_caches()
def tearDown(self):
actions = []
@@ -229,20 +234,40 @@ class LongReprTest(unittest.TestCase):
os.remove(p)
del sys.path[0]
+ def _check_path_limitations(self, module_name):
+ # base directory
+ source_path_len = len(self.here)
+ # a path separator + `longname` (twice)
+ source_path_len += 2 * (len(self.longname) + 1)
+ # a path separator + `module_name` + ".py"
+ source_path_len += len(module_name) + 1 + len(".py")
+ cached_path_len = source_path_len + len(imp.cache_from_source("x.py")) - len("x.py")
+ if os.name == 'nt' and cached_path_len >= 258:
+ # Under Windows, the max path len is 260 including C's terminating
+ # NUL character.
+ # (see http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx#maxpath)
+ self.skipTest("test paths too long (%d characters) for Windows' 260 character limit"
+ % cached_path_len)
+ elif os.name == 'nt' and verbose:
+ print("cached_path_len =", cached_path_len)
+
def test_module(self):
- eq = self.assertEqual
- touch(os.path.join(self.subpkgname, self.pkgname + '.py'))
+ self._check_path_limitations(self.pkgname)
+ create_empty_file(os.path.join(self.subpkgname, self.pkgname + '.py'))
+ importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import areallylongpackageandmodulenametotestreprtruncation
- eq(repr(areallylongpackageandmodulenametotestreprtruncation),
- "<module '%s' from '%s'>" % (areallylongpackageandmodulenametotestreprtruncation.__name__, areallylongpackageandmodulenametotestreprtruncation.__file__))
- eq(repr(sys), "<module 'sys' (built-in)>")
+ module = areallylongpackageandmodulenametotestreprtruncation
+ self.assertEqual(repr(module), "<module %r from %r>" % (module.__name__, module.__file__))
+ self.assertEqual(repr(sys), "<module 'sys' (built-in)>")
def test_type(self):
+ self._check_path_limitations('foo')
eq = self.assertEqual
- touch(os.path.join(self.subpkgname, 'foo.py'), '''\
+ write_file(os.path.join(self.subpkgname, 'foo.py'), '''\
class foo(object):
pass
''')
+ importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import foo
eq(repr(foo.foo),
"<class '%s.foo'>" % foo.__name__)
@@ -253,39 +278,46 @@ class foo(object):
pass
def test_class(self):
- touch(os.path.join(self.subpkgname, 'bar.py'), '''\
+ self._check_path_limitations('bar')
+ write_file(os.path.join(self.subpkgname, 'bar.py'), '''\
class bar:
pass
''')
+ importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import bar
# Module name may be prefixed with "test.", depending on how run.
self.assertEqual(repr(bar.bar), "<class '%s.bar'>" % bar.__name__)
def test_instance(self):
- touch(os.path.join(self.subpkgname, 'baz.py'), '''\
+ self._check_path_limitations('baz')
+ write_file(os.path.join(self.subpkgname, 'baz.py'), '''\
class baz:
pass
''')
+ importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import baz
ibaz = baz.baz()
self.assertTrue(repr(ibaz).startswith(
"<%s.baz object at 0x" % baz.__name__))
def test_method(self):
+ self._check_path_limitations('qux')
eq = self.assertEqual
- touch(os.path.join(self.subpkgname, 'qux.py'), '''\
+ write_file(os.path.join(self.subpkgname, 'qux.py'), '''\
class aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:
def amethod(self): pass
''')
+ importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import qux
# Unbound methods first
- self.assertTrue(repr(qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod).startswith(
- '<function amethod'))
+ r = repr(qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod)
+ self.assertTrue(r.startswith('<function aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod'), r)
# Bound method next
iqux = qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa()
- self.assertTrue(repr(iqux.amethod).startswith(
+ r = repr(iqux.amethod)
+ self.assertTrue(r.startswith(
'<bound method aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod of <%s.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa object at 0x' \
- % (qux.__name__,) ))
+ % (qux.__name__,) ), r)
def test_builtin_function(self):
# XXX test built-in functions and methods with really long names
diff --git a/Lib/test/test_richcmp.py b/Lib/test/test_richcmp.py
index f8f3717..0b629dc 100644
--- a/Lib/test/test_richcmp.py
+++ b/Lib/test/test_richcmp.py
@@ -220,6 +220,7 @@ class MiscTest(unittest.TestCase):
for func in (do, operator.not_):
self.assertRaises(Exc, func, Bad())
+ @support.no_tracing
def test_recursion(self):
# Check that comparison for recursive objects fails gracefully
from collections import UserList
diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py
index 96d2beb..fefefc4 100644
--- a/Lib/test/test_runpy.py
+++ b/Lib/test/test_runpy.py
@@ -5,11 +5,15 @@ import os.path
import sys
import re
import tempfile
+import importlib
import py_compile
-from test.support import forget, make_legacy_pyc, run_unittest, unload, verbose
+from test.support import (
+ forget, make_legacy_pyc, run_unittest, unload, verbose, no_tracing,
+ create_empty_file)
from test.script_helper import (
make_pkg, make_script, make_zip_pkg, make_zip_script, temp_dir)
+
import runpy
from runpy import _run_code, _run_module_code, run_module, run_path
# Note: This module can't safely test _run_module_as_main as it
@@ -145,7 +149,7 @@ class ExecutionLayerTestCase(unittest.TestCase, CodeExecutionMixin):
mod_package)
self.check_code_execution(create_ns, expected_ns)
-
+# TODO: Use self.addCleanup to get rid of a lot of try-finally blocks
class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
"""Unit tests for runpy.run_module"""
@@ -175,8 +179,7 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
def _add_pkg_dir(self, pkg_dir):
os.mkdir(pkg_dir)
pkg_fname = os.path.join(pkg_dir, "__init__.py")
- pkg_file = open(pkg_fname, "w")
- pkg_file.close()
+ create_empty_file(pkg_fname)
return pkg_fname
def _make_pkg(self, source, depth, mod_base="runpy_test"):
@@ -252,10 +255,12 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
try:
if verbose > 1: print("Running from source:", mod_name)
self.check_code_execution(create_ns, expected_ns)
+ importlib.invalidate_caches()
__import__(mod_name)
os.remove(mod_fname)
make_legacy_pyc(mod_fname)
unload(mod_name) # In case loader caches paths
+ importlib.invalidate_caches()
if verbose > 1: print("Running from compiled:", mod_name)
self._fix_ns_for_legacy_pyc(expected_ns, alter_sys)
self.check_code_execution(create_ns, expected_ns)
@@ -285,11 +290,13 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
try:
if verbose > 1: print("Running from source:", pkg_name)
self.check_code_execution(create_ns, expected_ns)
+ importlib.invalidate_caches()
__import__(mod_name)
os.remove(mod_fname)
make_legacy_pyc(mod_fname)
unload(mod_name) # In case loader caches paths
if verbose > 1: print("Running from compiled:", pkg_name)
+ importlib.invalidate_caches()
self._fix_ns_for_legacy_pyc(expected_ns, alter_sys)
self.check_code_execution(create_ns, expected_ns)
finally:
@@ -306,8 +313,7 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
module_dir = os.path.join(module_dir, pkg_name)
# Add sibling module
sibling_fname = os.path.join(module_dir, "sibling.py")
- sibling_file = open(sibling_fname, "w")
- sibling_file.close()
+ create_empty_file(sibling_fname)
if verbose > 1: print(" Added sibling module:", sibling_fname)
# Add nephew module
uncle_dir = os.path.join(parent_dir, "uncle")
@@ -317,8 +323,7 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
self._add_pkg_dir(cousin_dir)
if verbose > 1: print(" Added cousin package:", cousin_dir)
nephew_fname = os.path.join(cousin_dir, "nephew.py")
- nephew_file = open(nephew_fname, "w")
- nephew_file.close()
+ create_empty_file(nephew_fname)
if verbose > 1: print(" Added nephew module:", nephew_fname)
def _check_relative_imports(self, depth, run_name=None):
@@ -343,11 +348,13 @@ from ..uncle.cousin import nephew
self.assertIn("sibling", d1)
self.assertIn("nephew", d1)
del d1 # Ensure __loader__ entry doesn't keep file open
+ importlib.invalidate_caches()
__import__(mod_name)
os.remove(mod_fname)
make_legacy_pyc(mod_fname)
unload(mod_name) # In case the loader caches paths
if verbose > 1: print("Running from compiled:", mod_name)
+ importlib.invalidate_caches()
d2 = run_module(mod_name, run_name=run_name) # Read from bytecode
self.assertEqual(d2["__name__"], expected_name)
self.assertEqual(d2["__package__"], pkg_name)
@@ -407,6 +414,40 @@ from ..uncle.cousin import nephew
finally:
self._del_pkg(pkg_dir, depth, mod_name)
+ def test_pkgutil_walk_packages(self):
+ # This is a dodgy hack to use the test_runpy infrastructure to test
+ # issue #15343. Issue #15348 declares this is indeed a dodgy hack ;)
+ import pkgutil
+ max_depth = 4
+ base_name = "__runpy_pkg__"
+ package_suffixes = ["uncle", "uncle.cousin"]
+ module_suffixes = ["uncle.cousin.nephew", base_name + ".sibling"]
+ expected_packages = set()
+ expected_modules = set()
+ for depth in range(1, max_depth):
+ pkg_name = ".".join([base_name] * depth)
+ expected_packages.add(pkg_name)
+ for name in package_suffixes:
+ expected_packages.add(pkg_name + "." + name)
+ for name in module_suffixes:
+ expected_modules.add(pkg_name + "." + name)
+ pkg_name = ".".join([base_name] * max_depth)
+ expected_packages.add(pkg_name)
+ expected_modules.add(pkg_name + ".runpy_test")
+ pkg_dir, mod_fname, mod_name = (
+ self._make_pkg("", max_depth))
+ self.addCleanup(self._del_pkg, pkg_dir, max_depth, mod_name)
+ for depth in range(2, max_depth+1):
+ self._add_relative_modules(pkg_dir, "", depth)
+ for finder, mod_name, ispkg in pkgutil.walk_packages([pkg_dir]):
+ self.assertIsInstance(finder,
+ importlib.machinery.FileFinder)
+ if ispkg:
+ expected_packages.remove(mod_name)
+ else:
+ expected_modules.remove(mod_name)
+ self.assertEqual(len(expected_packages), 0, expected_packages)
+ self.assertEqual(len(expected_modules), 0, expected_modules)
class RunPathTestCase(unittest.TestCase, CodeExecutionMixin):
"""Unit tests for runpy.run_path"""
@@ -507,6 +548,7 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin):
msg = "can't find '__main__' module in %r" % zip_name
self._check_import_error(zip_name, msg)
+ @no_tracing
def test_main_recursion_error(self):
with temp_dir() as script_dir, temp_dir() as dummy_dir:
mod_name = '__main__'
diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py
index 218d519..98cc58e 100644
--- a/Lib/test/test_sax.py
+++ b/Lib/test/test_sax.py
@@ -23,8 +23,8 @@ import unittest
TEST_XMLFILE = findfile("test.xml", subdir="xmltestdata")
TEST_XMLFILE_OUT = findfile("test.xml.out", subdir="xmltestdata")
try:
- TEST_XMLFILE.encode("utf8")
- TEST_XMLFILE_OUT.encode("utf8")
+ TEST_XMLFILE.encode("utf-8")
+ TEST_XMLFILE_OUT.encode("utf-8")
except UnicodeEncodeError:
raise unittest.SkipTest("filename is not encodable to utf8")
diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py
index 91b8f0c..1fe6ad4 100644
--- a/Lib/test/test_sched.py
+++ b/Lib/test/test_sched.py
@@ -1,9 +1,44 @@
#!/usr/bin/env python
+import queue
import sched
import time
import unittest
from test import support
+try:
+ import threading
+except ImportError:
+ threading = None
+
+TIMEOUT = 10
+
+
+class Timer:
+ def __init__(self):
+ self._cond = threading.Condition()
+ self._time = 0
+ self._stop = 0
+
+ def time(self):
+ with self._cond:
+ return self._time
+
+ # increase the time but not beyond the established limit
+ def sleep(self, t):
+ assert t >= 0
+ with self._cond:
+ t += self._time
+ while self._stop < t:
+ self._time = self._stop
+ self._cond.wait()
+ self._time = t
+
+ # advance time limit for user code
+ def advance(self, t):
+ assert t >= 0
+ with self._cond:
+ self._stop += t
+ self._cond.notify_all()
class TestCase(unittest.TestCase):
@@ -26,6 +61,37 @@ class TestCase(unittest.TestCase):
scheduler.run()
self.assertEqual(l, [0.01, 0.02, 0.03, 0.04, 0.05])
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_enter_concurrent(self):
+ q = queue.Queue()
+ fun = q.put
+ timer = Timer()
+ scheduler = sched.scheduler(timer.time, timer.sleep)
+ scheduler.enter(1, 1, fun, (1,))
+ scheduler.enter(3, 1, fun, (3,))
+ t = threading.Thread(target=scheduler.run)
+ t.start()
+ timer.advance(1)
+ self.assertEqual(q.get(timeout=TIMEOUT), 1)
+ self.assertTrue(q.empty())
+ for x in [4, 5, 2]:
+ z = scheduler.enter(x - 1, 1, fun, (x,))
+ timer.advance(2)
+ self.assertEqual(q.get(timeout=TIMEOUT), 2)
+ self.assertEqual(q.get(timeout=TIMEOUT), 3)
+ self.assertTrue(q.empty())
+ timer.advance(1)
+ self.assertEqual(q.get(timeout=TIMEOUT), 4)
+ self.assertTrue(q.empty())
+ timer.advance(1)
+ self.assertEqual(q.get(timeout=TIMEOUT), 5)
+ self.assertTrue(q.empty())
+ timer.advance(1000)
+ t.join(timeout=TIMEOUT)
+ self.assertFalse(t.is_alive())
+ self.assertTrue(q.empty())
+ self.assertEqual(timer.time(), 5)
+
def test_priority(self):
l = []
fun = lambda x: l.append(x)
@@ -50,6 +116,39 @@ class TestCase(unittest.TestCase):
scheduler.run()
self.assertEqual(l, [0.02, 0.03, 0.04])
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_cancel_concurrent(self):
+ q = queue.Queue()
+ fun = q.put
+ timer = Timer()
+ scheduler = sched.scheduler(timer.time, timer.sleep)
+ now = timer.time()
+ event1 = scheduler.enterabs(now + 1, 1, fun, (1,))
+ event2 = scheduler.enterabs(now + 2, 1, fun, (2,))
+ event4 = scheduler.enterabs(now + 4, 1, fun, (4,))
+ event5 = scheduler.enterabs(now + 5, 1, fun, (5,))
+ event3 = scheduler.enterabs(now + 3, 1, fun, (3,))
+ t = threading.Thread(target=scheduler.run)
+ t.start()
+ timer.advance(1)
+ self.assertEqual(q.get(timeout=TIMEOUT), 1)
+ self.assertTrue(q.empty())
+ scheduler.cancel(event2)
+ scheduler.cancel(event5)
+ timer.advance(1)
+ self.assertTrue(q.empty())
+ timer.advance(1)
+ self.assertEqual(q.get(timeout=TIMEOUT), 3)
+ self.assertTrue(q.empty())
+ timer.advance(1)
+ self.assertEqual(q.get(timeout=TIMEOUT), 4)
+ self.assertTrue(q.empty())
+ timer.advance(1000)
+ t.join(timeout=TIMEOUT)
+ self.assertFalse(t.is_alive())
+ self.assertTrue(q.empty())
+ self.assertEqual(timer.time(), 4)
+
def test_empty(self):
l = []
fun = lambda x: l.append(x)
@@ -63,15 +162,39 @@ class TestCase(unittest.TestCase):
def test_queue(self):
l = []
- events = []
fun = lambda x: l.append(x)
scheduler = sched.scheduler(time.time, time.sleep)
- self.assertEqual(scheduler._queue, [])
- for x in [0.05, 0.04, 0.03, 0.02, 0.01]:
- events.append(scheduler.enterabs(x, 1, fun, (x,)))
- self.assertEqual(scheduler._queue.sort(), events.sort())
+ now = time.time()
+ e5 = scheduler.enterabs(now + 0.05, 1, fun)
+ e1 = scheduler.enterabs(now + 0.01, 1, fun)
+ e2 = scheduler.enterabs(now + 0.02, 1, fun)
+ e4 = scheduler.enterabs(now + 0.04, 1, fun)
+ e3 = scheduler.enterabs(now + 0.03, 1, fun)
+ # queue property is supposed to return an order list of
+ # upcoming events
+ self.assertEqual(list(scheduler.queue), [e1, e2, e3, e4, e5])
+
+ def test_args_kwargs(self):
+ flag = []
+
+ def fun(*a, **b):
+ flag.append(None)
+ self.assertEqual(a, (1,2,3))
+ self.assertEqual(b, {"foo":1})
+
+ scheduler = sched.scheduler(time.time, time.sleep)
+ z = scheduler.enterabs(0.01, 1, fun, argument=(1,2,3), kwargs={"foo":1})
scheduler.run()
- self.assertEqual(scheduler._queue, [])
+ self.assertEqual(flag, [None])
+
+ def test_run_non_blocking(self):
+ l = []
+ fun = lambda x: l.append(x)
+ scheduler = sched.scheduler(time.time, time.sleep)
+ for x in [10, 9, 8, 7, 6]:
+ scheduler.enter(x, 1, fun, (x,))
+ scheduler.run(blocking=False)
+ self.assertEqual(l, [])
def test_main():
diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py
index fbc87aa..129a18a 100644
--- a/Lib/test/test_scope.py
+++ b/Lib/test/test_scope.py
@@ -1,5 +1,5 @@
import unittest
-from test.support import check_syntax_error, run_unittest
+from test.support import check_syntax_error, cpython_only, run_unittest
class ScopeTests(unittest.TestCase):
@@ -496,23 +496,22 @@ class ScopeTests(unittest.TestCase):
self.assertNotIn("x", varnames)
self.assertIn("y", varnames)
+ @cpython_only
def testLocalsClass_WithTrace(self):
# Issue23728: after the trace function returns, the locals()
# dictionary is used to update all variables, this used to
# include free variables. But in class statements, free
# variables are not inserted...
import sys
+ self.addCleanup(sys.settrace, sys.gettrace())
sys.settrace(lambda a,b,c:None)
- try:
- x = 12
+ x = 12
- class C:
- def f(self):
- return x
+ class C:
+ def f(self):
+ return x
- self.assertEqual(x, 12) # Used to raise UnboundLocalError
- finally:
- sys.settrace(None)
+ self.assertEqual(x, 12) # Used to raise UnboundLocalError
def testBoundAndFree(self):
# var is bound and free in class
@@ -527,6 +526,7 @@ class ScopeTests(unittest.TestCase):
inst = f(3)()
self.assertEqual(inst.a, inst.m())
+ @cpython_only
def testInteractionWithTraceFunc(self):
import sys
@@ -543,6 +543,7 @@ class ScopeTests(unittest.TestCase):
class TestClass:
pass
+ self.addCleanup(sys.settrace, sys.gettrace())
sys.settrace(tracer)
adaptgetter("foo", TestClass, (1, ""))
sys.settrace(None)
diff --git a/Lib/test/test_select.py b/Lib/test/test_select.py
index 4ddc217..ddb9a0f 100644
--- a/Lib/test/test_select.py
+++ b/Lib/test/test_select.py
@@ -1,8 +1,9 @@
-from test import support
-import unittest
-import select
+import errno
import os
+import select
import sys
+import unittest
+from test import support
@unittest.skipIf(sys.platform[:3] in ('win', 'os2', 'riscos'),
"can't easily test on this system")
@@ -20,6 +21,21 @@ class SelectTestCase(unittest.TestCase):
self.assertRaises(TypeError, select.select, [self.Nope()], [], [])
self.assertRaises(TypeError, select.select, [self.Almost()], [], [])
self.assertRaises(TypeError, select.select, [], [], [], "not a number")
+ self.assertRaises(ValueError, select.select, [], [], [], -1)
+
+ # Issue #12367: http://www.freebsd.org/cgi/query-pr.cgi?pr=kern/155606
+ @unittest.skipIf(sys.platform.startswith('freebsd'),
+ 'skip because of a FreeBSD bug: kern/155606')
+ def test_errno(self):
+ with open(__file__, 'rb') as fp:
+ fd = fp.fileno()
+ fp.close()
+ try:
+ select.select([fd], [], [], 0)
+ except select.error as err:
+ self.assertEqual(err.errno, errno.EBADF)
+ else:
+ self.fail("exception not raised")
def test_returned_list_identity(self):
# See issue #8329
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 6642440..8e9e587 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -9,6 +9,7 @@ from random import randrange, shuffle
import sys
import warnings
import collections
+import collections.abc
class PassThru(Exception):
pass
@@ -234,6 +235,26 @@ class TestJointOps(unittest.TestCase):
dup = pickle.loads(p)
self.assertEqual(self.s.x, dup.x)
+ def test_iterator_pickling(self):
+ itorg = iter(self.s)
+ data = self.thetype(self.s)
+ d = pickle.dumps(itorg)
+ it = pickle.loads(d)
+ # Set iterators unpickle as list iterators due to the
+ # undefined order of set items.
+ # self.assertEqual(type(itorg), type(it))
+ self.assertTrue(isinstance(it, collections.abc.Iterator))
+ self.assertEqual(self.thetype(it), data)
+
+ it = pickle.loads(d)
+ try:
+ drop = next(it)
+ except StopIteration:
+ return
+ d = pickle.dumps(it)
+ it = pickle.loads(d)
+ self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
+
def test_deepcopy(self):
class Tracer:
def __init__(self, value):
diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py
index 3e73f52..13c1265 100644
--- a/Lib/test/test_shelve.py
+++ b/Lib/test/test_shelve.py
@@ -2,7 +2,7 @@ import unittest
import shelve
import glob
from test import support
-from collections import MutableMapping
+from collections.abc import MutableMapping
from test.test_dbm import dbm_iterator
def L1(s):
@@ -129,8 +129,8 @@ class TestCase(unittest.TestCase):
shelve.Shelf(d)[key] = [1]
self.assertIn(key.encode('utf-8'), d)
# but a different one can be given
- shelve.Shelf(d, keyencoding='latin1')[key] = [1]
- self.assertIn(key.encode('latin1'), d)
+ shelve.Shelf(d, keyencoding='latin-1')[key] = [1]
+ self.assertIn(key.encode('latin-1'), d)
# with all consequences
s = shelve.Shelf(d, keyencoding='ascii')
self.assertRaises(UnicodeEncodeError, s.__setitem__, key, [1])
diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py
index 25e4b6d..d4463f30 100644
--- a/Lib/test/test_shlex.py
+++ b/Lib/test/test_shlex.py
@@ -1,6 +1,7 @@
-import unittest
-import os, sys, io
+import io
import shlex
+import string
+import unittest
from test import support
@@ -173,6 +174,22 @@ class ShlexTest(unittest.TestCase):
"%s: %s != %s" %
(self.data[i][0], l, self.data[i][1:]))
+ def testQuote(self):
+ safeunquoted = string.ascii_letters + string.digits + '@%_-+=:,./'
+ unicode_sample = '\xe9\xe0\xdf' # e + acute accent, a + grave, sharp s
+ unsafe = '"`$\\!' + unicode_sample
+
+ self.assertEqual(shlex.quote(''), "''")
+ self.assertEqual(shlex.quote(safeunquoted), safeunquoted)
+ self.assertEqual(shlex.quote('test file name'), "'test file name'")
+ for u in unsafe:
+ self.assertEqual(shlex.quote('test%sname' % u),
+ "'test%sname'" % u)
+ for u in unsafe:
+ self.assertEqual(shlex.quote("test%s'name'" % u),
+ "'test%s'\"'\"'name'\"'\"''" % u)
+
+
# Allow this test to be used with old shlex.py
if not getattr(shlex, "split", None):
for methname in dir(ShlexTest):
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index a9b4676..3b1b694 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -7,8 +7,9 @@ import sys
import stat
import os
import os.path
-import functools
import errno
+import functools
+import subprocess
from test import support
from test.support import TESTFN
from os.path import splitdrive
@@ -22,7 +23,7 @@ import tarfile
import warnings
from test import support
-from test.support import TESTFN, check_warnings, captured_stdout
+from test.support import TESTFN, check_warnings, captured_stdout, requires_zlib
try:
import bz2
@@ -40,11 +41,6 @@ except ImportError:
UID_GID_SUPPORT = False
try:
- import zlib
-except ImportError:
- zlib = None
-
-try:
import zipfile
ZIP_SUPPORT = True
except ImportError:
@@ -52,7 +48,7 @@ except ImportError:
def _fake_rename(*args, **kwargs):
# Pretend the destination path is on a different filesystem.
- raise OSError()
+ raise OSError(getattr(errno, 'EXDEV', 18), "Invalid cross-device link")
def mock_rename(func):
@functools.wraps(func)
@@ -65,6 +61,31 @@ def mock_rename(func):
os.rename = builtin_rename
return wrap
+def write_file(path, content, binary=False):
+ """Write *content* to a file located at *path*.
+
+ If *path* is a tuple instead of a string, os.path.join will be used to
+ make a path. If *binary* is true, the file will be opened in binary
+ mode.
+ """
+ if isinstance(path, tuple):
+ path = os.path.join(*path)
+ with open(path, 'wb' if binary else 'w') as fp:
+ fp.write(content)
+
+def read_file(path, binary=False):
+ """Return contents from a file located at *path*.
+
+ If *path* is a tuple instead of a string, os.path.join will be used to
+ make a path. If *binary* is true, the file will be opened in binary
+ mode.
+ """
+ if isinstance(path, tuple):
+ path = os.path.join(*path)
+ with open(path, 'rb' if binary else 'r') as fp:
+ return fp.read()
+
+
class TestShutil(unittest.TestCase):
def setUp(self):
@@ -77,19 +98,6 @@ class TestShutil(unittest.TestCase):
d = self.tempdirs.pop()
shutil.rmtree(d, os.name in ('nt', 'cygwin'))
- def write_file(self, path, content='xxx'):
- """Writes a file in the given path.
-
-
- path can be a string or a sequence.
- """
- if isinstance(path, (list, tuple)):
- path = os.path.join(*path)
- f = open(path, 'w')
- try:
- f.write(content)
- finally:
- f.close()
def mkdtemp(self):
"""Create a temporary directory that will be cleaned up.
@@ -100,6 +108,15 @@ class TestShutil(unittest.TestCase):
self.tempdirs.append(d)
return d
+ def test_rmtree_works_on_bytes(self):
+ tmp = self.mkdtemp()
+ victim = os.path.join(tmp, 'killme')
+ os.mkdir(victim)
+ write_file(os.path.join(victim, 'somefile'), 'foo')
+ victim = os.fsencode(victim)
+ self.assertIsInstance(victim, bytes)
+ shutil.rmtree(victim)
+
@support.skip_unless_symlink
def test_rmtree_fails_on_symlink(self):
tmp = self.mkdtemp()
@@ -119,18 +136,40 @@ class TestShutil(unittest.TestCase):
self.assertEqual(errors[0][1], link)
self.assertIsInstance(errors[0][2][1], OSError)
+ @support.skip_unless_symlink
+ def test_rmtree_works_on_symlinks(self):
+ tmp = self.mkdtemp()
+ dir1 = os.path.join(tmp, 'dir1')
+ dir2 = os.path.join(dir1, 'dir2')
+ dir3 = os.path.join(tmp, 'dir3')
+ for d in dir1, dir2, dir3:
+ os.mkdir(d)
+ file1 = os.path.join(tmp, 'file1')
+ write_file(file1, 'foo')
+ link1 = os.path.join(dir1, 'link1')
+ os.symlink(dir2, link1)
+ link2 = os.path.join(dir1, 'link2')
+ os.symlink(dir3, link2)
+ link3 = os.path.join(dir1, 'link3')
+ os.symlink(file1, link3)
+ # make sure symlinks are removed but not followed
+ shutil.rmtree(dir1)
+ self.assertFalse(os.path.exists(dir1))
+ self.assertTrue(os.path.exists(dir3))
+ self.assertTrue(os.path.exists(file1))
+
def test_rmtree_errors(self):
# filename is guaranteed not to exist
filename = tempfile.mktemp()
- self.assertRaises(OSError, shutil.rmtree, filename)
- # test that ignore_errors option is honoured
+ self.assertRaises(FileNotFoundError, shutil.rmtree, filename)
+ # test that ignore_errors option is honored
shutil.rmtree(filename, ignore_errors=True)
# existing file
tmpdir = self.mkdtemp()
- self.write_file((tmpdir, "tstfile"), "")
+ write_file((tmpdir, "tstfile"), "")
filename = os.path.join(tmpdir, "tstfile")
- with self.assertRaises(OSError) as cm:
+ with self.assertRaises(NotADirectoryError) as cm:
shutil.rmtree(filename)
# The reason for this rather odd construct is that Windows sprinkles
# a \*.* at the end of file names. But only sometimes on some buildbots
@@ -147,13 +186,14 @@ class TestShutil(unittest.TestCase):
self.assertEqual(len(errors), 2)
self.assertIs(errors[0][0], os.listdir)
self.assertEqual(errors[0][1], filename)
- self.assertIsInstance(errors[0][2][1], OSError)
+ self.assertIsInstance(errors[0][2][1], NotADirectoryError)
self.assertIn(errors[0][2][1].filename, possible_args)
self.assertIs(errors[1][0], os.rmdir)
self.assertEqual(errors[1][1], filename)
- self.assertIsInstance(errors[1][2][1], OSError)
+ self.assertIsInstance(errors[1][2][1], NotADirectoryError)
self.assertIn(errors[1][2][1].filename, possible_args)
+
# See bug #1071513 for why we don't run this on cygwin
# and bug #1076467 for why we don't run this as root.
if (hasattr(os, 'chmod') and sys.platform[:6] != 'cygwin'
@@ -161,30 +201,34 @@ class TestShutil(unittest.TestCase):
def test_on_error(self):
self.errorState = 0
os.mkdir(TESTFN)
- self.childpath = os.path.join(TESTFN, 'a')
- f = open(self.childpath, 'w')
- f.close()
+ self.addCleanup(shutil.rmtree, TESTFN)
+
+ self.child_file_path = os.path.join(TESTFN, 'a')
+ self.child_dir_path = os.path.join(TESTFN, 'b')
+ support.create_empty_file(self.child_file_path)
+ os.mkdir(self.child_dir_path)
old_dir_mode = os.stat(TESTFN).st_mode
- old_child_mode = os.stat(self.childpath).st_mode
+ old_child_file_mode = os.stat(self.child_file_path).st_mode
+ old_child_dir_mode = os.stat(self.child_dir_path).st_mode
# Make unwritable.
- os.chmod(self.childpath, stat.S_IREAD)
- os.chmod(TESTFN, stat.S_IREAD)
+ new_mode = stat.S_IREAD|stat.S_IEXEC
+ os.chmod(self.child_file_path, new_mode)
+ os.chmod(self.child_dir_path, new_mode)
+ os.chmod(TESTFN, new_mode)
+
+ self.addCleanup(os.chmod, TESTFN, old_dir_mode)
+ self.addCleanup(os.chmod, self.child_file_path, old_child_file_mode)
+ self.addCleanup(os.chmod, self.child_dir_path, old_child_dir_mode)
shutil.rmtree(TESTFN, onerror=self.check_args_to_onerror)
# Test whether onerror has actually been called.
- self.assertEqual(self.errorState, 2,
- "Expected call to onerror function did not happen.")
-
- # Make writable again.
- os.chmod(TESTFN, old_dir_mode)
- os.chmod(self.childpath, old_child_mode)
-
- # Clean up.
- shutil.rmtree(TESTFN)
+ self.assertEqual(self.errorState, 3,
+ "Expected call to onerror function did not "
+ "happen.")
def check_args_to_onerror(self, func, arg, exc):
# test_rmtree_errors deliberately runs rmtree
- # on a directory that is chmod 400, which will fail.
+ # on a directory that is chmod 500, which will fail.
# This function is run when shutil.rmtree fails.
# 99.9% of the time it initially fails to remove
# a file in the directory, so the first time through
@@ -193,99 +237,447 @@ class TestShutil(unittest.TestCase):
# FUSE experienced a failure earlier in the process
# at os.listdir. The first failure may legally
# be either.
- if self.errorState == 0:
- if func is os.remove:
- self.assertEqual(arg, self.childpath)
+ if self.errorState < 2:
+ if func is os.unlink:
+ self.assertEqual(arg, self.child_file_path)
+ elif func is os.rmdir:
+ self.assertEqual(arg, self.child_dir_path)
else:
- self.assertIs(func, os.listdir,
- "func must be either os.remove or os.listdir")
- self.assertEqual(arg, TESTFN)
+ self.assertIs(func, os.listdir)
+ self.assertIn(arg, [TESTFN, self.child_dir_path])
self.assertTrue(issubclass(exc[0], OSError))
- self.errorState = 1
+ self.errorState += 1
else:
self.assertEqual(func, os.rmdir)
self.assertEqual(arg, TESTFN)
self.assertTrue(issubclass(exc[0], OSError))
- self.errorState = 2
+ self.errorState = 3
+
+ def test_rmtree_does_not_choke_on_failing_lstat(self):
+ try:
+ orig_lstat = os.lstat
+ def raiser(fn, *args, **kwargs):
+ if fn != TESTFN:
+ raise OSError()
+ else:
+ return orig_lstat(fn)
+ os.lstat = raiser
+
+ os.mkdir(TESTFN)
+ write_file((TESTFN, 'foo'), 'foo')
+ shutil.rmtree(TESTFN)
+ finally:
+ os.lstat = orig_lstat
+
+ @unittest.skipUnless(hasattr(os, 'chmod'), 'requires os.chmod')
+ @support.skip_unless_symlink
+ def test_copymode_follow_symlinks(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ src_link = os.path.join(tmp_dir, 'baz')
+ dst_link = os.path.join(tmp_dir, 'quux')
+ write_file(src, 'foo')
+ write_file(dst, 'foo')
+ os.symlink(src, src_link)
+ os.symlink(dst, dst_link)
+ os.chmod(src, stat.S_IRWXU|stat.S_IRWXG)
+ # file to file
+ os.chmod(dst, stat.S_IRWXO)
+ self.assertNotEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+ shutil.copymode(src, dst)
+ self.assertEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+ # follow src link
+ os.chmod(dst, stat.S_IRWXO)
+ shutil.copymode(src_link, dst)
+ self.assertEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+ # follow dst link
+ os.chmod(dst, stat.S_IRWXO)
+ shutil.copymode(src, dst_link)
+ self.assertEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+ # follow both links
+ os.chmod(dst, stat.S_IRWXO)
+ shutil.copymode(src_link, dst)
+ self.assertEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+
+ @unittest.skipUnless(hasattr(os, 'lchmod'), 'requires os.lchmod')
+ @support.skip_unless_symlink
+ def test_copymode_symlink_to_symlink(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ src_link = os.path.join(tmp_dir, 'baz')
+ dst_link = os.path.join(tmp_dir, 'quux')
+ write_file(src, 'foo')
+ write_file(dst, 'foo')
+ os.symlink(src, src_link)
+ os.symlink(dst, dst_link)
+ os.chmod(src, stat.S_IRWXU|stat.S_IRWXG)
+ os.chmod(dst, stat.S_IRWXU)
+ os.lchmod(src_link, stat.S_IRWXO|stat.S_IRWXG)
+ # link to link
+ os.lchmod(dst_link, stat.S_IRWXO)
+ shutil.copymode(src_link, dst_link, follow_symlinks=False)
+ self.assertEqual(os.lstat(src_link).st_mode,
+ os.lstat(dst_link).st_mode)
+ self.assertNotEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+ # src link - use chmod
+ os.lchmod(dst_link, stat.S_IRWXO)
+ shutil.copymode(src_link, dst, follow_symlinks=False)
+ self.assertEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+ # dst link - use chmod
+ os.lchmod(dst_link, stat.S_IRWXO)
+ shutil.copymode(src, dst_link, follow_symlinks=False)
+ self.assertEqual(os.stat(src).st_mode, os.stat(dst).st_mode)
+
+ @unittest.skipIf(hasattr(os, 'lchmod'), 'requires os.lchmod to be missing')
+ @support.skip_unless_symlink
+ def test_copymode_symlink_to_symlink_wo_lchmod(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ src_link = os.path.join(tmp_dir, 'baz')
+ dst_link = os.path.join(tmp_dir, 'quux')
+ write_file(src, 'foo')
+ write_file(dst, 'foo')
+ os.symlink(src, src_link)
+ os.symlink(dst, dst_link)
+ shutil.copymode(src_link, dst_link, follow_symlinks=False) # silent fail
+
+ @support.skip_unless_symlink
+ def test_copystat_symlinks(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ src_link = os.path.join(tmp_dir, 'baz')
+ dst_link = os.path.join(tmp_dir, 'qux')
+ write_file(src, 'foo')
+ src_stat = os.stat(src)
+ os.utime(src, (src_stat.st_atime,
+ src_stat.st_mtime - 42.0)) # ensure different mtimes
+ write_file(dst, 'bar')
+ self.assertNotEqual(os.stat(src).st_mtime, os.stat(dst).st_mtime)
+ os.symlink(src, src_link)
+ os.symlink(dst, dst_link)
+ if hasattr(os, 'lchmod'):
+ os.lchmod(src_link, stat.S_IRWXO)
+ if hasattr(os, 'lchflags') and hasattr(stat, 'UF_NODUMP'):
+ os.lchflags(src_link, stat.UF_NODUMP)
+ src_link_stat = os.lstat(src_link)
+ # follow
+ if hasattr(os, 'lchmod'):
+ shutil.copystat(src_link, dst_link, follow_symlinks=True)
+ self.assertNotEqual(src_link_stat.st_mode, os.stat(dst).st_mode)
+ # don't follow
+ shutil.copystat(src_link, dst_link, follow_symlinks=False)
+ dst_link_stat = os.lstat(dst_link)
+ if os.utime in os.supports_follow_symlinks:
+ for attr in 'st_atime', 'st_mtime':
+ # The modification times may be truncated in the new file.
+ self.assertLessEqual(getattr(src_link_stat, attr),
+ getattr(dst_link_stat, attr) + 1)
+ if hasattr(os, 'lchmod'):
+ self.assertEqual(src_link_stat.st_mode, dst_link_stat.st_mode)
+ if hasattr(os, 'lchflags') and hasattr(src_link_stat, 'st_flags'):
+ self.assertEqual(src_link_stat.st_flags, dst_link_stat.st_flags)
+ # tell to follow but dst is not a link
+ shutil.copystat(src_link, dst, follow_symlinks=False)
+ self.assertTrue(abs(os.stat(src).st_mtime - os.stat(dst).st_mtime) <
+ 00000.1)
+
+ @unittest.skipUnless(hasattr(os, 'chflags') and
+ hasattr(errno, 'EOPNOTSUPP') and
+ hasattr(errno, 'ENOTSUP'),
+ "requires os.chflags, EOPNOTSUPP & ENOTSUP")
+ def test_copystat_handles_harmless_chflags_errors(self):
+ tmpdir = self.mkdtemp()
+ file1 = os.path.join(tmpdir, 'file1')
+ file2 = os.path.join(tmpdir, 'file2')
+ write_file(file1, 'xxx')
+ write_file(file2, 'xxx')
+
+ def make_chflags_raiser(err):
+ ex = OSError()
+
+ def _chflags_raiser(path, flags, *, follow_symlinks=True):
+ ex.errno = err
+ raise ex
+ return _chflags_raiser
+ old_chflags = os.chflags
+ try:
+ for err in errno.EOPNOTSUPP, errno.ENOTSUP:
+ os.chflags = make_chflags_raiser(err)
+ shutil.copystat(file1, file2)
+ # assert others errors break it
+ os.chflags = make_chflags_raiser(errno.EOPNOTSUPP + errno.ENOTSUP)
+ self.assertRaises(OSError, shutil.copystat, file1, file2)
+ finally:
+ os.chflags = old_chflags
+
+ @support.skip_unless_xattr
+ def test_copyxattr(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ write_file(src, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ write_file(dst, 'bar')
+
+ # no xattr == no problem
+ shutil._copyxattr(src, dst)
+ # common case
+ os.setxattr(src, 'user.foo', b'42')
+ os.setxattr(src, 'user.bar', b'43')
+ shutil._copyxattr(src, dst)
+ self.assertEqual(os.listxattr(src), os.listxattr(dst))
+ self.assertEqual(
+ os.getxattr(src, 'user.foo'),
+ os.getxattr(dst, 'user.foo'))
+ # check errors don't affect other attrs
+ os.remove(dst)
+ write_file(dst, 'bar')
+ os_error = OSError(errno.EPERM, 'EPERM')
+
+ def _raise_on_user_foo(fname, attr, val, **kwargs):
+ if attr == 'user.foo':
+ raise os_error
+ else:
+ orig_setxattr(fname, attr, val, **kwargs)
+ try:
+ orig_setxattr = os.setxattr
+ os.setxattr = _raise_on_user_foo
+ shutil._copyxattr(src, dst)
+ self.assertIn('user.bar', os.listxattr(dst))
+ finally:
+ os.setxattr = orig_setxattr
+ # the source filesystem not supporting xattrs should be ok, too.
+ def _raise_on_src(fname, *, follow_symlinks=True):
+ if fname == src:
+ raise OSError(errno.ENOTSUP, 'Operation not supported')
+ return orig_listxattr(fname, follow_symlinks=follow_symlinks)
+ try:
+ orig_listxattr = os.listxattr
+ os.listxattr = _raise_on_src
+ shutil._copyxattr(src, dst)
+ finally:
+ os.listxattr = orig_listxattr
+
+ # test that shutil.copystat copies xattrs
+ src = os.path.join(tmp_dir, 'the_original')
+ write_file(src, src)
+ os.setxattr(src, 'user.the_value', b'fiddly')
+ dst = os.path.join(tmp_dir, 'the_copy')
+ write_file(dst, dst)
+ shutil.copystat(src, dst)
+ self.assertEqual(os.getxattr(dst, 'user.the_value'), b'fiddly')
+
+ @support.skip_unless_symlink
+ @support.skip_unless_xattr
+ @unittest.skipUnless(hasattr(os, 'geteuid') and os.geteuid() == 0,
+ 'root privileges required')
+ def test_copyxattr_symlinks(self):
+ # On Linux, it's only possible to access non-user xattr for symlinks;
+ # which in turn require root privileges. This test should be expanded
+ # as soon as other platforms gain support for extended attributes.
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ src_link = os.path.join(tmp_dir, 'baz')
+ write_file(src, 'foo')
+ os.symlink(src, src_link)
+ os.setxattr(src, 'trusted.foo', b'42')
+ os.setxattr(src_link, 'trusted.foo', b'43', follow_symlinks=False)
+ dst = os.path.join(tmp_dir, 'bar')
+ dst_link = os.path.join(tmp_dir, 'qux')
+ write_file(dst, 'bar')
+ os.symlink(dst, dst_link)
+ shutil._copyxattr(src_link, dst_link, follow_symlinks=False)
+ self.assertEqual(os.getxattr(dst_link, 'trusted.foo', follow_symlinks=False), b'43')
+ self.assertRaises(OSError, os.getxattr, dst, 'trusted.foo')
+ shutil._copyxattr(src_link, dst, follow_symlinks=False)
+ self.assertEqual(os.getxattr(dst, 'trusted.foo'), b'43')
+
+ @support.skip_unless_symlink
+ def test_copy_symlinks(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ src_link = os.path.join(tmp_dir, 'baz')
+ write_file(src, 'foo')
+ os.symlink(src, src_link)
+ if hasattr(os, 'lchmod'):
+ os.lchmod(src_link, stat.S_IRWXU | stat.S_IRWXO)
+ # don't follow
+ shutil.copy(src_link, dst, follow_symlinks=True)
+ self.assertFalse(os.path.islink(dst))
+ self.assertEqual(read_file(src), read_file(dst))
+ os.remove(dst)
+ # follow
+ shutil.copy(src_link, dst, follow_symlinks=False)
+ self.assertTrue(os.path.islink(dst))
+ self.assertEqual(os.readlink(dst), os.readlink(src_link))
+ if hasattr(os, 'lchmod'):
+ self.assertEqual(os.lstat(src_link).st_mode,
+ os.lstat(dst).st_mode)
+
+ @support.skip_unless_symlink
+ def test_copy2_symlinks(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ src_link = os.path.join(tmp_dir, 'baz')
+ write_file(src, 'foo')
+ os.symlink(src, src_link)
+ if hasattr(os, 'lchmod'):
+ os.lchmod(src_link, stat.S_IRWXU | stat.S_IRWXO)
+ if hasattr(os, 'lchflags') and hasattr(stat, 'UF_NODUMP'):
+ os.lchflags(src_link, stat.UF_NODUMP)
+ src_stat = os.stat(src)
+ src_link_stat = os.lstat(src_link)
+ # follow
+ shutil.copy2(src_link, dst, follow_symlinks=True)
+ self.assertFalse(os.path.islink(dst))
+ self.assertEqual(read_file(src), read_file(dst))
+ os.remove(dst)
+ # don't follow
+ shutil.copy2(src_link, dst, follow_symlinks=False)
+ self.assertTrue(os.path.islink(dst))
+ self.assertEqual(os.readlink(dst), os.readlink(src_link))
+ dst_stat = os.lstat(dst)
+ if os.utime in os.supports_follow_symlinks:
+ for attr in 'st_atime', 'st_mtime':
+ # The modification times may be truncated in the new file.
+ self.assertLessEqual(getattr(src_link_stat, attr),
+ getattr(dst_stat, attr) + 1)
+ if hasattr(os, 'lchmod'):
+ self.assertEqual(src_link_stat.st_mode, dst_stat.st_mode)
+ self.assertNotEqual(src_stat.st_mode, dst_stat.st_mode)
+ if hasattr(os, 'lchflags') and hasattr(src_link_stat, 'st_flags'):
+ self.assertEqual(src_link_stat.st_flags, dst_stat.st_flags)
+
+ @support.skip_unless_xattr
+ def test_copy2_xattr(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'foo')
+ dst = os.path.join(tmp_dir, 'bar')
+ write_file(src, 'foo')
+ os.setxattr(src, 'user.foo', b'42')
+ shutil.copy2(src, dst)
+ self.assertEqual(
+ os.getxattr(src, 'user.foo'),
+ os.getxattr(dst, 'user.foo'))
+ os.remove(dst)
+
+ @support.skip_unless_symlink
+ def test_copyfile_symlinks(self):
+ tmp_dir = self.mkdtemp()
+ src = os.path.join(tmp_dir, 'src')
+ dst = os.path.join(tmp_dir, 'dst')
+ dst_link = os.path.join(tmp_dir, 'dst_link')
+ link = os.path.join(tmp_dir, 'link')
+ write_file(src, 'foo')
+ os.symlink(src, link)
+ # don't follow
+ shutil.copyfile(link, dst_link, follow_symlinks=False)
+ self.assertTrue(os.path.islink(dst_link))
+ self.assertEqual(os.readlink(link), os.readlink(dst_link))
+ # follow
+ shutil.copyfile(link, dst)
+ self.assertFalse(os.path.islink(dst))
+
+ def test_rmtree_uses_safe_fd_version_if_available(self):
+ _use_fd_functions = ({os.open, os.stat, os.unlink, os.rmdir} <=
+ os.supports_dir_fd and
+ os.listdir in os.supports_fd and
+ os.stat in os.supports_follow_symlinks)
+ if _use_fd_functions:
+ self.assertTrue(shutil._use_fd_functions)
+ self.assertTrue(shutil.rmtree.avoids_symlink_attacks)
+ tmp_dir = self.mkdtemp()
+ d = os.path.join(tmp_dir, 'a')
+ os.mkdir(d)
+ try:
+ real_rmtree = shutil._rmtree_safe_fd
+ class Called(Exception): pass
+ def _raiser(*args, **kwargs):
+ raise Called
+ shutil._rmtree_safe_fd = _raiser
+ self.assertRaises(Called, shutil.rmtree, d)
+ finally:
+ shutil._rmtree_safe_fd = real_rmtree
+ else:
+ self.assertFalse(shutil._use_fd_functions)
+ self.assertFalse(shutil.rmtree.avoids_symlink_attacks)
def test_rmtree_dont_delete_file(self):
# When called on a file instead of a directory, don't delete it.
handle, path = tempfile.mkstemp()
- os.fdopen(handle).close()
- self.assertRaises(OSError, shutil.rmtree, path)
+ os.close(handle)
+ self.assertRaises(NotADirectoryError, shutil.rmtree, path)
os.remove(path)
- def _write_data(self, path, data):
- f = open(path, "w")
- f.write(data)
- f.close()
-
def test_copytree_simple(self):
-
- def read_data(path):
- f = open(path)
- data = f.read()
- f.close()
- return data
-
src_dir = tempfile.mkdtemp()
dst_dir = os.path.join(tempfile.mkdtemp(), 'destination')
- self._write_data(os.path.join(src_dir, 'test.txt'), '123')
+ self.addCleanup(shutil.rmtree, src_dir)
+ self.addCleanup(shutil.rmtree, os.path.dirname(dst_dir))
+ write_file((src_dir, 'test.txt'), '123')
os.mkdir(os.path.join(src_dir, 'test_dir'))
- self._write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456')
+ write_file((src_dir, 'test_dir', 'test.txt'), '456')
+
+ shutil.copytree(src_dir, dst_dir)
+ self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test.txt')))
+ self.assertTrue(os.path.isdir(os.path.join(dst_dir, 'test_dir')))
+ self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test_dir',
+ 'test.txt')))
+ actual = read_file((dst_dir, 'test.txt'))
+ self.assertEqual(actual, '123')
+ actual = read_file((dst_dir, 'test_dir', 'test.txt'))
+ self.assertEqual(actual, '456')
- try:
- shutil.copytree(src_dir, dst_dir)
- self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test.txt')))
- self.assertTrue(os.path.isdir(os.path.join(dst_dir, 'test_dir')))
- self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test_dir',
- 'test.txt')))
- actual = read_data(os.path.join(dst_dir, 'test.txt'))
- self.assertEqual(actual, '123')
- actual = read_data(os.path.join(dst_dir, 'test_dir', 'test.txt'))
- self.assertEqual(actual, '456')
- finally:
- for path in (
- os.path.join(src_dir, 'test.txt'),
- os.path.join(dst_dir, 'test.txt'),
- os.path.join(src_dir, 'test_dir', 'test.txt'),
- os.path.join(dst_dir, 'test_dir', 'test.txt'),
- ):
- if os.path.exists(path):
- os.remove(path)
- for path in (src_dir,
- os.path.dirname(dst_dir)
- ):
- if os.path.exists(path):
- shutil.rmtree(path)
+ @support.skip_unless_symlink
+ def test_copytree_symlinks(self):
+ tmp_dir = self.mkdtemp()
+ src_dir = os.path.join(tmp_dir, 'src')
+ dst_dir = os.path.join(tmp_dir, 'dst')
+ sub_dir = os.path.join(src_dir, 'sub')
+ os.mkdir(src_dir)
+ os.mkdir(sub_dir)
+ write_file((src_dir, 'file.txt'), 'foo')
+ src_link = os.path.join(sub_dir, 'link')
+ dst_link = os.path.join(dst_dir, 'sub/link')
+ os.symlink(os.path.join(src_dir, 'file.txt'),
+ src_link)
+ if hasattr(os, 'lchmod'):
+ os.lchmod(src_link, stat.S_IRWXU | stat.S_IRWXO)
+ if hasattr(os, 'lchflags') and hasattr(stat, 'UF_NODUMP'):
+ os.lchflags(src_link, stat.UF_NODUMP)
+ src_stat = os.lstat(src_link)
+ shutil.copytree(src_dir, dst_dir, symlinks=True)
+ self.assertTrue(os.path.islink(os.path.join(dst_dir, 'sub', 'link')))
+ self.assertEqual(os.readlink(os.path.join(dst_dir, 'sub', 'link')),
+ os.path.join(src_dir, 'file.txt'))
+ dst_stat = os.lstat(dst_link)
+ if hasattr(os, 'lchmod'):
+ self.assertEqual(dst_stat.st_mode, src_stat.st_mode)
+ if hasattr(os, 'lchflags'):
+ self.assertEqual(dst_stat.st_flags, src_stat.st_flags)
def test_copytree_with_exclude(self):
-
- def read_data(path):
- f = open(path)
- data = f.read()
- f.close()
- return data
-
# creating data
join = os.path.join
exists = os.path.exists
src_dir = tempfile.mkdtemp()
try:
dst_dir = join(tempfile.mkdtemp(), 'destination')
- self._write_data(join(src_dir, 'test.txt'), '123')
- self._write_data(join(src_dir, 'test.tmp'), '123')
+ write_file((src_dir, 'test.txt'), '123')
+ write_file((src_dir, 'test.tmp'), '123')
os.mkdir(join(src_dir, 'test_dir'))
- self._write_data(join(src_dir, 'test_dir', 'test.txt'), '456')
+ write_file((src_dir, 'test_dir', 'test.txt'), '456')
os.mkdir(join(src_dir, 'test_dir2'))
- self._write_data(join(src_dir, 'test_dir2', 'test.txt'), '456')
+ write_file((src_dir, 'test_dir2', 'test.txt'), '456')
os.mkdir(join(src_dir, 'test_dir2', 'subdir'))
os.mkdir(join(src_dir, 'test_dir2', 'subdir2'))
- self._write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'),
- '456')
- self._write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'),
- '456')
-
+ write_file((src_dir, 'test_dir2', 'subdir', 'test.txt'), '456')
+ write_file((src_dir, 'test_dir2', 'subdir2', 'test.py'), '456')
# testing glob-like patterns
try:
@@ -293,21 +685,19 @@ class TestShutil(unittest.TestCase):
shutil.copytree(src_dir, dst_dir, ignore=patterns)
# checking the result: some elements should not be copied
self.assertTrue(exists(join(dst_dir, 'test.txt')))
- self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
- self.assertTrue(not exists(join(dst_dir, 'test_dir2')))
+ self.assertFalse(exists(join(dst_dir, 'test.tmp')))
+ self.assertFalse(exists(join(dst_dir, 'test_dir2')))
finally:
- if os.path.exists(dst_dir):
- shutil.rmtree(dst_dir)
+ shutil.rmtree(dst_dir)
try:
patterns = shutil.ignore_patterns('*.tmp', 'subdir*')
shutil.copytree(src_dir, dst_dir, ignore=patterns)
# checking the result: some elements should not be copied
- self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
- self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2')))
- self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
+ self.assertFalse(exists(join(dst_dir, 'test.tmp')))
+ self.assertFalse(exists(join(dst_dir, 'test_dir2', 'subdir2')))
+ self.assertFalse(exists(join(dst_dir, 'test_dir2', 'subdir')))
finally:
- if os.path.exists(dst_dir):
- shutil.rmtree(dst_dir)
+ shutil.rmtree(dst_dir)
# testing callable-style
try:
@@ -326,13 +716,12 @@ class TestShutil(unittest.TestCase):
shutil.copytree(src_dir, dst_dir, ignore=_filter)
# checking the result: some elements should not be copied
- self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2',
- 'test.py')))
- self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
+ self.assertFalse(exists(join(dst_dir, 'test_dir2', 'subdir2',
+ 'test.py')))
+ self.assertFalse(exists(join(dst_dir, 'test_dir2', 'subdir')))
finally:
- if os.path.exists(dst_dir):
- shutil.rmtree(dst_dir)
+ shutil.rmtree(dst_dir)
finally:
shutil.rmtree(src_dir)
shutil.rmtree(os.path.dirname(dst_dir))
@@ -357,35 +746,6 @@ class TestShutil(unittest.TestCase):
finally:
shutil.rmtree(TESTFN, ignore_errors=True)
- @unittest.skipUnless(hasattr(os, 'chflags') and
- hasattr(errno, 'EOPNOTSUPP') and
- hasattr(errno, 'ENOTSUP'),
- "requires os.chflags, EOPNOTSUPP & ENOTSUP")
- def test_copystat_handles_harmless_chflags_errors(self):
- tmpdir = self.mkdtemp()
- file1 = os.path.join(tmpdir, 'file1')
- file2 = os.path.join(tmpdir, 'file2')
- self.write_file(file1, 'xxx')
- self.write_file(file2, 'xxx')
-
- def make_chflags_raiser(err):
- ex = OSError()
-
- def _chflags_raiser(path, flags):
- ex.errno = err
- raise ex
- return _chflags_raiser
- old_chflags = os.chflags
- try:
- for err in errno.EOPNOTSUPP, errno.ENOTSUP:
- os.chflags = make_chflags_raiser(err)
- shutil.copystat(file1, file2)
- # assert others errors break it
- os.chflags = make_chflags_raiser(errno.EOPNOTSUPP + errno.ENOTSUP)
- self.assertRaises(OSError, shutil.copystat, file1, file2)
- finally:
- os.chflags = old_chflags
-
@support.skip_unless_symlink
def test_dont_copy_file_onto_symlink_to_itself(self):
# bug 851123.
@@ -416,6 +776,7 @@ class TestShutil(unittest.TestCase):
os.mkdir(src)
os.symlink(src, dst)
self.assertRaises(OSError, shutil.rmtree, dst)
+ shutil.rmtree(dst, ignore_errors=True)
finally:
shutil.rmtree(TESTFN, ignore_errors=True)
@@ -456,9 +817,9 @@ class TestShutil(unittest.TestCase):
src_dir = self.mkdtemp()
dst_dir = os.path.join(self.mkdtemp(), 'destination')
- self._write_data(os.path.join(src_dir, 'test.txt'), '123')
+ write_file((src_dir, 'test.txt'), '123')
os.mkdir(os.path.join(src_dir, 'test_dir'))
- self._write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456')
+ write_file((src_dir, 'test_dir', 'test.txt'), '456')
copied = []
def _copy(src, dst):
@@ -475,7 +836,7 @@ class TestShutil(unittest.TestCase):
dst_dir = os.path.join(self.mkdtemp(), 'destination')
os.symlink('IDONTEXIST', os.path.join(src_dir, 'test.txt'))
os.mkdir(os.path.join(src_dir, 'test_dir'))
- self._write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456')
+ write_file((src_dir, 'test_dir', 'test.txt'), '456')
self.assertRaises(Error, shutil.copytree, src_dir, dst_dir)
# a dangling symlink is ignored with the proper flag
@@ -491,7 +852,7 @@ class TestShutil(unittest.TestCase):
def _copy_file(self, method):
fname = 'test.txt'
tmpdir = self.mkdtemp()
- self.write_file([tmpdir, fname])
+ write_file((tmpdir, fname), 'xxx')
file1 = os.path.join(tmpdir, fname)
tmpdir2 = self.mkdtemp()
method(file1, tmpdir2)
@@ -523,14 +884,14 @@ class TestShutil(unittest.TestCase):
self.assertEqual(getattr(file1_stat, 'st_flags'),
getattr(file2_stat, 'st_flags'))
- @unittest.skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_make_tarball(self):
# creating something to tar
tmpdir = self.mkdtemp()
- self.write_file([tmpdir, 'file1'], 'xxx')
- self.write_file([tmpdir, 'file2'], 'xxx')
+ write_file((tmpdir, 'file1'), 'xxx')
+ write_file((tmpdir, 'file2'), 'xxx')
os.mkdir(os.path.join(tmpdir, 'sub'))
- self.write_file([tmpdir, 'sub', 'file3'], 'xxx')
+ write_file((tmpdir, 'sub', 'file3'), 'xxx')
tmpdir2 = self.mkdtemp()
# force shutil to create the directory
@@ -577,16 +938,16 @@ class TestShutil(unittest.TestCase):
tmpdir = self.mkdtemp()
dist = os.path.join(tmpdir, 'dist')
os.mkdir(dist)
- self.write_file([dist, 'file1'], 'xxx')
- self.write_file([dist, 'file2'], 'xxx')
+ write_file((dist, 'file1'), 'xxx')
+ write_file((dist, 'file2'), 'xxx')
os.mkdir(os.path.join(dist, 'sub'))
- self.write_file([dist, 'sub', 'file3'], 'xxx')
+ write_file((dist, 'sub', 'file3'), 'xxx')
os.mkdir(os.path.join(dist, 'sub2'))
tmpdir2 = self.mkdtemp()
base_name = os.path.join(tmpdir2, 'archive')
return tmpdir, tmpdir2, base_name
- @unittest.skipUnless(zlib, "Requires zlib")
+ @requires_zlib
@unittest.skipUnless(find_executable('tar') and find_executable('gzip'),
'Need the tar command to run')
def test_tarfile_vs_tar(self):
@@ -641,13 +1002,13 @@ class TestShutil(unittest.TestCase):
tarball = base_name + '.tar'
self.assertTrue(os.path.exists(tarball))
- @unittest.skipUnless(zlib, "Requires zlib")
+ @requires_zlib
@unittest.skipUnless(ZIP_SUPPORT, 'Need zip support to run')
def test_make_zipfile(self):
# creating something to tar
tmpdir = self.mkdtemp()
- self.write_file([tmpdir, 'file1'], 'xxx')
- self.write_file([tmpdir, 'file2'], 'xxx')
+ write_file((tmpdir, 'file1'), 'xxx')
+ write_file((tmpdir, 'file2'), 'xxx')
tmpdir2 = self.mkdtemp()
# force shutil to create the directory
@@ -665,7 +1026,7 @@ class TestShutil(unittest.TestCase):
base_name = os.path.join(tmpdir, 'archive')
self.assertRaises(ValueError, make_archive, base_name, 'xxx')
- @unittest.skipUnless(zlib, "Requires zlib")
+ @requires_zlib
def test_make_archive_owner_group(self):
# testing make_archive with owner and group, with various combinations
# this works even if there's not gid/uid support
@@ -693,7 +1054,7 @@ class TestShutil(unittest.TestCase):
self.assertTrue(os.path.exists(res))
- @unittest.skipUnless(zlib, "Requires zlib")
+ @requires_zlib
@unittest.skipUnless(UID_GID_SUPPORT, "Requires grp and pwd support")
def test_tarfile_root_owner(self):
tmpdir, tmpdir2, base_name = self._create_files()
@@ -762,7 +1123,7 @@ class TestShutil(unittest.TestCase):
diff.append(file_)
return diff
- @unittest.skipUnless(zlib, "Requires zlib")
+ @requires_zlib
def test_unpack_archive(self):
formats = ['tar', 'gztar', 'zip']
if BZ2_SUPPORTED:
@@ -813,6 +1174,184 @@ class TestShutil(unittest.TestCase):
unregister_unpack_format('Boo2')
self.assertEqual(get_unpack_formats(), formats)
+ @unittest.skipUnless(hasattr(shutil, 'disk_usage'),
+ "disk_usage not available on this platform")
+ def test_disk_usage(self):
+ usage = shutil.disk_usage(os.getcwd())
+ self.assertGreater(usage.total, 0)
+ self.assertGreater(usage.used, 0)
+ self.assertGreaterEqual(usage.free, 0)
+ self.assertGreaterEqual(usage.total, usage.used)
+ self.assertGreater(usage.total, usage.free)
+
+ @unittest.skipUnless(UID_GID_SUPPORT, "Requires grp and pwd support")
+ @unittest.skipUnless(hasattr(os, 'chown'), 'requires os.chown')
+ def test_chown(self):
+
+ # cleaned-up automatically by TestShutil.tearDown method
+ dirname = self.mkdtemp()
+ filename = tempfile.mktemp(dir=dirname)
+ write_file(filename, 'testing chown function')
+
+ with self.assertRaises(ValueError):
+ shutil.chown(filename)
+
+ with self.assertRaises(LookupError):
+ shutil.chown(filename, user='non-exising username')
+
+ with self.assertRaises(LookupError):
+ shutil.chown(filename, group='non-exising groupname')
+
+ with self.assertRaises(TypeError):
+ shutil.chown(filename, b'spam')
+
+ with self.assertRaises(TypeError):
+ shutil.chown(filename, 3.14)
+
+ uid = os.getuid()
+ gid = os.getgid()
+
+ def check_chown(path, uid=None, gid=None):
+ s = os.stat(filename)
+ if uid is not None:
+ self.assertEqual(uid, s.st_uid)
+ if gid is not None:
+ self.assertEqual(gid, s.st_gid)
+
+ shutil.chown(filename, uid, gid)
+ check_chown(filename, uid, gid)
+ shutil.chown(filename, uid)
+ check_chown(filename, uid)
+ shutil.chown(filename, user=uid)
+ check_chown(filename, uid)
+ shutil.chown(filename, group=gid)
+ check_chown(filename, gid=gid)
+
+ shutil.chown(dirname, uid, gid)
+ check_chown(dirname, uid, gid)
+ shutil.chown(dirname, uid)
+ check_chown(dirname, uid)
+ shutil.chown(dirname, user=uid)
+ check_chown(dirname, uid)
+ shutil.chown(dirname, group=gid)
+ check_chown(dirname, gid=gid)
+
+ user = pwd.getpwuid(uid)[0]
+ group = grp.getgrgid(gid)[0]
+ shutil.chown(filename, user, group)
+ check_chown(filename, uid, gid)
+ shutil.chown(dirname, user, group)
+ check_chown(dirname, uid, gid)
+
+ def test_copy_return_value(self):
+ # copy and copy2 both return their destination path.
+ for fn in (shutil.copy, shutil.copy2):
+ src_dir = self.mkdtemp()
+ dst_dir = self.mkdtemp()
+ src = os.path.join(src_dir, 'foo')
+ write_file(src, 'foo')
+ rv = fn(src, dst_dir)
+ self.assertEqual(rv, os.path.join(dst_dir, 'foo'))
+ rv = fn(src, os.path.join(dst_dir, 'bar'))
+ self.assertEqual(rv, os.path.join(dst_dir, 'bar'))
+
+ def test_copyfile_return_value(self):
+ # copytree returns its destination path.
+ src_dir = self.mkdtemp()
+ dst_dir = self.mkdtemp()
+ dst_file = os.path.join(dst_dir, 'bar')
+ src_file = os.path.join(src_dir, 'foo')
+ write_file(src_file, 'foo')
+ rv = shutil.copyfile(src_file, dst_file)
+ self.assertTrue(os.path.exists(rv))
+ self.assertEqual(read_file(src_file), read_file(dst_file))
+
+ def test_copytree_return_value(self):
+ # copytree returns its destination path.
+ src_dir = self.mkdtemp()
+ dst_dir = src_dir + "dest"
+ self.addCleanup(shutil.rmtree, dst_dir, True)
+ src = os.path.join(src_dir, 'foo')
+ write_file(src, 'foo')
+ rv = shutil.copytree(src_dir, dst_dir)
+ self.assertEqual(['foo'], os.listdir(rv))
+
+
+class TestWhich(unittest.TestCase):
+
+ def setUp(self):
+ self.temp_dir = tempfile.mkdtemp(prefix="Tmp")
+ self.addCleanup(shutil.rmtree, self.temp_dir, True)
+ # Give the temp_file an ".exe" suffix for all.
+ # It's needed on Windows and not harmful on other platforms.
+ self.temp_file = tempfile.NamedTemporaryFile(dir=self.temp_dir,
+ prefix="Tmp",
+ suffix=".Exe")
+ os.chmod(self.temp_file.name, stat.S_IXUSR)
+ self.addCleanup(self.temp_file.close)
+ self.dir, self.file = os.path.split(self.temp_file.name)
+
+ def test_basic(self):
+ # Given an EXE in a directory, it should be returned.
+ rv = shutil.which(self.file, path=self.dir)
+ self.assertEqual(rv, self.temp_file.name)
+
+ def test_absolute_cmd(self):
+ # When given the fully qualified path to an executable that exists,
+ # it should be returned.
+ rv = shutil.which(self.temp_file.name, path=self.temp_dir)
+ self.assertEqual(rv, self.temp_file.name)
+
+ def test_relative_cmd(self):
+ # When given the relative path with a directory part to an executable
+ # that exists, it should be returned.
+ base_dir, tail_dir = os.path.split(self.dir)
+ relpath = os.path.join(tail_dir, self.file)
+ with support.temp_cwd(path=base_dir):
+ rv = shutil.which(relpath, path=self.temp_dir)
+ self.assertEqual(rv, relpath)
+ # But it shouldn't be searched in PATH directories (issue #16957).
+ with support.temp_cwd(path=self.dir):
+ rv = shutil.which(relpath, path=base_dir)
+ self.assertIsNone(rv)
+
+ def test_cwd(self):
+ # Issue #16957
+ base_dir = os.path.dirname(self.dir)
+ with support.temp_cwd(path=self.dir):
+ rv = shutil.which(self.file, path=base_dir)
+ if sys.platform == "win32":
+ # Windows: current directory implicitly on PATH
+ self.assertEqual(rv, os.path.join(os.curdir, self.file))
+ else:
+ # Other platforms: shouldn't match in the current directory.
+ self.assertIsNone(rv)
+
+ def test_non_matching_mode(self):
+ # Set the file read-only and ask for writeable files.
+ os.chmod(self.temp_file.name, stat.S_IREAD)
+ rv = shutil.which(self.file, path=self.dir, mode=os.W_OK)
+ self.assertIsNone(rv)
+
+ def test_relative_path(self):
+ base_dir, tail_dir = os.path.split(self.dir)
+ with support.temp_cwd(path=base_dir):
+ rv = shutil.which(self.file, path=tail_dir)
+ self.assertEqual(rv, os.path.join(tail_dir, self.file))
+
+ def test_nonexistent_file(self):
+ # Return None when no matching executable file is found on the path.
+ rv = shutil.which("foo.exe", path=self.dir)
+ self.assertIsNone(rv)
+
+ @unittest.skipUnless(sys.platform == "win32",
+ "pathext check is Windows-only")
+ def test_pathext_checking(self):
+ # Ask for the file without the ".exe" extension, then ensure that
+ # it gets found properly with the extension.
+ rv = shutil.which(self.file[:-4], path=self.dir)
+ self.assertEqual(rv, self.temp_file.name[:-4] + ".EXE")
+
class TestMove(unittest.TestCase):
@@ -926,6 +1465,58 @@ class TestMove(unittest.TestCase):
finally:
shutil.rmtree(TESTFN, ignore_errors=True)
+ @support.skip_unless_symlink
+ @mock_rename
+ def test_move_file_symlink(self):
+ dst = os.path.join(self.src_dir, 'bar')
+ os.symlink(self.src_file, dst)
+ shutil.move(dst, self.dst_file)
+ self.assertTrue(os.path.islink(self.dst_file))
+ self.assertTrue(os.path.samefile(self.src_file, self.dst_file))
+
+ @support.skip_unless_symlink
+ @mock_rename
+ def test_move_file_symlink_to_dir(self):
+ filename = "bar"
+ dst = os.path.join(self.src_dir, filename)
+ os.symlink(self.src_file, dst)
+ shutil.move(dst, self.dst_dir)
+ final_link = os.path.join(self.dst_dir, filename)
+ self.assertTrue(os.path.islink(final_link))
+ self.assertTrue(os.path.samefile(self.src_file, final_link))
+
+ @support.skip_unless_symlink
+ @mock_rename
+ def test_move_dangling_symlink(self):
+ src = os.path.join(self.src_dir, 'baz')
+ dst = os.path.join(self.src_dir, 'bar')
+ os.symlink(src, dst)
+ dst_link = os.path.join(self.dst_dir, 'quux')
+ shutil.move(dst, dst_link)
+ self.assertTrue(os.path.islink(dst_link))
+ self.assertEqual(os.path.realpath(src), os.path.realpath(dst_link))
+
+ @support.skip_unless_symlink
+ @mock_rename
+ def test_move_dir_symlink(self):
+ src = os.path.join(self.src_dir, 'baz')
+ dst = os.path.join(self.src_dir, 'bar')
+ os.mkdir(src)
+ os.symlink(src, dst)
+ dst_link = os.path.join(self.dst_dir, 'quux')
+ shutil.move(dst, dst_link)
+ self.assertTrue(os.path.islink(dst_link))
+ self.assertTrue(os.path.samefile(src, dst_link))
+
+ def test_move_return_value(self):
+ rv = shutil.move(self.src_file, self.dst_dir)
+ self.assertEqual(rv,
+ os.path.join(self.dst_dir, os.path.basename(self.src_file)))
+
+ def test_move_as_rename_return_value(self):
+ rv = shutil.move(self.src_file, os.path.join(self.dst_dir, 'bar'))
+ self.assertEqual(rv, os.path.join(self.dst_dir, 'bar'))
+
class TestCopyFile(unittest.TestCase):
@@ -1035,6 +1626,7 @@ class TestCopyFile(unittest.TestCase):
# but a different case.
self.src_dir = tempfile.mkdtemp()
+ self.addCleanup(shutil.rmtree, self.src_dir, True)
dst_dir = os.path.join(
os.path.dirname(self.src_dir),
os.path.basename(self.src_dir).upper())
@@ -1044,13 +1636,57 @@ class TestCopyFile(unittest.TestCase):
shutil.move(self.src_dir, dst_dir)
self.assertTrue(os.path.isdir(dst_dir))
finally:
- if os.path.exists(dst_dir):
- os.rmdir(dst_dir)
+ os.rmdir(dst_dir)
+
+class TermsizeTests(unittest.TestCase):
+ def test_does_not_crash(self):
+ """Check if get_terminal_size() returns a meaningful value.
+
+ There's no easy portable way to actually check the size of the
+ terminal, so let's check if it returns something sensible instead.
+ """
+ size = shutil.get_terminal_size()
+ self.assertGreaterEqual(size.columns, 0)
+ self.assertGreaterEqual(size.lines, 0)
+
+ def test_os_environ_first(self):
+ "Check if environment variables have precedence"
+
+ with support.EnvironmentVarGuard() as env:
+ env['COLUMNS'] = '777'
+ size = shutil.get_terminal_size()
+ self.assertEqual(size.columns, 777)
+
+ with support.EnvironmentVarGuard() as env:
+ env['LINES'] = '888'
+ size = shutil.get_terminal_size()
+ self.assertEqual(size.lines, 888)
+
+ @unittest.skipUnless(os.isatty(sys.__stdout__.fileno()), "not on tty")
+ def test_stty_match(self):
+ """Check if stty returns the same results ignoring env
+
+ This test will fail if stdin and stdout are connected to
+ different terminals with different sizes. Nevertheless, such
+ situations should be pretty rare.
+ """
+ try:
+ size = subprocess.check_output(['stty', 'size']).decode().split()
+ except (FileNotFoundError, subprocess.CalledProcessError):
+ self.skipTest("stty invocation failed")
+ expected = (int(size[1]), int(size[0])) # reversed order
+
+ with support.EnvironmentVarGuard() as env:
+ del env['LINES']
+ del env['COLUMNS']
+ actual = shutil.get_terminal_size()
+ self.assertEqual(expected, actual)
def test_main():
- support.run_unittest(TestShutil, TestMove, TestCopyFile)
+ support.run_unittest(TestShutil, TestMove, TestCopyFile,
+ TermsizeTests, TestWhich)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_signal.py b/Lib/test/test_signal.py
index 4a1b4a6..99243df 100644
--- a/Lib/test/test_signal.py
+++ b/Lib/test/test_signal.py
@@ -1,17 +1,19 @@
-import errno
+import unittest
+from test import support
+from contextlib import closing
import gc
-import os
import pickle
import select
import signal
+import struct
import subprocess
-import sys
-import time
import traceback
-import unittest
-from test import support
-from contextlib import closing
+import sys, os, time, errno
from test.script_helper import assert_python_ok, spawn_python
+try:
+ import threading
+except ImportError:
+ threading = None
if sys.platform in ('os2', 'riscos'):
raise unittest.SkipTest("Can't test signal on %s" % sys.platform)
@@ -57,15 +59,9 @@ class InterProcessSignalTests(unittest.TestCase):
def handlerA(self, signum, frame):
self.a_called = True
- if support.verbose:
- print("handlerA invoked from signal %s at:\n%s" % (
- signum, self.format_frame(frame, limit=1)))
def handlerB(self, signum, frame):
self.b_called = True
- if support.verbose:
- print ("handlerB invoked from signal %s at:\n%s" % (
- signum, self.format_frame(frame, limit=1)))
raise HandlerBCalled(signum, self.format_frame(frame))
def wait(self, child):
@@ -92,8 +88,6 @@ class InterProcessSignalTests(unittest.TestCase):
# Let the sub-processes know who to send signals to.
pid = os.getpid()
- if support.verbose:
- print("test runner's pid is", pid)
child = ignoring_eintr(subprocess.Popen, ['kill', '-HUP', str(pid)])
if child:
@@ -117,8 +111,6 @@ class InterProcessSignalTests(unittest.TestCase):
except HandlerBCalled:
self.assertTrue(self.b_called)
self.assertFalse(self.a_called)
- if support.verbose:
- print("HandlerBCalled exception caught")
child = ignoring_eintr(subprocess.Popen, ['kill', '-USR2', str(pid)])
if child:
@@ -134,8 +126,7 @@ class InterProcessSignalTests(unittest.TestCase):
# may return early.
time.sleep(1)
except KeyboardInterrupt:
- if support.verbose:
- print("KeyboardInterrupt (the alarm() went off)")
+ pass
except:
self.fail("Some other exception woke us from pause: %s" %
traceback.format_exc())
@@ -191,7 +182,7 @@ class InterProcessSignalTests(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
-class BasicSignalTests(unittest.TestCase):
+class PosixTests(unittest.TestCase):
def trivial_signal_handler(self, *args):
pass
@@ -231,33 +222,53 @@ class WindowsSignalTests(unittest.TestCase):
signal.signal(7, handler)
+class WakeupFDTests(unittest.TestCase):
+
+ def test_invalid_fd(self):
+ fd = support.make_bad_fd()
+ self.assertRaises(ValueError, signal.set_wakeup_fd, fd)
+
+
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class WakeupSignalTests(unittest.TestCase):
- def check_wakeup(self, test_body):
- # use a subprocess to have only one thread and to not change signal
- # handling of the parent process
+ def check_wakeup(self, test_body, *signals, ordered=True):
+ # use a subprocess to have only one thread
code = """if 1:
import fcntl
import os
import signal
+ import struct
+
+ signals = {!r}
def handler(signum, frame):
pass
+ def check_signum(signals):
+ data = os.read(read, len(signals)+1)
+ raised = struct.unpack('%uB' % len(data), data)
+ if not {!r}:
+ raised = set(raised)
+ signals = set(signals)
+ if raised != signals:
+ raise Exception("%r != %r" % (raised, signals))
+
{}
signal.signal(signal.SIGALRM, handler)
read, write = os.pipe()
- flags = fcntl.fcntl(write, fcntl.F_GETFL, 0)
- flags = flags | os.O_NONBLOCK
- fcntl.fcntl(write, fcntl.F_SETFL, flags)
+ for fd in (read, write):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
signal.set_wakeup_fd(write)
test()
+ check_signum(signals)
os.close(read)
os.close(write)
- """.format(test_body)
+ """.format(signals, ordered, test_body)
assert_python_ok('-c', code)
@@ -283,7 +294,7 @@ class WakeupSignalTests(unittest.TestCase):
dt = after_time - mid_time
if dt >= TIMEOUT_HALF:
raise Exception("%s >= %s" % (dt, TIMEOUT_HALF))
- """)
+ """, signal.SIGALRM)
def test_wakeup_fd_during(self):
self.check_wakeup("""def test():
@@ -306,7 +317,32 @@ class WakeupSignalTests(unittest.TestCase):
dt = after_time - before_time
if dt >= TIMEOUT_HALF:
raise Exception("%s >= %s" % (dt, TIMEOUT_HALF))
- """)
+ """, signal.SIGALRM)
+
+ def test_signum(self):
+ self.check_wakeup("""def test():
+ signal.signal(signal.SIGUSR1, handler)
+ os.kill(os.getpid(), signal.SIGUSR1)
+ os.kill(os.getpid(), signal.SIGALRM)
+ """, signal.SIGUSR1, signal.SIGALRM)
+
+ @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
+ 'need signal.pthread_sigmask()')
+ def test_pending(self):
+ self.check_wakeup("""def test():
+ signum1 = signal.SIGUSR1
+ signum2 = signal.SIGUSR2
+
+ signal.signal(signum1, handler)
+ signal.signal(signum2, handler)
+
+ signal.pthread_sigmask(signal.SIG_BLOCK, (signum1, signum2))
+ os.kill(os.getpid(), signum1)
+ os.kill(os.getpid(), signum2)
+ # Unblocking the 2 signals calls the C signal handler twice
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, (signum1, signum2))
+ """, signal.SIGUSR1, signal.SIGUSR2, ordered=False)
+
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class SiginterruptTest(unittest.TestCase):
@@ -316,9 +352,6 @@ class SiginterruptTest(unittest.TestCase):
read is interrupted by the signal and raises an exception. Return False
if it returns normally.
"""
- class Timeout(Exception):
- pass
-
# use a subprocess to have only one thread, to have a timeout on the
# blocking read and to not touch signal handling in this process
code = """if 1:
@@ -359,18 +392,8 @@ class SiginterruptTest(unittest.TestCase):
# wait until the child process is loaded and has started
first_line = process.stdout.readline()
- # Wait the process with a timeout of 5 seconds
- timeout = time.time() + 5.0
- while True:
- if timeout < time.time():
- raise Timeout()
- status = process.poll()
- if status is not None:
- break
- time.sleep(0.1)
-
- stdout, stderr = process.communicate()
- except Timeout:
+ stdout, stderr = process.communicate(timeout=5.0)
+ except subprocess.TimeoutExpired:
process.kill()
return False
else:
@@ -419,8 +442,6 @@ class ItimerTest(unittest.TestCase):
def sig_alrm(self, *args):
self.hndl_called = True
- if support.verbose:
- print("SIGALRM handler invoked", args)
def sig_vtalrm(self, *args):
self.hndl_called = True
@@ -432,21 +453,13 @@ class ItimerTest(unittest.TestCase):
elif self.hndl_count == 3:
# disable ITIMER_VIRTUAL, this function shouldn't be called anymore
signal.setitimer(signal.ITIMER_VIRTUAL, 0)
- if support.verbose:
- print("last SIGVTALRM handler call")
self.hndl_count += 1
- if support.verbose:
- print("SIGVTALRM handler invoked", args)
-
def sig_prof(self, *args):
self.hndl_called = True
signal.setitimer(signal.ITIMER_PROF, 0)
- if support.verbose:
- print("SIGPROF handler invoked", args)
-
def test_itimer_exc(self):
# XXX I'm assuming -1 is an invalid itimer, but maybe some platform
# defines it ?
@@ -459,10 +472,7 @@ class ItimerTest(unittest.TestCase):
def test_itimer_real(self):
self.itimer = signal.ITIMER_REAL
signal.setitimer(self.itimer, 1.0)
- if support.verbose:
- print("\ncall pause()...")
signal.pause()
-
self.assertEqual(self.hndl_called, True)
# Issue 3864, unknown if this affects earlier versions of freebsd also
@@ -511,11 +521,359 @@ class ItimerTest(unittest.TestCase):
# and the handler should have been called
self.assertEqual(self.hndl_called, True)
+
+class PendingSignalsTests(unittest.TestCase):
+ """
+ Test pthread_sigmask(), pthread_kill(), sigpending() and sigwait()
+ functions.
+ """
+ @unittest.skipUnless(hasattr(signal, 'sigpending'),
+ 'need signal.sigpending()')
+ def test_sigpending_empty(self):
+ self.assertEqual(signal.sigpending(), set())
+
+ @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
+ 'need signal.pthread_sigmask()')
+ @unittest.skipUnless(hasattr(signal, 'sigpending'),
+ 'need signal.sigpending()')
+ def test_sigpending(self):
+ code = """if 1:
+ import os
+ import signal
+
+ def handler(signum, frame):
+ 1/0
+
+ signum = signal.SIGUSR1
+ signal.signal(signum, handler)
+
+ signal.pthread_sigmask(signal.SIG_BLOCK, [signum])
+ os.kill(os.getpid(), signum)
+ pending = signal.sigpending()
+ if pending != {signum}:
+ raise Exception('%s != {%s}' % (pending, signum))
+ try:
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signum])
+ except ZeroDivisionError:
+ pass
+ else:
+ raise Exception("ZeroDivisionError not raised")
+ """
+ assert_python_ok('-c', code)
+
+ @unittest.skipUnless(hasattr(signal, 'pthread_kill'),
+ 'need signal.pthread_kill()')
+ def test_pthread_kill(self):
+ code = """if 1:
+ import signal
+ import threading
+ import sys
+
+ signum = signal.SIGUSR1
+
+ def handler(signum, frame):
+ 1/0
+
+ signal.signal(signum, handler)
+
+ if sys.platform == 'freebsd6':
+ # Issue #12392 and #12469: send a signal to the main thread
+ # doesn't work before the creation of the first thread on
+ # FreeBSD 6
+ def noop():
+ pass
+ thread = threading.Thread(target=noop)
+ thread.start()
+ thread.join()
+
+ tid = threading.get_ident()
+ try:
+ signal.pthread_kill(tid, signum)
+ except ZeroDivisionError:
+ pass
+ else:
+ raise Exception("ZeroDivisionError not raised")
+ """
+ assert_python_ok('-c', code)
+
+ @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
+ 'need signal.pthread_sigmask()')
+ def wait_helper(self, blocked, test):
+ """
+ test: body of the "def test(signum):" function.
+ blocked: number of the blocked signal
+ """
+ code = '''if 1:
+ import signal
+ import sys
+
+ def handler(signum, frame):
+ 1/0
+
+ %s
+
+ blocked = %s
+ signum = signal.SIGALRM
+
+ # child: block and wait the signal
+ try:
+ signal.signal(signum, handler)
+ signal.pthread_sigmask(signal.SIG_BLOCK, [blocked])
+
+ # Do the tests
+ test(signum)
+
+ # The handler must not be called on unblock
+ try:
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [blocked])
+ except ZeroDivisionError:
+ print("the signal handler has been called",
+ file=sys.stderr)
+ sys.exit(1)
+ except BaseException as err:
+ print("error: {}".format(err), file=sys.stderr)
+ sys.stderr.flush()
+ sys.exit(1)
+ ''' % (test.strip(), blocked)
+
+ # sig*wait* must be called with the signal blocked: since the current
+ # process might have several threads running, use a subprocess to have
+ # a single thread.
+ assert_python_ok('-c', code)
+
+ @unittest.skipUnless(hasattr(signal, 'sigwait'),
+ 'need signal.sigwait()')
+ def test_sigwait(self):
+ self.wait_helper(signal.SIGALRM, '''
+ def test(signum):
+ signal.alarm(1)
+ received = signal.sigwait([signum])
+ if received != signum:
+ raise Exception('received %s, not %s' % (received, signum))
+ ''')
+
+ @unittest.skipUnless(hasattr(signal, 'sigwaitinfo'),
+ 'need signal.sigwaitinfo()')
+ def test_sigwaitinfo(self):
+ self.wait_helper(signal.SIGALRM, '''
+ def test(signum):
+ signal.alarm(1)
+ info = signal.sigwaitinfo([signum])
+ if info.si_signo != signum:
+ raise Exception("info.si_signo != %s" % signum)
+ ''')
+
+ @unittest.skipUnless(hasattr(signal, 'sigtimedwait'),
+ 'need signal.sigtimedwait()')
+ def test_sigtimedwait(self):
+ self.wait_helper(signal.SIGALRM, '''
+ def test(signum):
+ signal.alarm(1)
+ info = signal.sigtimedwait([signum], 10.1000)
+ if info.si_signo != signum:
+ raise Exception('info.si_signo != %s' % signum)
+ ''')
+
+ @unittest.skipUnless(hasattr(signal, 'sigtimedwait'),
+ 'need signal.sigtimedwait()')
+ def test_sigtimedwait_poll(self):
+ # check that polling with sigtimedwait works
+ self.wait_helper(signal.SIGALRM, '''
+ def test(signum):
+ import os
+ os.kill(os.getpid(), signum)
+ info = signal.sigtimedwait([signum], 0)
+ if info.si_signo != signum:
+ raise Exception('info.si_signo != %s' % signum)
+ ''')
+
+ @unittest.skipUnless(hasattr(signal, 'sigtimedwait'),
+ 'need signal.sigtimedwait()')
+ def test_sigtimedwait_timeout(self):
+ self.wait_helper(signal.SIGALRM, '''
+ def test(signum):
+ received = signal.sigtimedwait([signum], 1.0)
+ if received is not None:
+ raise Exception("received=%r" % (received,))
+ ''')
+
+ @unittest.skipUnless(hasattr(signal, 'sigtimedwait'),
+ 'need signal.sigtimedwait()')
+ def test_sigtimedwait_negative_timeout(self):
+ signum = signal.SIGALRM
+ self.assertRaises(ValueError, signal.sigtimedwait, [signum], -1.0)
+
+ @unittest.skipUnless(hasattr(signal, 'sigwaitinfo'),
+ 'need signal.sigwaitinfo()')
+ def test_sigwaitinfo_interrupted(self):
+ self.wait_helper(signal.SIGUSR1, '''
+ def test(signum):
+ import errno
+
+ hndl_called = True
+ def alarm_handler(signum, frame):
+ hndl_called = False
+
+ signal.signal(signal.SIGALRM, alarm_handler)
+ signal.alarm(1)
+ try:
+ signal.sigwaitinfo([signal.SIGUSR1])
+ except OSError as e:
+ if e.errno == errno.EINTR:
+ if not hndl_called:
+ raise Exception("SIGALRM handler not called")
+ else:
+ raise Exception("Expected EINTR to be raised by sigwaitinfo")
+ else:
+ raise Exception("Expected EINTR to be raised by sigwaitinfo")
+ ''')
+
+ @unittest.skipUnless(hasattr(signal, 'sigwait'),
+ 'need signal.sigwait()')
+ @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
+ 'need signal.pthread_sigmask()')
+ @unittest.skipIf(threading is None, "test needs threading module")
+ def test_sigwait_thread(self):
+ # Check that calling sigwait() from a thread doesn't suspend the whole
+ # process. A new interpreter is spawned to avoid problems when mixing
+ # threads and fork(): only async-safe functions are allowed between
+ # fork() and exec().
+ assert_python_ok("-c", """if True:
+ import os, threading, sys, time, signal
+
+ # the default handler terminates the process
+ signum = signal.SIGUSR1
+
+ def kill_later():
+ # wait until the main thread is waiting in sigwait()
+ time.sleep(1)
+ os.kill(os.getpid(), signum)
+
+ # the signal must be blocked by all the threads
+ signal.pthread_sigmask(signal.SIG_BLOCK, [signum])
+ killer = threading.Thread(target=kill_later)
+ killer.start()
+ received = signal.sigwait([signum])
+ if received != signum:
+ print("sigwait() received %s, not %s" % (received, signum),
+ file=sys.stderr)
+ sys.exit(1)
+ killer.join()
+ # unblock the signal, which should have been cleared by sigwait()
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signum])
+ """)
+
+ @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
+ 'need signal.pthread_sigmask()')
+ def test_pthread_sigmask_arguments(self):
+ self.assertRaises(TypeError, signal.pthread_sigmask)
+ self.assertRaises(TypeError, signal.pthread_sigmask, 1)
+ self.assertRaises(TypeError, signal.pthread_sigmask, 1, 2, 3)
+ self.assertRaises(OSError, signal.pthread_sigmask, 1700, [])
+
+ @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
+ 'need signal.pthread_sigmask()')
+ def test_pthread_sigmask(self):
+ code = """if 1:
+ import signal
+ import os; import threading
+
+ def handler(signum, frame):
+ 1/0
+
+ def kill(signum):
+ os.kill(os.getpid(), signum)
+
+ def read_sigmask():
+ return signal.pthread_sigmask(signal.SIG_BLOCK, [])
+
+ signum = signal.SIGUSR1
+
+ # Install our signal handler
+ old_handler = signal.signal(signum, handler)
+
+ # Unblock SIGUSR1 (and copy the old mask) to test our signal handler
+ old_mask = signal.pthread_sigmask(signal.SIG_UNBLOCK, [signum])
+ try:
+ kill(signum)
+ except ZeroDivisionError:
+ pass
+ else:
+ raise Exception("ZeroDivisionError not raised")
+
+ # Block and then raise SIGUSR1. The signal is blocked: the signal
+ # handler is not called, and the signal is now pending
+ signal.pthread_sigmask(signal.SIG_BLOCK, [signum])
+ kill(signum)
+
+ # Check the new mask
+ blocked = read_sigmask()
+ if signum not in blocked:
+ raise Exception("%s not in %s" % (signum, blocked))
+ if old_mask ^ blocked != {signum}:
+ raise Exception("%s ^ %s != {%s}" % (old_mask, blocked, signum))
+
+ # Unblock SIGUSR1
+ try:
+ # unblock the pending signal calls immediatly the signal handler
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signum])
+ except ZeroDivisionError:
+ pass
+ else:
+ raise Exception("ZeroDivisionError not raised")
+ try:
+ kill(signum)
+ except ZeroDivisionError:
+ pass
+ else:
+ raise Exception("ZeroDivisionError not raised")
+
+ # Check the new mask
+ unblocked = read_sigmask()
+ if signum in unblocked:
+ raise Exception("%s in %s" % (signum, unblocked))
+ if blocked ^ unblocked != {signum}:
+ raise Exception("%s ^ %s != {%s}" % (blocked, unblocked, signum))
+ if old_mask != unblocked:
+ raise Exception("%s != %s" % (old_mask, unblocked))
+ """
+ assert_python_ok('-c', code)
+
+ @unittest.skipIf(sys.platform == 'freebsd6',
+ "issue #12392: send a signal to the main thread doesn't work "
+ "before the creation of the first thread on FreeBSD 6")
+ @unittest.skipUnless(hasattr(signal, 'pthread_kill'),
+ 'need signal.pthread_kill()')
+ def test_pthread_kill_main_thread(self):
+ # Test that a signal can be sent to the main thread with pthread_kill()
+ # before any other thread has been created (see issue #12392).
+ code = """if True:
+ import threading
+ import signal
+ import sys
+
+ def handler(signum, frame):
+ sys.exit(3)
+
+ signal.signal(signal.SIGUSR1, handler)
+ signal.pthread_kill(threading.get_ident(), signal.SIGUSR1)
+ sys.exit(2)
+ """
+
+ with spawn_python('-c', code) as process:
+ stdout, stderr = process.communicate()
+ exitcode = process.wait()
+ if exitcode != 3:
+ raise Exception("Child error (exit code %s): %s" %
+ (exitcode, stdout))
+
+
def test_main():
try:
- support.run_unittest(BasicSignalTests, InterProcessSignalTests,
- WakeupSignalTests, SiginterruptTest,
- ItimerTest, WindowsSignalTests)
+ support.run_unittest(PosixTests, InterProcessSignalTests,
+ WakeupFDTests, WakeupSignalTests,
+ SiginterruptTest, ItimerTest, WindowsSignalTests,
+ PendingSignalsTests)
finally:
support.reap_children()
diff --git a/Lib/test/test_site.py b/Lib/test/test_site.py
index ba42649..29286c7 100644
--- a/Lib/test/test_site.py
+++ b/Lib/test/test_site.py
@@ -39,6 +39,7 @@ class HelperFunctionsTests(unittest.TestCase):
self.old_base = site.USER_BASE
self.old_site = site.USER_SITE
self.old_prefixes = site.PREFIXES
+ self.original_vars = sysconfig._CONFIG_VARS
self.old_vars = copy(sysconfig._CONFIG_VARS)
def tearDown(self):
@@ -47,7 +48,9 @@ class HelperFunctionsTests(unittest.TestCase):
site.USER_BASE = self.old_base
site.USER_SITE = self.old_site
site.PREFIXES = self.old_prefixes
- sysconfig._CONFIG_VARS = self.old_vars
+ sysconfig._CONFIG_VARS = self.original_vars
+ sysconfig._CONFIG_VARS.clear()
+ sysconfig._CONFIG_VARS.update(self.old_vars)
def test_makepath(self):
# Test makepath() have an absolute path for its first return value
diff --git a/Lib/test/test_smtpd.py b/Lib/test/test_smtpd.py
index 68ccc29..93f14c4 100644
--- a/Lib/test/test_smtpd.py
+++ b/Lib/test/test_smtpd.py
@@ -1,4 +1,4 @@
-from unittest import TestCase
+import unittest
from test import support, mock_socket
import socket
import io
@@ -26,7 +26,7 @@ class BrokenDummyServer(DummyServer):
raise DummyDispatcherBroken()
-class SMTPDServerTest(TestCase):
+class SMTPDServerTest(unittest.TestCase):
def setUp(self):
smtpd.socket = asyncore.socket = mock_socket
@@ -39,6 +39,7 @@ class SMTPDServerTest(TestCase):
channel.socket.queue_recv(line)
channel.handle_read()
+ write_line(b'HELO example')
write_line(b'MAIL From:eggs@example')
write_line(b'RCPT To:spam@example')
write_line(b'DATA')
@@ -49,7 +50,7 @@ class SMTPDServerTest(TestCase):
asyncore.socket = smtpd.socket = socket
-class SMTPDChannelTest(TestCase):
+class SMTPDChannelTest(unittest.TestCase):
def setUp(self):
smtpd.socket = asyncore.socket = mock_socket
self.old_debugstream = smtpd.DEBUGSTREAM
@@ -78,31 +79,94 @@ class SMTPDChannelTest(TestCase):
self.assertEqual(self.channel.socket.last,
b'500 Error: bad syntax\r\n')
- def test_EHLO_not_implemented(self):
- self.write_line(b'EHLO test.example')
+ def test_EHLO(self):
+ self.write_line(b'EHLO example')
+ self.assertEqual(self.channel.socket.last, b'250 HELP\r\n')
+
+ def test_EHLO_bad_syntax(self):
+ self.write_line(b'EHLO')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Syntax: EHLO hostname\r\n')
+
+ def test_EHLO_duplicate(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'EHLO example')
self.assertEqual(self.channel.socket.last,
- b'502 Error: command "EHLO" not implemented\r\n')
+ b'503 Duplicate HELO/EHLO\r\n')
+
+ def test_EHLO_HELO_duplicate(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'HELO example')
+ self.assertEqual(self.channel.socket.last,
+ b'503 Duplicate HELO/EHLO\r\n')
def test_HELO(self):
name = smtpd.socket.getfqdn()
- self.write_line(b'HELO test.example')
+ self.write_line(b'HELO example')
self.assertEqual(self.channel.socket.last,
'250 {}\r\n'.format(name).encode('ascii'))
+ def test_HELO_EHLO_duplicate(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'EHLO example')
+ self.assertEqual(self.channel.socket.last,
+ b'503 Duplicate HELO/EHLO\r\n')
+
+ def test_HELP(self):
+ self.write_line(b'HELP')
+ self.assertEqual(self.channel.socket.last,
+ b'250 Supported commands: EHLO HELO MAIL RCPT ' + \
+ b'DATA RSET NOOP QUIT VRFY\r\n')
+
+ def test_HELP_command(self):
+ self.write_line(b'HELP MAIL')
+ self.assertEqual(self.channel.socket.last,
+ b'250 Syntax: MAIL FROM: <address>\r\n')
+
+ def test_HELP_command_unknown(self):
+ self.write_line(b'HELP SPAM')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Supported commands: EHLO HELO MAIL RCPT ' + \
+ b'DATA RSET NOOP QUIT VRFY\r\n')
+
def test_HELO_bad_syntax(self):
self.write_line(b'HELO')
self.assertEqual(self.channel.socket.last,
b'501 Syntax: HELO hostname\r\n')
def test_HELO_duplicate(self):
- self.write_line(b'HELO test.example')
- self.write_line(b'HELO test.example')
+ self.write_line(b'HELO example')
+ self.write_line(b'HELO example')
self.assertEqual(self.channel.socket.last,
b'503 Duplicate HELO/EHLO\r\n')
+ def test_HELO_parameter_rejected_when_extensions_not_enabled(self):
+ self.extended_smtp = False
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL from:<foo@example.com> SIZE=1234')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Syntax: MAIL FROM: <address>\r\n')
+
+ def test_MAIL_allows_space_after_colon(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL from: <foo@example.com>')
+ self.assertEqual(self.channel.socket.last,
+ b'250 OK\r\n')
+
+ def test_extended_MAIL_allows_space_after_colon(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from: <foo@example.com> size=20')
+ self.assertEqual(self.channel.socket.last,
+ b'250 OK\r\n')
+
def test_NOOP(self):
self.write_line(b'NOOP')
- self.assertEqual(self.channel.socket.last, b'250 Ok\r\n')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ def test_HELO_NOOP(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'NOOP')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
def test_NOOP_bad_syntax(self):
self.write_line(b'NOOP hi')
@@ -113,26 +177,47 @@ class SMTPDChannelTest(TestCase):
self.write_line(b'QUIT')
self.assertEqual(self.channel.socket.last, b'221 Bye\r\n')
+ def test_HELO_QUIT(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'QUIT')
+ self.assertEqual(self.channel.socket.last, b'221 Bye\r\n')
+
def test_QUIT_arg_ignored(self):
self.write_line(b'QUIT bye bye')
self.assertEqual(self.channel.socket.last, b'221 Bye\r\n')
def test_bad_state(self):
self.channel.smtp_state = 'BAD STATE'
- self.write_line(b'HELO')
+ self.write_line(b'HELO example')
self.assertEqual(self.channel.socket.last,
b'451 Internal confusion\r\n')
def test_command_too_long(self):
- self.write_line(b'MAIL from ' +
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL from: ' +
b'a' * self.channel.command_size_limit +
b'@example')
self.assertEqual(self.channel.socket.last,
b'500 Error: line too long\r\n')
- def test_data_too_long(self):
- # Small hack. Setting limit to 2K octets here will save us some time.
- self.channel.data_size_limit = 2048
+ def test_MAIL_command_limit_extended_with_SIZE(self):
+ self.write_line(b'EHLO example')
+ fill_len = self.channel.command_size_limit - len('MAIL from:<@example>')
+ self.write_line(b'MAIL from:<' +
+ b'a' * fill_len +
+ b'@example> SIZE=1234')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ self.write_line(b'MAIL from:<' +
+ b'a' * (fill_len + 26) +
+ b'@example> SIZE=1234')
+ self.assertEqual(self.channel.socket.last,
+ b'500 Error: line too long\r\n')
+
+ def test_data_longer_than_default_data_size_limit(self):
+ # Hack the default so we don't have to generate so much data.
+ self.channel.data_size_limit = 1048
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'RCPT To:spam@example')
self.write_line(b'DATA')
@@ -141,64 +226,183 @@ class SMTPDChannelTest(TestCase):
self.assertEqual(self.channel.socket.last,
b'552 Error: Too much mail data\r\n')
+ def test_MAIL_size_parameter(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL FROM:<eggs@example> SIZE=512')
+ self.assertEqual(self.channel.socket.last,
+ b'250 OK\r\n')
+
+ def test_MAIL_invalid_size_parameter(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL FROM:<eggs@example> SIZE=invalid')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Syntax: MAIL FROM: <address> [SP <mail-parameters>]\r\n')
+
+ def test_MAIL_RCPT_unknown_parameters(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL FROM:<eggs@example> ham=green')
+ self.assertEqual(self.channel.socket.last,
+ b'555 MAIL FROM parameters not recognized or not implemented\r\n')
+
+ self.write_line(b'MAIL FROM:<eggs@example>')
+ self.write_line(b'RCPT TO:<eggs@example> ham=green')
+ self.assertEqual(self.channel.socket.last,
+ b'555 RCPT TO parameters not recognized or not implemented\r\n')
+
+ def test_MAIL_size_parameter_larger_than_default_data_size_limit(self):
+ self.channel.data_size_limit = 1048
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL FROM:<eggs@example> SIZE=2096')
+ self.assertEqual(self.channel.socket.last,
+ b'552 Error: message size exceeds fixed maximum message size\r\n')
+
def test_need_MAIL(self):
+ self.write_line(b'HELO example')
self.write_line(b'RCPT to:spam@example')
self.assertEqual(self.channel.socket.last,
b'503 Error: need MAIL command\r\n')
- def test_MAIL_syntax(self):
+ def test_MAIL_syntax_HELO(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL from eggs@example')
self.assertEqual(self.channel.socket.last,
- b'501 Syntax: MAIL FROM:<address>\r\n')
+ b'501 Syntax: MAIL FROM: <address>\r\n')
- def test_MAIL_missing_from(self):
+ def test_MAIL_syntax_EHLO(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from eggs@example')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Syntax: MAIL FROM: <address> [SP <mail-parameters>]\r\n')
+
+ def test_MAIL_missing_address(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL from:')
self.assertEqual(self.channel.socket.last,
- b'501 Syntax: MAIL FROM:<address>\r\n')
+ b'501 Syntax: MAIL FROM: <address>\r\n')
def test_MAIL_chevrons(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL from:<eggs@example>')
- self.assertEqual(self.channel.socket.last, b'250 Ok\r\n')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ def test_MAIL_empty_chevrons(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from:<>')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ def test_MAIL_quoted_localpart(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from: <"Fred Blogs"@example.com>')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.assertEqual(self.channel.mailfrom, '"Fred Blogs"@example.com')
+
+ def test_MAIL_quoted_localpart_no_angles(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from: "Fred Blogs"@example.com')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.assertEqual(self.channel.mailfrom, '"Fred Blogs"@example.com')
+
+ def test_MAIL_quoted_localpart_with_size(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from: <"Fred Blogs"@example.com> SIZE=1000')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.assertEqual(self.channel.mailfrom, '"Fred Blogs"@example.com')
+
+ def test_MAIL_quoted_localpart_with_size_no_angles(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL from: "Fred Blogs"@example.com SIZE=1000')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.assertEqual(self.channel.mailfrom, '"Fred Blogs"@example.com')
def test_nested_MAIL(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL from:eggs@example')
self.write_line(b'MAIL from:spam@example')
self.assertEqual(self.channel.socket.last,
b'503 Error: nested MAIL command\r\n')
+ def test_VRFY(self):
+ self.write_line(b'VRFY eggs@example')
+ self.assertEqual(self.channel.socket.last,
+ b'252 Cannot VRFY user, but will accept message and attempt ' + \
+ b'delivery\r\n')
+
+ def test_VRFY_syntax(self):
+ self.write_line(b'VRFY')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Syntax: VRFY <address>\r\n')
+
+ def test_EXPN_not_implemented(self):
+ self.write_line(b'EXPN')
+ self.assertEqual(self.channel.socket.last,
+ b'502 EXPN not implemented\r\n')
+
+ def test_no_HELO_MAIL(self):
+ self.write_line(b'MAIL from:<foo@example.com>')
+ self.assertEqual(self.channel.socket.last,
+ b'503 Error: send HELO first\r\n')
+
def test_need_RCPT(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'DATA')
self.assertEqual(self.channel.socket.last,
b'503 Error: need RCPT command\r\n')
- def test_RCPT_syntax(self):
- self.write_line(b'MAIL From:eggs@example')
+ def test_RCPT_syntax_HELO(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL From: eggs@example')
self.write_line(b'RCPT to eggs@example')
self.assertEqual(self.channel.socket.last,
b'501 Syntax: RCPT TO: <address>\r\n')
+ def test_RCPT_syntax_EHLO(self):
+ self.write_line(b'EHLO example')
+ self.write_line(b'MAIL From: eggs@example')
+ self.write_line(b'RCPT to eggs@example')
+ self.assertEqual(self.channel.socket.last,
+ b'501 Syntax: RCPT TO: <address> [SP <mail-parameters>]\r\n')
+
+ def test_RCPT_lowercase_to_OK(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL From: eggs@example')
+ self.write_line(b'RCPT to: <eggs@example>')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ def test_no_HELO_RCPT(self):
+ self.write_line(b'RCPT to eggs@example')
+ self.assertEqual(self.channel.socket.last,
+ b'503 Error: send HELO first\r\n')
+
def test_data_dialog(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
- self.assertEqual(self.channel.socket.last, b'250 Ok\r\n')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
self.write_line(b'RCPT To:spam@example')
- self.assertEqual(self.channel.socket.last, b'250 Ok\r\n')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
self.write_line(b'DATA')
self.assertEqual(self.channel.socket.last,
b'354 End data with <CR><LF>.<CR><LF>\r\n')
self.write_line(b'data\r\nmore\r\n.')
- self.assertEqual(self.channel.socket.last, b'250 Ok\r\n')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
self.assertEqual(self.server.messages,
[('peer', 'eggs@example', ['spam@example'], 'data\nmore')])
def test_DATA_syntax(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'RCPT To:spam@example')
self.write_line(b'DATA spam')
self.assertEqual(self.channel.socket.last, b'501 Syntax: DATA\r\n')
+ def test_no_HELO_DATA(self):
+ self.write_line(b'DATA spam')
+ self.assertEqual(self.channel.socket.last,
+ b'503 Error: send HELO first\r\n')
+
def test_data_transparency_section_4_5_2(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'RCPT To:spam@example')
self.write_line(b'DATA')
@@ -206,6 +410,7 @@ class SMTPDChannelTest(TestCase):
self.assertEqual(self.channel.received_data, '.')
def test_multiple_RCPT(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'RCPT To:spam@example')
self.write_line(b'RCPT To:ham@example')
@@ -216,6 +421,7 @@ class SMTPDChannelTest(TestCase):
def test_manual_status(self):
# checks that the Channel is able to return a custom status message
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'RCPT To:spam@example')
self.write_line(b'DATA')
@@ -223,10 +429,11 @@ class SMTPDChannelTest(TestCase):
self.assertEqual(self.channel.socket.last, b'250 Okish\r\n')
def test_RSET(self):
+ self.write_line(b'HELO example')
self.write_line(b'MAIL From:eggs@example')
self.write_line(b'RCPT To:spam@example')
self.write_line(b'RSET')
- self.assertEqual(self.channel.socket.last, b'250 Ok\r\n')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
self.write_line(b'MAIL From:foo@example')
self.write_line(b'RCPT To:eggs@example')
self.write_line(b'DATA')
@@ -234,58 +441,117 @@ class SMTPDChannelTest(TestCase):
self.assertEqual(self.server.messages,
[('peer', 'foo@example', ['eggs@example'], 'data')])
+ def test_HELO_RSET(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'RSET')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
def test_RSET_syntax(self):
self.write_line(b'RSET hi')
self.assertEqual(self.channel.socket.last, b'501 Syntax: RSET\r\n')
+ def test_unknown_command(self):
+ self.write_line(b'UNKNOWN_CMD')
+ self.assertEqual(self.channel.socket.last,
+ b'500 Error: command "UNKNOWN_CMD" not ' + \
+ b'recognized\r\n')
+
def test_attribute_deprecations(self):
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__server
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__server = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__line
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__line = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__state
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__state = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__greeting
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__greeting = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__mailfrom
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__mailfrom = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__rcpttos
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__rcpttos = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__data
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__data = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__fqdn
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__fqdn = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__peer
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__peer = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__conn
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__conn = 'spam'
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
spam = self.channel._SMTPChannel__addr
- with support.check_warnings(('', PendingDeprecationWarning)):
+ with support.check_warnings(('', DeprecationWarning)):
self.channel._SMTPChannel__addr = 'spam'
-def test_main():
- support.run_unittest(SMTPDServerTest, SMTPDChannelTest)
+
+class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase):
+
+ def setUp(self):
+ smtpd.socket = asyncore.socket = mock_socket
+ self.old_debugstream = smtpd.DEBUGSTREAM
+ self.debug = smtpd.DEBUGSTREAM = io.StringIO()
+ self.server = DummyServer('a', 'b')
+ conn, addr = self.server.accept()
+ # Set DATA size limit to 32 bytes for easy testing
+ self.channel = smtpd.SMTPChannel(self.server, conn, addr, 32)
+
+ def tearDown(self):
+ asyncore.close_all()
+ asyncore.socket = smtpd.socket = socket
+ smtpd.DEBUGSTREAM = self.old_debugstream
+
+ def write_line(self, line):
+ self.channel.socket.queue_recv(line)
+ self.channel.handle_read()
+
+ def test_data_limit_dialog(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL From:eggs@example')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.write_line(b'RCPT To:spam@example')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ self.write_line(b'DATA')
+ self.assertEqual(self.channel.socket.last,
+ b'354 End data with <CR><LF>.<CR><LF>\r\n')
+ self.write_line(b'data\r\nmore\r\n.')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.assertEqual(self.server.messages,
+ [('peer', 'eggs@example', ['spam@example'], 'data\nmore')])
+
+ def test_data_limit_dialog_too_much_data(self):
+ self.write_line(b'HELO example')
+ self.write_line(b'MAIL From:eggs@example')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+ self.write_line(b'RCPT To:spam@example')
+ self.assertEqual(self.channel.socket.last, b'250 OK\r\n')
+
+ self.write_line(b'DATA')
+ self.assertEqual(self.channel.socket.last,
+ b'354 End data with <CR><LF>.<CR><LF>\r\n')
+ self.write_line(b'This message is longer than 32 bytes\r\n.')
+ self.assertEqual(self.channel.socket.last,
+ b'552 Error: Too much mail data\r\n')
+
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
index 2cb0d1a..befc49e 100644
--- a/Lib/test/test_smtplib.py
+++ b/Lib/test/test_smtplib.py
@@ -9,6 +9,7 @@ import re
import sys
import time
import select
+import errno
import unittest
from test import support, mock_socket
@@ -72,6 +73,14 @@ class GeneralTests(unittest.TestCase):
smtp = smtplib.SMTP(HOST, self.port)
smtp.close()
+ def testSourceAddress(self):
+ mock_socket.reply_with(b"220 Hola mundo")
+ # connects
+ smtp = smtplib.SMTP(HOST, self.port,
+ source_address=('127.0.0.1',19876))
+ self.assertEqual(smtp.source_address, ('127.0.0.1', 19876))
+ smtp.close()
+
def testBasic2(self):
mock_socket.reply_with(b"220 Hola mundo")
# connects, include port in host name
@@ -204,15 +213,29 @@ class DebuggingServerTests(unittest.TestCase):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
smtp.quit()
+ def testSourceAddress(self):
+ # connect
+ port = support.find_unused_port()
+ try:
+ smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost',
+ timeout=3, source_address=('127.0.0.1', port))
+ self.assertEqual(smtp.source_address, ('127.0.0.1', port))
+ self.assertEqual(smtp.local_hostname, 'localhost')
+ smtp.quit()
+ except IOError as e:
+ if e.errno == errno.EADDRINUSE:
+ self.skipTest("couldn't bind to port %d" % port)
+ raise
+
def testNOOP(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
- expected = (250, b'Ok')
+ expected = (250, b'OK')
self.assertEqual(smtp.noop(), expected)
smtp.quit()
def testRSET(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
- expected = (250, b'Ok')
+ expected = (250, b'OK')
self.assertEqual(smtp.rset(), expected)
smtp.quit()
@@ -223,10 +246,18 @@ class DebuggingServerTests(unittest.TestCase):
self.assertEqual(smtp.ehlo(), expected)
smtp.quit()
+ def testNotImplemented(self):
+ # EXPN isn't implemented in DebuggingServer
+ smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
+ expected = (502, b'EXPN not implemented')
+ smtp.putcmd('EXPN')
+ self.assertEqual(smtp.getreply(), expected)
+ smtp.quit()
+
def testVRFY(self):
- # VRFY isn't implemented in DebuggingServer
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
- expected = (502, b'Error: command "VRFY" not implemented')
+ expected = (252, b'Cannot VRFY user, but will accept message ' + \
+ b'and attempt delivery')
self.assertEqual(smtp.vrfy('nobody@nowhere.com'), expected)
self.assertEqual(smtp.verify('nobody@nowhere.com'), expected)
smtp.quit()
@@ -242,7 +273,8 @@ class DebuggingServerTests(unittest.TestCase):
def testHELP(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
- self.assertEqual(smtp.help(), b'Error: command "HELP" not implemented')
+ self.assertEqual(smtp.help(), b'Supported commands: EHLO HELO MAIL ' + \
+ b'RCPT DATA RSET NOOP QUIT VRFY')
smtp.quit()
def testSend(self):
@@ -560,6 +592,9 @@ sim_lists = {'list-1':['Mr.A@somewhere.com','Mrs.C@somewhereesle.com'],
# Simulated SMTP channel & server
class SimSMTPChannel(smtpd.SMTPChannel):
+ # For testing failures in QUIT when using the context manager API.
+ quit_response = None
+
def __init__(self, extra_features, *args, **kw):
self._extrafeatures = ''.join(
[ "250-{0}\r\n".format(x) for x in extra_features ])
@@ -610,19 +645,31 @@ class SimSMTPChannel(smtpd.SMTPChannel):
else:
self.push('550 No access for you!')
+ def smtp_QUIT(self, arg):
+ # args is ignored
+ if self.quit_response is None:
+ super(SimSMTPChannel, self).smtp_QUIT(arg)
+ else:
+ self.push(self.quit_response)
+ self.close_when_done()
+
def handle_error(self):
raise
class SimSMTPServer(smtpd.SMTPServer):
+ # For testing failures in QUIT when using the context manager API.
+ quit_response = None
+
def __init__(self, *args, **kw):
self._extra_features = []
smtpd.SMTPServer.__init__(self, *args, **kw)
def handle_accepted(self, conn, addr):
- self._SMTPchannel = SimSMTPChannel(self._extra_features,
- self, conn, addr)
+ self._SMTPchannel = SimSMTPChannel(
+ self._extra_features, self, conn, addr)
+ self._SMTPchannel.quit_response = self.quit_response
def process_message(self, peer, mailfrom, rcpttos, data):
pass
@@ -752,6 +799,25 @@ class SMTPSimTests(unittest.TestCase):
self.assertIn(sim_auth_credentials['cram-md5'], str(err))
smtp.close()
+ def test_with_statement(self):
+ with smtplib.SMTP(HOST, self.port) as smtp:
+ code, message = smtp.noop()
+ self.assertEqual(code, 250)
+ self.assertRaises(smtplib.SMTPServerDisconnected, smtp.send, b'foo')
+ with smtplib.SMTP(HOST, self.port) as smtp:
+ smtp.close()
+ self.assertRaises(smtplib.SMTPServerDisconnected, smtp.send, b'foo')
+
+ def test_with_statement_QUIT_failure(self):
+ self.serv.quit_response = '421 QUIT FAILED'
+ with self.assertRaises(smtplib.SMTPResponseException) as error:
+ with smtplib.SMTP(HOST, self.port) as smtp:
+ smtp.noop()
+ self.assertEqual(error.exception.smtp_code, 421)
+ self.assertEqual(error.exception.smtp_error, b'QUIT FAILED')
+ # We don't need to clean up self.serv.quit_response because a new
+ # server is always instantiated in the setUp().
+
#TODO: add tests for correct AUTH method fallback now that the
#test infrastructure can support it.
diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py
index 0198ab6..86224ef 100644
--- a/Lib/test/test_smtpnet.py
+++ b/Lib/test/test_smtpnet.py
@@ -4,28 +4,60 @@ import unittest
from test import support
import smtplib
+ssl = support.import_module("ssl")
+
support.requires("network")
+
+class SmtpTest(unittest.TestCase):
+ testServer = 'smtp.gmail.com'
+ remotePort = 25
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+
+ def test_connect_starttls(self):
+ support.get_attribute(smtplib, 'SMTP_SSL')
+ with support.transient_internet(self.testServer):
+ server = smtplib.SMTP(self.testServer, self.remotePort)
+ try:
+ server.starttls(context=self.context)
+ except smtplib.SMTPException as e:
+ if e.args[0] == 'STARTTLS extension not supported by server.':
+ unittest.skip(e.args[0])
+ else:
+ raise
+ server.ehlo()
+ server.quit()
+
+
class SmtpSSLTest(unittest.TestCase):
testServer = 'smtp.gmail.com'
remotePort = 465
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
def test_connect(self):
support.get_attribute(smtplib, 'SMTP_SSL')
with support.transient_internet(self.testServer):
server = smtplib.SMTP_SSL(self.testServer, self.remotePort)
- server.ehlo()
- server.quit()
+ server.ehlo()
+ server.quit()
def test_connect_default_port(self):
support.get_attribute(smtplib, 'SMTP_SSL')
with support.transient_internet(self.testServer):
server = smtplib.SMTP_SSL(self.testServer)
- server.ehlo()
- server.quit()
+ server.ehlo()
+ server.quit()
+
+ def test_connect_using_sslcontext(self):
+ support.get_attribute(smtplib, 'SMTP_SSL')
+ with support.transient_internet(self.testServer):
+ server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=self.context)
+ server.ehlo()
+ server.quit()
+
def test_main():
- support.run_unittest(SmtpSSLTest)
+ support.run_unittest(SmtpTest, SmtpSSLTest)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 5573421..da6ef05 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,11 +2,13 @@
import unittest
from test import support
+from unittest.case import _ExpectedFailure
import errno
import io
import socket
import select
+import tempfile
import _testcapi
import time
import traceback
@@ -19,34 +21,19 @@ import contextlib
from weakref import proxy
import signal
import math
+import pickle
+import struct
try:
import fcntl
except ImportError:
fcntl = False
-
-def try_address(host, port=0, family=socket.AF_INET):
- """Try to bind a socket on the given host:port and return True
- if that has been possible."""
- try:
- sock = socket.socket(family, socket.SOCK_STREAM)
- sock.bind((host, port))
- except (socket.error, socket.gaierror):
- return False
- else:
- sock.close()
- return True
-
-def linux_version():
- try:
- # platform.release() is something like '2.6.33.7-desktop-2mnb'
- version_string = platform.release().split('-')[0]
- return tuple(map(int, version_string.split('.')))
- except ValueError:
- return 0, 0, 0
+try:
+ import multiprocessing
+except ImportError:
+ multiprocessing = False
HOST = support.HOST
-MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf8') ## test unicode string and carriage return
-SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
+MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return
try:
import _thread as thread
@@ -55,6 +42,33 @@ except ImportError:
thread = None
threading = None
+def _have_socket_can():
+ """Check whether CAN sockets are supported on this host."""
+ try:
+ s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
+ except (AttributeError, socket.error, OSError):
+ return False
+ else:
+ s.close()
+ return True
+
+def _have_socket_rds():
+ """Check whether RDS sockets are supported on this host."""
+ try:
+ s = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
+ except (AttributeError, OSError):
+ return False
+ else:
+ s.close()
+ return True
+
+HAVE_SOCKET_CAN = _have_socket_can()
+
+HAVE_SOCKET_RDS = _have_socket_rds()
+
+# Size in bytes of the int type
+SIZEOF_INT = array.array("i").itemsize
+
class SocketTCPTest(unittest.TestCase):
def setUp(self):
@@ -76,6 +90,63 @@ class SocketUDPTest(unittest.TestCase):
self.serv.close()
self.serv = None
+class ThreadSafeCleanupTestCase(unittest.TestCase):
+ """Subclass of unittest.TestCase with thread-safe cleanup methods.
+
+ This subclass protects the addCleanup() and doCleanups() methods
+ with a recursive lock.
+ """
+
+ if threading:
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._cleanup_lock = threading.RLock()
+
+ def addCleanup(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().addCleanup(*args, **kwargs)
+
+ def doCleanups(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().doCleanups(*args, **kwargs)
+
+class SocketCANTest(unittest.TestCase):
+
+ """To be able to run this test, a `vcan0` CAN interface can be created with
+ the following commands:
+ # modprobe vcan
+ # ip link add dev vcan0 type vcan
+ # ifconfig vcan0 up
+ """
+ interface = 'vcan0'
+ bufsize = 128
+
+ def setUp(self):
+ self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
+ self.addCleanup(self.s.close)
+ try:
+ self.s.bind((self.interface,))
+ except socket.error:
+ self.skipTest('network interface `%s` does not exist' %
+ self.interface)
+
+
+class SocketRDSTest(unittest.TestCase):
+
+ """To be able to run this test, the `rds` kernel module must be loaded:
+ # modprobe rds
+ """
+ bufsize = 8192
+
+ def setUp(self):
+ self.serv = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
+ self.addCleanup(self.serv.close)
+ try:
+ self.port = support.bind_port(self.serv)
+ except OSError:
+ self.skipTest('unable to bind RDS socket')
+
+
class ThreadableTest:
"""Threadable Test class
@@ -133,6 +204,7 @@ class ThreadableTest:
self.client_ready = threading.Event()
self.done = threading.Event()
self.queue = queue.Queue(1)
+ self.server_crashed = False
# Do some munging to start the client test.
methodname = self.id()
@@ -142,8 +214,12 @@ class ThreadableTest:
self.client_thread = thread.start_new_thread(
self.clientRun, (test_method,))
- self.__setUp()
- if not self.server_ready.is_set():
+ try:
+ self.__setUp()
+ except:
+ self.server_crashed = True
+ raise
+ finally:
self.server_ready.set()
self.client_ready.wait()
@@ -159,10 +235,16 @@ class ThreadableTest:
self.server_ready.wait()
self.clientSetUp()
self.client_ready.set()
+ if self.server_crashed:
+ self.clientTearDown()
+ return
if not hasattr(test_func, '__call__'):
raise TypeError("test_func must be a callable function")
try:
test_func()
+ except _ExpectedFailure:
+ # We deliberately ignore expected failures
+ pass
except BaseException as e:
self.queue.put(e)
finally:
@@ -203,6 +285,48 @@ class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
self.cli = None
ThreadableTest.clientTearDown(self)
+class ThreadedCANSocketTest(SocketCANTest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketCANTest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
+ try:
+ self.cli.bind((self.interface,))
+ except socket.error:
+ # skipTest should not be called here, and will be called in the
+ # server instead
+ pass
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+class ThreadedRDSSocketTest(SocketRDSTest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketRDSTest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
+ try:
+ # RDS sockets must be bound explicitly to send or receive data
+ self.cli.bind((HOST, 0))
+ self.cli_addr = self.cli.getsockname()
+ except OSError:
+ # skipTest should not be called here, and will be called in the
+ # server instead
+ pass
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
class SocketConnectedTest(ThreadedTCPSocketTest):
"""Socket tests for client-server connection.
@@ -258,6 +382,243 @@ class SocketPairTest(unittest.TestCase, ThreadableTest):
ThreadableTest.clientTearDown(self)
+# The following classes are used by the sendmsg()/recvmsg() tests.
+# Combining, for instance, ConnectedStreamTestMixin and TCPTestBase
+# gives a drop-in replacement for SocketConnectedTest, but different
+# address families can be used, and the attributes serv_addr and
+# cli_addr will be set to the addresses of the endpoints.
+
+class SocketTestBase(unittest.TestCase):
+ """A base class for socket tests.
+
+ Subclasses must provide methods newSocket() to return a new socket
+ and bindSock(sock) to bind it to an unused address.
+
+ Creates a socket self.serv and sets self.serv_addr to its address.
+ """
+
+ def setUp(self):
+ self.serv = self.newSocket()
+ self.bindServer()
+
+ def bindServer(self):
+ """Bind server socket and set self.serv_addr to its address."""
+ self.bindSock(self.serv)
+ self.serv_addr = self.serv.getsockname()
+
+ def tearDown(self):
+ self.serv.close()
+ self.serv = None
+
+
+class SocketListeningTestMixin(SocketTestBase):
+ """Mixin to listen on the server socket."""
+
+ def setUp(self):
+ super().setUp()
+ self.serv.listen(1)
+
+
+class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase,
+ ThreadableTest):
+ """Mixin to add client socket and allow client/server tests.
+
+ Client socket is self.cli and its address is self.cli_addr. See
+ ThreadableTest for usage information.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = self.newClientSocket()
+ self.bindClient()
+
+ def newClientSocket(self):
+ """Return a new socket for use as client."""
+ return self.newSocket()
+
+ def bindClient(self):
+ """Bind client socket and set self.cli_addr to its address."""
+ self.bindSock(self.cli)
+ self.cli_addr = self.cli.getsockname()
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+
+class ConnectedStreamTestMixin(SocketListeningTestMixin,
+ ThreadedSocketTestMixin):
+ """Mixin to allow client/server stream tests with connected client.
+
+ Server's socket representing connection to client is self.cli_conn
+ and client's connection to server is self.serv_conn. (Based on
+ SocketConnectedTest.)
+ """
+
+ def setUp(self):
+ super().setUp()
+ # Indicate explicitly we're ready for the client thread to
+ # proceed and then perform the blocking call to accept
+ self.serverExplicitReady()
+ conn, addr = self.serv.accept()
+ self.cli_conn = conn
+
+ def tearDown(self):
+ self.cli_conn.close()
+ self.cli_conn = None
+ super().tearDown()
+
+ def clientSetUp(self):
+ super().clientSetUp()
+ self.cli.connect(self.serv_addr)
+ self.serv_conn = self.cli
+
+ def clientTearDown(self):
+ self.serv_conn.close()
+ self.serv_conn = None
+ super().clientTearDown()
+
+
+class UnixSocketTestBase(SocketTestBase):
+ """Base class for Unix-domain socket tests."""
+
+ # This class is used for file descriptor passing tests, so we
+ # create the sockets in a private directory so that other users
+ # can't send anything that might be problematic for a privileged
+ # user running the tests.
+
+ def setUp(self):
+ self.dir_path = tempfile.mkdtemp()
+ self.addCleanup(os.rmdir, self.dir_path)
+ super().setUp()
+
+ def bindSock(self, sock):
+ path = tempfile.mktemp(dir=self.dir_path)
+ sock.bind(path)
+ self.addCleanup(support.unlink, path)
+
+class UnixStreamBase(UnixSocketTestBase):
+ """Base class for Unix-domain SOCK_STREAM tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+
+class InetTestBase(SocketTestBase):
+ """Base class for IPv4 socket tests."""
+
+ host = HOST
+
+ def setUp(self):
+ super().setUp()
+ self.port = self.serv_addr[1]
+
+ def bindSock(self, sock):
+ support.bind_port(sock, host=self.host)
+
+class TCPTestBase(InetTestBase):
+ """Base class for TCP-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+class UDPTestBase(InetTestBase):
+ """Base class for UDP-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+class SCTPStreamBase(InetTestBase):
+ """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_STREAM,
+ socket.IPPROTO_SCTP)
+
+
+class Inet6TestBase(InetTestBase):
+ """Base class for IPv6 socket tests."""
+
+ # Don't use "localhost" here - it may not have an IPv6 address
+ # assigned to it by default (e.g. in /etc/hosts), and if someone
+ # has assigned it an IPv4-mapped address, then it's unlikely to
+ # work with the full IPv6 API.
+ host = "::1"
+
+class UDP6TestBase(Inet6TestBase):
+ """Base class for UDP-over-IPv6 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+
+
+# Test-skipping decorators for use with ThreadableTest.
+
+def skipWithClientIf(condition, reason):
+ """Skip decorated test if condition is true, add client_skip decorator.
+
+ If the decorated object is not a class, sets its attribute
+ "client_skip" to a decorator which will return an empty function
+ if the test is to be skipped, or the original function if it is
+ not. This can be used to avoid running the client part of a
+ skipped test when using ThreadableTest.
+ """
+ def client_pass(*args, **kwargs):
+ pass
+ def skipdec(obj):
+ retval = unittest.skip(reason)(obj)
+ if not isinstance(obj, type):
+ retval.client_skip = lambda f: client_pass
+ return retval
+ def noskipdec(obj):
+ if not (isinstance(obj, type) or hasattr(obj, "client_skip")):
+ obj.client_skip = lambda f: f
+ return obj
+ return skipdec if condition else noskipdec
+
+
+def requireAttrs(obj, *attributes):
+ """Skip decorated test if obj is missing any of the given attributes.
+
+ Sets client_skip attribute as skipWithClientIf() does.
+ """
+ missing = [name for name in attributes if not hasattr(obj, name)]
+ return skipWithClientIf(
+ missing, "don't have " + ", ".join(name for name in missing))
+
+
+def requireSocket(*args):
+ """Skip decorated test if a socket cannot be created with given arguments.
+
+ When an argument is given as a string, will use the value of that
+ attribute of the socket module, or skip the test if it doesn't
+ exist. Sets client_skip attribute as skipWithClientIf() does.
+ """
+ err = None
+ missing = [obj for obj in args if
+ isinstance(obj, str) and not hasattr(socket, obj)]
+ if missing:
+ err = "don't have " + ", ".join(name for name in missing)
+ else:
+ callargs = [getattr(socket, obj) if isinstance(obj, str) else obj
+ for obj in args]
+ try:
+ s = socket.socket(*callargs)
+ except socket.error as e:
+ # XXX: check errno?
+ err = str(e)
+ else:
+ s.close()
+ return skipWithClientIf(
+ err is not None,
+ "can't create socket({0}): {1}".format(
+ ", ".join(str(o) for o in args), err))
+
+
#######################################################################
## Begin Tests
@@ -283,18 +644,13 @@ class GeneralModuleTests(unittest.TestCase):
def testSocketError(self):
# Testing socket module exceptions
- def raise_error(*args, **kwargs):
+ msg = "Error raising socket exception (%s)."
+ with self.assertRaises(socket.error, msg=msg % 'socket.error'):
raise socket.error
- def raise_herror(*args, **kwargs):
+ with self.assertRaises(socket.error, msg=msg % 'socket.herror'):
raise socket.herror
- def raise_gaierror(*args, **kwargs):
+ with self.assertRaises(socket.error, msg=msg % 'socket.gaierror'):
raise socket.gaierror
- self.assertRaises(socket.error, raise_error,
- "Error raising socket exception.")
- self.assertRaises(socket.error, raise_herror,
- "Error raising socket exception.")
- self.assertRaises(socket.error, raise_gaierror,
- "Error raising socket exception.")
def testSendtoErrors(self):
# Testing that sendto doens't masks failures. See #10169.
@@ -370,6 +726,52 @@ class GeneralModuleTests(unittest.TestCase):
if not fqhn in all_host_names:
self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
+ @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()")
+ @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()")
+ def test_sethostname(self):
+ oldhn = socket.gethostname()
+ try:
+ socket.sethostname('new')
+ except socket.error as e:
+ if e.errno == errno.EPERM:
+ self.skipTest("test should be run as root")
+ else:
+ raise
+ try:
+ # running test as root!
+ self.assertEqual(socket.gethostname(), 'new')
+ # Should work with bytes objects too
+ socket.sethostname(b'bar')
+ self.assertEqual(socket.gethostname(), 'bar')
+ finally:
+ socket.sethostname(oldhn)
+
+ @unittest.skipUnless(hasattr(socket, 'if_nameindex'),
+ 'socket.if_nameindex() not available.')
+ def testInterfaceNameIndex(self):
+ interfaces = socket.if_nameindex()
+ for index, name in interfaces:
+ self.assertIsInstance(index, int)
+ self.assertIsInstance(name, str)
+ # interface indices are non-zero integers
+ self.assertGreater(index, 0)
+ _index = socket.if_nametoindex(name)
+ self.assertIsInstance(_index, int)
+ self.assertEqual(index, _index)
+ _name = socket.if_indextoname(index)
+ self.assertIsInstance(_name, str)
+ self.assertEqual(name, _name)
+
+ @unittest.skipUnless(hasattr(socket, 'if_nameindex'),
+ 'socket.if_nameindex() not available.')
+ def testInvalidInterfaceNameIndex(self):
+ # test nonexistent interface index/name
+ self.assertRaises(socket.error, socket.if_indextoname, 0)
+ self.assertRaises(socket.error, socket.if_nametoindex, '_DEADBEEF')
+ # test with invalid values
+ self.assertRaises(TypeError, socket.if_nametoindex, 0)
+ self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF')
+
def testRefCountGetNameInfo(self):
# Testing reference count for getnameinfo
if hasattr(sys, "getrefcount"):
@@ -422,10 +824,8 @@ class GeneralModuleTests(unittest.TestCase):
# Find one service that exists, then check all the related interfaces.
# I've ordered this by protocols that have both a tcp and udp
# protocol, at least for modern Linuxes.
- if (sys.platform.startswith('linux') or
- sys.platform.startswith('freebsd') or
- sys.platform.startswith('netbsd') or
- sys.platform == 'darwin'):
+ if (sys.platform.startswith(('freebsd', 'netbsd'))
+ or sys.platform in ('linux', 'darwin')):
# avoid the 'echo' service on this platform, as there is an
# assumption breaking non-standard port/protocol entry
services = ('daytime', 'qotd', 'domain')
@@ -720,7 +1120,7 @@ class GeneralModuleTests(unittest.TestCase):
socket.getaddrinfo('localhost', 80)
socket.getaddrinfo('127.0.0.1', 80)
socket.getaddrinfo(None, 80)
- if SUPPORTS_IPV6:
+ if support.IPV6_ENABLED:
socket.getaddrinfo('::1', 80)
# port can be a string service name such as "http", a numeric
# port number or None
@@ -773,7 +1173,12 @@ class GeneralModuleTests(unittest.TestCase):
@unittest.skipUnless(support.is_resource_enabled('network'),
'network is not enabled')
def test_idna(self):
- support.requires('network')
+ # Check for internet access before running test (issue #12804).
+ try:
+ socket.gethostbyname('python.org')
+ except socket.gaierror as e:
+ if e.errno == socket.EAI_NODATA:
+ self.skipTest('internet access required for this test')
# these should all be successful
socket.gethostbyname('испытание.python.org')
socket.gethostbyname_ex('испытание.python.org')
@@ -851,6 +1256,12 @@ class GeneralModuleTests(unittest.TestCase):
self.assertRaises(ValueError, fp.writable)
self.assertRaises(ValueError, fp.seekable)
+ def test_pickle(self):
+ sock = socket.socket()
+ with sock:
+ for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.assertRaises(TypeError, pickle.dumps, sock, protocol)
+
def test_listen_backlog(self):
for backlog in 0, -1:
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -864,7 +1275,7 @@ class GeneralModuleTests(unittest.TestCase):
self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
srv.close()
- @unittest.skipUnless(SUPPORTS_IPV6, 'IPv6 required for this test.')
+ @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
def test_flowinfo(self):
self.assertRaises(OverflowError, socket.getnameinfo,
('::1',0, 0xffffffff), 0)
@@ -872,6 +1283,222 @@ class GeneralModuleTests(unittest.TestCase):
self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
+@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
+class BasicCANTest(unittest.TestCase):
+
+ def testCrucialConstants(self):
+ socket.AF_CAN
+ socket.PF_CAN
+ socket.CAN_RAW
+
+ def testCreateSocket(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ pass
+
+ def testBindAny(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ s.bind(('', ))
+
+ def testTooLongInterfaceName(self):
+ # most systems limit IFNAMSIZ to 16, take 1024 to be sure
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ self.assertRaisesRegex(socket.error, 'interface name too long',
+ s.bind, ('x' * 1024,))
+
+ @unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"),
+ 'socket.CAN_RAW_LOOPBACK required for this test.')
+ def testLoopback(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ for loopback in (0, 1):
+ s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK,
+ loopback)
+ self.assertEqual(loopback,
+ s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK))
+
+ @unittest.skipUnless(hasattr(socket, "CAN_RAW_FILTER"),
+ 'socket.CAN_RAW_FILTER required for this test.')
+ def testFilter(self):
+ can_id, can_mask = 0x200, 0x700
+ can_filter = struct.pack("=II", can_id, can_mask)
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, can_filter)
+ self.assertEqual(can_filter,
+ s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, 8))
+
+
+@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class CANTest(ThreadedCANSocketTest):
+
+ """The CAN frame structure is defined in <linux/can.h>:
+
+ struct can_frame {
+ canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */
+ __u8 can_dlc; /* data length code: 0 .. 8 */
+ __u8 data[8] __attribute__((aligned(8)));
+ };
+ """
+ can_frame_fmt = "=IB3x8s"
+
+ def __init__(self, methodName='runTest'):
+ ThreadedCANSocketTest.__init__(self, methodName=methodName)
+
+ @classmethod
+ def build_can_frame(cls, can_id, data):
+ """Build a CAN frame."""
+ can_dlc = len(data)
+ data = data.ljust(8, b'\x00')
+ return struct.pack(cls.can_frame_fmt, can_id, can_dlc, data)
+
+ @classmethod
+ def dissect_can_frame(cls, frame):
+ """Dissect a CAN frame."""
+ can_id, can_dlc, data = struct.unpack(cls.can_frame_fmt, frame)
+ return (can_id, can_dlc, data[:can_dlc])
+
+ def testSendFrame(self):
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf, cf)
+ self.assertEqual(addr[0], self.interface)
+ self.assertEqual(addr[1], socket.AF_CAN)
+
+ def _testSendFrame(self):
+ self.cf = self.build_can_frame(0x00, b'\x01\x02\x03\x04\x05')
+ self.cli.send(self.cf)
+
+ def testSendMaxFrame(self):
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf, cf)
+
+ def _testSendMaxFrame(self):
+ self.cf = self.build_can_frame(0x00, b'\x07' * 8)
+ self.cli.send(self.cf)
+
+ def testSendMultiFrames(self):
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf1, cf)
+
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf2, cf)
+
+ def _testSendMultiFrames(self):
+ self.cf1 = self.build_can_frame(0x07, b'\x44\x33\x22\x11')
+ self.cli.send(self.cf1)
+
+ self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33')
+ self.cli.send(self.cf2)
+
+
+@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
+class BasicRDSTest(unittest.TestCase):
+
+ def testCrucialConstants(self):
+ socket.AF_RDS
+ socket.PF_RDS
+
+ def testCreateSocket(self):
+ with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
+ pass
+
+ def testSocketBufferSize(self):
+ bufsize = 16384
+ with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, bufsize)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, bufsize)
+
+
+@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RDSTest(ThreadedRDSSocketTest):
+
+ def __init__(self, methodName='runTest'):
+ ThreadedRDSSocketTest.__init__(self, methodName=methodName)
+
+ def setUp(self):
+ super().setUp()
+ self.evt = threading.Event()
+
+ def testSendAndRecv(self):
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data, data)
+ self.assertEqual(self.cli_addr, addr)
+
+ def _testSendAndRecv(self):
+ self.data = b'spam'
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+
+ def testPeek(self):
+ data, addr = self.serv.recvfrom(self.bufsize, socket.MSG_PEEK)
+ self.assertEqual(self.data, data)
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data, data)
+
+ def _testPeek(self):
+ self.data = b'spam'
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+
+ @requireAttrs(socket.socket, 'recvmsg')
+ def testSendAndRecvMsg(self):
+ data, ancdata, msg_flags, addr = self.serv.recvmsg(self.bufsize)
+ self.assertEqual(self.data, data)
+
+ @requireAttrs(socket.socket, 'sendmsg')
+ def _testSendAndRecvMsg(self):
+ self.data = b'hello ' * 10
+ self.cli.sendmsg([self.data], (), 0, (HOST, self.port))
+
+ def testSendAndRecvMulti(self):
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data1, data)
+
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data2, data)
+
+ def _testSendAndRecvMulti(self):
+ self.data1 = b'bacon'
+ self.cli.sendto(self.data1, 0, (HOST, self.port))
+
+ self.data2 = b'egg'
+ self.cli.sendto(self.data2, 0, (HOST, self.port))
+
+ def testSelect(self):
+ r, w, x = select.select([self.serv], [], [], 3.0)
+ self.assertIn(self.serv, r)
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data, data)
+
+ def _testSelect(self):
+ self.data = b'select'
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+
+ def testCongestion(self):
+ # wait until the sender is done
+ self.evt.wait()
+
+ def _testCongestion(self):
+ # test the behavior in case of congestion
+ self.data = b'fill'
+ self.cli.setblocking(False)
+ try:
+ # try to lower the receiver's socket buffer size
+ self.cli.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 16384)
+ except OSError:
+ pass
+ with self.assertRaises(OSError) as cm:
+ try:
+ # fill the receiver's socket buffer
+ while True:
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+ finally:
+ # signal the receiver we're done
+ self.evt.set()
+ # sendto() should have failed with ENOBUFS
+ self.assertEqual(cm.exception.errno, errno.ENOBUFS)
+ # and we should have received a congestion notification through poll
+ r, w, x = select.select([self.serv], [], [], 3.0)
+ self.assertIn(self.serv, r)
+
+
@unittest.skipUnless(thread, 'Threading required for this test.')
class BasicTCPTest(SocketConnectedTest):
@@ -1016,6 +1643,1866 @@ class BasicUDPTest(ThreadedUDPSocketTest):
def _testRecvFromNegative(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
+# Tests for the sendmsg()/recvmsg() interface. Where possible, the
+# same test code is used with different families and types of socket
+# (e.g. stream, datagram), and tests using recvmsg() are repeated
+# using recvmsg_into().
+#
+# The generic test classes such as SendmsgTests and
+# RecvmsgGenericTests inherit from SendrecvmsgBase and expect to be
+# supplied with sockets cli_sock and serv_sock representing the
+# client's and the server's end of the connection respectively, and
+# attributes cli_addr and serv_addr holding their (numeric where
+# appropriate) addresses.
+#
+# The final concrete test classes combine these with subclasses of
+# SocketTestBase which set up client and server sockets of a specific
+# type, and with subclasses of SendrecvmsgBase such as
+# SendrecvmsgDgramBase and SendrecvmsgConnectedBase which map these
+# sockets to cli_sock and serv_sock and override the methods and
+# attributes of SendrecvmsgBase to fill in destination addresses if
+# needed when sending, check for specific flags in msg_flags, etc.
+#
+# RecvmsgIntoMixin provides a version of doRecvmsg() implemented using
+# recvmsg_into().
+
+# XXX: like the other datagram (UDP) tests in this module, the code
+# here assumes that datagram delivery on the local machine will be
+# reliable.
+
+class SendrecvmsgBase(ThreadSafeCleanupTestCase):
+ # Base class for sendmsg()/recvmsg() tests.
+
+ # Time in seconds to wait before considering a test failed, or
+ # None for no timeout. Not all tests actually set a timeout.
+ fail_timeout = 3.0
+
+ def setUp(self):
+ self.misc_event = threading.Event()
+ super().setUp()
+
+ def sendToServer(self, msg):
+ # Send msg to the server.
+ return self.cli_sock.send(msg)
+
+ # Tuple of alternative default arguments for sendmsg() when called
+ # via sendmsgToServer() (e.g. to include a destination address).
+ sendmsg_to_server_defaults = ()
+
+ def sendmsgToServer(self, *args):
+ # Call sendmsg() on self.cli_sock with the given arguments,
+ # filling in any arguments which are not supplied with the
+ # corresponding items of self.sendmsg_to_server_defaults, if
+ # any.
+ return self.cli_sock.sendmsg(
+ *(args + self.sendmsg_to_server_defaults[len(args):]))
+
+ def doRecvmsg(self, sock, bufsize, *args):
+ # Call recvmsg() on sock with given arguments and return its
+ # result. Should be used for tests which can use either
+ # recvmsg() or recvmsg_into() - RecvmsgIntoMixin overrides
+ # this method with one which emulates it using recvmsg_into(),
+ # thus allowing the same test to be used for both methods.
+ result = sock.recvmsg(bufsize, *args)
+ self.registerRecvmsgResult(result)
+ return result
+
+ def registerRecvmsgResult(self, result):
+ # Called by doRecvmsg() with the return value of recvmsg() or
+ # recvmsg_into(). Can be overridden to arrange cleanup based
+ # on the returned ancillary data, for instance.
+ pass
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Called to compare the received address with the address of
+ # the peer.
+ self.assertEqual(addr1, addr2)
+
+ # Flags that are normally unset in msg_flags
+ msg_flags_common_unset = 0
+ for name in ("MSG_CTRUNC", "MSG_OOB"):
+ msg_flags_common_unset |= getattr(socket, name, 0)
+
+ # Flags that are normally set
+ msg_flags_common_set = 0
+
+ # Flags set when a complete record has been received (e.g. MSG_EOR
+ # for SCTP)
+ msg_flags_eor_indicator = 0
+
+ # Flags set when a complete record has not been received
+ # (e.g. MSG_TRUNC for datagram sockets)
+ msg_flags_non_eor_indicator = 0
+
+ def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0):
+ # Method to check the value of msg_flags returned by recvmsg[_into]().
+ #
+ # Checks that all bits in msg_flags_common_set attribute are
+ # set in "flags" and all bits in msg_flags_common_unset are
+ # unset.
+ #
+ # The "eor" argument specifies whether the flags should
+ # indicate that a full record (or datagram) has been received.
+ # If "eor" is None, no checks are done; otherwise, checks
+ # that:
+ #
+ # * if "eor" is true, all bits in msg_flags_eor_indicator are
+ # set and all bits in msg_flags_non_eor_indicator are unset
+ #
+ # * if "eor" is false, all bits in msg_flags_non_eor_indicator
+ # are set and all bits in msg_flags_eor_indicator are unset
+ #
+ # If "checkset" and/or "checkunset" are supplied, they require
+ # the given bits to be set or unset respectively, overriding
+ # what the attributes require for those bits.
+ #
+ # If any bits are set in "ignore", they will not be checked,
+ # regardless of the other inputs.
+ #
+ # Will raise Exception if the inputs require a bit to be both
+ # set and unset, and it is not ignored.
+
+ defaultset = self.msg_flags_common_set
+ defaultunset = self.msg_flags_common_unset
+
+ if eor:
+ defaultset |= self.msg_flags_eor_indicator
+ defaultunset |= self.msg_flags_non_eor_indicator
+ elif eor is not None:
+ defaultset |= self.msg_flags_non_eor_indicator
+ defaultunset |= self.msg_flags_eor_indicator
+
+ # Function arguments override defaults
+ defaultset &= ~checkunset
+ defaultunset &= ~checkset
+
+ # Merge arguments with remaining defaults, and check for conflicts
+ checkset |= defaultset
+ checkunset |= defaultunset
+ inboth = checkset & checkunset & ~ignore
+ if inboth:
+ raise Exception("contradictory set, unset requirements for flags "
+ "{0:#x}".format(inboth))
+
+ # Compare with given msg_flags value
+ mask = (checkset | checkunset) & ~ignore
+ self.assertEqual(flags & mask, checkset & mask)
+
+
+class RecvmsgIntoMixin(SendrecvmsgBase):
+ # Mixin to implement doRecvmsg() using recvmsg_into().
+
+ def doRecvmsg(self, sock, bufsize, *args):
+ buf = bytearray(bufsize)
+ result = sock.recvmsg_into([buf], *args)
+ self.registerRecvmsgResult(result)
+ self.assertGreaterEqual(result[0], 0)
+ self.assertLessEqual(result[0], bufsize)
+ return (bytes(buf[:result[0]]),) + result[1:]
+
+
+class SendrecvmsgDgramFlagsBase(SendrecvmsgBase):
+ # Defines flags to be checked in msg_flags for datagram sockets.
+
+ @property
+ def msg_flags_non_eor_indicator(self):
+ return super().msg_flags_non_eor_indicator | socket.MSG_TRUNC
+
+
+class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase):
+ # Defines flags to be checked in msg_flags for SCTP sockets.
+
+ @property
+ def msg_flags_eor_indicator(self):
+ return super().msg_flags_eor_indicator | socket.MSG_EOR
+
+
+class SendrecvmsgConnectionlessBase(SendrecvmsgBase):
+ # Base class for tests on connectionless-mode sockets. Users must
+ # supply sockets on attributes cli and serv to be mapped to
+ # cli_sock and serv_sock respectively.
+
+ @property
+ def serv_sock(self):
+ return self.serv
+
+ @property
+ def cli_sock(self):
+ return self.cli
+
+ @property
+ def sendmsg_to_server_defaults(self):
+ return ([], [], 0, self.serv_addr)
+
+ def sendToServer(self, msg):
+ return self.cli_sock.sendto(msg, self.serv_addr)
+
+
+class SendrecvmsgConnectedBase(SendrecvmsgBase):
+ # Base class for tests on connected sockets. Users must supply
+ # sockets on attributes serv_conn and cli_conn (representing the
+ # connections *to* the server and the client), to be mapped to
+ # cli_sock and serv_sock respectively.
+
+ @property
+ def serv_sock(self):
+ return self.cli_conn
+
+ @property
+ def cli_sock(self):
+ return self.serv_conn
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Address is currently "unspecified" for a connected socket,
+ # so we don't examine it
+ pass
+
+
+class SendrecvmsgServerTimeoutBase(SendrecvmsgBase):
+ # Base class to set a timeout on server's socket.
+
+ def setUp(self):
+ super().setUp()
+ self.serv_sock.settimeout(self.fail_timeout)
+
+
+class SendmsgTests(SendrecvmsgServerTimeoutBase):
+ # Tests for sendmsg() which can use any socket type and do not
+ # involve recvmsg() or recvmsg_into().
+
+ def testSendmsg(self):
+ # Send a simple message with sendmsg().
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsg(self):
+ self.assertEqual(self.sendmsgToServer([MSG]), len(MSG))
+
+ def testSendmsgDataGenerator(self):
+ # Send from buffer obtained from a generator (not a sequence).
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgDataGenerator(self):
+ self.assertEqual(self.sendmsgToServer((o for o in [MSG])),
+ len(MSG))
+
+ def testSendmsgAncillaryGenerator(self):
+ # Gather (empty) ancillary data from a generator.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgAncillaryGenerator(self):
+ self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])),
+ len(MSG))
+
+ def testSendmsgArray(self):
+ # Send data from an array instead of the usual bytes object.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgArray(self):
+ self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]),
+ len(MSG))
+
+ def testSendmsgGather(self):
+ # Send message data from more than one buffer (gather write).
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgGather(self):
+ self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
+
+ def testSendmsgBadArgs(self):
+ # Check that sendmsg() rejects invalid arguments.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgBadArgs(self):
+ self.assertRaises(TypeError, self.cli_sock.sendmsg)
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ b"not in an iterable")
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG, object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [], object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [], 0, object())
+ self.sendToServer(b"done")
+
+ def testSendmsgBadCmsg(self):
+ # Check that invalid ancillary data items are rejected.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgBadCmsg(self):
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(object(), 0, b"data")])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, object(), b"data")])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, object())])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0)])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, b"data", 42)])
+ self.sendToServer(b"done")
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testSendmsgBadMultiCmsg(self):
+ # Check that invalid ancillary data items are rejected when
+ # more than one item is present.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ @testSendmsgBadMultiCmsg.client_skip
+ def _testSendmsgBadMultiCmsg(self):
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [0, 0, b""])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, b""), object()])
+ self.sendToServer(b"done")
+
+ def testSendmsgExcessCmsgReject(self):
+ # Check that sendmsg() rejects excess ancillary data items
+ # when the number that can be sent is limited.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgExcessCmsgReject(self):
+ if not hasattr(socket, "CMSG_SPACE"):
+ # Can only send one item
+ with self.assertRaises(socket.error) as cm:
+ self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")])
+ self.assertIsNone(cm.exception.errno)
+ self.sendToServer(b"done")
+
+ def testSendmsgAfterClose(self):
+ # Check that sendmsg() fails on a closed socket.
+ pass
+
+ def _testSendmsgAfterClose(self):
+ self.cli_sock.close()
+ self.assertRaises(socket.error, self.sendmsgToServer, [MSG])
+
+
+class SendmsgStreamTests(SendmsgTests):
+ # Tests for sendmsg() which require a stream socket and do not
+ # involve recvmsg() or recvmsg_into().
+
+ def testSendmsgExplicitNoneAddr(self):
+ # Check that peer address can be specified as None.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgExplicitNoneAddr(self):
+ self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG))
+
+ def testSendmsgTimeout(self):
+ # Check that timeout works with sendmsg().
+ self.assertEqual(self.serv_sock.recv(512), b"a"*512)
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ def _testSendmsgTimeout(self):
+ try:
+ self.cli_sock.settimeout(0.03)
+ with self.assertRaises(socket.timeout):
+ while True:
+ self.sendmsgToServer([b"a"*512])
+ finally:
+ self.misc_event.set()
+
+ # XXX: would be nice to have more tests for sendmsg flags argument.
+
+ # Linux supports MSG_DONTWAIT when sending, but in general, it
+ # only works when receiving. Could add other platforms if they
+ # support it too.
+ @skipWithClientIf(sys.platform not in {"linux2"},
+ "MSG_DONTWAIT not known to work on this platform when "
+ "sending")
+ def testSendmsgDontWait(self):
+ # Check that MSG_DONTWAIT in flags causes non-blocking behaviour.
+ self.assertEqual(self.serv_sock.recv(512), b"a"*512)
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ @testSendmsgDontWait.client_skip
+ def _testSendmsgDontWait(self):
+ try:
+ with self.assertRaises(socket.error) as cm:
+ while True:
+ self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
+ self.assertIn(cm.exception.errno,
+ (errno.EAGAIN, errno.EWOULDBLOCK))
+ finally:
+ self.misc_event.set()
+
+
+class SendmsgConnectionlessTests(SendmsgTests):
+ # Tests for sendmsg() which require a connectionless-mode
+ # (e.g. datagram) socket, and do not involve recvmsg() or
+ # recvmsg_into().
+
+ def testSendmsgNoDestAddr(self):
+ # Check that sendmsg() fails when no destination address is
+ # given for unconnected socket.
+ pass
+
+ def _testSendmsgNoDestAddr(self):
+ self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ [MSG])
+ self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ [MSG], [], 0, None)
+
+
+class RecvmsgGenericTests(SendrecvmsgBase):
+ # Tests for recvmsg() which can also be emulated using
+ # recvmsg_into(), and can use any socket type.
+
+ def testRecvmsg(self):
+ # Receive a simple message with recvmsg[_into]().
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsg(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgExplicitDefaults(self):
+ # Test recvmsg[_into]() with default arguments provided explicitly.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 0, 0)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgExplicitDefaults(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgShorter(self):
+ # Receive a message smaller than buffer.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) + 42)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgShorter(self):
+ self.sendToServer(MSG)
+
+ # FreeBSD < 8 doesn't always set the MSG_TRUNC flag when a truncated
+ # datagram is received (issue #13001).
+ @support.requires_freebsd_version(8)
+ def testRecvmsgTrunc(self):
+ # Receive part of message, check for truncation indicators.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3)
+ self.assertEqual(msg, MSG[:-3])
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=False)
+
+ @support.requires_freebsd_version(8)
+ def _testRecvmsgTrunc(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgShortAncillaryBuf(self):
+ # Test ancillary data buffer too small to hold any ancillary data.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 1)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgShortAncillaryBuf(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgLongAncillaryBuf(self):
+ # Test large ancillary data buffer.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgLongAncillaryBuf(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgAfterClose(self):
+ # Check that recvmsg[_into]() fails on a closed socket.
+ self.serv_sock.close()
+ self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024)
+
+ def _testRecvmsgAfterClose(self):
+ pass
+
+ def testRecvmsgTimeout(self):
+ # Check that timeout works.
+ try:
+ self.serv_sock.settimeout(0.03)
+ self.assertRaises(socket.timeout,
+ self.doRecvmsg, self.serv_sock, len(MSG))
+ finally:
+ self.misc_event.set()
+
+ def _testRecvmsgTimeout(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ @requireAttrs(socket, "MSG_PEEK")
+ def testRecvmsgPeek(self):
+ # Check that MSG_PEEK in flags enables examination of pending
+ # data without consuming it.
+
+ # Receive part of data with MSG_PEEK.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3, 0,
+ socket.MSG_PEEK)
+ self.assertEqual(msg, MSG[:-3])
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ # Ignoring MSG_TRUNC here (so this test is the same for stream
+ # and datagram sockets). Some wording in POSIX seems to
+ # suggest that it needn't be set when peeking, but that may
+ # just be a slip.
+ self.checkFlags(flags, eor=False,
+ ignore=getattr(socket, "MSG_TRUNC", 0))
+
+ # Receive all data with MSG_PEEK.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 0,
+ socket.MSG_PEEK)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ # Check that the same data can still be received normally.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ @testRecvmsgPeek.client_skip
+ def _testRecvmsgPeek(self):
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ def testRecvmsgFromSendmsg(self):
+ # Test receiving with recvmsg[_into]() when message is sent
+ # using sendmsg().
+ self.serv_sock.settimeout(self.fail_timeout)
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ @testRecvmsgFromSendmsg.client_skip
+ def _testRecvmsgFromSendmsg(self):
+ self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
+
+
+class RecvmsgGenericStreamTests(RecvmsgGenericTests):
+ # Tests which require a stream socket and can use either recvmsg()
+ # or recvmsg_into().
+
+ def testRecvmsgEOF(self):
+ # Receive end-of-stream indicator (b"", peer socket closed).
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
+ self.assertEqual(msg, b"")
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=None) # Might not have end-of-record marker
+
+ def _testRecvmsgEOF(self):
+ self.cli_sock.close()
+
+ def testRecvmsgOverflow(self):
+ # Receive a message in more than one chunk.
+ seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=False)
+
+ seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ msg = seg1 + seg2
+ self.assertEqual(msg, MSG)
+
+ def _testRecvmsgOverflow(self):
+ self.sendToServer(MSG)
+
+
+class RecvmsgTests(RecvmsgGenericTests):
+ # Tests for recvmsg() which can use any socket type.
+
+ def testRecvmsgBadArgs(self):
+ # Check that recvmsg() rejects invalid arguments.
+ self.assertRaises(TypeError, self.serv_sock.recvmsg)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg,
+ -1, 0, 0)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg,
+ len(MSG), -1, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ [bytearray(10)], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ object(), 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ len(MSG), object(), 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ len(MSG), 0, object())
+
+ msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgBadArgs(self):
+ self.sendToServer(MSG)
+
+
+class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests):
+ # Tests for recvmsg_into() which can use any socket type.
+
+ def testRecvmsgIntoBadArgs(self):
+ # Check that recvmsg_into() rejects invalid arguments.
+ buf = bytearray(len(MSG))
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ len(MSG), 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ buf, 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [object()], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [b"I'm not writable"], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf, object()], 0, 0)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg_into,
+ [buf], -1, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf], object(), 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf], 0, object())
+
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0)
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf, bytearray(MSG))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoBadArgs(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoGenerator(self):
+ # Receive into buffer obtained from a generator (not a sequence).
+ buf = bytearray(len(MSG))
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
+ (o for o in [buf]))
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf, bytearray(MSG))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoGenerator(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoArray(self):
+ # Receive into an array rather than the usual bytearray.
+ buf = array.array("B", [0] * len(MSG))
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf])
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf.tobytes(), MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoArray(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoScatter(self):
+ # Receive into multiple buffers (scatter write).
+ b1 = bytearray(b"----")
+ b2 = bytearray(b"0123456789")
+ b3 = bytearray(b"--------------")
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
+ [b1, memoryview(b2)[2:9], b3])
+ self.assertEqual(nbytes, len(b"Mary had a little lamb"))
+ self.assertEqual(b1, bytearray(b"Mary"))
+ self.assertEqual(b2, bytearray(b"01 had a 9"))
+ self.assertEqual(b3, bytearray(b"little lamb---"))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoScatter(self):
+ self.sendToServer(b"Mary had a little lamb")
+
+
+class CmsgMacroTests(unittest.TestCase):
+ # Test the functions CMSG_LEN() and CMSG_SPACE(). Tests
+ # assumptions used by sendmsg() and recvmsg[_into](), which share
+ # code with these functions.
+
+ # Match the definition in socketmodule.c
+ socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX)
+
+ @requireAttrs(socket, "CMSG_LEN")
+ def testCMSG_LEN(self):
+ # Test CMSG_LEN() with various valid and invalid values,
+ # checking the assumptions used by recvmsg() and sendmsg().
+ toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1
+ values = list(range(257)) + list(range(toobig - 257, toobig))
+
+ # struct cmsghdr has at least three members, two of which are ints
+ self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2)
+ for n in values:
+ ret = socket.CMSG_LEN(n)
+ # This is how recvmsg() calculates the data size
+ self.assertEqual(ret - socket.CMSG_LEN(0), n)
+ self.assertLessEqual(ret, self.socklen_t_limit)
+
+ self.assertRaises(OverflowError, socket.CMSG_LEN, -1)
+ # sendmsg() shares code with these functions, and requires
+ # that it reject values over the limit.
+ self.assertRaises(OverflowError, socket.CMSG_LEN, toobig)
+ self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testCMSG_SPACE(self):
+ # Test CMSG_SPACE() with various valid and invalid values,
+ # checking the assumptions used by sendmsg().
+ toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1
+ values = list(range(257)) + list(range(toobig - 257, toobig))
+
+ last = socket.CMSG_SPACE(0)
+ # struct cmsghdr has at least three members, two of which are ints
+ self.assertGreater(last, array.array("i").itemsize * 2)
+ for n in values:
+ ret = socket.CMSG_SPACE(n)
+ self.assertGreaterEqual(ret, last)
+ self.assertGreaterEqual(ret, socket.CMSG_LEN(n))
+ self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0))
+ self.assertLessEqual(ret, self.socklen_t_limit)
+ last = ret
+
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, -1)
+ # sendmsg() shares code with these functions, and requires
+ # that it reject values over the limit.
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig)
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize)
+
+
+class SCMRightsTest(SendrecvmsgServerTimeoutBase):
+ # Tests for file descriptor passing on Unix-domain sockets.
+
+ # Invalid file descriptor value that's unlikely to evaluate to a
+ # real FD even if one of its bytes is replaced with a different
+ # value (which shouldn't actually happen).
+ badfd = -0x5555
+
+ def newFDs(self, n):
+ # Return a list of n file descriptors for newly-created files
+ # containing their list indices as ASCII numbers.
+ fds = []
+ for i in range(n):
+ fd, path = tempfile.mkstemp()
+ self.addCleanup(os.unlink, path)
+ self.addCleanup(os.close, fd)
+ os.write(fd, str(i).encode())
+ fds.append(fd)
+ return fds
+
+ def checkFDs(self, fds):
+ # Check that the file descriptors in the given list contain
+ # their correct list indices as ASCII numbers.
+ for n, fd in enumerate(fds):
+ os.lseek(fd, 0, os.SEEK_SET)
+ self.assertEqual(os.read(fd, 1024), str(n).encode())
+
+ def registerRecvmsgResult(self, result):
+ self.addCleanup(self.closeRecvmsgFDs, result)
+
+ def closeRecvmsgFDs(self, recvmsg_result):
+ # Close all file descriptors specified in the ancillary data
+ # of the given return value from recvmsg() or recvmsg_into().
+ for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]:
+ if (cmsg_level == socket.SOL_SOCKET and
+ cmsg_type == socket.SCM_RIGHTS):
+ fds = array.array("i")
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ for fd in fds:
+ os.close(fd)
+
+ def createAndSendFDs(self, n):
+ # Send n new file descriptors created by newFDs() to the
+ # server, with the constant MSG as the non-ancillary data.
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", self.newFDs(n)))]),
+ len(MSG))
+
+ def checkRecvmsgFDs(self, numfds, result, maxcmsgs=1, ignoreflags=0):
+ # Check that constant MSG was received with numfds file
+ # descriptors in a maximum of maxcmsgs control messages (which
+ # must contain only complete integers). By default, check
+ # that MSG_CTRUNC is unset, but ignore any flags in
+ # ignoreflags.
+ msg, ancdata, flags, addr = result
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertIsInstance(ancdata, list)
+ self.assertLessEqual(len(ancdata), maxcmsgs)
+ fds = array.array("i")
+ for item in ancdata:
+ self.assertIsInstance(item, tuple)
+ cmsg_level, cmsg_type, cmsg_data = item
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertIsInstance(cmsg_data, bytes)
+ self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0)
+ fds.frombytes(cmsg_data)
+
+ self.assertEqual(len(fds), numfds)
+ self.checkFDs(fds)
+
+ def testFDPassSimple(self):
+ # Pass a single FD (array read from bytes object).
+ self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240))
+
+ def _testFDPassSimple(self):
+ self.assertEqual(
+ self.sendmsgToServer(
+ [MSG],
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", self.newFDs(1)).tobytes())]),
+ len(MSG))
+
+ def testMultipleFDPass(self):
+ # Pass multiple FDs in a single array.
+ self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240))
+
+ def _testMultipleFDPass(self):
+ self.createAndSendFDs(4)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassCMSG_SPACE(self):
+ # Test using CMSG_SPACE() to calculate ancillary buffer size.
+ self.checkRecvmsgFDs(
+ 4, self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(4 * SIZEOF_INT)))
+
+ @testFDPassCMSG_SPACE.client_skip
+ def _testFDPassCMSG_SPACE(self):
+ self.createAndSendFDs(4)
+
+ def testFDPassCMSG_LEN(self):
+ # Test using CMSG_LEN() to calculate ancillary buffer size.
+ self.checkRecvmsgFDs(1,
+ self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_LEN(4 * SIZEOF_INT)),
+ # RFC 3542 says implementations may set
+ # MSG_CTRUNC if there isn't enough space
+ # for trailing padding.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testFDPassCMSG_LEN(self):
+ self.createAndSendFDs(1)
+
+ # Issue #12958: The following test has problems on Mac OS X
+ @support.anticipate_failure(sys.platform == "darwin")
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassSeparate(self):
+ # Pass two FDs in two separate arrays. Arrays may be combined
+ # into a single control message by the OS.
+ self.checkRecvmsgFDs(2,
+ self.doRecvmsg(self.serv_sock, len(MSG), 10240),
+ maxcmsgs=2)
+
+ @testFDPassSeparate.client_skip
+ @support.anticipate_failure(sys.platform == "darwin")
+ def _testFDPassSeparate(self):
+ fd0, fd1 = self.newFDs(2)
+ self.assertEqual(
+ self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0])),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))]),
+ len(MSG))
+
+ # Issue #12958: The following test has problems on Mac OS X
+ @support.anticipate_failure(sys.platform == "darwin")
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassSeparateMinSpace(self):
+ # Pass two FDs in two separate arrays, receiving them into the
+ # minimum space for two arrays.
+ self.checkRecvmsgFDs(2,
+ self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(SIZEOF_INT) +
+ socket.CMSG_LEN(SIZEOF_INT)),
+ maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC)
+
+ @testFDPassSeparateMinSpace.client_skip
+ @support.anticipate_failure(sys.platform == "darwin")
+ def _testFDPassSeparateMinSpace(self):
+ fd0, fd1 = self.newFDs(2)
+ self.assertEqual(
+ self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0])),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))]),
+ len(MSG))
+
+ def sendAncillaryIfPossible(self, msg, ancdata):
+ # Try to send msg and ancdata to server, but if the system
+ # call fails, just send msg with no ancillary data.
+ try:
+ nbytes = self.sendmsgToServer([msg], ancdata)
+ except socket.error as e:
+ # Check that it was the system call that failed
+ self.assertIsInstance(e.errno, int)
+ nbytes = self.sendmsgToServer([msg])
+ self.assertEqual(nbytes, len(msg))
+
+ def testFDPassEmpty(self):
+ # Try to pass an empty FD array. Can receive either no array
+ # or an empty array.
+ self.checkRecvmsgFDs(0, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testFDPassEmpty(self):
+ self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ b"")])
+
+ def testFDPassPartialInt(self):
+ # Try to pass a truncated FD array.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
+ self.assertLessEqual(len(ancdata), 1)
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ def _testFDPassPartialInt(self):
+ self.sendAncillaryIfPossible(
+ MSG,
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [self.badfd]).tobytes()[:-1])])
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassPartialIntInMiddle(self):
+ # Try to pass two FD arrays, the first of which is truncated.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
+ self.assertLessEqual(len(ancdata), 2)
+ fds = array.array("i")
+ # Arrays may have been combined in a single control message
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ self.assertLessEqual(len(fds), 2)
+ self.checkFDs(fds)
+
+ @testFDPassPartialIntInMiddle.client_skip
+ def _testFDPassPartialIntInMiddle(self):
+ fd0, fd1 = self.newFDs(2)
+ self.sendAncillaryIfPossible(
+ MSG,
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0, self.badfd]).tobytes()[:-1]),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))])
+
+ def checkTruncatedHeader(self, result, ignoreflags=0):
+ # Check that no ancillary data items are returned when data is
+ # truncated inside the cmsghdr structure.
+ msg, ancdata, flags, addr = result
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ def testCmsgTruncNoBufSize(self):
+ # Check that no ancillary data is received when no buffer size
+ # is specified.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG)),
+ # BSD seems to set MSG_CTRUNC only
+ # if an item has been partially
+ # received.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testCmsgTruncNoBufSize(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTrunc0(self):
+ # Check that no ancillary data is received when buffer size is 0.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testCmsgTrunc0(self):
+ self.createAndSendFDs(1)
+
+ # Check that no ancillary data is returned for various non-zero
+ # (but still too small) buffer sizes.
+
+ def testCmsgTrunc1(self):
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1))
+
+ def _testCmsgTrunc1(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTrunc2Int(self):
+ # The cmsghdr structure has at least three members, two of
+ # which are ints, so we still shouldn't see any ancillary
+ # data.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
+ SIZEOF_INT * 2))
+
+ def _testCmsgTrunc2Int(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTruncLen0Minus1(self):
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_LEN(0) - 1))
+
+ def _testCmsgTruncLen0Minus1(self):
+ self.createAndSendFDs(1)
+
+ # The following tests try to truncate the control message in the
+ # middle of the FD array.
+
+ def checkTruncatedArray(self, ancbuf, maxdata, mindata=0):
+ # Check that file descriptor data is truncated to between
+ # mindata and maxdata bytes when received with buffer size
+ # ancbuf, and that any complete file descriptor numbers are
+ # valid.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbuf)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ if mindata == 0 and ancdata == []:
+ return
+ self.assertEqual(len(ancdata), 1)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertGreaterEqual(len(cmsg_data), mindata)
+ self.assertLessEqual(len(cmsg_data), maxdata)
+ fds = array.array("i")
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ self.checkFDs(fds)
+
+ def testCmsgTruncLen0(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0)
+
+ def _testCmsgTruncLen0(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTruncLen0Plus1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1)
+
+ def _testCmsgTruncLen0Plus1(self):
+ self.createAndSendFDs(2)
+
+ def testCmsgTruncLen1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT),
+ maxdata=SIZEOF_INT)
+
+ def _testCmsgTruncLen1(self):
+ self.createAndSendFDs(2)
+
+ def testCmsgTruncLen2Minus1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1,
+ maxdata=(2 * SIZEOF_INT) - 1)
+
+ def _testCmsgTruncLen2Minus1(self):
+ self.createAndSendFDs(2)
+
+
+class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase):
+ # Test sendmsg() and recvmsg[_into]() using the ancillary data
+ # features of the RFC 3542 Advanced Sockets API for IPv6.
+ # Currently we can only handle certain data items (e.g. traffic
+ # class, hop limit, MTU discovery and fragmentation settings)
+ # without resorting to unportable means such as the struct module,
+ # but the tests here are aimed at testing the ancillary data
+ # handling in sendmsg() and recvmsg() rather than the IPv6 API
+ # itself.
+
+ # Test value to use when setting hop limit of packet
+ hop_limit = 2
+
+ # Test value to use when setting traffic class of packet.
+ # -1 means "use kernel default".
+ traffic_class = -1
+
+ def ancillaryMapping(self, ancdata):
+ # Given ancillary data list ancdata, return a mapping from
+ # pairs (cmsg_level, cmsg_type) to corresponding cmsg_data.
+ # Check that no (level, type) pair appears more than once.
+ d = {}
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertNotIn((cmsg_level, cmsg_type), d)
+ d[(cmsg_level, cmsg_type)] = cmsg_data
+ return d
+
+ def checkHopLimit(self, ancbufsize, maxhop=255, ignoreflags=0):
+ # Receive hop limit into ancbufsize bytes of ancillary data
+ # space. Check that data is MSG, ancillary data is not
+ # truncated (but ignore any flags in ignoreflags), and hop
+ # limit is between 0 and maxhop inclusive.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertEqual(len(ancdata), 1)
+ self.assertIsInstance(ancdata[0], tuple)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
+ self.assertIsInstance(cmsg_data, bytes)
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], maxhop)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testRecvHopLimit(self):
+ # Test receiving the packet hop limit as ancillary data.
+ self.checkHopLimit(ancbufsize=10240)
+
+ @testRecvHopLimit.client_skip
+ def _testRecvHopLimit(self):
+ # Need to wait until server has asked to receive ancillary
+ # data, as implementations are not required to buffer it
+ # otherwise.
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testRecvHopLimitCMSG_SPACE(self):
+ # Test receiving hop limit, using CMSG_SPACE to calculate buffer size.
+ self.checkHopLimit(ancbufsize=socket.CMSG_SPACE(SIZEOF_INT))
+
+ @testRecvHopLimitCMSG_SPACE.client_skip
+ def _testRecvHopLimitCMSG_SPACE(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ # Could test receiving into buffer sized using CMSG_LEN, but RFC
+ # 3542 says portable applications must provide space for trailing
+ # padding. Implementations may set MSG_CTRUNC if there isn't
+ # enough space for the padding.
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSetHopLimit(self):
+ # Test setting hop limit on outgoing packet and receiving it
+ # at the other end.
+ self.checkHopLimit(ancbufsize=10240, maxhop=self.hop_limit)
+
+ @testSetHopLimit.client_skip
+ def _testSetHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))]),
+ len(MSG))
+
+ def checkTrafficClassAndHopLimit(self, ancbufsize, maxhop=255,
+ ignoreflags=0):
+ # Receive traffic class and hop limit into ancbufsize bytes of
+ # ancillary data space. Check that data is MSG, ancillary
+ # data is not truncated (but ignore any flags in ignoreflags),
+ # and traffic class and hop limit are in range (hop limit no
+ # more than maxhop).
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+ self.assertEqual(len(ancdata), 2)
+ ancmap = self.ancillaryMapping(ancdata)
+
+ tcdata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_TCLASS)]
+ self.assertEqual(len(tcdata), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(tcdata)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ hldata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT)]
+ self.assertEqual(len(hldata), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(hldata)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], maxhop)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testRecvTrafficClassAndHopLimit(self):
+ # Test receiving traffic class and hop limit as ancillary data.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240)
+
+ @testRecvTrafficClassAndHopLimit.client_skip
+ def _testRecvTrafficClassAndHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
+ # Test receiving traffic class and hop limit, using
+ # CMSG_SPACE() to calculate buffer size.
+ self.checkTrafficClassAndHopLimit(
+ ancbufsize=socket.CMSG_SPACE(SIZEOF_INT) * 2)
+
+ @testRecvTrafficClassAndHopLimitCMSG_SPACE.client_skip
+ def _testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSetTrafficClassAndHopLimit(self):
+ # Test setting traffic class and hop limit on outgoing packet,
+ # and receiving them at the other end.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240,
+ maxhop=self.hop_limit)
+
+ @testSetTrafficClassAndHopLimit.client_skip
+ def _testSetTrafficClassAndHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class])),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))]),
+ len(MSG))
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testOddCmsgSize(self):
+ # Try to send ancillary data with first item one byte too
+ # long. Fall back to sending with correct size if this fails,
+ # and check that second item was handled correctly.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240,
+ maxhop=self.hop_limit)
+
+ @testOddCmsgSize.client_skip
+ def _testOddCmsgSize(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ try:
+ nbytes = self.sendmsgToServer(
+ [MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class]).tobytes() + b"\x00"),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))])
+ except socket.error as e:
+ self.assertIsInstance(e.errno, int)
+ nbytes = self.sendmsgToServer(
+ [MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class])),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))])
+ self.assertEqual(nbytes, len(MSG))
+
+ # Tests for proper handling of truncated ancillary data
+
+ def checkHopLimitTruncatedHeader(self, ancbufsize, ignoreflags=0):
+ # Receive hop limit into ancbufsize bytes of ancillary data
+ # space, which should be too small to contain the ancillary
+ # data header (if ancbufsize is None, pass no second argument
+ # to recvmsg()). Check that data is MSG, MSG_CTRUNC is set
+ # (unless included in ignoreflags), and no ancillary data is
+ # returned.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ args = () if ancbufsize is None else (ancbufsize,)
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), *args)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testCmsgTruncNoBufSize(self):
+ # Check that no ancillary data is received when no ancillary
+ # buffer size is provided.
+ self.checkHopLimitTruncatedHeader(ancbufsize=None,
+ # BSD seems to set
+ # MSG_CTRUNC only if an item
+ # has been partially
+ # received.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testCmsgTruncNoBufSize.client_skip
+ def _testCmsgTruncNoBufSize(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc0(self):
+ # Check that no ancillary data is received when ancillary
+ # buffer size is zero.
+ self.checkHopLimitTruncatedHeader(ancbufsize=0,
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testSingleCmsgTrunc0.client_skip
+ def _testSingleCmsgTrunc0(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ # Check that no ancillary data is returned for various non-zero
+ # (but still too small) buffer sizes.
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc1(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=1)
+
+ @testSingleCmsgTrunc1.client_skip
+ def _testSingleCmsgTrunc1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc2Int(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=2 * SIZEOF_INT)
+
+ @testSingleCmsgTrunc2Int.client_skip
+ def _testSingleCmsgTrunc2Int(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTruncLen0Minus1(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=socket.CMSG_LEN(0) - 1)
+
+ @testSingleCmsgTruncLen0Minus1.client_skip
+ def _testSingleCmsgTruncLen0Minus1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTruncInData(self):
+ # Test truncation of a control message inside its associated
+ # data. The message may be returned with its data truncated,
+ # or not returned at all.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(
+ self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ self.assertLessEqual(len(ancdata), 1)
+ if ancdata:
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ @testSingleCmsgTruncInData.client_skip
+ def _testSingleCmsgTruncInData(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ def checkTruncatedSecondHeader(self, ancbufsize, ignoreflags=0):
+ # Receive traffic class and hop limit into ancbufsize bytes of
+ # ancillary data space, which should be large enough to
+ # contain the first item, but too small to contain the header
+ # of the second. Check that data is MSG, MSG_CTRUNC is set
+ # (unless included in ignoreflags), and only one ancillary
+ # data item is returned.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertEqual(len(ancdata), 1)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertIn(cmsg_type, {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT})
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ # Try the above test with various buffer sizes.
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc0(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testSecondCmsgTrunc0.client_skip
+ def _testSecondCmsgTrunc0(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc1(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + 1)
+
+ @testSecondCmsgTrunc1.client_skip
+ def _testSecondCmsgTrunc1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc2Int(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
+ 2 * SIZEOF_INT)
+
+ @testSecondCmsgTrunc2Int.client_skip
+ def _testSecondCmsgTrunc2Int(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTruncLen0Minus1(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
+ socket.CMSG_LEN(0) - 1)
+
+ @testSecondCmsgTruncLen0Minus1.client_skip
+ def _testSecondCmsgTruncLen0Minus1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecomdCmsgTruncInData(self):
+ # Test truncation of the second of two control messages inside
+ # its associated data.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(
+ self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ cmsg_types = {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT}
+
+ cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ cmsg_types.remove(cmsg_type)
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ if ancdata:
+ cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ cmsg_types.remove(cmsg_type)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ self.assertEqual(ancdata, [])
+
+ @testSecomdCmsgTruncInData.client_skip
+ def _testSecomdCmsgTruncInData(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+
+# Derive concrete test classes for different socket types.
+
+class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase):
+ pass
+
+
+class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
+ SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
+ RFC3542AncillaryTest,
+ SendrecvmsgUDP6TestBase):
+ pass
+
+
+class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, TCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgTCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgTCPTestBase):
+ pass
+
+
+class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase,
+ SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, SCTPStreamBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgSCTPStreamTestBase):
+
+ def testRecvmsgEOF(self):
+ try:
+ super(RecvmsgSCTPStreamTest, self).testRecvmsgEOF()
+ except OSError as e:
+ if e.errno != errno.ENOTCONN:
+ raise
+ self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgSCTPStreamTestBase):
+
+ def testRecvmsgEOF(self):
+ try:
+ super(RecvmsgIntoSCTPStreamTest, self).testRecvmsgEOF()
+ except OSError as e:
+ if e.errno != errno.ENOTCONN:
+ raise
+ self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
+
+
+class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, UnixStreamBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg", "recvmsg")
+@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg", "recvmsg_into")
+@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+
+# Test interrupting the interruptible send/receive methods with a
+# signal when a timeout is set. These tests avoid having multiple
+# threads alive during the test so that the OS cannot deliver the
+# signal to the wrong one.
+
+class InterruptedTimeoutBase(unittest.TestCase):
+ # Base class for interrupted send/receive tests. Installs an
+ # empty handler for SIGALRM and removes it on teardown, along with
+ # any scheduled alarms.
+
+ def setUp(self):
+ super().setUp()
+ orig_alrm_handler = signal.signal(signal.SIGALRM,
+ lambda signum, frame: None)
+ self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
+ self.addCleanup(self.setAlarm, 0)
+
+ # Timeout for socket operations
+ timeout = 4.0
+
+ # Provide setAlarm() method to schedule delivery of SIGALRM after
+ # given number of seconds, or cancel it if zero, and an
+ # appropriate time value to use. Use setitimer() if available.
+ if hasattr(signal, "setitimer"):
+ alarm_time = 0.05
+
+ def setAlarm(self, seconds):
+ signal.setitimer(signal.ITIMER_REAL, seconds)
+ else:
+ # Old systems may deliver the alarm up to one second early
+ alarm_time = 2
+
+ def setAlarm(self, seconds):
+ signal.alarm(seconds)
+
+
+# Require siginterrupt() in order to ensure that system calls are
+# interrupted by default.
+@requireAttrs(signal, "siginterrupt")
+@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
+ "Don't have signal.alarm or signal.setitimer")
+class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
+ # Test interrupting the recv*() methods with signals when a
+ # timeout is set.
+
+ def setUp(self):
+ super().setUp()
+ self.serv.settimeout(self.timeout)
+
+ def checkInterruptedRecv(self, func, *args, **kwargs):
+ # Check that func(*args, **kwargs) raises socket.error with an
+ # errno of EINTR when interrupted by a signal.
+ self.setAlarm(self.alarm_time)
+ with self.assertRaises(socket.error) as cm:
+ func(*args, **kwargs)
+ self.assertNotIsInstance(cm.exception, socket.timeout)
+ self.assertEqual(cm.exception.errno, errno.EINTR)
+
+ def testInterruptedRecvTimeout(self):
+ self.checkInterruptedRecv(self.serv.recv, 1024)
+
+ def testInterruptedRecvIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recv_into, bytearray(1024))
+
+ def testInterruptedRecvfromTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvfrom, 1024)
+
+ def testInterruptedRecvfromIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvfrom_into, bytearray(1024))
+
+ @requireAttrs(socket.socket, "recvmsg")
+ def testInterruptedRecvmsgTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvmsg, 1024)
+
+ @requireAttrs(socket.socket, "recvmsg_into")
+ def testInterruptedRecvmsgIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvmsg_into, [bytearray(1024)])
+
+
+# Require siginterrupt() in order to ensure that system calls are
+# interrupted by default.
+@requireAttrs(signal, "siginterrupt")
+@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
+ "Don't have signal.alarm or signal.setitimer")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
+ ThreadSafeCleanupTestCase,
+ SocketListeningTestMixin, TCPTestBase):
+ # Test interrupting the interruptible send*() methods with signals
+ # when a timeout is set.
+
+ def setUp(self):
+ super().setUp()
+ self.serv_conn = self.newSocket()
+ self.addCleanup(self.serv_conn.close)
+ # Use a thread to complete the connection, but wait for it to
+ # terminate before running the test, so that there is only one
+ # thread to accept the signal.
+ cli_thread = threading.Thread(target=self.doConnect)
+ cli_thread.start()
+ self.cli_conn, addr = self.serv.accept()
+ self.addCleanup(self.cli_conn.close)
+ cli_thread.join()
+ self.serv_conn.settimeout(self.timeout)
+
+ def doConnect(self):
+ self.serv_conn.connect(self.serv_addr)
+
+ def checkInterruptedSend(self, func, *args, **kwargs):
+ # Check that func(*args, **kwargs), run in a loop, raises
+ # socket.error with an errno of EINTR when interrupted by a
+ # signal.
+ with self.assertRaises(socket.error) as cm:
+ while True:
+ self.setAlarm(self.alarm_time)
+ func(*args, **kwargs)
+ self.assertNotIsInstance(cm.exception, socket.timeout)
+ self.assertEqual(cm.exception.errno, errno.EINTR)
+
+ # Issue #12958: The following tests have problems on Mac OS X
+ @support.anticipate_failure(sys.platform == "darwin")
+ def testInterruptedSendTimeout(self):
+ self.checkInterruptedSend(self.serv_conn.send, b"a"*512)
+
+ @support.anticipate_failure(sys.platform == "darwin")
+ def testInterruptedSendtoTimeout(self):
+ # Passing an actual address here as Python's wrapper for
+ # sendto() doesn't allow passing a zero-length one; POSIX
+ # requires that the address is ignored since the socket is
+ # connection-mode, however.
+ self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512,
+ self.serv_addr)
+
+ @support.anticipate_failure(sys.platform == "darwin")
+ @requireAttrs(socket.socket, "sendmsg")
+ def testInterruptedSendmsgTimeout(self):
+ self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512])
+
+
@unittest.skipUnless(thread, 'Threading required for this test.')
class TCPCloserTest(ThreadedTCPSocketTest):
@@ -1099,11 +3586,8 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
pass
if hasattr(socket, "SOCK_NONBLOCK"):
+ @support.requires_linux_version(2, 6, 28)
def testInitNonBlocking(self):
- v = linux_version()
- if v < (2, 6, 28):
- self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
- % ".".join(map(str, v)))
# reinit server socket
self.serv.close()
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM |
@@ -1205,7 +3689,7 @@ class FileObjectClassTestCase(SocketConnectedTest):
"""
bufsize = -1 # Use default buffer size
- encoding = 'utf8'
+ encoding = 'utf-8'
errors = 'strict'
newline = None
@@ -1386,7 +3870,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase):
@staticmethod
def _raise_eintr():
- raise socket.error(errno.EINTR)
+ raise socket.error(errno.EINTR, "interrupted")
def _textiowrap_mock_socket(self, mock, buffering=-1):
raw = socket.SocketIO(mock, "r")
@@ -1426,7 +3910,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase):
data = b''
else:
data = ''
- expecting = expecting.decode('utf8')
+ expecting = expecting.decode('utf-8')
while len(data) != len(expecting):
part = fo.read(size)
if not part:
@@ -1588,7 +4072,7 @@ class UnicodeReadFileObjectClassTestCase(FileObjectClassTestCase):
"""Tests for socket.makefile() in text mode (rather than binary)"""
read_mode = 'r'
- read_msg = MSG.decode('utf8')
+ read_msg = MSG.decode('utf-8')
write_mode = 'wb'
write_msg = MSG
newline = ''
@@ -1600,7 +4084,7 @@ class UnicodeWriteFileObjectClassTestCase(FileObjectClassTestCase):
read_mode = 'rb'
read_msg = MSG
write_mode = 'w'
- write_msg = MSG.decode('utf8')
+ write_msg = MSG.decode('utf-8')
newline = ''
@@ -1608,9 +4092,9 @@ class UnicodeReadWriteFileObjectClassTestCase(FileObjectClassTestCase):
"""Tests for socket.makefile() in text mode (rather than binary)"""
read_mode = 'r'
- read_msg = MSG.decode('utf8')
+ read_msg = MSG.decode('utf-8')
write_mode = 'w'
- write_msg = MSG.decode('utf8')
+ write_msg = MSG.decode('utf-8')
newline = ''
@@ -1901,6 +4385,78 @@ class TestLinuxAbstractNamespace(unittest.TestCase):
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
self.assertRaises(socket.error, s.bind, address)
+ def testStrName(self):
+ # Check that an abstract name can be passed as a string.
+ s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ s.bind("\x00python\x00test\x00")
+ self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
+ finally:
+ s.close()
+
+class TestUnixDomain(unittest.TestCase):
+
+ def setUp(self):
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+ def tearDown(self):
+ self.sock.close()
+
+ def encoded(self, path):
+ # Return the given path encoded in the file system encoding,
+ # or skip the test if this is not possible.
+ try:
+ return os.fsencode(path)
+ except UnicodeEncodeError:
+ self.skipTest(
+ "Pathname {0!a} cannot be represented in file "
+ "system encoding {1!r}".format(
+ path, sys.getfilesystemencoding()))
+
+ def bind(self, sock, path):
+ # Bind the socket
+ try:
+ sock.bind(path)
+ except OSError as e:
+ if str(e) == "AF_UNIX path too long":
+ self.skipTest(
+ "Pathname {0!a} is too long to serve as a AF_UNIX path"
+ .format(path))
+ else:
+ raise
+
+ def testStrAddr(self):
+ # Test binding to and retrieving a normal string pathname.
+ path = os.path.abspath(support.TESTFN)
+ self.bind(self.sock, path)
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
+
+ def testBytesAddr(self):
+ # Test binding to a bytes pathname.
+ path = os.path.abspath(support.TESTFN)
+ self.bind(self.sock, self.encoded(path))
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
+
+ def testSurrogateescapeBind(self):
+ # Test binding to a valid non-ASCII pathname, with the
+ # non-ASCII bytes supplied using surrogateescape encoding.
+ path = os.path.abspath(support.TESTFN_UNICODE)
+ b = self.encoded(path)
+ self.bind(self.sock, b.decode("ascii", "surrogateescape"))
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
+
+ def testUnencodableAddr(self):
+ # Test binding to a pathname that cannot be encoded in the
+ # file system encoding.
+ if support.TESTFN_UNENCODABLE is None:
+ self.skipTest("No unencodable filename available")
+ path = os.path.abspath(support.TESTFN_UNENCODABLE)
+ self.bind(self.sock, path)
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
@unittest.skipUnless(thread, 'Threading required for this test.')
class BufferIOTest(SocketConnectedTest):
@@ -2101,11 +4657,8 @@ class ContextManagersTest(ThreadedTCPSocketTest):
"SOCK_CLOEXEC not defined")
@unittest.skipUnless(fcntl, "module fcntl not available")
class CloexecConstantTest(unittest.TestCase):
+ @support.requires_linux_version(2, 6, 28)
def test_SOCK_CLOEXEC(self):
- v = linux_version()
- if v < (2, 6, 28):
- self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
- % ".".join(map(str, v)))
with socket.socket(socket.AF_INET,
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
self.assertTrue(s.type & socket.SOCK_CLOEXEC)
@@ -2123,11 +4676,8 @@ class NonblockConstantTest(unittest.TestCase):
self.assertFalse(s.type & socket.SOCK_NONBLOCK)
self.assertEqual(s.gettimeout(), None)
+ @support.requires_linux_version(2, 6, 28)
def test_SOCK_NONBLOCK(self):
- v = linux_version()
- if v < (2, 6, 28):
- self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
- % ".".join(map(str, v)))
# a lot of it seems silly and redundant, but I wanted to test that
# changing back and forth worked ok
with socket.socket(socket.AF_INET,
@@ -2160,6 +4710,109 @@ class NonblockConstantTest(unittest.TestCase):
socket.setdefaulttimeout(t)
+@unittest.skipUnless(os.name == "nt", "Windows specific")
+@unittest.skipUnless(multiprocessing, "need multiprocessing")
+class TestSocketSharing(SocketTCPTest):
+ # This must be classmethod and not staticmethod or multiprocessing
+ # won't be able to bootstrap it.
+ @classmethod
+ def remoteProcessServer(cls, q):
+ # Recreate socket from shared data
+ sdata = q.get()
+ message = q.get()
+
+ s = socket.fromshare(sdata)
+ s2, c = s.accept()
+
+ # Send the message
+ s2.sendall(message)
+ s2.close()
+ s.close()
+
+ def testShare(self):
+ # Transfer the listening server socket to another process
+ # and service it from there.
+
+ # Create process:
+ q = multiprocessing.Queue()
+ p = multiprocessing.Process(target=self.remoteProcessServer, args=(q,))
+ p.start()
+
+ # Get the shared socket data
+ data = self.serv.share(p.pid)
+
+ # Pass the shared socket to the other process
+ addr = self.serv.getsockname()
+ self.serv.close()
+ q.put(data)
+
+ # The data that the server will send us
+ message = b"slapmahfro"
+ q.put(message)
+
+ # Connect
+ s = socket.create_connection(addr)
+ # listen for the data
+ m = []
+ while True:
+ data = s.recv(100)
+ if not data:
+ break
+ m.append(data)
+ s.close()
+ received = b"".join(m)
+ self.assertEqual(received, message)
+ p.join()
+
+ def testShareLength(self):
+ data = self.serv.share(os.getpid())
+ self.assertRaises(ValueError, socket.fromshare, data[:-1])
+ self.assertRaises(ValueError, socket.fromshare, data+b"foo")
+
+ def compareSockets(self, org, other):
+ # socket sharing is expected to work only for blocking socket
+ # since the internal python timout value isn't transfered.
+ self.assertEqual(org.gettimeout(), None)
+ self.assertEqual(org.gettimeout(), other.gettimeout())
+
+ self.assertEqual(org.family, other.family)
+ self.assertEqual(org.type, other.type)
+ # If the user specified "0" for proto, then
+ # internally windows will have picked the correct value.
+ # Python introspection on the socket however will still return
+ # 0. For the shared socket, the python value is recreated
+ # from the actual value, so it may not compare correctly.
+ if org.proto != 0:
+ self.assertEqual(org.proto, other.proto)
+
+ def testShareLocal(self):
+ data = self.serv.share(os.getpid())
+ s = socket.fromshare(data)
+ try:
+ self.compareSockets(self.serv, s)
+ finally:
+ s.close()
+
+ def testTypes(self):
+ families = [socket.AF_INET, socket.AF_INET6]
+ types = [socket.SOCK_STREAM, socket.SOCK_DGRAM]
+ for f in families:
+ for t in types:
+ try:
+ source = socket.socket(f, t)
+ except OSError:
+ continue # This combination is not supported
+ try:
+ data = source.share(os.getpid())
+ shared = socket.fromshare(data)
+ try:
+ self.compareSockets(source, shared)
+ finally:
+ shared.close()
+ finally:
+ source.close()
+
+
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@@ -2183,11 +4836,41 @@ def test_main():
])
if hasattr(socket, "socketpair"):
tests.append(BasicSocketPairTest)
- if sys.platform == 'linux2':
+ if hasattr(socket, "AF_UNIX"):
+ tests.append(TestUnixDomain)
+ if sys.platform == 'linux':
tests.append(TestLinuxAbstractNamespace)
if isTipcAvailable():
tests.append(TIPCTest)
tests.append(TIPCThreadableTest)
+ tests.extend([BasicCANTest, CANTest])
+ tests.extend([BasicRDSTest, RDSTest])
+ tests.extend([
+ CmsgMacroTests,
+ SendmsgUDPTest,
+ RecvmsgUDPTest,
+ RecvmsgIntoUDPTest,
+ SendmsgUDP6Test,
+ RecvmsgUDP6Test,
+ RecvmsgRFC3542AncillaryUDP6Test,
+ RecvmsgIntoRFC3542AncillaryUDP6Test,
+ RecvmsgIntoUDP6Test,
+ SendmsgTCPTest,
+ RecvmsgTCPTest,
+ RecvmsgIntoTCPTest,
+ SendmsgSCTPStreamTest,
+ RecvmsgSCTPStreamTest,
+ RecvmsgIntoSCTPStreamTest,
+ SendmsgUnixStreamTest,
+ RecvmsgUnixStreamTest,
+ RecvmsgIntoUnixStreamTest,
+ RecvmsgSCMRightsStreamTest,
+ RecvmsgIntoSCMRightsStreamTest,
+ # These are slow when setitimer() is not available
+ InterruptedRecvTimeoutTest,
+ InterruptedSendTimeoutTest,
+ TestSocketSharing,
+ ])
thread_info = support.threading_setup()
support.run_unittest(*tests)
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 160f5b8..464057e 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -244,7 +244,7 @@ class SocketServerTest(unittest.TestCase):
self.called += 1
if self.called == 1:
# raise the exception on first call
- raise select.error(errno.EINTR, os.strerror(errno.EINTR))
+ raise OSError(errno.EINTR, os.strerror(errno.EINTR))
else:
# Return real select value for consecutive calls
return old_select(*args)
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 4f254a9..4bf7ad8 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -42,6 +42,9 @@ ONLYCERT = data_file("ssl_cert.pem")
ONLYKEY = data_file("ssl_key.pem")
BYTES_ONLYCERT = os.fsencode(ONLYCERT)
BYTES_ONLYKEY = os.fsencode(ONLYKEY)
+CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
+ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
+KEY_PASSWORD = "somepass"
CAPATH = data_file("capath")
BYTES_CAPATH = os.fsencode(CAPATH)
@@ -53,6 +56,8 @@ WRONGCERT = data_file("XXXnonexisting.pem")
BADKEY = data_file("badkey.pem")
NOKIACERT = data_file("nokia.pem")
+DHFILE = data_file("dh512.pem")
+BYTES_DHFILE = os.fsencode(DHFILE)
def handle_error(prefix):
exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
@@ -95,7 +100,14 @@ class BasicSocketTests(unittest.TestCase):
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED
+ ssl.OP_CIPHER_SERVER_PREFERENCE
+ ssl.OP_SINGLE_DH_USE
+ if ssl.HAS_ECDH:
+ ssl.OP_SINGLE_ECDH_USE
+ if ssl.OPENSSL_VERSION_INFO >= (1, 0):
+ ssl.OP_NO_COMPRESSION
self.assertIn(ssl.HAS_SNI, {True, False})
+ self.assertIn(ssl.HAS_ECDH, {True, False})
def test_random(self):
v = ssl.RAND_status()
@@ -103,6 +115,16 @@ class BasicSocketTests(unittest.TestCase):
sys.stdout.write("\n RAND_status is %d (%s)\n"
% (v, (v and "sufficient randomness") or
"insufficient randomness"))
+
+ data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
+ self.assertEqual(len(data), 16)
+ self.assertEqual(is_cryptographic, v == 1)
+ if v:
+ data = ssl.RAND_bytes(16)
+ self.assertEqual(len(data), 16)
+ else:
+ self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
+
self.assertRaises(TypeError, ssl.RAND_egd, 1)
self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
ssl.RAND_add("this is a random string", 75.0)
@@ -186,20 +208,21 @@ class BasicSocketTests(unittest.TestCase):
s = socket.socket(socket.AF_INET)
ss = ssl.wrap_socket(s)
wr = weakref.ref(ss)
- del ss
- self.assertEqual(wr(), None)
+ with support.check_warnings(("", ResourceWarning)):
+ del ss
+ self.assertEqual(wr(), None)
def test_wrapped_unconnected(self):
# Methods on an unconnected SSLSocket propagate the original
# socket.error raise by the underlying socket object.
s = socket.socket(socket.AF_INET)
- ss = ssl.wrap_socket(s)
- self.assertRaises(socket.error, ss.recv, 1)
- self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
- self.assertRaises(socket.error, ss.recvfrom, 1)
- self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
- self.assertRaises(socket.error, ss.send, b'x')
- self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
+ with ssl.wrap_socket(s) as ss:
+ self.assertRaises(socket.error, ss.recv, 1)
+ self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
+ self.assertRaises(socket.error, ss.recvfrom, 1)
+ self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
+ self.assertRaises(socket.error, ss.send, b'x')
+ self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the
@@ -207,8 +230,8 @@ class BasicSocketTests(unittest.TestCase):
for timeout in (None, 0.0, 5.0):
s = socket.socket(socket.AF_INET)
s.settimeout(timeout)
- ss = ssl.wrap_socket(s)
- self.assertEqual(timeout, ss.gettimeout())
+ with ssl.wrap_socket(s) as ss:
+ self.assertEqual(timeout, ss.gettimeout())
def test_errors(self):
sock = socket.socket()
@@ -221,9 +244,9 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaisesRegex(ValueError,
"certfile must be specified for server-side operations",
ssl.wrap_socket, sock, server_side=True, certfile="")
- s = ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE)
- self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
- s.connect, (HOST, 8080))
+ with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
+ self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
+ s.connect, (HOST, 8080))
with self.assertRaises(IOError) as cm:
with socket.socket() as sock:
ssl.wrap_socket(sock, certfile=WRONGCERT)
@@ -333,6 +356,33 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
server_hostname="some.hostname")
+ def test_unknown_channel_binding(self):
+ # should raise ValueError for unknown type
+ s = socket.socket(socket.AF_INET)
+ with ssl.wrap_socket(s) as ss:
+ with self.assertRaises(ValueError):
+ ss.get_channel_binding("unknown-type")
+
+ @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
+ "'tls-unique' channel binding not available")
+ def test_tls_unique_channel_binding(self):
+ # unconnected should return None for known type
+ s = socket.socket(socket.AF_INET)
+ with ssl.wrap_socket(s) as ss:
+ self.assertIsNone(ss.get_channel_binding("tls-unique"))
+ # the same for server-side
+ s = socket.socket(socket.AF_INET)
+ with ssl.wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
+ self.assertIsNone(ss.get_channel_binding("tls-unique"))
+
+ def test_dealloc_warn(self):
+ ss = ssl.wrap_socket(socket.socket(socket.AF_INET))
+ r = repr(ss)
+ with self.assertWarns(ResourceWarning) as cm:
+ ss = None
+ support.gc_collect()
+ self.assertIn(r, str(cm.warning.args[0]))
+
class ContextTests(unittest.TestCase):
@skip_if_broken_ubuntu_ssl
@@ -423,6 +473,60 @@ class ContextTests(unittest.TestCase):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
ctx.load_cert_chain(SVN_PYTHON_ORG_ROOT_CERT, ONLYKEY)
+ # Password protected key and cert
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
+ ctx.load_cert_chain(CERTFILE_PROTECTED,
+ password=bytearray(KEY_PASSWORD.encode()))
+ ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
+ ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
+ ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
+ bytearray(KEY_PASSWORD.encode()))
+ with self.assertRaisesRegex(TypeError, "should be a string"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
+ with self.assertRaises(ssl.SSLError):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
+ with self.assertRaisesRegex(ValueError, "cannot be longer"):
+ # openssl has a fixed limit on the password buffer.
+ # PEM_BUFSIZE is generally set to 1kb.
+ # Return a string larger than this.
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
+ # Password callback
+ def getpass_unicode():
+ return KEY_PASSWORD
+ def getpass_bytes():
+ return KEY_PASSWORD.encode()
+ def getpass_bytearray():
+ return bytearray(KEY_PASSWORD.encode())
+ def getpass_badpass():
+ return "badpass"
+ def getpass_huge():
+ return b'a' * (1024 * 1024)
+ def getpass_bad_type():
+ return 9
+ def getpass_exception():
+ raise Exception('getpass error')
+ class GetPassCallable:
+ def __call__(self):
+ return KEY_PASSWORD
+ def getpass(self):
+ return KEY_PASSWORD
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
+ ctx.load_cert_chain(CERTFILE_PROTECTED,
+ password=GetPassCallable().getpass)
+ with self.assertRaises(ssl.SSLError):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
+ with self.assertRaisesRegex(ValueError, "cannot be longer"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
+ with self.assertRaisesRegex(TypeError, "must return a string"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
+ with self.assertRaisesRegex(Exception, "getpass error"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
+ # Make sure the password function isn't called if it isn't needed
+ ctx.load_cert_chain(CERTFILE, password=getpass_exception)
def test_load_verify_locations(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
@@ -443,6 +547,19 @@ class ContextTests(unittest.TestCase):
# Issue #10989: crash if the second argument type is invalid
self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
+ def test_load_dh_params(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx.load_dh_params(DHFILE)
+ if os.name != 'nt':
+ ctx.load_dh_params(BYTES_DHFILE)
+ self.assertRaises(TypeError, ctx.load_dh_params)
+ self.assertRaises(TypeError, ctx.load_dh_params, None)
+ with self.assertRaises(FileNotFoundError) as cm:
+ ctx.load_dh_params(WRONGCERT)
+ self.assertEqual(cm.exception.errno, errno.ENOENT)
+ with self.assertRaises(ssl.SSLError) as cm:
+ ctx.load_dh_params(CERTFILE)
+
@skip_if_broken_ubuntu_ssl
def test_session_stats(self):
for proto in PROTOCOLS:
@@ -467,6 +584,57 @@ class ContextTests(unittest.TestCase):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx.set_default_verify_paths()
+ @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
+ def test_set_ecdh_curve(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx.set_ecdh_curve("prime256v1")
+ ctx.set_ecdh_curve(b"prime256v1")
+ self.assertRaises(TypeError, ctx.set_ecdh_curve)
+ self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
+ self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
+ self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
+
+
+class SSLErrorTests(unittest.TestCase):
+
+ def test_str(self):
+ # The str() of a SSLError doesn't include the errno
+ e = ssl.SSLError(1, "foo")
+ self.assertEqual(str(e), "foo")
+ self.assertEqual(e.errno, 1)
+ # Same for a subclass
+ e = ssl.SSLZeroReturnError(1, "foo")
+ self.assertEqual(str(e), "foo")
+ self.assertEqual(e.errno, 1)
+
+ def test_lib_reason(self):
+ # Test the library and reason attributes
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with self.assertRaises(ssl.SSLError) as cm:
+ ctx.load_dh_params(CERTFILE)
+ self.assertEqual(cm.exception.library, 'PEM')
+ self.assertEqual(cm.exception.reason, 'NO_START_LINE')
+ s = str(cm.exception)
+ self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
+
+ def test_subclass(self):
+ # Check that the appropriate SSLError subclass is raised
+ # (this only tests one of them)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with socket.socket() as s:
+ s.bind(("127.0.0.1", 0))
+ s.listen(5)
+ c = socket.socket()
+ c.connect(s.getsockname())
+ c.setblocking(False)
+ with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
+ with self.assertRaises(ssl.SSLWantReadError) as cm:
+ c.do_handshake()
+ s = str(cm.exception)
+ self.assertTrue(s.startswith("The operation did not complete (read)"), s)
+ # For compatibility
+ self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
+
class NetworkedTests(unittest.TestCase):
@@ -529,13 +697,10 @@ class NetworkedTests(unittest.TestCase):
try:
s.do_handshake()
break
- except ssl.SSLError as err:
- if err.args[0] == ssl.SSL_ERROR_WANT_READ:
- select.select([s], [], [], 5.0)
- elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
- select.select([], [s], [], 5.0)
- else:
- raise
+ except ssl.SSLWantReadError:
+ select.select([s], [], [], 5.0)
+ except ssl.SSLWantWriteError:
+ select.select([], [s], [], 5.0)
# SSL established
self.assertTrue(s.getpeercert())
finally:
@@ -666,47 +831,49 @@ class NetworkedTests(unittest.TestCase):
count += 1
s.do_handshake()
break
- except ssl.SSLError as err:
- if err.args[0] == ssl.SSL_ERROR_WANT_READ:
- select.select([s], [], [])
- elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
- select.select([], [s], [])
- else:
- raise
+ except ssl.SSLWantReadError:
+ select.select([s], [], [])
+ except ssl.SSLWantWriteError:
+ select.select([], [s], [])
s.close()
if support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
def test_get_server_certificate(self):
- with support.transient_internet("svn.python.org"):
- pem = ssl.get_server_certificate(("svn.python.org", 443))
- if not pem:
- self.fail("No server certificate on svn.python.org:443!")
+ def _test_get_server_certificate(host, port, cert=None):
+ with support.transient_internet(host):
+ pem = ssl.get_server_certificate((host, port))
+ if not pem:
+ self.fail("No server certificate on %s:%s!" % (host, port))
- try:
- pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
- except ssl.SSLError as x:
- #should fail
+ try:
+ pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
+ except ssl.SSLError as x:
+ #should fail
+ if support.verbose:
+ sys.stdout.write("%s\n" % x)
+ else:
+ self.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
+
+ pem = ssl.get_server_certificate((host, port), ca_certs=cert)
+ if not pem:
+ self.fail("No server certificate on %s:%s!" % (host, port))
if support.verbose:
- sys.stdout.write("%s\n" % x)
- else:
- self.fail("Got server certificate %s for svn.python.org!" % pem)
+ sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
- pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
- if not pem:
- self.fail("No server certificate on svn.python.org:443!")
- if support.verbose:
- sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
+ _test_get_server_certificate('svn.python.org', 443, SVN_PYTHON_ORG_ROOT_CERT)
+ if support.IPV6_ENABLED:
+ _test_get_server_certificate('ipv6.google.com', 443)
def test_ciphers(self):
remote = ("svn.python.org", 443)
with support.transient_internet(remote[0]):
- s = ssl.wrap_socket(socket.socket(socket.AF_INET),
- cert_reqs=ssl.CERT_NONE, ciphers="ALL")
- s.connect(remote)
- s = ssl.wrap_socket(socket.socket(socket.AF_INET),
- cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
- s.connect(remote)
+ with ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
+ s.connect(remote)
+ with ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
+ s.connect(remote)
# Error checking can happen at instantiation or when connecting
with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
with socket.socket(socket.AF_INET) as sock:
@@ -774,6 +941,7 @@ else:
try:
self.sslconn = self.server.context.wrap_socket(
self.sock, server_side=True)
+ self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
except ssl.SSLError as e:
# XXX Various errors can have happened here, for example
# a mismatching protocol version, an invalid certificate,
@@ -796,6 +964,8 @@ else:
cipher = self.sslconn.cipher()
if support.verbose and self.server.chatty:
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+ sys.stdout.write(" server: selected protocol is now "
+ + str(self.sslconn.selected_npn_protocol()) + "\n")
return True
def read(self):
@@ -850,6 +1020,11 @@ else:
self.sslconn = None
if support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: connection is now unencrypted...\n")
+ elif stripped == b'CB tls-unique':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
+ data = self.sslconn.get_channel_binding("tls-unique")
+ self.write(repr(data).encode("us-ascii") + b"\n")
else:
if (support.verbose and
self.server.connectionchatty):
@@ -869,7 +1044,7 @@ else:
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
- ciphers=None, context=None):
+ npn_protocols=None, ciphers=None, context=None):
if context:
self.context = context
else:
@@ -882,6 +1057,8 @@ else:
self.context.load_verify_locations(cacerts)
if certificate:
self.context.load_cert_chain(certificate)
+ if npn_protocols:
+ self.context.set_npn_protocols(npn_protocols)
if ciphers:
self.context.set_ciphers(ciphers)
self.chatty = chatty
@@ -891,6 +1068,7 @@ else:
self.port = support.bind_port(self.sock)
self.flag = None
self.active = False
+ self.selected_protocols = []
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
@@ -958,12 +1136,11 @@ else:
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
- except ssl.SSLError as err:
- if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE):
- return
- elif err.args[0] == ssl.SSL_ERROR_EOF:
- return self.handle_close()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
+ return
+ except ssl.SSLEOFError:
+ return self.handle_close()
+ except ssl.SSLError:
raise
except socket.error as err:
if err.args[0] == errno.ECONNABORTED:
@@ -1086,6 +1263,7 @@ else:
Launch a server, connect a client to it and try various reads
and writes.
"""
+ stats = {}
server = ThreadedEchoServer(context=server_context,
chatty=chatty,
connectionchatty=False)
@@ -1111,7 +1289,14 @@ else:
if connectionchatty:
if support.verbose:
sys.stdout.write(" client: closing connection.\n")
+ stats.update({
+ 'compression': s.compression(),
+ 'cipher': s.cipher(),
+ 'client_npn_protocol': s.selected_npn_protocol()
+ })
s.close()
+ stats['server_npn_protocols'] = server.selected_protocols
+ return stats
def try_protocol_combo(server_protocol, client_protocol, expect_success,
certsreqs=None, server_options=0, client_options=0):
@@ -1263,7 +1448,8 @@ else:
t.join()
@skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), "need SSLv2")
+ @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
+ "OpenSSL is compiled without SSLv2 support")
def test_protocol_sslv2(self):
"""Connecting to an SSLv2 server with various client options"""
if support.verbose:
@@ -1569,6 +1755,15 @@ else:
)
# consume data
s.read()
+
+ # Make sure sendmsg et al are disallowed to avoid
+ # inadvertent disclosure of data and/or corruption
+ # of the encrypted data stream
+ self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
+ self.assertRaises(NotImplementedError, s.recvmsg, 100)
+ self.assertRaises(NotImplementedError,
+ s.recvmsg_into, bytearray(100))
+
s.write(b"over\n")
s.close()
@@ -1653,6 +1848,8 @@ else:
client_addr = client.getsockname()
client.close()
t.join()
+ remote.close()
+ server.close()
# Sanity checks.
self.assertIsInstance(remote, ssl.SSLSocket)
self.assertEqual(peer, client_addr)
@@ -1667,12 +1864,140 @@ else:
with ThreadedEchoServer(CERTFILE,
ssl_version=ssl.PROTOCOL_SSLv23,
chatty=False) as server:
- with socket.socket() as sock:
- s = context.wrap_socket(sock)
+ with context.wrap_socket(socket.socket()) as s:
with self.assertRaises((OSError, ssl.SSLError)):
s.connect((HOST, server.port))
self.assertIn("no shared cipher", str(server.conn_errors[0]))
+ @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
+ "'tls-unique' channel binding not available")
+ def test_tls_unique_channel_binding(self):
+ """Test tls-unique channel binding."""
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # get the data
+ cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got channel binding data: {0!r}\n"
+ .format(cb_data))
+
+ # check if it is sane
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+
+ # and compare with the peers version
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(cb_data).encode("us-ascii"))
+ s.close()
+
+ # now, again
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ new_cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got another channel binding data: {0!r}\n"
+ .format(new_cb_data))
+ # is it really unique
+ self.assertNotEqual(cb_data, new_cb_data)
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(new_cb_data).encode("us-ascii"))
+ s.close()
+
+ def test_compression(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ if support.verbose:
+ sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
+ self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
+
+ @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
+ "ssl.OP_NO_COMPRESSION needed for this test")
+ def test_compression_disabled(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.options |= ssl.OP_NO_COMPRESSION
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['compression'], None)
+
+ def test_dh_params(self):
+ # Check we can get a connection with ephemeral Diffie-Hellman
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.load_dh_params(DHFILE)
+ context.set_ciphers("kEDH")
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ cipher = stats["cipher"][0]
+ parts = cipher.split("-")
+ if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
+ self.fail("Non-DH cipher: " + cipher[0])
+
+ def test_selected_npn_protocol(self):
+ # selected_npn_protocol() is None unless NPN is used
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_npn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
+ def test_npn_protocols(self):
+ server_protocols = ['http/1.1', 'spdy/2']
+ protocol_tests = [
+ (['http/1.1', 'spdy/2'], 'http/1.1'),
+ (['spdy/2', 'http/1.1'], 'http/1.1'),
+ (['spdy/2', 'test'], 'spdy/2'),
+ (['abc', 'def'], 'abc')
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_npn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_npn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_npn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_npn_protocols'][-1] \
+ if len(stats['server_npn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
def test_main(verbose=False):
if support.verbose:
@@ -1700,14 +2025,14 @@ def test_main(verbose=False):
if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename)
- tests = [ContextTests, BasicSocketTests]
+ tests = [ContextTests, BasicSocketTests, SSLErrorTests]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
if _have_threads:
thread_info = support.threading_setup()
- if thread_info and support.is_resource_enabled('network'):
+ if thread_info:
tests.append(ThreadedTests)
try:
diff --git a/Lib/test/test_startfile.py b/Lib/test/test_startfile.py
index ae9aeb9..5a9c2de 100644
--- a/Lib/test/test_startfile.py
+++ b/Lib/test/test_startfile.py
@@ -5,7 +5,7 @@
#
# A possible improvement would be to have empty.vbs do something that
# we can detect here, to make sure that not only the os.startfile()
-# call succeeded, but also the the script actually has run.
+# call succeeded, but also the script actually has run.
import unittest
from test import support
diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py
new file mode 100644
index 0000000..48d5e38
--- /dev/null
+++ b/Lib/test/test_stat.py
@@ -0,0 +1,66 @@
+import unittest
+import os
+import stat
+from test.support import TESTFN, run_unittest
+
+
+def get_mode(fname=TESTFN):
+ return stat.filemode(os.lstat(fname).st_mode)
+
+
+class TestFilemode(unittest.TestCase):
+
+ def setUp(self):
+ try:
+ os.remove(TESTFN)
+ except OSError:
+ try:
+ os.rmdir(TESTFN)
+ except OSError:
+ pass
+ tearDown = setUp
+
+ def test_mode(self):
+ with open(TESTFN, 'w'):
+ pass
+ if os.name == 'posix':
+ os.chmod(TESTFN, 0o700)
+ self.assertEqual(get_mode(), '-rwx------')
+ os.chmod(TESTFN, 0o070)
+ self.assertEqual(get_mode(), '----rwx---')
+ os.chmod(TESTFN, 0o007)
+ self.assertEqual(get_mode(), '-------rwx')
+ os.chmod(TESTFN, 0o444)
+ self.assertEqual(get_mode(), '-r--r--r--')
+ else:
+ os.chmod(TESTFN, 0o700)
+ self.assertEqual(get_mode()[:3], '-rw')
+
+ def test_directory(self):
+ os.mkdir(TESTFN)
+ os.chmod(TESTFN, 0o700)
+ if os.name == 'posix':
+ self.assertEqual(get_mode(), 'drwx------')
+ else:
+ self.assertEqual(get_mode()[0], 'd')
+
+ @unittest.skipUnless(hasattr(os, 'symlink'), 'os.symlink not available')
+ def test_link(self):
+ try:
+ os.symlink(os.getcwd(), TESTFN)
+ except (OSError, NotImplementedError) as err:
+ raise unittest.SkipTest(str(err))
+ else:
+ self.assertEqual(get_mode()[0], 'l')
+
+ @unittest.skipUnless(hasattr(os, 'mkfifo'), 'os.mkfifo not available')
+ def test_fifo(self):
+ os.mkfifo(TESTFN, 0o700)
+ self.assertEqual(get_mode(), 'prwx------')
+
+
+def test_main():
+ run_unittest(TestFilemode)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_string.py b/Lib/test/test_string.py
index 34d1dcf..c2bdfdb 100644
--- a/Lib/test/test_string.py
+++ b/Lib/test/test_string.py
@@ -26,6 +26,12 @@ class ModuleTest(unittest.TestCase):
self.assertEqual(string.capwords('\taBc\tDeF\t'), 'Abc Def')
self.assertEqual(string.capwords('\taBc\tDeF\t', '\t'), '\tAbc\tDef\t')
+ def test_basic_formatter(self):
+ fmt = string.Formatter()
+ self.assertEqual(fmt.format("foo"), "foo")
+ self.assertEqual(fmt.format("foo{0}", "bar"), "foobar")
+ self.assertEqual(fmt.format("foo{1}{0}-{1}", "bar", 6), "foo6bar-6")
+
def test_conversion_specifiers(self):
fmt = string.Formatter()
self.assertEqual(fmt.format("-{arg!r}-", arg='test'), "-'test'-")
@@ -38,15 +44,26 @@ class ModuleTest(unittest.TestCase):
self.assertEqual(fmt.format("{0!a}", chr(255)), "'\\xff'")
self.assertEqual(fmt.format("{0!a}", chr(256)), "'\\u0100'")
- def test_formatter(self):
+ def test_name_lookup(self):
fmt = string.Formatter()
- self.assertEqual(fmt.format("foo"), "foo")
-
- self.assertEqual(fmt.format("foo{0}", "bar"), "foobar")
- self.assertEqual(fmt.format("foo{1}{0}-{1}", "bar", 6), "foo6bar-6")
- self.assertEqual(fmt.format("-{arg!r}-", arg='test'), "-'test'-")
+ class AnyAttr:
+ def __getattr__(self, attr):
+ return attr
+ x = AnyAttr()
+ self.assertEqual(fmt.format("{0.lumber}{0.jack}", x), 'lumberjack')
+ with self.assertRaises(AttributeError):
+ fmt.format("{0.lumber}{0.jack}", '')
+
+ def test_index_lookup(self):
+ fmt = string.Formatter()
+ lookup = ["eggs", "and", "spam"]
+ self.assertEqual(fmt.format("{0[2]}{0[0]}", lookup), 'spameggs')
+ with self.assertRaises(IndexError):
+ fmt.format("{0[2]}{0[0]}", [])
+ with self.assertRaises(KeyError):
+ fmt.format("{0[2]}{0[0]}", {})
- # override get_value ############################################
+ def test_override_get_value(self):
class NamespaceFormatter(string.Formatter):
def __init__(self, namespace={}):
string.Formatter.__init__(self)
@@ -66,7 +83,7 @@ class ModuleTest(unittest.TestCase):
self.assertEqual(fmt.format("{greeting}, world!"), 'hello, world!')
- # override format_field #########################################
+ def test_override_format_field(self):
class CallFormatter(string.Formatter):
def format_field(self, value, format_spec):
return format(value(), format_spec)
@@ -75,18 +92,18 @@ class ModuleTest(unittest.TestCase):
self.assertEqual(fmt.format('*{0}*', lambda : 'result'), '*result*')
- # override convert_field ########################################
+ def test_override_convert_field(self):
class XFormatter(string.Formatter):
def convert_field(self, value, conversion):
if conversion == 'x':
return None
- return super(XFormatter, self).convert_field(value, conversion)
+ return super().convert_field(value, conversion)
fmt = XFormatter()
self.assertEqual(fmt.format("{0!r}:{0!x}", 'foo', 'foo'), "'foo':None")
- # override parse ################################################
+ def test_override_parse(self):
class BarFormatter(string.Formatter):
# returns an iterable that contains tuples of the form:
# (literal_text, field_name, format_spec, conversion)
@@ -102,7 +119,7 @@ class ModuleTest(unittest.TestCase):
fmt = BarFormatter()
self.assertEqual(fmt.format('*|+0:^10s|*', 'foo'), '* foo *')
- # test all parameters used
+ def test_check_unused_args(self):
class CheckAllUsedFormatter(string.Formatter):
def check_unused_args(self, used_args, args, kwargs):
# Track which arguments actually got used
@@ -124,28 +141,13 @@ class ModuleTest(unittest.TestCase):
self.assertRaises(ValueError, fmt.format, "{0}", 10, 20, i=100)
self.assertRaises(ValueError, fmt.format, "{i}", 10, 20, i=100)
- def test_vformat_assert(self):
- cls = string.Formatter()
- kwargs = {
- "i": 100
- }
- self.assertRaises(ValueError, cls._vformat,
- cls.format, "{0}", kwargs, set(), -2)
-
- def test_convert_field(self):
- cls = string.Formatter()
- self.assertEqual(cls.format("{0!s}", 'foo'), 'foo')
- self.assertRaises(ValueError, cls.format, "{0!h}", 'foo')
-
- def test_get_field(self):
- cls = string.Formatter()
- class MyClass:
- name = 'lumberjack'
- x = MyClass()
- self.assertEqual(cls.format("{0.name}", x), 'lumberjack')
-
- lookup = ["eggs", "and", "spam"]
- self.assertEqual(cls.format("{0[2]}", lookup), 'spam')
+ def test_vformat_recursion_limit(self):
+ fmt = string.Formatter()
+ args = ()
+ kwargs = dict(i=100)
+ with self.assertRaises(ValueError) as err:
+ fmt._vformat("{i}", args, kwargs, set(), -1)
+ self.assertIn("recursion", str(err.exception))
def test_main():
diff --git a/Lib/test/test_strlit.py b/Lib/test/test_strlit.py
index a4ae198..d01322f 100644
--- a/Lib/test/test_strlit.py
+++ b/Lib/test/test_strlit.py
@@ -2,10 +2,10 @@ r"""Test correct treatment of various string literals by the parser.
There are four types of string literals:
- 'abc' -- normal str
- r'abc' -- raw str
- b'xyz' -- normal bytes
- br'xyz' -- raw bytes
+ 'abc' -- normal str
+ r'abc' -- raw str
+ b'xyz' -- normal bytes
+ br'xyz' | rb'xyz' -- raw bytes
The difference between normal and raw strings is of course that in a
raw string, \ escapes (while still used to determine the end of the
@@ -133,18 +133,38 @@ class TestLiterals(unittest.TestCase):
def test_eval_bytes_raw(self):
self.assertEqual(eval(""" br'x' """), b'x')
+ self.assertEqual(eval(""" rb'x' """), b'x')
self.assertEqual(eval(r""" br'\x01' """), b'\\' + b'x01')
+ self.assertEqual(eval(r""" rb'\x01' """), b'\\' + b'x01')
self.assertEqual(eval(""" br'\x01' """), byte(1))
+ self.assertEqual(eval(""" rb'\x01' """), byte(1))
self.assertEqual(eval(r""" br'\x81' """), b"\\" + b"x81")
+ self.assertEqual(eval(r""" rb'\x81' """), b"\\" + b"x81")
self.assertRaises(SyntaxError, eval, """ br'\x81' """)
+ self.assertRaises(SyntaxError, eval, """ rb'\x81' """)
self.assertEqual(eval(r""" br'\u1881' """), b"\\" + b"u1881")
+ self.assertEqual(eval(r""" rb'\u1881' """), b"\\" + b"u1881")
self.assertRaises(SyntaxError, eval, """ br'\u1881' """)
+ self.assertRaises(SyntaxError, eval, """ rb'\u1881' """)
self.assertEqual(eval(r""" br'\U0001d120' """), b"\\" + b"U0001d120")
+ self.assertEqual(eval(r""" rb'\U0001d120' """), b"\\" + b"U0001d120")
self.assertRaises(SyntaxError, eval, """ br'\U0001d120' """)
- self.assertRaises(SyntaxError, eval, """ rb'' """)
+ self.assertRaises(SyntaxError, eval, """ rb'\U0001d120' """)
self.assertRaises(SyntaxError, eval, """ bb'' """)
self.assertRaises(SyntaxError, eval, """ rr'' """)
self.assertRaises(SyntaxError, eval, """ brr'' """)
+ self.assertRaises(SyntaxError, eval, """ bbr'' """)
+ self.assertRaises(SyntaxError, eval, """ rrb'' """)
+ self.assertRaises(SyntaxError, eval, """ rbb'' """)
+
+ def test_eval_str_u(self):
+ self.assertEqual(eval(""" u'x' """), 'x')
+ self.assertEqual(eval(""" U'\u00e4' """), 'ä')
+ self.assertEqual(eval(""" u'\N{LATIN SMALL LETTER A WITH DIAERESIS}' """), 'ä')
+ self.assertRaises(SyntaxError, eval, """ ur'' """)
+ self.assertRaises(SyntaxError, eval, """ ru'' """)
+ self.assertRaises(SyntaxError, eval, """ bu'' """)
+ self.assertRaises(SyntaxError, eval, """ ub'' """)
def check_encoding(self, encoding, extra=""):
modname = "xx_" + encoding.replace("-", "_")
@@ -167,7 +187,7 @@ class TestLiterals(unittest.TestCase):
self.assertRaises(SyntaxError, self.check_encoding, "utf-8", extra)
def test_file_utf8(self):
- self.check_encoding("utf8")
+ self.check_encoding("utf-8")
def test_file_iso_8859_1(self):
self.check_encoding("iso-8859-1")
diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py
index dcc73ab..22374d2 100644
--- a/Lib/test/test_struct.py
+++ b/Lib/test/test_struct.py
@@ -8,9 +8,19 @@ from test import support
ISBIGENDIAN = sys.byteorder == "big"
IS32BIT = sys.maxsize == 0x7fffffff
-integer_codes = 'b', 'B', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q'
+integer_codes = 'b', 'B', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q', 'n', 'N'
byteorders = '', '@', '=', '<', '>', '!'
+def iter_integer_formats(byteorders=byteorders):
+ for code in integer_codes:
+ for byteorder in byteorders:
+ if (byteorder in ('', '@') and code in ('q', 'Q') and
+ not HAVE_LONG_LONG):
+ continue
+ if (byteorder not in ('', '@') and code in ('n', 'N')):
+ continue
+ yield code, byteorder
+
# Native 'q' packing isn't available on systems that don't have the C
# long long type.
try:
@@ -30,7 +40,6 @@ def bigendian_to_native(value):
return string_reverse(value)
class StructTest(unittest.TestCase):
-
def test_isbigendian(self):
self.assertEqual((struct.pack('=i', 1)[0] == 0), ISBIGENDIAN)
@@ -142,14 +151,13 @@ class StructTest(unittest.TestCase):
}
# standard integer sizes
- for code in integer_codes:
- for byteorder in '=', '<', '>', '!':
- format = byteorder+code
- size = struct.calcsize(format)
- self.assertEqual(size, expected_size[code])
+ for code, byteorder in iter_integer_formats(('=', '<', '>', '!')):
+ format = byteorder+code
+ size = struct.calcsize(format)
+ self.assertEqual(size, expected_size[code])
# native integer sizes
- native_pairs = 'bB', 'hH', 'iI', 'lL'
+ native_pairs = 'bB', 'hH', 'iI', 'lL', 'nN'
if HAVE_LONG_LONG:
native_pairs += 'qQ',
for format_pair in native_pairs:
@@ -167,9 +175,11 @@ class StructTest(unittest.TestCase):
if HAVE_LONG_LONG:
self.assertLessEqual(8, struct.calcsize('q'))
self.assertLessEqual(struct.calcsize('l'), struct.calcsize('q'))
+ self.assertGreaterEqual(struct.calcsize('n'), struct.calcsize('i'))
+ self.assertGreaterEqual(struct.calcsize('n'), struct.calcsize('P'))
def test_integers(self):
- # Integer tests (bBhHiIlLqQ).
+ # Integer tests (bBhHiIlLqQnN).
import binascii
class IntTester(unittest.TestCase):
@@ -183,11 +193,11 @@ class StructTest(unittest.TestCase):
self.byteorder)
self.bytesize = struct.calcsize(format)
self.bitsize = self.bytesize * 8
- if self.code in tuple('bhilq'):
+ if self.code in tuple('bhilqn'):
self.signed = True
self.min_value = -(2**(self.bitsize-1))
self.max_value = 2**(self.bitsize-1) - 1
- elif self.code in tuple('BHILQ'):
+ elif self.code in tuple('BHILQN'):
self.signed = False
self.min_value = 0
self.max_value = 2**self.bitsize - 1
@@ -317,14 +327,23 @@ class StructTest(unittest.TestCase):
struct.pack, self.format,
obj)
- for code in integer_codes:
- for byteorder in byteorders:
- if (byteorder in ('', '@') and code in ('q', 'Q') and
- not HAVE_LONG_LONG):
- continue
+ for code, byteorder in iter_integer_formats():
+ format = byteorder+code
+ t = IntTester(format)
+ t.run()
+
+ def test_nN_code(self):
+ # n and N don't exist in standard sizes
+ def assertStructError(func, *args, **kwargs):
+ with self.assertRaises(struct.error) as cm:
+ func(*args, **kwargs)
+ self.assertIn("bad char in struct format", str(cm.exception))
+ for code in 'nN':
+ for byteorder in ('=', '<', '>', '!'):
format = byteorder+code
- t = IntTester(format)
- t.run()
+ assertStructError(struct.calcsize, format)
+ assertStructError(struct.pack, format, 0)
+ assertStructError(struct.unpack, format, b"")
def test_p_code(self):
# Test p ("Pascal string") code.
@@ -378,14 +397,10 @@ class StructTest(unittest.TestCase):
self.assertRaises(OverflowError, struct.pack, ">f", big)
def test_1530559(self):
- for byteorder in '', '@', '=', '<', '>', '!':
- for code in integer_codes:
- if (byteorder in ('', '@') and code in ('q', 'Q') and
- not HAVE_LONG_LONG):
- continue
- format = byteorder + code
- self.assertRaises(struct.error, struct.pack, format, 1.0)
- self.assertRaises(struct.error, struct.pack, format, 1.5)
+ for code, byteorder in iter_integer_formats():
+ format = byteorder + code
+ self.assertRaises(struct.error, struct.pack, format, 1.0)
+ self.assertRaises(struct.error, struct.pack, format, 1.5)
self.assertRaises(struct.error, struct.pack, 'P', 1.0)
self.assertRaises(struct.error, struct.pack, 'P', 1.5)
@@ -559,9 +574,9 @@ class StructTest(unittest.TestCase):
def check_sizeof(self, format_str, number_of_codes):
# The size of 'PyStructObject'
- totalsize = support.calcobjsize('5P')
+ totalsize = support.calcobjsize('2n3P')
# The size taken up by the 'formatcode' dynamic array
- totalsize += struct.calcsize('3P') * (number_of_codes + 1)
+ totalsize += struct.calcsize('P2n0P') * (number_of_codes + 1)
support.check_sizeof(self, struct.Struct(format_str), totalsize)
@support.cpython_only
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index d6c63b7..a89e955 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -78,8 +78,9 @@ class StructSeqTest(unittest.TestCase):
def test_fields(self):
t = time.gmtime()
- self.assertEqual(len(t), t.n_fields)
- self.assertEqual(t.n_fields, t.n_sequence_fields+t.n_unnamed_fields)
+ self.assertEqual(len(t), t.n_sequence_fields)
+ self.assertEqual(t.n_unnamed_fields, 0)
+ self.assertEqual(t.n_fields, time._STRUCT_TM_ITEMS)
def test_constructor(self):
t = time.struct_time
diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py
index 45ba373..ff74e87 100644
--- a/Lib/test/test_subprocess.py
+++ b/Lib/test/test_subprocess.py
@@ -16,6 +16,7 @@ import warnings
import select
import shutil
import gc
+import textwrap
try:
import resource
@@ -62,6 +63,8 @@ class BaseTestCase(unittest.TestCase):
# shutdown time. That frustrates tests trying to check stderr produced
# from a spawned Python process.
actual = support.strip_python_stderr(stderr)
+ # strip_python_stderr also strips whitespace, so we do too.
+ expected = expected.strip()
self.assertEqual(actual, expected, msg)
@@ -85,6 +88,15 @@ class ProcessTestCase(BaseTestCase):
"import sys; sys.exit(47)"])
self.assertEqual(rc, 47)
+ def test_call_timeout(self):
+ # call() function with timeout argument; we want to test that the child
+ # process gets killed when the timeout expires. If the child isn't
+ # killed, this call will deadlock since subprocess.call waits for the
+ # child.
+ self.assertRaises(subprocess.TimeoutExpired, subprocess.call,
+ [sys.executable, "-c", "while True: pass"],
+ timeout=0.1)
+
def test_check_call_zero(self):
# check_call() function with zero return code
rc = subprocess.check_call([sys.executable, "-c",
@@ -127,6 +139,21 @@ class ProcessTestCase(BaseTestCase):
self.fail("Expected ValueError when stdout arg supplied.")
self.assertIn('stdout', c.exception.args[0])
+ def test_check_output_timeout(self):
+ # check_output() function with timeout arg
+ with self.assertRaises(subprocess.TimeoutExpired) as c:
+ output = subprocess.check_output(
+ [sys.executable, "-c",
+ "import sys, time\n"
+ "sys.stdout.write('BDFL')\n"
+ "sys.stdout.flush()\n"
+ "time.sleep(3600)"],
+ # Some heavily loaded buildbots (sparc Debian 3.x) require
+ # this much time to start and print.
+ timeout=3)
+ self.fail("Expected TimeoutExpired.")
+ self.assertEqual(c.exception.output, b'BDFL')
+
def test_call_kwargs(self):
# call() function with keyword args
newenv = os.environ.copy()
@@ -177,6 +204,40 @@ class ProcessTestCase(BaseTestCase):
p.wait()
self.assertEqual(p.stderr, None)
+ def _assert_python(self, pre_args, **kwargs):
+ # We include sys.exit() to prevent the test runner from hanging
+ # whenever python is found.
+ args = pre_args + ["import sys; sys.exit(47)"]
+ p = subprocess.Popen(args, **kwargs)
+ p.wait()
+ self.assertEqual(47, p.returncode)
+
+ def test_executable(self):
+ # Check that the executable argument works.
+ #
+ # On Unix (non-Mac and non-Windows), Python looks at args[0] to
+ # determine where its standard library is, so we need the directory
+ # of args[0] to be valid for the Popen() call to Python to succeed.
+ # See also issue #16170 and issue #7774.
+ doesnotexist = os.path.join(os.path.dirname(sys.executable),
+ "doesnotexist")
+ self._assert_python([doesnotexist, "-c"], executable=sys.executable)
+
+ def test_executable_takes_precedence(self):
+ # Check that the executable argument takes precedence over args[0].
+ #
+ # Verify first that the call succeeds without the executable arg.
+ pre_args = [sys.executable, "-c"]
+ self._assert_python(pre_args)
+ self.assertRaises(FileNotFoundError, self._assert_python, pre_args,
+ executable="doesnotexist")
+
+ @unittest.skipIf(mswindows, "executable argument replaces shell")
+ def test_executable_replaces_shell(self):
+ # Check that the executable argument replaces the default shell
+ # when shell=True.
+ self._assert_python([], executable=sys.executable, shell=True)
+
# For use in the test_cwd* tests below.
def _normalize_cwd(self, cwd):
# Normalize an expected cwd (for Tru64 support).
@@ -227,9 +288,9 @@ class ProcessTestCase(BaseTestCase):
with support.temp_cwd() as wrong_dir:
# Before calling with the correct cwd, confirm that the call fails
# without cwd and with the wrong cwd.
- self.assertRaises(OSError, subprocess.Popen,
+ self.assertRaises(FileNotFoundError, subprocess.Popen,
[rel_python])
- self.assertRaises(OSError, subprocess.Popen,
+ self.assertRaises(FileNotFoundError, subprocess.Popen,
[rel_python], cwd=wrong_dir)
python_dir = self._normalize_cwd(python_dir)
self._assert_cwd(python_dir, rel_python, cwd=python_dir)
@@ -244,9 +305,9 @@ class ProcessTestCase(BaseTestCase):
with support.temp_cwd() as wrong_dir:
# Before calling with the correct cwd, confirm that the call fails
# without cwd and with the wrong cwd.
- self.assertRaises(OSError, subprocess.Popen,
+ self.assertRaises(FileNotFoundError, subprocess.Popen,
[doesntexist], executable=rel_python)
- self.assertRaises(OSError, subprocess.Popen,
+ self.assertRaises(FileNotFoundError, subprocess.Popen,
[doesntexist], executable=rel_python,
cwd=wrong_dir)
python_dir = self._normalize_cwd(python_dir)
@@ -262,17 +323,21 @@ class ProcessTestCase(BaseTestCase):
with script_helper.temp_dir() as wrong_dir:
# Before calling with an absolute path, confirm that using a
# relative path fails.
- self.assertRaises(OSError, subprocess.Popen,
+ self.assertRaises(FileNotFoundError, subprocess.Popen,
[rel_python], cwd=wrong_dir)
wrong_dir = self._normalize_cwd(wrong_dir)
self._assert_cwd(wrong_dir, abs_python, cwd=wrong_dir)
+ @unittest.skipIf(sys.base_prefix != sys.prefix,
+ 'Test is not venv-compatible')
def test_executable_with_cwd(self):
python_dir, python_base = self._split_python_path()
python_dir = self._normalize_cwd(python_dir)
self._assert_cwd(python_dir, "somethingyoudonthave",
executable=sys.executable, cwd=python_dir)
+ @unittest.skipIf(sys.base_prefix != sys.prefix,
+ 'Test is not venv-compatible')
@unittest.skipIf(sysconfig.is_python_build(),
"need an installed Python. See #7774")
def test_executable_without_cwd(self):
@@ -410,6 +475,31 @@ class ProcessTestCase(BaseTestCase):
rc = subprocess.call([sys.executable, "-c", cmd], stdout=1)
self.assertEqual(rc, 2)
+ def test_stdout_devnull(self):
+ p = subprocess.Popen([sys.executable, "-c",
+ 'for i in range(10240):'
+ 'print("x" * 1024)'],
+ stdout=subprocess.DEVNULL)
+ p.wait()
+ self.assertEqual(p.stdout, None)
+
+ def test_stderr_devnull(self):
+ p = subprocess.Popen([sys.executable, "-c",
+ 'import sys\n'
+ 'for i in range(10240):'
+ 'sys.stderr.write("x" * 1024)'],
+ stderr=subprocess.DEVNULL)
+ p.wait()
+ self.assertEqual(p.stderr, None)
+
+ def test_stdin_devnull(self):
+ p = subprocess.Popen([sys.executable, "-c",
+ 'import sys;'
+ 'sys.stdin.read(1)'],
+ stdin=subprocess.DEVNULL)
+ p.wait()
+ self.assertEqual(p.stdin, None)
+
def test_env(self):
newenv = os.environ.copy()
newenv["FRUIT"] = "orange"
@@ -480,6 +570,41 @@ class ProcessTestCase(BaseTestCase):
self.assertEqual(stdout, b"banana")
self.assertStderrEqual(stderr, b"pineapple")
+ def test_communicate_timeout(self):
+ p = subprocess.Popen([sys.executable, "-c",
+ 'import sys,os,time;'
+ 'sys.stderr.write("pineapple\\n");'
+ 'time.sleep(1);'
+ 'sys.stderr.write("pear\\n");'
+ 'sys.stdout.write(sys.stdin.read())'],
+ universal_newlines=True,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ self.assertRaises(subprocess.TimeoutExpired, p.communicate, "banana",
+ timeout=0.3)
+ # Make sure we can keep waiting for it, and that we get the whole output
+ # after it completes.
+ (stdout, stderr) = p.communicate()
+ self.assertEqual(stdout, "banana")
+ self.assertStderrEqual(stderr.encode(), b"pineapple\npear\n")
+
+ def test_communicate_timeout_large_ouput(self):
+ # Test an expiring timeout while the child is outputting lots of data.
+ p = subprocess.Popen([sys.executable, "-c",
+ 'import sys,os,time;'
+ 'sys.stdout.write("a" * (64 * 1024));'
+ 'time.sleep(0.2);'
+ 'sys.stdout.write("a" * (64 * 1024));'
+ 'time.sleep(0.2);'
+ 'sys.stdout.write("a" * (64 * 1024));'
+ 'time.sleep(0.2);'
+ 'sys.stdout.write("a" * (64 * 1024));'],
+ stdout=subprocess.PIPE)
+ self.assertRaises(subprocess.TimeoutExpired, p.communicate, timeout=0.4)
+ (stdout, _) = p.communicate()
+ self.assertEqual(len(stdout), 4 * 64 * 1024)
+
# Test for the fd leak reported in http://bugs.python.org/issue2791.
def test_communicate_pipe_fd_leak(self):
for stdin_pipe in (False, True):
@@ -516,24 +641,21 @@ class ProcessTestCase(BaseTestCase):
# This test will probably deadlock rather than fail, if
# communicate() does not work properly.
x, y = os.pipe()
- if mswindows:
- pipe_buf = 512
- else:
- pipe_buf = os.fpathconf(x, "PC_PIPE_BUF")
os.close(x)
os.close(y)
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;'
'sys.stdout.write(sys.stdin.read(47));'
- 'sys.stderr.write("xyz"*%d);'
- 'sys.stdout.write(sys.stdin.read())' % pipe_buf],
+ 'sys.stderr.write("x" * %d);'
+ 'sys.stdout.write(sys.stdin.read())' %
+ support.PIPE_MAX_SIZE],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
self.addCleanup(p.stdin.close)
- string_to_write = b"abc"*pipe_buf
+ string_to_write = b"a" * support.PIPE_MAX_SIZE
(stdout, stderr) = p.communicate(string_to_write)
self.assertEqual(stdout, string_to_write)
@@ -615,12 +737,12 @@ class ProcessTestCase(BaseTestCase):
def test_universal_newlines_communicate_stdin(self):
# universal newlines through communicate(), with only stdin
p = subprocess.Popen([sys.executable, "-c",
- 'import sys,os;' + SETBINARY + '''\nif True:
- s = sys.stdin.readline()
- assert s == "line1\\n", repr(s)
- s = sys.stdin.read()
- assert s == "line3\\n", repr(s)
- '''],
+ 'import sys,os;' + SETBINARY + textwrap.dedent('''
+ s = sys.stdin.readline()
+ assert s == "line1\\n", repr(s)
+ s = sys.stdin.read()
+ assert s == "line3\\n", repr(s)
+ ''')],
stdin=subprocess.PIPE,
universal_newlines=1)
(stdout, stderr) = p.communicate("line1\nline3\n")
@@ -641,18 +763,18 @@ class ProcessTestCase(BaseTestCase):
def test_universal_newlines_communicate_stdin_stdout_stderr(self):
# universal newlines through communicate(), with stdin, stdout, stderr
p = subprocess.Popen([sys.executable, "-c",
- 'import sys,os;' + SETBINARY + '''\nif True:
- s = sys.stdin.buffer.readline()
- sys.stdout.buffer.write(s)
- sys.stdout.buffer.write(b"line2\\r")
- sys.stderr.buffer.write(b"eline2\\n")
- s = sys.stdin.buffer.read()
- sys.stdout.buffer.write(s)
- sys.stdout.buffer.write(b"line4\\n")
- sys.stdout.buffer.write(b"line5\\r\\n")
- sys.stderr.buffer.write(b"eline6\\r")
- sys.stderr.buffer.write(b"eline7\\r\\nz")
- '''],
+ 'import sys,os;' + SETBINARY + textwrap.dedent('''
+ s = sys.stdin.buffer.readline()
+ sys.stdout.buffer.write(s)
+ sys.stdout.buffer.write(b"line2\\r")
+ sys.stderr.buffer.write(b"eline2\\n")
+ s = sys.stdin.buffer.read()
+ sys.stdout.buffer.write(s)
+ sys.stdout.buffer.write(b"line4\\n")
+ sys.stdout.buffer.write(b"line5\\r\\n")
+ sys.stderr.buffer.write(b"eline6\\r")
+ sys.stderr.buffer.write(b"eline7\\r\\nz")
+ ''')],
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
@@ -696,7 +818,6 @@ class ProcessTestCase(BaseTestCase):
stdout, stderr = popen.communicate(input='')
finally:
locale.getpreferredencoding = old_getpreferredencoding
-
self.assertEqual(stdout, '1\n2\n3\n4')
def test_no_leaking(self):
@@ -756,30 +877,32 @@ class ProcessTestCase(BaseTestCase):
self.assertEqual(subprocess.list2cmdline(['ab', '']),
'ab ""')
-
def test_poll(self):
- p = subprocess.Popen([sys.executable,
- "-c", "import time; time.sleep(1)"])
- count = 0
- while p.poll() is None:
- time.sleep(0.1)
- count += 1
- # We expect that the poll loop probably went around about 10 times,
- # but, based on system scheduling we can't control, it's possible
- # poll() never returned None. It "should be" very rare that it
- # didn't go around at least twice.
- self.assertGreaterEqual(count, 2)
+ p = subprocess.Popen([sys.executable, "-c",
+ "import os; os.read(0, 1)"],
+ stdin=subprocess.PIPE)
+ self.addCleanup(p.stdin.close)
+ self.assertIsNone(p.poll())
+ os.write(p.stdin.fileno(), b'A')
+ p.wait()
# Subsequent invocations should just return the returncode
self.assertEqual(p.poll(), 0)
-
def test_wait(self):
- p = subprocess.Popen([sys.executable,
- "-c", "import time; time.sleep(2)"])
+ p = subprocess.Popen([sys.executable, "-c", "pass"])
self.assertEqual(p.wait(), 0)
# Subsequent invocations should just return the returncode
self.assertEqual(p.wait(), 0)
+ def test_wait_timeout(self):
+ p = subprocess.Popen([sys.executable,
+ "-c", "import time; time.sleep(0.1)"])
+ with self.assertRaises(subprocess.TimeoutExpired) as c:
+ p.wait(timeout=0.01)
+ self.assertIn("0.01", str(c.exception)) # For coverage of __str__.
+ # Some heavily loaded buildbots (sparc Debian 3.x) require this much
+ # time to start.
+ self.assertEqual(p.wait(timeout=3), 0)
def test_invalid_bufsize(self):
# an invalid type of the bufsize argument should raise
@@ -858,25 +981,29 @@ class ProcessTestCase(BaseTestCase):
p = subprocess.Popen([sys.executable, "-c", 'pass'],
stdin=subprocess.PIPE)
self.addCleanup(p.stdin.close)
- time.sleep(2)
+ p.wait()
p.communicate(b"x" * 2**20)
- @unittest.skipUnless(hasattr(signal, 'SIGALRM'),
- "Requires signal.SIGALRM")
+ @unittest.skipUnless(hasattr(signal, 'SIGUSR1'),
+ "Requires signal.SIGUSR1")
+ @unittest.skipUnless(hasattr(os, 'kill'),
+ "Requires os.kill")
+ @unittest.skipUnless(hasattr(os, 'getppid'),
+ "Requires os.getppid")
def test_communicate_eintr(self):
# Issue #12493: communicate() should handle EINTR
def handler(signum, frame):
pass
- old_handler = signal.signal(signal.SIGALRM, handler)
- self.addCleanup(signal.signal, signal.SIGALRM, old_handler)
+ old_handler = signal.signal(signal.SIGUSR1, handler)
+ self.addCleanup(signal.signal, signal.SIGUSR1, old_handler)
- # the process is running for 2 seconds
- args = [sys.executable, "-c", 'import time; time.sleep(2)']
+ args = [sys.executable, "-c",
+ 'import os, signal;'
+ 'os.kill(os.getppid(), signal.SIGUSR1)']
for stream in ('stdout', 'stderr'):
kw = {stream: subprocess.PIPE}
with subprocess.Popen(args, **kw) as process:
- signal.alarm(1)
- # communicate() will be interrupted by SIGALRM
+ # communicate() will be interrupted by SIGUSR1
process.communicate()
@@ -1500,6 +1627,11 @@ class POSIXProcessTestCase(BaseTestCase):
exitcode = subprocess.call([abs_program, "-c", "pass"])
self.assertEqual(exitcode, 0)
+ # absolute bytes path as a string
+ cmd = b"'" + abs_program + b"' -c pass"
+ exitcode = subprocess.call(cmd, shell=True)
+ self.assertEqual(exitcode, 0)
+
# bytes program, unicode PATH
env = os.environ.copy()
env["PATH"] = path
@@ -1683,7 +1815,7 @@ class POSIXProcessTestCase(BaseTestCase):
stdout, stderr = p.communicate()
self.assertEqual(0, p.returncode, "sigchild_ignore.py exited"
" non-zero with this error:\n%s" %
- stderr.decode('utf8'))
+ stderr.decode('utf-8'))
def test_select_unbuffered(self):
# Issue #11459: bufsize=0 should really set the pipes as
@@ -1930,28 +2062,6 @@ class ProcessTestCaseNoPoll(ProcessTestCase):
ProcessTestCase.tearDown(self)
-@unittest.skipUnless(getattr(subprocess, '_posixsubprocess', False),
- "_posixsubprocess extension module not found.")
-class ProcessTestCasePOSIXPurePython(ProcessTestCase, POSIXProcessTestCase):
- @classmethod
- def setUpClass(cls):
- global subprocess
- assert subprocess._posixsubprocess
- # Reimport subprocess while forcing _posixsubprocess to not exist.
- with support.check_warnings(('.*_posixsubprocess .* not being used.*',
- RuntimeWarning)):
- subprocess = support.import_fresh_module(
- 'subprocess', blocked=['_posixsubprocess'])
- assert not subprocess._posixsubprocess
-
- @classmethod
- def tearDownClass(cls):
- global subprocess
- # Reimport subprocess as it should be, restoring order to the universe.
- subprocess = support.import_fresh_module('subprocess')
- assert subprocess._posixsubprocess
-
-
class HelperFunctionTests(unittest.TestCase):
@unittest.skipIf(mswindows, "errno and EINTR make no sense on windows")
def test_eintr_retry_call(self):
@@ -2046,20 +2156,17 @@ class ContextManagerTests(BaseTestCase):
self.assertEqual(proc.returncode, 1)
def test_invalid_args(self):
- with self.assertRaises(EnvironmentError) as c:
+ with self.assertRaises(FileNotFoundError) as c:
with subprocess.Popen(['nonexisting_i_hope'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as proc:
pass
- self.assertEqual(c.exception.errno, errno.ENOENT)
-
def test_main():
unit_tests = (ProcessTestCase,
POSIXProcessTestCase,
Win32ProcessTestCase,
- ProcessTestCasePOSIXPurePython,
CommandTests,
ProcessTestCaseNoPoll,
HelperFunctionTests,
diff --git a/Lib/test/test_sundry.py b/Lib/test/test_sundry.py
index 07802d6..fcccdf7 100644
--- a/Lib/test/test_sundry.py
+++ b/Lib/test/test_sundry.py
@@ -9,7 +9,6 @@ class TestUntestedModules(unittest.TestCase):
with support.check_warnings(quiet=True):
import bdb
import cgitb
- import code
import distutils.bcppcompiler
import distutils.ccompiler
diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py
index 914216d..f6469cf 100644
--- a/Lib/test/test_super.py
+++ b/Lib/test/test_super.py
@@ -81,6 +81,55 @@ class TestSuper(unittest.TestCase):
self.assertEqual(E().f(), 'AE')
+ @unittest.expectedFailure
+ def test___class___set(self):
+ # See issue #12370
+ class X(A):
+ def f(self):
+ return super().f()
+ __class__ = 413
+ x = X()
+ self.assertEqual(x.f(), 'A')
+ self.assertEqual(x.__class__, 413)
+
+ def test___class___instancemethod(self):
+ # See issue #14857
+ class X:
+ def f(self):
+ return __class__
+ self.assertIs(X().f(), X)
+
+ def test___class___classmethod(self):
+ # See issue #14857
+ class X:
+ @classmethod
+ def f(cls):
+ return __class__
+ self.assertIs(X.f(), X)
+
+ def test___class___staticmethod(self):
+ # See issue #14857
+ class X:
+ @staticmethod
+ def f():
+ return __class__
+ self.assertIs(X.f(), X)
+
+ def test_obscure_super_errors(self):
+ def f():
+ super()
+ self.assertRaises(RuntimeError, f)
+ def f(x):
+ del x
+ super()
+ self.assertRaises(RuntimeError, f, None)
+ class X:
+ def f(x):
+ nonlocal __class__
+ del __class__
+ super()
+ self.assertRaises(RuntimeError, X().f)
+
def test_main():
support.run_unittest(TestSuper)
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
new file mode 100644
index 0000000..f6ef5f6
--- /dev/null
+++ b/Lib/test/test_support.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python
+
+import importlib
+import sys
+import os
+import unittest
+import socket
+import tempfile
+import errno
+from test import support
+
+TESTFN = support.TESTFN
+TESTDIRN = os.path.basename(tempfile.mkdtemp(dir='.'))
+
+
+class TestSupport(unittest.TestCase):
+ def setUp(self):
+ support.unlink(TESTFN)
+ support.rmtree(TESTDIRN)
+ tearDown = setUp
+
+ def test_import_module(self):
+ support.import_module("ftplib")
+ self.assertRaises(unittest.SkipTest, support.import_module, "foo")
+
+ def test_import_fresh_module(self):
+ support.import_fresh_module("ftplib")
+
+ def test_get_attribute(self):
+ self.assertEqual(support.get_attribute(self, "test_get_attribute"),
+ self.test_get_attribute)
+ self.assertRaises(unittest.SkipTest, support.get_attribute, self, "foo")
+
+ @unittest.skip("failing buildbots")
+ def test_get_original_stdout(self):
+ self.assertEqual(support.get_original_stdout(), sys.stdout)
+
+ def test_unload(self):
+ import sched
+ self.assertIn("sched", sys.modules)
+ support.unload("sched")
+ self.assertNotIn("sched", sys.modules)
+
+ def test_unlink(self):
+ with open(TESTFN, "w") as f:
+ pass
+ support.unlink(TESTFN)
+ self.assertFalse(os.path.exists(TESTFN))
+ support.unlink(TESTFN)
+
+ def test_rmtree(self):
+ os.mkdir(TESTDIRN)
+ os.mkdir(os.path.join(TESTDIRN, TESTDIRN))
+ support.rmtree(TESTDIRN)
+ self.assertFalse(os.path.exists(TESTDIRN))
+ support.rmtree(TESTDIRN)
+
+ def test_forget(self):
+ mod_filename = TESTFN + '.py'
+ with open(mod_filename, 'w') as f:
+ print('foo = 1', file=f)
+ sys.path.insert(0, os.curdir)
+ importlib.invalidate_caches()
+ try:
+ mod = __import__(TESTFN)
+ self.assertIn(TESTFN, sys.modules)
+
+ support.forget(TESTFN)
+ self.assertNotIn(TESTFN, sys.modules)
+ finally:
+ del sys.path[0]
+ support.unlink(mod_filename)
+
+ def test_HOST(self):
+ s = socket.socket()
+ s.bind((support.HOST, 0))
+ s.close()
+
+ def test_find_unused_port(self):
+ port = support.find_unused_port()
+ s = socket.socket()
+ s.bind((support.HOST, port))
+ s.close()
+
+ def test_bind_port(self):
+ s = socket.socket()
+ support.bind_port(s)
+ s.listen(1)
+ s.close()
+
+ def test_temp_cwd(self):
+ here = os.getcwd()
+ with support.temp_cwd(name=TESTFN):
+ self.assertEqual(os.path.basename(os.getcwd()), TESTFN)
+ self.assertFalse(os.path.exists(TESTFN))
+ self.assertTrue(os.path.basename(os.getcwd()), here)
+
+ def test_temp_cwd__chdir_warning(self):
+ """Check the warning message when os.chdir() fails."""
+ path = TESTFN + '_does_not_exist'
+ with support.check_warnings() as recorder:
+ with support.temp_cwd(path=path, quiet=True):
+ pass
+ messages = [str(w.message) for w in recorder.warnings]
+ self.assertEqual(messages, ['tests may fail, unable to change the CWD to ' + path])
+
+ def test_sortdict(self):
+ self.assertEqual(support.sortdict({3:3, 2:2, 1:1}), "{1: 1, 2: 2, 3: 3}")
+
+ def test_make_bad_fd(self):
+ fd = support.make_bad_fd()
+ with self.assertRaises(OSError) as cm:
+ os.write(fd, b"foo")
+ self.assertEqual(cm.exception.errno, errno.EBADF)
+
+ def test_check_syntax_error(self):
+ support.check_syntax_error(self, "def class")
+ self.assertRaises(AssertionError, support.check_syntax_error, self, "1")
+
+ def test_CleanImport(self):
+ import importlib
+ with support.CleanImport("asyncore"):
+ importlib.import_module("asyncore")
+
+ def test_DirsOnSysPath(self):
+ with support.DirsOnSysPath('foo', 'bar'):
+ self.assertIn("foo", sys.path)
+ self.assertIn("bar", sys.path)
+ self.assertNotIn("foo", sys.path)
+ self.assertNotIn("bar", sys.path)
+
+ def test_captured_stdout(self):
+ with support.captured_stdout() as s:
+ print("hello")
+ self.assertEqual(s.getvalue(), "hello\n")
+
+ def test_captured_stderr(self):
+ with support.captured_stderr() as s:
+ print("hello", file=sys.stderr)
+ self.assertEqual(s.getvalue(), "hello\n")
+
+ def test_captured_stdin(self):
+ with support.captured_stdin() as s:
+ print("hello", file=sys.stdin)
+ self.assertEqual(s.getvalue(), "hello\n")
+
+ def test_gc_collect(self):
+ support.gc_collect()
+
+ def test_python_is_optimized(self):
+ self.assertIsInstance(support.python_is_optimized(), bool)
+
+ def test_swap_attr(self):
+ class Obj:
+ x = 1
+ obj = Obj()
+ with support.swap_attr(obj, "x", 5):
+ self.assertEqual(obj.x, 5)
+ self.assertEqual(obj.x, 1)
+
+ def test_swap_item(self):
+ D = {"item":1}
+ with support.swap_item(D, "item", 5):
+ self.assertEqual(D["item"], 5)
+ self.assertEqual(D["item"], 1)
+
+ # XXX -follows a list of untested API
+ # make_legacy_pyc
+ # is_resource_enabled
+ # requires
+ # fcmp
+ # umaks
+ # findfile
+ # check_warnings
+ # EnvironmentVarGuard
+ # TransientResource
+ # transient_internet
+ # run_with_locale
+ # set_memlimit
+ # bigmemtest
+ # precisionbigmemtest
+ # bigaddrspacetest
+ # requires_resource
+ # run_doctest
+ # threading_cleanup
+ # reap_threads
+ # reap_children
+ # strip_python_stderr
+ # args_from_interpreter_flags
+ # can_symlink
+ # skip_unless_symlink
+
+
+def test_main():
+ tests = [TestSupport]
+ support.run_unittest(*tests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index a52767a..e5ec85c 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -303,6 +303,7 @@ class SysModuleTest(unittest.TestCase):
self.assertEqual(sys.getdlopenflags(), oldflags+1)
sys.setdlopenflags(oldflags)
+ @test.support.refcount_test
def test_refcount(self):
# n here must be a global in order for this test to pass while
# tracing with a python function. Tracing calls PyFrame_FastToLocals
@@ -342,7 +343,7 @@ class SysModuleTest(unittest.TestCase):
# Test sys._current_frames() in a WITH_THREADS build.
@test.support.reap_threads
def current_frames_with_threads(self):
- import threading, _thread
+ import threading
import traceback
# Spawn a thread that blocks at a known place. Then the main
@@ -356,7 +357,7 @@ class SysModuleTest(unittest.TestCase):
g456()
def g456():
- thread_info.append(_thread.get_ident())
+ thread_info.append(threading.get_ident())
entered_g.set()
leave_g.wait()
@@ -372,7 +373,7 @@ class SysModuleTest(unittest.TestCase):
d = sys._current_frames()
- main_id = _thread.get_ident()
+ main_id = threading.get_ident()
self.assertIn(main_id, d)
self.assertIn(thread_id, d)
@@ -418,6 +419,7 @@ class SysModuleTest(unittest.TestCase):
self.assertIsInstance(sys.builtin_module_names, tuple)
self.assertIsInstance(sys.copyright, str)
self.assertIsInstance(sys.exec_prefix, str)
+ self.assertIsInstance(sys.base_exec_prefix, str)
self.assertIsInstance(sys.executable, str)
self.assertEqual(len(sys.float_info), 11)
self.assertEqual(sys.float_info.radix, 2)
@@ -446,8 +448,10 @@ class SysModuleTest(unittest.TestCase):
self.assertIsInstance(sys.maxsize, int)
self.assertIsInstance(sys.maxunicode, int)
+ self.assertEqual(sys.maxunicode, 0x10FFFF)
self.assertIsInstance(sys.platform, str)
self.assertIsInstance(sys.prefix, str)
+ self.assertIsInstance(sys.base_prefix, str)
self.assertIsInstance(sys.version, str)
vi = sys.version_info
self.assertIsInstance(vi[:], tuple)
@@ -473,6 +477,14 @@ class SysModuleTest(unittest.TestCase):
if not sys.platform.startswith('win'):
self.assertIsInstance(sys.abiflags, str)
+ @unittest.skipUnless(hasattr(sys, 'thread_info'),
+ 'Threading required for this test.')
+ def test_thread_info(self):
+ info = sys.thread_info
+ self.assertEqual(len(info), 3)
+ self.assertIn(info.name, ('nt', 'os2', 'pthread', 'solaris', None))
+ self.assertIn(info.lock, ('semaphore', 'mutex+cond', None))
+
def test_43581(self):
# Can't use sys.stdout, as this is a StringIO object when
# the test runs under regrtest.
@@ -500,7 +512,7 @@ class SysModuleTest(unittest.TestCase):
def test_sys_flags(self):
self.assertTrue(sys.flags)
- attrs = ("debug", "division_warning",
+ attrs = ("debug",
"inspect", "interactive", "optimize", "dont_write_bytecode",
"no_user_site", "no_site", "ignore_environment", "verbose",
"bytes_warning", "quiet", "hash_randomization")
@@ -532,6 +544,8 @@ class SysModuleTest(unittest.TestCase):
out = p.communicate()[0].strip()
self.assertEqual(out, b'?')
+ @unittest.skipIf(sys.base_prefix != sys.prefix,
+ 'Test is not venv-compatible')
def test_executable(self):
# sys.executable should be absolute
self.assertEqual(os.path.abspath(sys.executable), sys.executable)
@@ -568,6 +582,34 @@ class SysModuleTest(unittest.TestCase):
expected = None
self.check_fsencoding(fs_encoding, expected)
+ def test_implementation(self):
+ # This test applies to all implementations equally.
+
+ levels = {'alpha': 0xA, 'beta': 0xB, 'candidate': 0xC, 'final': 0xF}
+
+ self.assertTrue(hasattr(sys.implementation, 'name'))
+ self.assertTrue(hasattr(sys.implementation, 'version'))
+ self.assertTrue(hasattr(sys.implementation, 'hexversion'))
+ self.assertTrue(hasattr(sys.implementation, 'cache_tag'))
+
+ version = sys.implementation.version
+ self.assertEqual(version[:2], (version.major, version.minor))
+
+ hexversion = (version.major << 24 | version.minor << 16 |
+ version.micro << 8 | levels[version.releaselevel] << 4 |
+ version.serial << 0)
+ self.assertEqual(sys.implementation.hexversion, hexversion)
+
+ # PEP 421 requires that .name be lower case.
+ self.assertEqual(sys.implementation.name,
+ sys.implementation.name.lower())
+
+ def test_debugmallocstats(self):
+ # Test sys._debugmallocstats()
+ from test.script_helper import assert_python_ok
+ args = ['-c', 'import sys; sys._debugmallocstats()']
+ ret, out, err = assert_python_ok(*args)
+ self.assertIn(b"free PyDictObjects", err)
class SizeofTest(unittest.TestCase):
@@ -591,12 +633,12 @@ class SizeofTest(unittest.TestCase):
# bool objects are not gc tracked
self.assertEqual(sys.getsizeof(True), vsize('') + self.longdigit)
# but lists are
- self.assertEqual(sys.getsizeof([]), vsize('PP') + gc_header_size)
+ self.assertEqual(sys.getsizeof([]), vsize('Pn') + gc_header_size)
def test_default(self):
- vsize = test.support.calcvobjsize
- self.assertEqual(sys.getsizeof(True), vsize('') + self.longdigit)
- self.assertEqual(sys.getsizeof(True, -1), vsize('') + self.longdigit)
+ size = test.support.calcvobjsize
+ self.assertEqual(sys.getsizeof(True), size('') + self.longdigit)
+ self.assertEqual(sys.getsizeof(True, -1), size('') + self.longdigit)
def test_objecttypes(self):
# check all types defined in Objects/
@@ -613,9 +655,9 @@ class SizeofTest(unittest.TestCase):
samples = [b'', b'u'*100000]
for sample in samples:
x = bytearray(sample)
- check(x, vsize('iPP') + x.__alloc__())
+ check(x, vsize('inP') + x.__alloc__())
# bytearray_iterator
- check(iter(bytearray()), size('PP'))
+ check(iter(bytearray()), size('nP'))
# cell
def get_cell():
x = 42
@@ -624,27 +666,33 @@ class SizeofTest(unittest.TestCase):
return inner
check(get_cell().__closure__[0], size('P'))
# code
- check(get_cell().__code__, size('5i8Pi3P'))
+ check(get_cell().__code__, size('5i9Pi3P'))
+ check(get_cell.__code__, size('5i9Pi3P'))
+ def get_cell2(x):
+ def inner():
+ return x
+ return inner
+ check(get_cell2.__code__, size('5i9Pi3P') + 1)
# complex
check(complex(0,1), size('2d'))
# method_descriptor (descriptor object)
- check(str.lower, size('2PP'))
+ check(str.lower, size('3PP'))
# classmethod_descriptor (descriptor object)
# XXX
# member_descriptor (descriptor object)
import datetime
- check(datetime.timedelta.days, size('2PP'))
+ check(datetime.timedelta.days, size('3PP'))
# getset_descriptor (descriptor object)
import collections
- check(collections.defaultdict.default_factory, size('2PP'))
+ check(collections.defaultdict.default_factory, size('3PP'))
# wrapper_descriptor (descriptor object)
- check(int.__add__, size('2P2P'))
+ check(int.__add__, size('3P2P'))
# method-wrapper (descriptor object)
check({}.__iter__, size('2P'))
# dict
- check({}, size('3P2P' + 8*'P2P'))
+ check({}, size('n2P' + '2nPn' + 8*'n2P'))
longdict = {1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8}
- check(longdict, size('3P2P' + 8*'P2P') + 16*struct.calcsize('P2P'))
+ check(longdict, size('n2P' + '2nPn') + 16*struct.calcsize('n2P'))
# dictionary-keyiterator
check({}.keys(), size('P'))
# dictionary-valueiterator
@@ -652,18 +700,18 @@ class SizeofTest(unittest.TestCase):
# dictionary-itemiterator
check({}.items(), size('P'))
# dictionary iterator
- check(iter({}), size('P2PPP'))
+ check(iter({}), size('P2nPn'))
# dictproxy
class C(object): pass
check(C.__dict__, size('P'))
# BaseException
- check(BaseException(), size('5P'))
+ check(BaseException(), size('5Pi'))
# UnicodeEncodeError
- check(UnicodeEncodeError("", "", 0, 0, ""), size('5P 2P2PP'))
+ check(UnicodeEncodeError("", "", 0, 0, ""), size('5Pi 2P2nP'))
# UnicodeDecodeError
- check(UnicodeDecodeError("", b"", 0, 0, ""), size('5P 2P2PP'))
+ check(UnicodeDecodeError("", b"", 0, 0, ""), size('5Pi 2P2nP'))
# UnicodeTranslateError
- check(UnicodeTranslateError("", 0, 1, ""), size('5P 2P2PP'))
+ check(UnicodeTranslateError("", 0, 1, ""), size('5Pi 2P2nP'))
# ellipses
check(Ellipsis, size(''))
# EncodingMap
@@ -671,9 +719,9 @@ class SizeofTest(unittest.TestCase):
x = codecs.charmap_build(encodings.iso8859_3.decoding_table)
check(x, size('32B2iB'))
# enumerate
- check(enumerate([]), size('l3P'))
+ check(enumerate([]), size('n3P'))
# reverse
- check(reversed(''), size('PP'))
+ check(reversed(''), size('nP'))
# float
check(float(0), size('d'))
# sys.floatinfo
@@ -689,7 +737,7 @@ class SizeofTest(unittest.TestCase):
check(x, vsize('12P3i' + CO_MAXBLOCKS*'3i' + 'P' + extras*'P'))
# function
def func(): pass
- check(func, size('11P'))
+ check(func, size('12P'))
class c():
@staticmethod
def foo():
@@ -698,12 +746,12 @@ class SizeofTest(unittest.TestCase):
def bar(cls):
pass
# staticmethod
- check(foo, size('P'))
+ check(foo, size('PP'))
# classmethod
- check(bar, size('P'))
+ check(bar, size('PP'))
# generator
def get_gen(): yield 1
- check(get_gen(), size('Pi2P'))
+ check(get_gen(), size('Pb2P'))
# iterator
check(iter('abc'), size('lP'))
# callable-iterator
@@ -712,7 +760,7 @@ class SizeofTest(unittest.TestCase):
# list
samples = [[], [1,2,3], ['1', '2', '3']]
for sample in samples:
- check(sample, vsize('PP') + len(sample)*self.P)
+ check(sample, vsize('Pn') + len(sample)*self.P)
# sortwrapper (list)
# XXX
# cmpwrapper (list)
@@ -720,7 +768,7 @@ class SizeofTest(unittest.TestCase):
# listiterator (list)
check(iter([]), size('lP'))
# listreverseiterator (list)
- check(reversed([]), size('lP'))
+ check(reversed([]), size('nP'))
# long
check(0, vsize(''))
check(1, vsize('') + self.longdigit)
@@ -730,9 +778,9 @@ class SizeofTest(unittest.TestCase):
check(int(PyLong_BASE**2-1), vsize('') + 2*self.longdigit)
check(int(PyLong_BASE**2), vsize('') + 3*self.longdigit)
# memoryview
- check(memoryview(b''), size('PP2P2i7P'))
+ check(memoryview(b''), size('Pnin 2P2n2i5P 3cPn'))
# module
- check(unittest, size('3P'))
+ check(unittest, size('PnP'))
# None
check(None, size(''))
# NotImplementedType
@@ -751,7 +799,7 @@ class SizeofTest(unittest.TestCase):
# rangeiterator
check(iter(range(1)), size('4l'))
# reverse
- check(reversed(''), size('PP'))
+ check(reversed(''), size('nP'))
# range
check(range(1), size('4P'))
check(range(66000), size('4P'))
@@ -759,7 +807,7 @@ class SizeofTest(unittest.TestCase):
# frozenset
PySet_MINSIZE = 8
samples = [[], range(10), range(50)]
- s = size('3P2P' + PySet_MINSIZE*'PP' + 'PP')
+ s = size('3n2P' + PySet_MINSIZE*'nP' + 'nP')
for sample in samples:
minused = len(sample)
if minused == 0: tmp = 1
@@ -773,10 +821,10 @@ class SizeofTest(unittest.TestCase):
check(set(sample), s)
check(frozenset(sample), s)
else:
- check(set(sample), s + newsize*struct.calcsize('lP'))
- check(frozenset(sample), s + newsize*struct.calcsize('lP'))
+ check(set(sample), s + newsize*struct.calcsize('nP'))
+ check(frozenset(sample), s + newsize*struct.calcsize('nP'))
# setiterator
- check(iter(set()), size('P3P'))
+ check(iter(set()), size('P3n'))
# slice
check(slice(0), size('3P'))
# super
@@ -785,28 +833,56 @@ class SizeofTest(unittest.TestCase):
check((), vsize(''))
check((1,2,3), vsize('') + 3*self.P)
# type
- # (PyTypeObject + PyNumberMethods + PyMappingMethods +
- # PySequenceMethods + PyBufferProcs)
- s = vsize('P2P15Pl4PP9PP11PI') + struct.calcsize('16Pi17P 3P 10P 2P 2P')
+ # static type: PyTypeObject
+ s = vsize('P2n15Pl4Pn9Pn11PI')
check(int, s)
+ # (PyTypeObject + PyNumberMethods + PyMappingMethods +
+ # PySequenceMethods + PyBufferProcs + 4P)
+ s = vsize('P2n15Pl4Pn9Pn11PI') + struct.calcsize('34P 3P 10P 2P 4P')
+ # Separate block for PyDictKeysObject with 4 entries
+ s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P")
# class
class newstyleclass(object): pass
check(newstyleclass, s)
+ # dict with shared keys
+ check(newstyleclass().__dict__, size('n2P' + '2nPn'))
# unicode
- usize = len('\0'.encode('unicode-internal'))
- samples = ['', '1'*100]
- # we need to test for both sizes, because we don't know if the string
- # has been cached
+ # each tuple contains a string and its expected character size
+ # don't put any static strings here, as they may contain
+ # wchar_t or UTF-8 representations
+ samples = ['1'*100, '\xff'*50,
+ '\u0100'*40, '\uffff'*100,
+ '\U00010000'*30, '\U0010ffff'*100]
+ asciifields = "nniP"
+ compactfields = asciifields + "nPn"
+ unicodefields = compactfields + "P"
for s in samples:
- basicsize = size('PPPiP') + usize * (len(s) + 1)
- check(s, basicsize)
+ maxchar = ord(max(s))
+ if maxchar < 128:
+ L = size(asciifields) + len(s) + 1
+ elif maxchar < 256:
+ L = size(compactfields) + len(s) + 1
+ elif maxchar < 65536:
+ L = size(compactfields) + 2*(len(s) + 1)
+ else:
+ L = size(compactfields) + 4*(len(s) + 1)
+ check(s, L)
+ # verify that the UTF-8 size is accounted for
+ s = chr(0x4000) # 4 bytes canonical representation
+ check(s, size(compactfields) + 4)
+ # compile() will trigger the generation of the UTF-8
+ # representation as a side effect
+ compile(s, "<stdin>", "eval")
+ check(s, size(compactfields) + 4 + 4)
+ # TODO: add check that forces the presence of wchar_t representation
+ # TODO: add check that forces layout of unicodefields
# weakref
import weakref
- check(weakref.ref(int), size('2Pl2P'))
+ check(weakref.ref(int), size('2Pn2P'))
# weakproxy
# XXX
# weakcallableproxy
- check(weakref.proxy(int), size('2Pl2P'))
+ check(weakref.proxy(int), size('2Pn2P'))
def test_pythontypes(self):
# check all types defined in Python/
@@ -815,16 +891,13 @@ class SizeofTest(unittest.TestCase):
check = self.check_sizeof
# _ast.AST
import _ast
- check(_ast.AST(), size(''))
- # imp.NullImporter
- import imp
- check(imp.NullImporter(self.file.name), size(''))
+ check(_ast.AST(), size('P'))
try:
raise TypeError
except TypeError:
tb = sys.exc_info()[2]
# traceback
- if tb != None:
+ if tb is not None:
check(tb, size('2P2i'))
# symtable entry
# XXX
diff --git a/Lib/test/test_sys_settrace.py b/Lib/test/test_sys_settrace.py
index ba3bc2e..63ae1b7 100644
--- a/Lib/test/test_sys_settrace.py
+++ b/Lib/test/test_sys_settrace.py
@@ -251,6 +251,7 @@ class TraceTestCase(unittest.TestCase):
def setUp(self):
self.using_gc = gc.isenabled()
gc.disable()
+ self.addCleanup(sys.settrace, sys.gettrace())
def tearDown(self):
if self.using_gc:
@@ -389,6 +390,9 @@ class TraceTestCase(unittest.TestCase):
class RaisingTraceFuncTestCase(unittest.TestCase):
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+
def trace(self, frame, event, arg):
"""A trace function that raises an exception in response to a
specific trace event."""
@@ -696,6 +700,10 @@ def no_jump_without_trace_function():
class JumpTestCase(unittest.TestCase):
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+ sys.settrace(None)
+
def compare_jump_output(self, expected, received):
if received != expected:
self.fail( "Outputs don't match:\n" +
@@ -747,6 +755,8 @@ class JumpTestCase(unittest.TestCase):
def test_18_no_jump_to_non_integers(self):
self.run_test(no_jump_to_non_integers)
def test_19_no_jump_without_trace_function(self):
+ # Must set sys.settrace(None) in setUp(), else condition is not
+ # triggered.
no_jump_without_trace_function()
def test_jump_across_with(self):
self.addCleanup(support.unlink, support.TESTFN)
diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py
index 1b6897d..9219360 100644
--- a/Lib/test/test_sysconfig.py
+++ b/Lib/test/test_sysconfig.py
@@ -1,11 +1,9 @@
-"""Tests for sysconfig."""
-
import unittest
import sys
import os
import subprocess
import shutil
-from copy import copy, deepcopy
+from copy import copy
from test.support import (run_unittest, TESTFN, unlink,
captured_stdout, skip_unless_symlink)
@@ -20,7 +18,6 @@ import _osx_support
class TestSysConfig(unittest.TestCase):
def setUp(self):
- """Make a copy of sys.path"""
super(TestSysConfig, self).setUp()
self.sys_path = sys.path[:]
# patching os.uname
@@ -29,7 +26,7 @@ class TestSysConfig(unittest.TestCase):
self._uname = os.uname()
else:
self.uname = None
- self._uname = None
+ self._set_uname(('',)*5)
os.uname = self._get_uname
# saving the environment
self.name = os.name
@@ -39,11 +36,16 @@ class TestSysConfig(unittest.TestCase):
self.join = os.path.join
self.isabs = os.path.isabs
self.splitdrive = os.path.splitdrive
- self._config_vars = copy(sysconfig._CONFIG_VARS)
- self.old_environ = deepcopy(os.environ)
+ self._config_vars = sysconfig._CONFIG_VARS, copy(sysconfig._CONFIG_VARS)
+ self._added_envvars = []
+ self._changed_envvars = []
+ for var in ('MACOSX_DEPLOYMENT_TARGET', 'PATH'):
+ if var in os.environ:
+ self._changed_envvars.append((var, os.environ[var]))
+ else:
+ self._added_envvars.append(var)
def tearDown(self):
- """Restore sys.path"""
sys.path[:] = self.sys_path
self._cleanup_testfn()
if self.uname is not None:
@@ -57,19 +59,18 @@ class TestSysConfig(unittest.TestCase):
os.path.join = self.join
os.path.isabs = self.isabs
os.path.splitdrive = self.splitdrive
- sysconfig._CONFIG_VARS = copy(self._config_vars)
- for key, value in self.old_environ.items():
- if os.environ.get(key) != value:
- os.environ[key] = value
-
- for key in list(os.environ.keys()):
- if key not in self.old_environ:
- del os.environ[key]
+ sysconfig._CONFIG_VARS = self._config_vars[0]
+ sysconfig._CONFIG_VARS.clear()
+ sysconfig._CONFIG_VARS.update(self._config_vars[1])
+ for var, value in self._changed_envvars:
+ os.environ[var] = value
+ for var in self._added_envvars:
+ os.environ.pop(var, None)
super(TestSysConfig, self).tearDown()
def _set_uname(self, uname):
- self._uname = uname
+ self._uname = os.uname_result(uname)
def _get_uname(self):
return self._uname
@@ -88,21 +89,19 @@ class TestSysConfig(unittest.TestCase):
scheme = get_paths()
default_scheme = _get_default_scheme()
wanted = _expand_vars(default_scheme, None)
- wanted = list(wanted.items())
- wanted.sort()
- scheme = list(scheme.items())
- scheme.sort()
+ wanted = sorted(wanted.items())
+ scheme = sorted(scheme.items())
self.assertEqual(scheme, wanted)
def test_get_path(self):
- # xxx make real tests here
+ # XXX make real tests here
for scheme in _INSTALL_SCHEMES:
for name in _INSTALL_SCHEMES[scheme]:
res = get_path(name, scheme)
def test_get_config_vars(self):
cvars = get_config_vars()
- self.assertTrue(isinstance(cvars, dict))
+ self.assertIsInstance(cvars, dict)
self.assertTrue(cvars)
def test_get_platform(self):
@@ -244,8 +243,8 @@ class TestSysConfig(unittest.TestCase):
# On Windows, the EXE needs to know where pythonXY.dll is at so we have
# to add the directory to the path.
if sys.platform == "win32":
- os.environ["Path"] = "{};{}".format(
- os.path.dirname(sys.executable), os.environ["Path"])
+ os.environ["PATH"] = "{};{}".format(
+ os.path.dirname(sys.executable), os.environ["PATH"])
# Issue 7880
def get(python):
@@ -269,12 +268,17 @@ class TestSysConfig(unittest.TestCase):
# the global scheme mirrors the distinction between prefix and
# exec-prefix but not the user scheme, so we have to adapt the paths
# before comparing (issue #9100)
- adapt = sys.prefix != sys.exec_prefix
+ adapt = sys.base_prefix != sys.base_exec_prefix
for name in ('stdlib', 'platstdlib', 'purelib', 'platlib'):
global_path = get_path(name, 'posix_prefix')
if adapt:
- global_path = global_path.replace(sys.exec_prefix, sys.prefix)
- base = base.replace(sys.exec_prefix, sys.prefix)
+ global_path = global_path.replace(sys.exec_prefix, sys.base_prefix)
+ base = base.replace(sys.exec_prefix, sys.base_prefix)
+ elif sys.base_prefix != sys.prefix:
+ # virtual environment? Likewise, we have to adapt the paths
+ # before comparing
+ global_path = global_path.replace(sys.base_prefix, sys.prefix)
+ base = base.replace(sys.base_prefix, sys.prefix)
user_path = get_path(name, 'posix_user')
self.assertEqual(user_path, global_path.replace(base, user, 1))
@@ -291,7 +295,6 @@ class TestSysConfig(unittest.TestCase):
self.assertIn(ldflags, ldshared)
-
@unittest.skipUnless(sys.platform == "darwin", "test only relevant on MacOSX")
def test_platform_in_subprocess(self):
my_platform = sysconfig.get_platform()
@@ -317,28 +320,58 @@ class TestSysConfig(unittest.TestCase):
self.assertEqual(status, 0)
self.assertEqual(my_platform, test_platform)
-
# Test with MACOSX_DEPLOYMENT_TARGET in the environment, and
# using a value that is unlikely to be the default one.
env = os.environ.copy()
env['MACOSX_DEPLOYMENT_TARGET'] = '10.1'
- p = subprocess.Popen([
- sys.executable, '-c',
- 'import sysconfig; print(sysconfig.get_platform())',
- ],
- stdout=subprocess.PIPE,
- stderr=open('/dev/null'),
- env=env)
- test_platform = p.communicate()[0].strip()
- test_platform = test_platform.decode('utf-8')
- status = p.wait()
-
- self.assertEqual(status, 0)
- self.assertEqual(my_platform, test_platform)
+ with open('/dev/null') as dev_null:
+ p = subprocess.Popen([
+ sys.executable, '-c',
+ 'import sysconfig; print(sysconfig.get_platform())',
+ ],
+ stdout=subprocess.PIPE,
+ stderr=dev_null,
+ env=env)
+ test_platform = p.communicate()[0].strip()
+ test_platform = test_platform.decode('utf-8')
+ status = p.wait()
+
+ self.assertEqual(status, 0)
+ self.assertEqual(my_platform, test_platform)
+
+ def test_srcdir(self):
+ # See Issues #15322, #15364.
+ srcdir = sysconfig.get_config_var('srcdir')
+
+ self.assertTrue(os.path.isabs(srcdir), srcdir)
+ self.assertTrue(os.path.isdir(srcdir), srcdir)
+
+ if sysconfig._PYTHON_BUILD:
+ # The python executable has not been installed so srcdir
+ # should be a full source checkout.
+ Python_h = os.path.join(srcdir, 'Include', 'Python.h')
+ self.assertTrue(os.path.exists(Python_h), Python_h)
+ self.assertTrue(sysconfig._is_python_source_dir(srcdir))
+ elif os.name == 'posix':
+ self.assertEqual(os.path.dirname(sysconfig.get_makefile_filename()),
+ srcdir)
+
+ def test_srcdir_independent_of_cwd(self):
+ # srcdir should be independent of the current working directory
+ # See Issues #15322, #15364.
+ srcdir = sysconfig.get_config_var('srcdir')
+ cwd = os.getcwd()
+ try:
+ os.chdir('..')
+ srcdir2 = sysconfig.get_config_var('srcdir')
+ finally:
+ os.chdir(cwd)
+ self.assertEqual(srcdir, srcdir2)
class MakefileTests(unittest.TestCase):
+
@unittest.skipIf(sys.platform.startswith('win'),
'Test is not Windows compatible')
def test_get_makefile_filename(self):
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index a53181e..b224bf0 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -21,6 +21,10 @@ try:
import bz2
except ImportError:
bz2 = None
+try:
+ import lzma
+except ImportError:
+ lzma = None
def md5sum(data):
return md5(data).hexdigest()
@@ -29,6 +33,7 @@ TEMPDIR = os.path.abspath(support.TESTFN) + "-tardir"
tarname = support.findfile("testtar.tar")
gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
+xzname = os.path.join(TEMPDIR, "testtar.tar.xz")
tmpname = os.path.join(TEMPDIR, "tmp.tar")
md5_regtype = "65f477c818ad9e15f7feab0c6d37742f"
@@ -51,13 +56,10 @@ class UstarReadTest(ReadTest):
def test_fileobj_regular_file(self):
tarinfo = self.tar.getmember("ustar/regtype")
- fobj = self.tar.extractfile(tarinfo)
- try:
+ with self.tar.extractfile(tarinfo) as fobj:
data = fobj.read()
self.assertTrue((len(data), md5sum(data)) == (tarinfo.size, md5_regtype),
"regular file extraction failed")
- finally:
- fobj.close()
def test_fileobj_readlines(self):
self.tar.extract("ustar/regtype", TEMPDIR)
@@ -65,8 +67,7 @@ class UstarReadTest(ReadTest):
with open(os.path.join(TEMPDIR, "ustar/regtype"), "r") as fobj1:
lines1 = fobj1.readlines()
- fobj = self.tar.extractfile(tarinfo)
- try:
+ with self.tar.extractfile(tarinfo) as fobj:
fobj2 = io.TextIOWrapper(fobj)
lines2 = fobj2.readlines()
self.assertTrue(lines1 == lines2,
@@ -76,21 +77,16 @@ class UstarReadTest(ReadTest):
self.assertTrue(lines2[83] ==
"I will gladly admit that Python is not the fastest running scripting language.\n",
"fileobj.readlines() failed")
- finally:
- fobj.close()
def test_fileobj_iter(self):
self.tar.extract("ustar/regtype", TEMPDIR)
tarinfo = self.tar.getmember("ustar/regtype")
- with open(os.path.join(TEMPDIR, "ustar/regtype"), "rU") as fobj1:
+ with open(os.path.join(TEMPDIR, "ustar/regtype"), "r") as fobj1:
lines1 = fobj1.readlines()
- fobj2 = self.tar.extractfile(tarinfo)
- try:
+ with self.tar.extractfile(tarinfo) as fobj2:
lines2 = list(io.TextIOWrapper(fobj2))
self.assertTrue(lines1 == lines2,
"fileobj.__iter__() failed")
- finally:
- fobj2.close()
def test_fileobj_seek(self):
self.tar.extract("ustar/regtype", TEMPDIR)
@@ -142,17 +138,24 @@ class UstarReadTest(ReadTest):
"read() after readline() failed")
fobj.close()
+ def test_fileobj_text(self):
+ with self.tar.extractfile("ustar/regtype") as fobj:
+ fobj = io.TextIOWrapper(fobj)
+ data = fobj.read().encode("iso8859-1")
+ self.assertEqual(md5sum(data), md5_regtype)
+ try:
+ fobj.seek(100)
+ except AttributeError:
+ # Issue #13815: seek() complained about a missing
+ # flush() method.
+ self.fail("seeking failed in text mode")
+
# Test if symbolic and hard links are resolved by extractfile(). The
# test link members each point to a regular member whose data is
# supposed to be exported.
def _test_fileobj_link(self, lnktype, regtype):
- a = self.tar.extractfile(lnktype)
- b = self.tar.extractfile(regtype)
- try:
+ with self.tar.extractfile(lnktype) as a, self.tar.extractfile(regtype) as b:
self.assertEqual(a.name, b.name)
- finally:
- a.close()
- b.close()
def test_fileobj_link1(self):
self._test_fileobj_link("ustar/lnktype", "ustar/regtype")
@@ -204,13 +207,15 @@ class CommonReadTest(ReadTest):
_open = gzip.GzipFile
elif self.mode.endswith(":bz2"):
_open = bz2.BZ2File
+ elif self.mode.endswith(":xz"):
+ _open = lzma.LZMAFile
else:
- _open = open
+ _open = io.FileIO
for char in (b'\0', b'a'):
# Test if EOFHeaderError ('\0') and InvalidHeaderError ('a')
# are ignored correctly.
- with _open(tmpname, "wb") as fobj:
+ with _open(tmpname, "w") as fobj:
fobj.write(char * 1024)
fobj.write(tarfile.TarInfo("foo").tobuf())
@@ -225,6 +230,10 @@ class CommonReadTest(ReadTest):
class MiscReadTest(CommonReadTest):
def test_no_name_argument(self):
+ if self.mode.endswith(("bz2", "xz")):
+ # BZ2File and LZMAFile have no name attribute.
+ self.skipTest("no name attribute")
+
with open(self.tarname, "rb") as fobj:
tar = tarfile.open(fileobj=fobj, mode=self.mode)
self.assertEqual(tar.name, os.path.abspath(fobj.name))
@@ -254,9 +263,8 @@ class MiscReadTest(CommonReadTest):
t = tar.next()
name = t.name
offset = t.offset
- f = tar.extractfile(t)
- data = f.read()
- f.close()
+ with tar.extractfile(t) as f:
+ data = f.read()
finally:
tar.close()
@@ -265,10 +273,12 @@ class MiscReadTest(CommonReadTest):
_open = gzip.GzipFile
elif self.mode.endswith(":bz2"):
_open = bz2.BZ2File
+ elif self.mode.endswith(":xz"):
+ _open = lzma.LZMAFile
else:
- _open = open
- fobj = _open(self.tarname, "rb")
- try:
+ _open = io.FileIO
+
+ with _open(self.tarname) as fobj:
fobj.seek(offset)
# Test if the tarfile starts with the second member.
@@ -281,8 +291,6 @@ class MiscReadTest(CommonReadTest):
self.assertEqual(tar.extractfile(t).read(), data,
"seek back did not work")
tar.close()
- finally:
- fobj.close()
def test_fail_comp(self):
# For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
@@ -418,27 +426,26 @@ class StreamReadTest(CommonReadTest):
for tarinfo in self.tar:
if not tarinfo.isreg():
continue
- fobj = self.tar.extractfile(tarinfo)
- while True:
- try:
- buf = fobj.read(512)
- except tarfile.StreamError:
- self.fail("simple read-through using TarFile.extractfile() failed")
- if not buf:
- break
- fobj.close()
+ with self.tar.extractfile(tarinfo) as fobj:
+ while True:
+ try:
+ buf = fobj.read(512)
+ except tarfile.StreamError:
+ self.fail("simple read-through using TarFile.extractfile() failed")
+ if not buf:
+ break
def test_fileobj_regular_file(self):
tarinfo = self.tar.next() # get "regtype" (can't use getmember)
- fobj = self.tar.extractfile(tarinfo)
- data = fobj.read()
+ with self.tar.extractfile(tarinfo) as fobj:
+ data = fobj.read()
self.assertTrue((len(data), md5sum(data)) == (tarinfo.size, md5_regtype),
"regular file extraction failed")
def test_provoke_stream_error(self):
tarinfos = self.tar.getmembers()
- f = self.tar.extractfile(tarinfos[0]) # read the first member
- self.assertRaises(tarfile.StreamError, f.read)
+ with self.tar.extractfile(tarinfos[0]) as f: # read the first member
+ self.assertRaises(tarfile.StreamError, f.read)
def test_compare_members(self):
tar1 = tarfile.open(tarname, encoding="iso8859-1")
@@ -516,6 +523,18 @@ class DetectReadTest(unittest.TestCase):
testfunc(bz2name, "r|*")
testfunc(bz2name, "r|bz2")
+ if lzma:
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r:xz")
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r|xz")
+ self.assertRaises(tarfile.ReadError, tarfile.open, xzname, mode="r:")
+ self.assertRaises(tarfile.ReadError, tarfile.open, xzname, mode="r|")
+
+ testfunc(xzname, "r")
+ testfunc(xzname, "r:*")
+ testfunc(xzname, "r:xz")
+ testfunc(xzname, "r|*")
+ testfunc(xzname, "r|xz")
+
def test_detect_file(self):
self._test_modes(self._testfunc_file)
@@ -713,7 +732,7 @@ class GNUReadTest(LongnameTest):
# Return True if the platform knows the st_blocks stat attribute and
# uses st_blocks units of 512 bytes, and if the filesystem is able to
# store holes in files.
- if sys.platform == "linux2":
+ if sys.platform.startswith("linux"):
# Linux evidentially has 512 byte st_blocks units.
name = os.path.join(TEMPDIR, "sparse-test")
with open(name, "wb") as fobj:
@@ -903,7 +922,7 @@ class WriteTest(WriteTestBase):
try:
for name in ("foo", "bar", "baz"):
name = os.path.join(tempdir, name)
- open(name, "wb").close()
+ support.create_empty_file(name)
exclude = os.path.isfile
@@ -930,7 +949,7 @@ class WriteTest(WriteTestBase):
try:
for name in ("foo", "bar", "baz"):
name = os.path.join(tempdir, name)
- open(name, "wb").close()
+ support.create_empty_file(name)
def filter(tarinfo):
if os.path.basename(tarinfo.name) == "bar":
@@ -969,7 +988,7 @@ class WriteTest(WriteTestBase):
# and compare the stored name with the original.
foo = os.path.join(TEMPDIR, "foo")
if not dir:
- open(foo, "w").close()
+ support.create_empty_file(foo)
else:
os.mkdir(foo)
@@ -1086,6 +1105,9 @@ class StreamWriteTest(WriteTestBase):
data = dec.decompress(data)
self.assertTrue(len(dec.unused_data) == 0,
"found trailing data")
+ elif self.mode.endswith("xz"):
+ with lzma.LZMAFile(tmpname) as fobj:
+ data = fobj.read()
else:
with open(tmpname, "rb") as fobj:
data = fobj.read()
@@ -1329,7 +1351,7 @@ class UstarUnicodeTest(unittest.TestCase):
self._test_unicode_filename("utf7")
def test_utf8_filename(self):
- self._test_unicode_filename("utf8")
+ self._test_unicode_filename("utf-8")
def _test_unicode_filename(self, encoding):
tar = tarfile.open(tmpname, "w", format=self.format, encoding=encoding, errors="strict")
@@ -1408,7 +1430,7 @@ class GNUUnicodeTest(UstarUnicodeTest):
def test_bad_pax_header(self):
# Test for issue #8633. GNU tar <= 1.23 creates raw binary fields
# without a hdrcharset=BINARY header.
- for encoding, name in (("utf8", "pax/bad-pax-\udce4\udcf6\udcfc"),
+ for encoding, name in (("utf-8", "pax/bad-pax-\udce4\udcf6\udcfc"),
("iso8859-1", "pax/bad-pax-\xe4\xf6\xfc"),):
with tarfile.open(tarname, encoding=encoding, errors="surrogateescape") as tar:
try:
@@ -1423,7 +1445,7 @@ class PAXUnicodeTest(UstarUnicodeTest):
def test_binary_header(self):
# Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field.
- for encoding, name in (("utf8", "pax/hdrcharset-\udce4\udcf6\udcfc"),
+ for encoding, name in (("utf-8", "pax/hdrcharset-\udce4\udcf6\udcfc"),
("iso8859-1", "pax/hdrcharset-\xe4\xf6\xfc"),):
with tarfile.open(tarname, encoding=encoding, errors="surrogateescape") as tar:
try:
@@ -1448,12 +1470,9 @@ class AppendTest(unittest.TestCase):
with tarfile.open(tarname, encoding="iso8859-1") as src:
t = src.getmember("ustar/regtype")
t.name = "foo"
- f = src.extractfile(t)
- try:
+ with src.extractfile(t) as f:
with tarfile.open(self.tarname, mode) as tar:
tar.addfile(t, f)
- finally:
- f.close()
def _test(self, names=["bar"], fileobj=None):
with tarfile.open(self.tarname, fileobj=fileobj) as tar:
@@ -1500,6 +1519,12 @@ class AppendTest(unittest.TestCase):
self._create_testtar("w:bz2")
self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a")
+ def test_append_lzma(self):
+ if lzma is None:
+ self.skipTest("lzma module not available")
+ self._create_testtar("w:xz")
+ self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a")
+
# Append mode is supposed to fail if the tarfile to append to
# does not end with a zero block.
def _test_error(self, data):
@@ -1778,6 +1803,21 @@ class Bz2PartialReadTest(unittest.TestCase):
self._test_partial_input("r:bz2")
+class LzmaMiscReadTest(MiscReadTest):
+ tarname = xzname
+ mode = "r:xz"
+class LzmaUstarReadTest(UstarReadTest):
+ tarname = xzname
+ mode = "r:xz"
+class LzmaStreamReadTest(StreamReadTest):
+ tarname = xzname
+ mode = "r|xz"
+class LzmaWriteTest(WriteTest):
+ mode = "w:xz"
+class LzmaStreamWriteTest(StreamWriteTest):
+ mode = "w|xz"
+
+
def test_main():
support.unlink(TEMPDIR)
os.makedirs(TEMPDIR)
@@ -1840,6 +1880,20 @@ def test_main():
Bz2PartialReadTest,
]
+ if lzma:
+ # Create testtar.tar.xz and add lzma-specific tests.
+ support.unlink(xzname)
+ with lzma.LZMAFile(xzname, "w") as tar:
+ tar.write(data)
+
+ tests += [
+ LzmaMiscReadTest,
+ LzmaUstarReadTest,
+ LzmaStreamReadTest,
+ LzmaWriteTest,
+ LzmaStreamWriteTest,
+ ]
+
try:
support.run_unittest(*tests)
finally:
diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py
index 95eebf3..11e95e7 100644
--- a/Lib/test/test_telnetlib.py
+++ b/Lib/test/test_telnetlib.py
@@ -36,6 +36,7 @@ class GeneralTests(TestCase):
def tearDown(self):
self.thread.join()
+ del self.thread # Clear out any dangling Thread objects.
def testBasic(self):
# connects
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index cf8b35a..2962939 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -23,7 +23,7 @@ has_spawnl = hasattr(os, 'spawnl')
# TEST_FILES may need to be tweaked for systems depending on the maximum
# number of files that can be opened at one time (see ulimit -n)
-if sys.platform in ('openbsd3', 'openbsd4'):
+if sys.platform.startswith('openbsd'):
TEST_FILES = 48
else:
TEST_FILES = 100
@@ -33,7 +33,7 @@ else:
# threads is not done here.
# Common functionality.
-class TC(unittest.TestCase):
+class BaseTestCase(unittest.TestCase):
str_check = re.compile(r"[a-zA-Z0-9_-]{6}$")
@@ -47,11 +47,6 @@ class TC(unittest.TestCase):
self._warnings_manager.__exit__(None, None, None)
- def failOnException(self, what, ei=None):
- if ei is None:
- ei = sys.exc_info()
- self.fail("%s raised %s: %s" % (what, ei[0], ei[1]))
-
def nameCheck(self, name, dir, pre, suf):
(ndir, nbase) = os.path.split(name)
npre = nbase[:len(pre)]
@@ -70,9 +65,8 @@ class TC(unittest.TestCase):
"random string '%s' does not match /^[a-zA-Z0-9_-]{6}$/"
% nbase)
-test_classes = []
-class test_exports(TC):
+class TestExports(BaseTestCase):
def test_exports(self):
# There are no surprising symbols in the tempfile module
dict = tempfile.__dict__
@@ -99,10 +93,8 @@ class test_exports(TC):
self.assertTrue(len(unexp) == 0,
"unexpected keys: %s" % unexp)
-test_classes.append(test_exports)
-
-class test__RandomNameSequence(TC):
+class TestRandomNameSequence(BaseTestCase):
"""Test the internal iterator object _RandomNameSequence."""
def setUp(self):
@@ -130,13 +122,10 @@ class test__RandomNameSequence(TC):
i = 0
r = self.r
- try:
- for s in r:
- i += 1
- if i == 20:
- break
- except:
- self.failOnException("iteration")
+ for s in r:
+ i += 1
+ if i == 20:
+ break
@unittest.skipUnless(hasattr(os, 'fork'),
"os.fork is required for this test")
@@ -169,10 +158,8 @@ class test__RandomNameSequence(TC):
self.assertNotEqual(child_value, parent_value)
-test_classes.append(test__RandomNameSequence)
-
-class test__candidate_tempdir_list(TC):
+class TestCandidateTempdirList(BaseTestCase):
"""Test the internal function _candidate_tempdir_list."""
def test_nonempty_list(self):
@@ -211,11 +198,10 @@ class test__candidate_tempdir_list(TC):
# Not practical to try to verify the presence of OS-specific
# paths in this list.
-test_classes.append(test__candidate_tempdir_list)
# We test _get_default_tempdir some more by testing gettempdir.
-class TestGetDefaultTempdir(TC):
+class TestGetDefaultTempdir(BaseTestCase):
"""Test _get_default_tempdir()."""
def test_no_files_left_behind(self):
@@ -232,13 +218,12 @@ class TestGetDefaultTempdir(TC):
self.assertEqual(os.listdir(our_temp_directory), [])
def raise_OSError(*args, **kwargs):
- raise OSError(-1)
+ raise OSError()
with support.swap_attr(io, "open", raise_OSError):
# test again with failing io.open()
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(FileNotFoundError):
tempfile._get_default_tempdir()
- self.assertEqual(cm.exception.args[0], errno.ENOENT)
self.assertEqual(os.listdir(our_temp_directory), [])
open = io.open
@@ -249,15 +234,12 @@ class TestGetDefaultTempdir(TC):
with support.swap_attr(io, "open", bad_writer):
# test again with failing write()
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(FileNotFoundError):
tempfile._get_default_tempdir()
- self.assertEqual(cm.exception.errno, errno.ENOENT)
self.assertEqual(os.listdir(our_temp_directory), [])
-test_classes.append(TestGetDefaultTempdir)
-
-class test__get_candidate_names(TC):
+class TestGetCandidateNames(BaseTestCase):
"""Test the internal function _get_candidate_names."""
def test_retval(self):
@@ -272,10 +254,8 @@ class test__get_candidate_names(TC):
self.assertTrue(a is b)
-test_classes.append(test__get_candidate_names)
-
-class test__mkstemp_inner(TC):
+class TestMkstempInner(BaseTestCase):
"""Test the internal function _mkstemp_inner."""
class mkstemped:
@@ -300,10 +280,7 @@ class test__mkstemp_inner(TC):
def do_create(self, dir=None, pre="", suf="", bin=1):
if dir is None:
dir = tempfile.gettempdir()
- try:
- file = self.mkstemped(dir, pre, suf, bin)
- except:
- self.failOnException("_mkstemp_inner")
+ file = self.mkstemped(dir, pre, suf, bin)
self.nameCheck(file.name, dir, pre, suf)
return file
@@ -395,10 +372,8 @@ class test__mkstemp_inner(TC):
os.lseek(f.fd, 0, os.SEEK_SET)
self.assertEqual(os.read(f.fd, 20), b"blat")
-test_classes.append(test__mkstemp_inner)
-
-class test_gettempprefix(TC):
+class TestGetTempPrefix(BaseTestCase):
"""Test gettempprefix()."""
def test_sane_template(self):
@@ -418,19 +393,14 @@ class test_gettempprefix(TC):
d = tempfile.mkdtemp(prefix="")
try:
p = os.path.join(d, p)
- try:
- fd = os.open(p, os.O_RDWR | os.O_CREAT)
- except:
- self.failOnException("os.open")
+ fd = os.open(p, os.O_RDWR | os.O_CREAT)
os.close(fd)
os.unlink(p)
finally:
os.rmdir(d)
-test_classes.append(test_gettempprefix)
-
-class test_gettempdir(TC):
+class TestGetTempDir(BaseTestCase):
"""Test gettempdir()."""
def test_directory_exists(self):
@@ -448,12 +418,9 @@ class test_gettempdir(TC):
# sneaky: just instantiate a NamedTemporaryFile, which
# defaults to writing into the directory returned by
# gettempdir.
- try:
- file = tempfile.NamedTemporaryFile()
- file.write(b"blat")
- file.close()
- except:
- self.failOnException("create file in %s" % tempfile.gettempdir())
+ file = tempfile.NamedTemporaryFile()
+ file.write(b"blat")
+ file.close()
def test_same_thing(self):
# gettempdir always returns the same object
@@ -462,23 +429,18 @@ class test_gettempdir(TC):
self.assertTrue(a is b)
-test_classes.append(test_gettempdir)
-
-class test_mkstemp(TC):
+class TestMkstemp(BaseTestCase):
"""Test mkstemp()."""
def do_create(self, dir=None, pre="", suf=""):
if dir is None:
dir = tempfile.gettempdir()
- try:
- (fd, name) = tempfile.mkstemp(dir=dir, prefix=pre, suffix=suf)
- (ndir, nbase) = os.path.split(name)
- adir = os.path.abspath(dir)
- self.assertEqual(adir, ndir,
- "Directory '%s' incorrectly returned as '%s'" % (adir, ndir))
- except:
- self.failOnException("mkstemp")
+ (fd, name) = tempfile.mkstemp(dir=dir, prefix=pre, suffix=suf)
+ (ndir, nbase) = os.path.split(name)
+ adir = os.path.abspath(dir)
+ self.assertEqual(adir, ndir,
+ "Directory '%s' incorrectly returned as '%s'" % (adir, ndir))
try:
self.nameCheck(name, dir, pre, suf)
@@ -503,19 +465,14 @@ class test_mkstemp(TC):
finally:
os.rmdir(dir)
-test_classes.append(test_mkstemp)
-
-class test_mkdtemp(TC):
+class TestMkdtemp(BaseTestCase):
"""Test mkdtemp()."""
def do_create(self, dir=None, pre="", suf=""):
if dir is None:
dir = tempfile.gettempdir()
- try:
- name = tempfile.mkdtemp(dir=dir, prefix=pre, suffix=suf)
- except:
- self.failOnException("mkdtemp")
+ name = tempfile.mkdtemp(dir=dir, prefix=pre, suffix=suf)
try:
self.nameCheck(name, dir, pre, suf)
@@ -570,10 +527,8 @@ class test_mkdtemp(TC):
finally:
os.rmdir(dir)
-test_classes.append(test_mkdtemp)
-
-class test_mktemp(TC):
+class TestMktemp(BaseTestCase):
"""Test mktemp()."""
# For safety, all use of mktemp must occur in a private directory.
@@ -602,10 +557,7 @@ class test_mktemp(TC):
self._unlink(self.name)
def do_create(self, pre="", suf=""):
- try:
- file = self.mktemped(self.dir, pre, suf)
- except:
- self.failOnException("mktemp")
+ file = self.mktemped(self.dir, pre, suf)
self.nameCheck(file.name, self.dir, pre, suf)
return file
@@ -632,23 +584,18 @@ class test_mktemp(TC):
## self.assertRaises(RuntimeWarning,
## tempfile.mktemp, dir=self.dir)
-test_classes.append(test_mktemp)
-
# We test _TemporaryFileWrapper by testing NamedTemporaryFile.
-class test_NamedTemporaryFile(TC):
+class TestNamedTemporaryFile(BaseTestCase):
"""Test NamedTemporaryFile()."""
def do_create(self, dir=None, pre="", suf="", delete=True):
if dir is None:
dir = tempfile.gettempdir()
- try:
- file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf,
- delete=delete)
- except:
- self.failOnException("NamedTemporaryFile")
+ file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf,
+ delete=delete)
self.nameCheck(file.name, dir, pre, suf)
return file
@@ -701,11 +648,8 @@ class test_NamedTemporaryFile(TC):
f = tempfile.NamedTemporaryFile()
f.write(b'abc\n')
f.close()
- try:
- f.close()
- f.close()
- except:
- self.failOnException("close")
+ f.close()
+ f.close()
def test_context_manager(self):
# A NamedTemporaryFile can be used as a context manager
@@ -719,18 +663,14 @@ class test_NamedTemporaryFile(TC):
# How to test the mode and bufsize parameters?
-test_classes.append(test_NamedTemporaryFile)
-class test_SpooledTemporaryFile(TC):
+class TestSpooledTemporaryFile(BaseTestCase):
"""Test SpooledTemporaryFile()."""
def do_create(self, max_size=0, dir=None, pre="", suf=""):
if dir is None:
dir = tempfile.gettempdir()
- try:
- file = tempfile.SpooledTemporaryFile(max_size=max_size, dir=dir, prefix=pre, suffix=suf)
- except:
- self.failOnException("SpooledTemporaryFile")
+ file = tempfile.SpooledTemporaryFile(max_size=max_size, dir=dir, prefix=pre, suffix=suf)
return file
@@ -818,11 +758,8 @@ class test_SpooledTemporaryFile(TC):
f.write(b'abc\n')
self.assertFalse(f._rolled)
f.close()
- try:
- f.close()
- f.close()
- except:
- self.failOnException("close")
+ f.close()
+ f.close()
def test_multiple_close_after_rollover(self):
# A SpooledTemporaryFile can be closed many times without error
@@ -830,11 +767,8 @@ class test_SpooledTemporaryFile(TC):
f.write(b'abc\n')
self.assertTrue(f._rolled)
f.close()
- try:
- f.close()
- f.close()
- except:
- self.failOnException("close")
+ f.close()
+ f.close()
def test_bound_methods(self):
# It should be OK to steal a bound method from a SpooledTemporaryFile
@@ -959,66 +893,76 @@ class test_SpooledTemporaryFile(TC):
pass
self.assertRaises(ValueError, use_closed)
+ def test_truncate_with_size_parameter(self):
+ # A SpooledTemporaryFile can be truncated to zero size
+ f = tempfile.SpooledTemporaryFile(max_size=10)
+ f.write(b'abcdefg\n')
+ f.seek(0)
+ f.truncate()
+ self.assertFalse(f._rolled)
+ self.assertEqual(f._file.getvalue(), b'')
+ # A SpooledTemporaryFile can be truncated to a specific size
+ f = tempfile.SpooledTemporaryFile(max_size=10)
+ f.write(b'abcdefg\n')
+ f.truncate(4)
+ self.assertFalse(f._rolled)
+ self.assertEqual(f._file.getvalue(), b'abcd')
+ # A SpooledTemporaryFile rolls over if truncated to large size
+ f = tempfile.SpooledTemporaryFile(max_size=10)
+ f.write(b'abcdefg\n')
+ f.truncate(20)
+ self.assertTrue(f._rolled)
+ if has_stat:
+ self.assertEqual(os.fstat(f.fileno()).st_size, 20)
-test_classes.append(test_SpooledTemporaryFile)
+if tempfile.NamedTemporaryFile is not tempfile.TemporaryFile:
-class test_TemporaryFile(TC):
- """Test TemporaryFile()."""
+ class TestTemporaryFile(BaseTestCase):
+ """Test TemporaryFile()."""
- def test_basic(self):
- # TemporaryFile can create files
- # No point in testing the name params - the file has no name.
- try:
+ def test_basic(self):
+ # TemporaryFile can create files
+ # No point in testing the name params - the file has no name.
tempfile.TemporaryFile()
- except:
- self.failOnException("TemporaryFile")
- def test_has_no_name(self):
- # TemporaryFile creates files with no names (on this system)
- dir = tempfile.mkdtemp()
- f = tempfile.TemporaryFile(dir=dir)
- f.write(b'blat')
+ def test_has_no_name(self):
+ # TemporaryFile creates files with no names (on this system)
+ dir = tempfile.mkdtemp()
+ f = tempfile.TemporaryFile(dir=dir)
+ f.write(b'blat')
- # Sneaky: because this file has no name, it should not prevent
- # us from removing the directory it was created in.
- try:
- os.rmdir(dir)
- except:
- ei = sys.exc_info()
- # cleanup
+ # Sneaky: because this file has no name, it should not prevent
+ # us from removing the directory it was created in.
+ try:
+ os.rmdir(dir)
+ except:
+ # cleanup
+ f.close()
+ os.rmdir(dir)
+ raise
+
+ def test_multiple_close(self):
+ # A TemporaryFile can be closed many times without error
+ f = tempfile.TemporaryFile()
+ f.write(b'abc\n')
f.close()
- os.rmdir(dir)
- self.failOnException("rmdir", ei)
-
- def test_multiple_close(self):
- # A TemporaryFile can be closed many times without error
- f = tempfile.TemporaryFile()
- f.write(b'abc\n')
- f.close()
- try:
f.close()
f.close()
- except:
- self.failOnException("close")
- # How to test the mode and bufsize parameters?
- def test_mode_and_encoding(self):
+ # How to test the mode and bufsize parameters?
+ def test_mode_and_encoding(self):
- def roundtrip(input, *args, **kwargs):
- with tempfile.TemporaryFile(*args, **kwargs) as fileobj:
- fileobj.write(input)
- fileobj.seek(0)
- self.assertEqual(input, fileobj.read())
+ def roundtrip(input, *args, **kwargs):
+ with tempfile.TemporaryFile(*args, **kwargs) as fileobj:
+ fileobj.write(input)
+ fileobj.seek(0)
+ self.assertEqual(input, fileobj.read())
- roundtrip(b"1234", "w+b")
- roundtrip("abdc\n", "w+")
- roundtrip("\u039B", "w+", encoding="utf-16")
- roundtrip("foo\r\n", "w+", newline="")
-
-
-if tempfile.NamedTemporaryFile is not tempfile.TemporaryFile:
- test_classes.append(test_TemporaryFile)
+ roundtrip(b"1234", "w+b")
+ roundtrip("abdc\n", "w+")
+ roundtrip("\u039B", "w+", encoding="utf-16")
+ roundtrip("foo\r\n", "w+", newline="")
# Helper for test_del_on_shutdown
@@ -1037,16 +981,13 @@ class NulledModules:
d.clear()
d.update(c)
-class test_TemporaryDirectory(TC):
+class TestTemporaryDirectory(BaseTestCase):
"""Test TemporaryDirectory()."""
def do_create(self, dir=None, pre="", suf="", recurse=1):
if dir is None:
dir = tempfile.gettempdir()
- try:
- tmp = tempfile.TemporaryDirectory(dir=dir, prefix=pre, suffix=suf)
- except:
- self.failOnException("TemporaryDirectory")
+ tmp = tempfile.TemporaryDirectory(dir=dir, prefix=pre, suffix=suf)
self.nameCheck(tmp.name, dir, pre, suf)
# Create a subdirectory and some files
if recurse:
@@ -1061,8 +1002,9 @@ class test_TemporaryDirectory(TC):
# (noted as part of Issue #10188)
with tempfile.TemporaryDirectory() as nonexistent:
pass
- with self.assertRaises(os.error):
+ with self.assertRaises(FileNotFoundError) as cm:
tempfile.TemporaryDirectory(dir=nonexistent)
+ self.assertEqual(cm.exception.errno, errno.ENOENT)
def test_explicit_cleanup(self):
# A TemporaryDirectory is deleted when cleaned up
@@ -1170,11 +1112,8 @@ class test_TemporaryDirectory(TC):
# Can be cleaned-up many times without error
d = self.do_create()
d.cleanup()
- try:
- d.cleanup()
- d.cleanup()
- except:
- self.failOnException("cleanup")
+ d.cleanup()
+ d.cleanup()
def test_context_manager(self):
# Can be used as a context manager
@@ -1185,10 +1124,8 @@ class test_TemporaryDirectory(TC):
self.assertFalse(os.path.exists(name))
-test_classes.append(test_TemporaryDirectory)
-
def test_main():
- support.run_unittest(*test_classes)
+ support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_textwrap.py b/Lib/test/test_textwrap.py
index 3901120..c86f5cf 100644
--- a/Lib/test/test_textwrap.py
+++ b/Lib/test/test_textwrap.py
@@ -11,7 +11,7 @@
import unittest
from test import support
-from textwrap import TextWrapper, wrap, fill, dedent
+from textwrap import TextWrapper, wrap, fill, dedent, indent
class BaseTestCase(unittest.TestCase):
@@ -100,6 +100,14 @@ What a mess!
result = wrapper.fill(text)
self.check(result, '\n'.join(expect))
+ text = "\tTest\tdefault\t\ttabsize."
+ expect = [" Test default tabsize."]
+ self.check_wrap(text, 80, expect)
+
+ text = "\tTest\tcustom\t\ttabsize."
+ expect = [" Test custom tabsize."]
+ self.check_wrap(text, 80, expect, tabsize=4)
+
def test_fix_sentence_endings(self):
wrapper = TextWrapper(60, fix_sentence_endings=True)
@@ -634,11 +642,147 @@ def foo():
self.assertEqual(expect, dedent(text))
+# Test textwrap.indent
+class IndentTestCase(unittest.TestCase):
+ # The examples used for tests. If any of these change, the expected
+ # results in the various test cases must also be updated.
+ # The roundtrip cases are separate, because textwrap.dedent doesn't
+ # handle Windows line endings
+ ROUNDTRIP_CASES = (
+ # Basic test case
+ "Hi.\nThis is a test.\nTesting.",
+ # Include a blank line
+ "Hi.\nThis is a test.\n\nTesting.",
+ # Include leading and trailing blank lines
+ "\nHi.\nThis is a test.\nTesting.\n",
+ )
+ CASES = ROUNDTRIP_CASES + (
+ # Use Windows line endings
+ "Hi.\r\nThis is a test.\r\nTesting.\r\n",
+ # Pathological case
+ "\nHi.\r\nThis is a test.\n\r\nTesting.\r\n\n",
+ )
+
+ def test_indent_nomargin_default(self):
+ # indent should do nothing if 'prefix' is empty.
+ for text in self.CASES:
+ self.assertEqual(indent(text, ''), text)
+
+ def test_indent_nomargin_explicit_default(self):
+ # The same as test_indent_nomargin, but explicitly requesting
+ # the default behaviour by passing None as the predicate
+ for text in self.CASES:
+ self.assertEqual(indent(text, '', None), text)
+
+ def test_indent_nomargin_all_lines(self):
+ # The same as test_indent_nomargin, but using the optional
+ # predicate argument
+ predicate = lambda line: True
+ for text in self.CASES:
+ self.assertEqual(indent(text, '', predicate), text)
+
+ def test_indent_no_lines(self):
+ # Explicitly skip indenting any lines
+ predicate = lambda line: False
+ for text in self.CASES:
+ self.assertEqual(indent(text, ' ', predicate), text)
+
+ def test_roundtrip_spaces(self):
+ # A whitespace prefix should roundtrip with dedent
+ for text in self.ROUNDTRIP_CASES:
+ self.assertEqual(dedent(indent(text, ' ')), text)
+
+ def test_roundtrip_tabs(self):
+ # A whitespace prefix should roundtrip with dedent
+ for text in self.ROUNDTRIP_CASES:
+ self.assertEqual(dedent(indent(text, '\t\t')), text)
+
+ def test_roundtrip_mixed(self):
+ # A whitespace prefix should roundtrip with dedent
+ for text in self.ROUNDTRIP_CASES:
+ self.assertEqual(dedent(indent(text, ' \t \t ')), text)
+
+ def test_indent_default(self):
+ # Test default indenting of lines that are not whitespace only
+ prefix = ' '
+ expected = (
+ # Basic test case
+ " Hi.\n This is a test.\n Testing.",
+ # Include a blank line
+ " Hi.\n This is a test.\n\n Testing.",
+ # Include leading and trailing blank lines
+ "\n Hi.\n This is a test.\n Testing.\n",
+ # Use Windows line endings
+ " Hi.\r\n This is a test.\r\n Testing.\r\n",
+ # Pathological case
+ "\n Hi.\r\n This is a test.\n\r\n Testing.\r\n\n",
+ )
+ for text, expect in zip(self.CASES, expected):
+ self.assertEqual(indent(text, prefix), expect)
+
+ def test_indent_explicit_default(self):
+ # Test default indenting of lines that are not whitespace only
+ prefix = ' '
+ expected = (
+ # Basic test case
+ " Hi.\n This is a test.\n Testing.",
+ # Include a blank line
+ " Hi.\n This is a test.\n\n Testing.",
+ # Include leading and trailing blank lines
+ "\n Hi.\n This is a test.\n Testing.\n",
+ # Use Windows line endings
+ " Hi.\r\n This is a test.\r\n Testing.\r\n",
+ # Pathological case
+ "\n Hi.\r\n This is a test.\n\r\n Testing.\r\n\n",
+ )
+ for text, expect in zip(self.CASES, expected):
+ self.assertEqual(indent(text, prefix, None), expect)
+
+ def test_indent_all_lines(self):
+ # Add 'prefix' to all lines, including whitespace-only ones.
+ prefix = ' '
+ expected = (
+ # Basic test case
+ " Hi.\n This is a test.\n Testing.",
+ # Include a blank line
+ " Hi.\n This is a test.\n \n Testing.",
+ # Include leading and trailing blank lines
+ " \n Hi.\n This is a test.\n Testing.\n",
+ # Use Windows line endings
+ " Hi.\r\n This is a test.\r\n Testing.\r\n",
+ # Pathological case
+ " \n Hi.\r\n This is a test.\n \r\n Testing.\r\n \n",
+ )
+ predicate = lambda line: True
+ for text, expect in zip(self.CASES, expected):
+ self.assertEqual(indent(text, prefix, predicate), expect)
+
+ def test_indent_empty_lines(self):
+ # Add 'prefix' solely to whitespace-only lines.
+ prefix = ' '
+ expected = (
+ # Basic test case
+ "Hi.\nThis is a test.\nTesting.",
+ # Include a blank line
+ "Hi.\nThis is a test.\n \nTesting.",
+ # Include leading and trailing blank lines
+ " \nHi.\nThis is a test.\nTesting.\n",
+ # Use Windows line endings
+ "Hi.\r\nThis is a test.\r\nTesting.\r\n",
+ # Pathological case
+ " \nHi.\r\nThis is a test.\n \r\nTesting.\r\n \n",
+ )
+ predicate = lambda line: not line.strip()
+ for text, expect in zip(self.CASES, expected):
+ self.assertEqual(indent(text, prefix, predicate), expect)
+
+
def test_main():
support.run_unittest(WrapTestCase,
LongWordTestCase,
IndentTestCases,
- DedentTestCase)
+ DedentTestCase,
+ IndentTestCase)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py
index 7791935..6c2965b 100644
--- a/Lib/test/test_threaded_import.py
+++ b/Lib/test/test_threaded_import.py
@@ -7,12 +7,13 @@
import os
import imp
+import importlib
import sys
import time
import shutil
import unittest
-from test.support import verbose, import_module, run_unittest, TESTFN
-thread = import_module('_thread')
+from test.support import (
+ verbose, import_module, run_unittest, TESTFN, reap_threads, forget, unlink)
threading = import_module('threading')
def task(N, done, done_tasks, errors):
@@ -30,7 +31,7 @@ def task(N, done, done_tasks, errors):
except Exception as e:
errors.append(e.with_traceback(None))
finally:
- done_tasks.append(thread.get_ident())
+ done_tasks.append(threading.get_ident())
finished = len(done_tasks) == N
if finished:
done.set()
@@ -62,16 +63,17 @@ class Finder:
def __init__(self):
self.numcalls = 0
self.x = 0
- self.lock = thread.allocate_lock()
+ self.lock = threading.Lock()
def find_module(self, name, path=None):
# Simulate some thread-unsafe behaviour. If calls to find_module()
# are properly serialized, `x` will end up the same as `numcalls`.
# Otherwise not.
+ assert imp.lock_held()
with self.lock:
self.numcalls += 1
x = self.x
- time.sleep(0.1)
+ time.sleep(0.01)
self.x = x + 1
class FlushingFinder:
@@ -113,8 +115,10 @@ class ThreadedImportTests(unittest.TestCase):
done_tasks = []
done.clear()
for i in range(N):
- thread.start_new_thread(task, (N, done, done_tasks, errors,))
- done.wait(60)
+ t = threading.Thread(target=task,
+ args=(N, done, done_tasks, errors,))
+ t.start()
+ self.assertTrue(done.wait(60))
self.assertFalse(errors)
if verbose:
print("OK.")
@@ -124,7 +128,7 @@ class ThreadedImportTests(unittest.TestCase):
def test_parallel_meta_path(self):
finder = Finder()
- sys.meta_path.append(finder)
+ sys.meta_path.insert(0, finder)
try:
self.check_parallel_module_init()
self.assertGreater(finder.numcalls, 0)
@@ -143,7 +147,7 @@ class ThreadedImportTests(unittest.TestCase):
def path_hook(path):
finder.find_module('')
raise ImportError
- sys.path_hooks.append(path_hook)
+ sys.path_hooks.insert(0, path_hook)
sys.meta_path.append(flushing_finder)
try:
# Flush the cache a first time
@@ -185,8 +189,9 @@ class ThreadedImportTests(unittest.TestCase):
contents = contents % {'delay': delay}
with open(os.path.join(TESTFN, name + ".py"), "wb") as f:
f.write(contents.encode('utf-8'))
- self.addCleanup(sys.modules.pop, name, None)
+ self.addCleanup(forget, name)
+ importlib.invalidate_caches()
results = []
def import_ab():
import A
@@ -202,9 +207,38 @@ class ThreadedImportTests(unittest.TestCase):
t2.join()
self.assertEqual(set(results), {'a', 'b'})
-
+ def test_side_effect_import(self):
+ code = """if 1:
+ import threading
+ def target():
+ import random
+ t = threading.Thread(target=target)
+ t.start()
+ t.join()"""
+ sys.path.insert(0, os.curdir)
+ self.addCleanup(sys.path.remove, os.curdir)
+ filename = TESTFN + ".py"
+ with open(filename, "wb") as f:
+ f.write(code.encode('utf-8'))
+ self.addCleanup(unlink, filename)
+ self.addCleanup(forget, TESTFN)
+ importlib.invalidate_caches()
+ __import__(TESTFN)
+
+
+@reap_threads
def test_main():
- run_unittest(ThreadedImportTests)
+ old_switchinterval = None
+ try:
+ old_switchinterval = sys.getswitchinterval()
+ sys.setswitchinterval(1e-5)
+ except AttributeError:
+ pass
+ try:
+ run_unittest(ThreadedImportTests)
+ finally:
+ if old_switchinterval is not None:
+ sys.setswitchinterval(old_switchinterval)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index 17be84b..11e63d3 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -175,7 +175,7 @@ class ThreadTests(BaseTestCase):
exception = ctypes.py_object(AsyncExc)
# First check it works when setting the exception from the same thread.
- tid = _thread.get_ident()
+ tid = threading.get_ident()
try:
result = set_async_exc(ctypes.c_long(tid), exception)
@@ -204,7 +204,7 @@ class ThreadTests(BaseTestCase):
class Worker(threading.Thread):
def run(self):
- self.id = _thread.get_ident()
+ self.id = threading.get_ident()
self.finished = False
try:
@@ -409,6 +409,14 @@ class ThreadTests(BaseTestCase):
t.daemon = True
self.assertTrue('daemon' in repr(t))
+ def test_deamon_param(self):
+ t = threading.Thread()
+ self.assertFalse(t.daemon)
+ t = threading.Thread(daemon=False)
+ self.assertFalse(t.daemon)
+ t = threading.Thread(daemon=True)
+ self.assertTrue(t.daemon)
+
@unittest.skipUnless(hasattr(os, 'fork'), 'test needs fork()')
def test_dummy_thread_after_fork(self):
# Issue #14308: a dummy thread in the active list doesn't mess up
@@ -444,7 +452,7 @@ class ThreadJoinOnShutdown(BaseTestCase):
# problems with some operating systems (issue #3863): skip problematic tests
# on platforms known to behave badly.
platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5',
- 'os2emx')
+ 'os2emx', 'hp-ux11')
def _run_and_join(self, script):
script = """if 1:
@@ -742,6 +750,10 @@ class ThreadingExceptionTests(BaseTestCase):
thread.start()
self.assertRaises(RuntimeError, setattr, thread, "daemon", True)
+ def test_releasing_unacquired_lock(self):
+ lock = threading.Lock()
+ self.assertRaises(RuntimeError, lock.release)
+
@unittest.skipUnless(sys.platform == 'darwin', 'test macosx problem')
def test_recursion_limit(self):
# Issue 9670
@@ -803,6 +815,7 @@ class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests):
class BarrierTests(lock_tests.BarrierTests):
barriertype = staticmethod(threading.Barrier)
+
def test_main():
test.support.run_unittest(LockTests, PyRLockTests, CRLockTests, EventTests,
ConditionAsRLockTests, ConditionTests,
@@ -810,7 +823,7 @@ def test_main():
ThreadTests,
ThreadJoinOnShutdown,
ThreadingExceptionTests,
- BarrierTests
+ BarrierTests,
)
if __name__ == "__main__":
diff --git a/Lib/test/test_threadsignals.py b/Lib/test/test_threadsignals.py
index e0af31d..f975a75 100644
--- a/Lib/test/test_threadsignals.py
+++ b/Lib/test/test_threadsignals.py
@@ -14,10 +14,8 @@ if sys.platform[:3] in ('win', 'os2') or sys.platform=='riscos':
process_pid = os.getpid()
signalled_all=thread.allocate_lock()
-# Issue #11223: Locks are implemented using a mutex and a condition variable of
-# the pthread library on FreeBSD6. POSIX condition variables cannot be
-# interrupted by signals (see pthread_cond_wait manual page).
-USING_PTHREAD_COND = (sys.platform == 'freebsd6')
+USING_PTHREAD_COND = (sys.thread_info.name == 'pthread'
+ and sys.thread_info.lock == 'mutex+cond')
def registerSignals(for_usr1, for_usr2, for_alrm):
usr1 = signal.signal(signal.SIGUSR1, for_usr1)
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
index d3c49bb..da0f555 100644
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -4,7 +4,17 @@ import unittest
import locale
import sysconfig
import sys
-import warnings
+import platform
+try:
+ import threading
+except ImportError:
+ threading = None
+
+# Max year is only limited by the size of C int.
+SIZEOF_INT = sysconfig.get_config_var('SIZEOF_INT') or 4
+TIME_MAXYEAR = (1 << 8 * SIZEOF_INT - 1) - 1
+TIME_MINYEAR = -TIME_MAXYEAR - 1
+
class TimeTestCase(unittest.TestCase):
@@ -17,9 +27,53 @@ class TimeTestCase(unittest.TestCase):
time.timezone
time.tzname
+ def test_time(self):
+ time.time()
+ info = time.get_clock_info('time')
+ self.assertFalse(info.monotonic)
+ self.assertTrue(info.adjustable)
+
def test_clock(self):
time.clock()
+ info = time.get_clock_info('clock')
+ self.assertTrue(info.monotonic)
+ self.assertFalse(info.adjustable)
+
+ @unittest.skipUnless(hasattr(time, 'clock_gettime'),
+ 'need time.clock_gettime()')
+ def test_clock_realtime(self):
+ time.clock_gettime(time.CLOCK_REALTIME)
+
+ @unittest.skipUnless(hasattr(time, 'clock_gettime'),
+ 'need time.clock_gettime()')
+ @unittest.skipUnless(hasattr(time, 'CLOCK_MONOTONIC'),
+ 'need time.CLOCK_MONOTONIC')
+ def test_clock_monotonic(self):
+ a = time.clock_gettime(time.CLOCK_MONOTONIC)
+ b = time.clock_gettime(time.CLOCK_MONOTONIC)
+ self.assertLessEqual(a, b)
+
+ @unittest.skipUnless(hasattr(time, 'clock_getres'),
+ 'need time.clock_getres()')
+ def test_clock_getres(self):
+ res = time.clock_getres(time.CLOCK_REALTIME)
+ self.assertGreater(res, 0.0)
+ self.assertLessEqual(res, 1.0)
+
+ @unittest.skipUnless(hasattr(time, 'clock_settime'),
+ 'need time.clock_settime()')
+ def test_clock_settime(self):
+ t = time.clock_gettime(time.CLOCK_REALTIME)
+ try:
+ time.clock_settime(time.CLOCK_REALTIME, t)
+ except PermissionError:
+ pass
+
+ if hasattr(time, 'CLOCK_MONOTONIC'):
+ self.assertRaises(OSError,
+ time.clock_settime, time.CLOCK_MONOTONIC, 0)
+
def test_conversions(self):
self.assertEqual(time.ctime(self.t),
time.asctime(time.localtime(self.t)))
@@ -27,6 +81,8 @@ class TimeTestCase(unittest.TestCase):
int(self.t))
def test_sleep(self):
+ self.assertRaises(ValueError, time.sleep, -2)
+ self.assertRaises(ValueError, time.sleep, -1)
time.sleep(1.2)
def test_strftime(self):
@@ -47,28 +103,34 @@ class TimeTestCase(unittest.TestCase):
with self.assertRaises(ValueError):
time.strftime('%f')
- def _bounds_checking(self, func=time.strftime):
+ def _bounds_checking(self, func):
# Make sure that strftime() checks the bounds of the various parts
- #of the time tuple (0 is valid for *all* values).
+ # of the time tuple (0 is valid for *all* values).
# The year field is tested by other test cases above
# Check month [1, 12] + zero support
+ func((1900, 0, 1, 0, 0, 0, 0, 1, -1))
+ func((1900, 12, 1, 0, 0, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, -1, 1, 0, 0, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 13, 1, 0, 0, 0, 0, 1, -1))
# Check day of month [1, 31] + zero support
+ func((1900, 1, 0, 0, 0, 0, 0, 1, -1))
+ func((1900, 1, 31, 0, 0, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 1, -1, 0, 0, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 1, 32, 0, 0, 0, 0, 1, -1))
# Check hour [0, 23]
+ func((1900, 1, 1, 23, 0, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 1, 1, -1, 0, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 1, 1, 24, 0, 0, 0, 1, -1))
# Check minute [0, 59]
+ func((1900, 1, 1, 0, 59, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 1, 1, 0, -1, 0, 0, 1, -1))
self.assertRaises(ValueError, func,
@@ -78,15 +140,21 @@ class TimeTestCase(unittest.TestCase):
(1900, 1, 1, 0, 0, -1, 0, 1, -1))
# C99 only requires allowing for one leap second, but Python's docs say
# allow two leap seconds (0..61)
+ func((1900, 1, 1, 0, 0, 60, 0, 1, -1))
+ func((1900, 1, 1, 0, 0, 61, 0, 1, -1))
self.assertRaises(ValueError, func,
(1900, 1, 1, 0, 0, 62, 0, 1, -1))
# No check for upper-bound day of week;
# value forced into range by a ``% 7`` calculation.
# Start check at -2 since gettmarg() increments value before taking
# modulo.
+ self.assertEqual(func((1900, 1, 1, 0, 0, 0, -1, 1, -1)),
+ func((1900, 1, 1, 0, 0, 0, +6, 1, -1)))
self.assertRaises(ValueError, func,
(1900, 1, 1, 0, 0, 0, -2, 1, -1))
# Check day of the year [1, 366] + zero support
+ func((1900, 1, 1, 0, 0, 0, 0, 0, -1))
+ func((1900, 1, 1, 0, 0, 0, 0, 366, -1))
self.assertRaises(ValueError, func,
(1900, 1, 1, 0, 0, 0, 0, -1, -1))
self.assertRaises(ValueError, func,
@@ -96,12 +164,13 @@ class TimeTestCase(unittest.TestCase):
self._bounds_checking(lambda tup: time.strftime('', tup))
def test_default_values_for_zero(self):
- # Make sure that using all zeros uses the proper default values.
- # No test for daylight savings since strftime() does not change output
- # based on its value.
+ # Make sure that using all zeros uses the proper default
+ # values. No test for daylight savings since strftime() does
+ # not change output based on its value and no test for year
+ # because systems vary in their support for year 0.
expected = "2000 01 01 00 00 00 1 001"
with support.check_warnings():
- result = time.strftime("%Y %m %d %H %M %S %w %j", (0,)*9)
+ result = time.strftime("%Y %m %d %H %M %S %w %j", (2000,)+(0,)*8)
self.assertEqual(expected, result)
def test_strptime(self):
@@ -128,11 +197,13 @@ class TimeTestCase(unittest.TestCase):
time.asctime(time.gmtime(self.t))
# Max year is only limited by the size of C int.
- sizeof_int = sysconfig.get_config_var('SIZEOF_INT') or 4
- bigyear = (1 << 8 * sizeof_int - 1) - 1
- asc = time.asctime((bigyear, 6, 1) + (0,)*6)
- self.assertEqual(asc[-len(str(bigyear)):], str(bigyear))
- self.assertRaises(OverflowError, time.asctime, (bigyear + 1,) + (0,)*8)
+ for bigyear in TIME_MAXYEAR, TIME_MINYEAR:
+ asc = time.asctime((bigyear, 6, 1) + (0,) * 6)
+ self.assertEqual(asc[-len(str(bigyear)):], str(bigyear))
+ self.assertRaises(OverflowError, time.asctime,
+ (TIME_MAXYEAR + 1,) + (0,) * 8)
+ self.assertRaises(OverflowError, time.asctime,
+ (TIME_MINYEAR - 1,) + (0,) * 8)
self.assertRaises(TypeError, time.asctime, 0)
self.assertRaises(TypeError, time.asctime, ())
self.assertRaises(TypeError, time.asctime, (0,) * 10)
@@ -155,8 +226,8 @@ class TimeTestCase(unittest.TestCase):
else:
self.assertEqual(time.ctime(testval)[20:], str(year))
- @unittest.skipIf(not hasattr(time, "tzset"),
- "time module has no attribute tzset")
+ @unittest.skipUnless(hasattr(time, "tzset"),
+ "time module has no attribute tzset")
def test_tzset(self):
from os import environ
@@ -208,11 +279,13 @@ class TimeTestCase(unittest.TestCase):
self.assertNotEqual(time.gmtime(xmas2002), time.localtime(xmas2002))
# Issue #11886: Australian Eastern Standard Time (UTC+10) is called
- # "EST" (as Eastern Standard Time, UTC-5) instead of "AEST" on some
- # operating systems (e.g. FreeBSD), which is wrong. See for example
- # this bug: http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=93810
+ # "EST" (as Eastern Standard Time, UTC-5) instead of "AEST"
+ # (non-DST timezone), and "EDT" instead of "AEDT" (DST timezone),
+ # on some operating systems (e.g. FreeBSD), which is wrong. See for
+ # example this bug:
+ # http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=93810
self.assertIn(time.tzname[0], ('AEST' 'EST'), time.tzname[0])
- self.assertTrue(time.tzname[1] == 'AEDT', str(time.tzname[1]))
+ self.assertTrue(time.tzname[1] in ('AEDT', 'EDT'), str(time.tzname[1]))
self.assertEqual(len(time.tzname), 2)
self.assertEqual(time.daylight, 1)
self.assertEqual(time.timezone, -36000)
@@ -235,7 +308,7 @@ class TimeTestCase(unittest.TestCase):
# results!).
for func in time.ctime, time.gmtime, time.localtime:
for unreasonable in -1e200, 1e200:
- self.assertRaises(ValueError, func, unreasonable)
+ self.assertRaises(OverflowError, func, unreasonable)
def test_ctime_without_arg(self):
# Not sure how to check the values, since the clock could tick
@@ -258,6 +331,117 @@ class TimeTestCase(unittest.TestCase):
t1 = time.mktime(lt1)
self.assertAlmostEqual(t1, t0, delta=0.2)
+ def test_mktime(self):
+ # Issue #1726687
+ for t in (-2, -1, 0, 1):
+ try:
+ tt = time.localtime(t)
+ except (OverflowError, OSError):
+ pass
+ else:
+ self.assertEqual(time.mktime(tt), t)
+
+ # Issue #13309: passing extreme values to mktime() or localtime()
+ # borks the glibc's internal timezone data.
+ @unittest.skipUnless(platform.libc_ver()[0] != 'glibc',
+ "disabled because of a bug in glibc. Issue #13309")
+ def test_mktime_error(self):
+ # It may not be possible to reliably make mktime return error
+ # on all platfom. This will make sure that no other exception
+ # than OverflowError is raised for an extreme value.
+ tt = time.gmtime(self.t)
+ tzname = time.strftime('%Z', tt)
+ self.assertNotEqual(tzname, 'LMT')
+ try:
+ time.mktime((-1, 1, 1, 0, 0, 0, -1, -1, -1))
+ except OverflowError:
+ pass
+ self.assertEqual(time.strftime('%Z', tt), tzname)
+
+ @unittest.skipUnless(hasattr(time, 'monotonic'),
+ 'need time.monotonic')
+ def test_monotonic(self):
+ t1 = time.monotonic()
+ time.sleep(0.1)
+ t2 = time.monotonic()
+ dt = t2 - t1
+ self.assertGreater(t2, t1)
+ self.assertAlmostEqual(dt, 0.1, delta=0.2)
+
+ info = time.get_clock_info('monotonic')
+ self.assertTrue(info.monotonic)
+ self.assertFalse(info.adjustable)
+
+ def test_perf_counter(self):
+ time.perf_counter()
+
+ def test_process_time(self):
+ # process_time() should not include time spend during a sleep
+ start = time.process_time()
+ time.sleep(0.100)
+ stop = time.process_time()
+ # use 20 ms because process_time() has usually a resolution of 15 ms
+ # on Windows
+ self.assertLess(stop - start, 0.020)
+
+ info = time.get_clock_info('process_time')
+ self.assertTrue(info.monotonic)
+ self.assertFalse(info.adjustable)
+
+ @unittest.skipUnless(hasattr(time, 'monotonic'),
+ 'need time.monotonic')
+ @unittest.skipUnless(hasattr(time, 'clock_settime'),
+ 'need time.clock_settime')
+ def test_monotonic_settime(self):
+ t1 = time.monotonic()
+ realtime = time.clock_gettime(time.CLOCK_REALTIME)
+ # jump backward with an offset of 1 hour
+ try:
+ time.clock_settime(time.CLOCK_REALTIME, realtime - 3600)
+ except PermissionError as err:
+ self.skipTest(err)
+ t2 = time.monotonic()
+ time.clock_settime(time.CLOCK_REALTIME, realtime)
+ # monotonic must not be affected by system clock updates
+ self.assertGreaterEqual(t2, t1)
+
+ def test_localtime_failure(self):
+ # Issue #13847: check for localtime() failure
+ invalid_time_t = None
+ for time_t in (-1, 2**30, 2**33, 2**60):
+ try:
+ time.localtime(time_t)
+ except OverflowError:
+ self.skipTest("need 64-bit time_t")
+ except OSError:
+ invalid_time_t = time_t
+ break
+ if invalid_time_t is None:
+ self.skipTest("unable to find an invalid time_t value")
+
+ self.assertRaises(OSError, time.localtime, invalid_time_t)
+ self.assertRaises(OSError, time.ctime, invalid_time_t)
+
+ def test_get_clock_info(self):
+ clocks = ['clock', 'perf_counter', 'process_time', 'time']
+ if hasattr(time, 'monotonic'):
+ clocks.append('monotonic')
+
+ for name in clocks:
+ info = time.get_clock_info(name)
+ #self.assertIsInstance(info, dict)
+ self.assertIsInstance(info.implementation, str)
+ self.assertNotEqual(info.implementation, '')
+ self.assertIsInstance(info.monotonic, bool)
+ self.assertIsInstance(info.resolution, float)
+ # 0.0 < resolution <= 1.0
+ self.assertGreater(info.resolution, 0.0)
+ self.assertLessEqual(info.resolution, 1.0)
+ self.assertIsInstance(info.adjustable, bool)
+
+ self.assertRaises(ValueError, time.get_clock_info, 'xxx')
+
+
class TestLocale(unittest.TestCase):
def setUp(self):
self.oldloc = locale.setlocale(locale.LC_ALL)
@@ -276,19 +460,12 @@ class TestLocale(unittest.TestCase):
class _BaseYearTest(unittest.TestCase):
- accept2dyear = None
-
- def setUp(self):
- self.saved_accept2dyear = time.accept2dyear
- time.accept2dyear = self.accept2dyear
-
- def tearDown(self):
- time.accept2dyear = self.saved_accept2dyear
-
def yearstr(self, y):
raise NotImplementedError()
class _TestAsctimeYear:
+ _format = '%d'
+
def yearstr(self, y):
return time.asctime((y,) + (0,) * 8).split()[-1]
@@ -298,114 +475,211 @@ class _TestAsctimeYear:
self.assertEqual(self.yearstr(123456789), '123456789')
class _TestStrftimeYear:
- def yearstr(self, y):
- return time.strftime('%Y', (y,) + (0,) * 8).split()[-1]
- def test_large_year(self):
+ # Issue 13305: For years < 1000, the value is not always
+ # padded to 4 digits across platforms. The C standard
+ # assumes year >= 1900, so it does not specify the number
+ # of digits.
+
+ if time.strftime('%Y', (1,) + (0,) * 8) == '0001':
+ _format = '%04d'
+ else:
+ _format = '%d'
+
+ def yearstr(self, y):
+ return time.strftime('%Y', (y,) + (0,) * 8)
+
+ def test_4dyear(self):
+ # Check that we can return the zero padded value.
+ if self._format == '%04d':
+ self.test_year('%04d')
+ else:
+ def year4d(y):
+ return time.strftime('%4Y', (y,) + (0,) * 8)
+ self.test_year('%04d', func=year4d)
+
+ def skip_if_not_supported(y):
+ msg = "strftime() is limited to [1; 9999] with Visual Studio"
# Check that it doesn't crash for year > 9999
try:
- text = self.yearstr(12345)
+ time.strftime('%Y', (y,) + (0,) * 8)
except ValueError:
- # strftime() is limited to [1; 9999] with Visual Studio
- return
- self.assertEqual(text, '12345')
- self.assertEqual(self.yearstr(123456789), '123456789')
+ cond = False
+ else:
+ cond = True
+ return unittest.skipUnless(cond, msg)
-class _Test2dYear(_BaseYearTest):
- accept2dyear = 1
+ @skip_if_not_supported(10000)
+ def test_large_year(self):
+ return super().test_large_year()
- def test_year(self):
- with support.check_warnings():
- self.assertEqual(self.yearstr(0), '2000')
- self.assertEqual(self.yearstr(69), '1969')
- self.assertEqual(self.yearstr(68), '2068')
- self.assertEqual(self.yearstr(99), '1999')
+ @skip_if_not_supported(0)
+ def test_negative(self):
+ return super().test_negative()
+
+ del skip_if_not_supported
- def test_invalid(self):
- self.assertRaises(ValueError, self.yearstr, -1)
- self.assertRaises(ValueError, self.yearstr, 100)
- self.assertRaises(ValueError, self.yearstr, 999)
class _Test4dYear(_BaseYearTest):
- accept2dyear = 0
+ _format = '%d'
+
+ def test_year(self, fmt=None, func=None):
+ fmt = fmt or self._format
+ func = func or self.yearstr
+ self.assertEqual(func(1), fmt % 1)
+ self.assertEqual(func(68), fmt % 68)
+ self.assertEqual(func(69), fmt % 69)
+ self.assertEqual(func(99), fmt % 99)
+ self.assertEqual(func(999), fmt % 999)
+ self.assertEqual(func(9999), fmt % 9999)
- def test_year(self):
- self.assertIn(self.yearstr(1), ('1', '0001'))
- self.assertIn(self.yearstr(68), ('68', '0068'))
- self.assertIn(self.yearstr(69), ('69', '0069'))
- self.assertIn(self.yearstr(99), ('99', '0099'))
- self.assertIn(self.yearstr(999), ('999', '0999'))
- self.assertEqual(self.yearstr(9999), '9999')
+ def test_large_year(self):
+ self.assertEqual(self.yearstr(12345), '12345')
+ self.assertEqual(self.yearstr(123456789), '123456789')
+ self.assertEqual(self.yearstr(TIME_MAXYEAR), str(TIME_MAXYEAR))
+ self.assertRaises(OverflowError, self.yearstr, TIME_MAXYEAR + 1)
def test_negative(self):
- try:
- text = self.yearstr(-1)
- except ValueError:
- # strftime() is limited to [1; 9999] with Visual Studio
- return
- self.assertIn(text, ('-1', '-001'))
-
+ self.assertEqual(self.yearstr(-1), self._format % -1)
self.assertEqual(self.yearstr(-1234), '-1234')
self.assertEqual(self.yearstr(-123456), '-123456')
+ self.assertEqual(self.yearstr(-123456789), str(-123456789))
+ self.assertEqual(self.yearstr(-1234567890), str(-1234567890))
+ self.assertEqual(self.yearstr(TIME_MINYEAR + 1900), str(TIME_MINYEAR + 1900))
+ # Issue #13312: it may return wrong value for year < TIME_MINYEAR + 1900
+ # Skip the value test, but check that no error is raised
+ self.yearstr(TIME_MINYEAR)
+ # self.assertEqual(self.yearstr(TIME_MINYEAR), str(TIME_MINYEAR))
+ self.assertRaises(OverflowError, self.yearstr, TIME_MINYEAR - 1)
- def test_mktime(self):
- # Issue #1726687
- for t in (-2, -1, 0, 1):
- try:
- tt = time.localtime(t)
- except (OverflowError, ValueError):
- pass
- else:
- self.assertEqual(time.mktime(tt), t)
- # It may not be possible to reliably make mktime return error
- # on all platfom. This will make sure that no other exception
- # than OverflowError is raised for an extreme value.
- try:
- time.mktime((-1, 1, 1, 0, 0, 0, -1, -1, -1))
- except OverflowError:
- pass
-
-class TestAsctimeAccept2dYear(_TestAsctimeYear, _Test2dYear):
- pass
-
-class TestStrftimeAccept2dYear(_TestStrftimeYear, _Test2dYear):
- pass
-
class TestAsctime4dyear(_TestAsctimeYear, _Test4dYear):
pass
class TestStrftime4dyear(_TestStrftimeYear, _Test4dYear):
pass
-class Test2dyearBool(_TestAsctimeYear, _Test2dYear):
- accept2dyear = True
-
-class Test4dyearBool(_TestAsctimeYear, _Test4dYear):
- accept2dyear = False
-
-class TestAccept2YearBad(_TestAsctimeYear, _BaseYearTest):
- class X:
- def __bool__(self):
- raise RuntimeError('boo')
- accept2dyear = X()
- def test_2dyear(self):
- pass
- def test_invalid(self):
- self.assertRaises(RuntimeError, self.yearstr, 200)
+class TestPytime(unittest.TestCase):
+ def setUp(self):
+ self.invalid_values = (
+ -(2 ** 100), 2 ** 100,
+ -(2.0 ** 100.0), 2.0 ** 100.0,
+ )
+
+ def test_time_t(self):
+ from _testcapi import pytime_object_to_time_t
+ for obj, time_t in (
+ (0, 0),
+ (-1, -1),
+ (-1.0, -1),
+ (-1.9, -1),
+ (1.0, 1),
+ (1.9, 1),
+ ):
+ self.assertEqual(pytime_object_to_time_t(obj), time_t)
+
+ for invalid in self.invalid_values:
+ self.assertRaises(OverflowError, pytime_object_to_time_t, invalid)
+
+ def test_timeval(self):
+ from _testcapi import pytime_object_to_timeval
+ for obj, timeval in (
+ (0, (0, 0)),
+ (-1, (-1, 0)),
+ (-1.0, (-1, 0)),
+ (1e-6, (0, 1)),
+ (-1e-6, (-1, 999999)),
+ (-1.2, (-2, 800000)),
+ (1.1234560, (1, 123456)),
+ (1.1234569, (1, 123456)),
+ (-1.1234560, (-2, 876544)),
+ (-1.1234561, (-2, 876543)),
+ ):
+ self.assertEqual(pytime_object_to_timeval(obj), timeval)
+
+ for invalid in self.invalid_values:
+ self.assertRaises(OverflowError, pytime_object_to_timeval, invalid)
+
+ def test_timespec(self):
+ from _testcapi import pytime_object_to_timespec
+ for obj, timespec in (
+ (0, (0, 0)),
+ (-1, (-1, 0)),
+ (-1.0, (-1, 0)),
+ (1e-9, (0, 1)),
+ (-1e-9, (-1, 999999999)),
+ (-1.2, (-2, 800000000)),
+ (1.1234567890, (1, 123456789)),
+ (1.1234567899, (1, 123456789)),
+ (-1.1234567890, (-2, 876543211)),
+ (-1.1234567891, (-2, 876543210)),
+ ):
+ self.assertEqual(pytime_object_to_timespec(obj), timespec)
+
+ for invalid in self.invalid_values:
+ self.assertRaises(OverflowError, pytime_object_to_timespec, invalid)
+
+ @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support")
+ def test_localtime_timezone(self):
+
+ # Get the localtime and examine it for the offset and zone.
+ lt = time.localtime()
+ self.assertTrue(hasattr(lt, "tm_gmtoff"))
+ self.assertTrue(hasattr(lt, "tm_zone"))
+
+ # See if the offset and zone are similar to the module
+ # attributes.
+ if lt.tm_gmtoff is None:
+ self.assertTrue(not hasattr(time, "timezone"))
+ else:
+ self.assertEqual(lt.tm_gmtoff, -[time.timezone, time.altzone][lt.tm_isdst])
+ if lt.tm_zone is None:
+ self.assertTrue(not hasattr(time, "tzname"))
+ else:
+ self.assertEqual(lt.tm_zone, time.tzname[lt.tm_isdst])
+
+ # Try and make UNIX times from the localtime and a 9-tuple
+ # created from the localtime. Test to see that the times are
+ # the same.
+ t = time.mktime(lt); t9 = time.mktime(lt[:9])
+ self.assertEqual(t, t9)
+
+ # Make localtimes from the UNIX times and compare them to
+ # the original localtime, thus making a round trip.
+ new_lt = time.localtime(t); new_lt9 = time.localtime(t9)
+ self.assertEqual(new_lt, lt)
+ self.assertEqual(new_lt.tm_gmtoff, lt.tm_gmtoff)
+ self.assertEqual(new_lt.tm_zone, lt.tm_zone)
+ self.assertEqual(new_lt9, lt)
+ self.assertEqual(new_lt.tm_gmtoff, lt.tm_gmtoff)
+ self.assertEqual(new_lt9.tm_zone, lt.tm_zone)
+
+ @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support")
+ def test_strptime_timezone(self):
+ t = time.strptime("UTC", "%Z")
+ self.assertEqual(t.tm_zone, 'UTC')
+ t = time.strptime("+0500", "%z")
+ self.assertEqual(t.tm_gmtoff, 5 * 3600)
+
+ @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support")
+ def test_short_times(self):
+
+ import pickle
+
+ # Load a short time structure using pickle.
+ st = b"ctime\nstruct_time\np0\n((I2007\nI8\nI11\nI1\nI24\nI49\nI5\nI223\nI1\ntp1\n(dp2\ntp3\nRp4\n."
+ lt = pickle.loads(st)
+ self.assertIs(lt.tm_gmtoff, None)
+ self.assertIs(lt.tm_zone, None)
def test_main():
support.run_unittest(
TimeTestCase,
TestLocale,
- TestAsctimeAccept2dYear,
- TestStrftimeAccept2dYear,
TestAsctime4dyear,
TestStrftime4dyear,
- Test2dyearBool,
- Test4dyearBool,
- TestAccept2YearBad)
+ TestPytime)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_tokenize.py b/Lib/test/test_tokenize.py
index f9652ce..b4a58f0 100644
--- a/Lib/test/test_tokenize.py
+++ b/Lib/test/test_tokenize.py
@@ -289,6 +289,64 @@ String literals
OP '+' (1, 29) (1, 30)
STRING 'R"ABC"' (1, 31) (1, 37)
+ >>> dump_tokens("u'abc' + U'abc'")
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING "u'abc'" (1, 0) (1, 6)
+ OP '+' (1, 7) (1, 8)
+ STRING "U'abc'" (1, 9) (1, 15)
+ >>> dump_tokens('u"abc" + U"abc"')
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING 'u"abc"' (1, 0) (1, 6)
+ OP '+' (1, 7) (1, 8)
+ STRING 'U"abc"' (1, 9) (1, 15)
+
+ >>> dump_tokens("b'abc' + B'abc'")
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING "b'abc'" (1, 0) (1, 6)
+ OP '+' (1, 7) (1, 8)
+ STRING "B'abc'" (1, 9) (1, 15)
+ >>> dump_tokens('b"abc" + B"abc"')
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING 'b"abc"' (1, 0) (1, 6)
+ OP '+' (1, 7) (1, 8)
+ STRING 'B"abc"' (1, 9) (1, 15)
+ >>> dump_tokens("br'abc' + bR'abc' + Br'abc' + BR'abc'")
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING "br'abc'" (1, 0) (1, 7)
+ OP '+' (1, 8) (1, 9)
+ STRING "bR'abc'" (1, 10) (1, 17)
+ OP '+' (1, 18) (1, 19)
+ STRING "Br'abc'" (1, 20) (1, 27)
+ OP '+' (1, 28) (1, 29)
+ STRING "BR'abc'" (1, 30) (1, 37)
+ >>> dump_tokens('br"abc" + bR"abc" + Br"abc" + BR"abc"')
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING 'br"abc"' (1, 0) (1, 7)
+ OP '+' (1, 8) (1, 9)
+ STRING 'bR"abc"' (1, 10) (1, 17)
+ OP '+' (1, 18) (1, 19)
+ STRING 'Br"abc"' (1, 20) (1, 27)
+ OP '+' (1, 28) (1, 29)
+ STRING 'BR"abc"' (1, 30) (1, 37)
+ >>> dump_tokens("rb'abc' + rB'abc' + Rb'abc' + RB'abc'")
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING "rb'abc'" (1, 0) (1, 7)
+ OP '+' (1, 8) (1, 9)
+ STRING "rB'abc'" (1, 10) (1, 17)
+ OP '+' (1, 18) (1, 19)
+ STRING "Rb'abc'" (1, 20) (1, 27)
+ OP '+' (1, 28) (1, 29)
+ STRING "RB'abc'" (1, 30) (1, 37)
+ >>> dump_tokens('rb"abc" + rB"abc" + Rb"abc" + RB"abc"')
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ STRING 'rb"abc"' (1, 0) (1, 7)
+ OP '+' (1, 8) (1, 9)
+ STRING 'rB"abc"' (1, 10) (1, 17)
+ OP '+' (1, 18) (1, 19)
+ STRING 'Rb"abc"' (1, 20) (1, 27)
+ OP '+' (1, 28) (1, 29)
+ STRING 'RB"abc"' (1, 30) (1, 37)
+
Operators
>>> dump_tokens("def d22(a, b, c=2, d=2, *k): pass")
@@ -552,11 +610,6 @@ Evil tabs
DEDENT '' (4, 0) (4, 0)
DEDENT '' (4, 0) (4, 0)
-Pathological whitespace (http://bugs.python.org/issue16152)
- >>> dump_tokens("@ ")
- ENCODING 'utf-8' (0, 0) (0, 0)
- OP '@' (1, 0) (1, 1)
-
Non-ascii identifiers
>>> dump_tokens("Örter = 'places'\\ngrün = 'green'")
@@ -568,15 +621,28 @@ Non-ascii identifiers
NAME 'grün' (2, 0) (2, 4)
OP '=' (2, 5) (2, 6)
STRING "'green'" (2, 7) (2, 14)
+
+Legacy unicode literals:
+
+ >>> dump_tokens("Örter = u'places'\\ngrün = U'green'")
+ ENCODING 'utf-8' (0, 0) (0, 0)
+ NAME 'Örter' (1, 0) (1, 5)
+ OP '=' (1, 6) (1, 7)
+ STRING "u'places'" (1, 8) (1, 17)
+ NEWLINE '\\n' (1, 17) (1, 18)
+ NAME 'grün' (2, 0) (2, 4)
+ OP '=' (2, 5) (2, 6)
+ STRING "U'green'" (2, 7) (2, 15)
"""
from test import support
from tokenize import (tokenize, _tokenize, untokenize, NUMBER, NAME, OP,
- STRING, ENDMARKER, tok_name, detect_encoding,
+ STRING, ENDMARKER, ENCODING, tok_name, detect_encoding,
open as tokenize_open)
from io import BytesIO
from unittest import TestCase
import os, sys, glob
+import token
def dump_tokens(s):
"""Print out the tokens in s in a table format.
@@ -605,7 +671,7 @@ def roundtrip(f):
f.close()
tokens1 = [tok[:2] for tok in token_list]
new_bytes = untokenize(tokens1)
- readline = (line for line in new_bytes.splitlines(1)).__next__
+ readline = (line for line in new_bytes.splitlines(keepends=True)).__next__
tokens2 = [tok[:2] for tok in tokenize(readline)]
return tokens1 == tokens2
@@ -900,6 +966,35 @@ class TestDetectEncoding(TestCase):
self.assertEqual(fp.encoding, 'utf-8-sig')
self.assertEqual(fp.mode, 'r')
+ def test_filename_in_exception(self):
+ # When possible, include the file name in the exception.
+ path = 'some_file_path'
+ lines = (
+ b'print("\xdf")', # Latin-1: LATIN SMALL LETTER SHARP S
+ )
+ class Bunk:
+ def __init__(self, lines, path):
+ self.name = path
+ self._lines = lines
+ self._index = 0
+
+ def readline(self):
+ if self._index == len(lines):
+ raise StopIteration
+ line = lines[self._index]
+ self._index += 1
+ return line
+
+ with self.assertRaises(SyntaxError):
+ ins = Bunk(lines, path)
+ # Make sure lacking a name isn't an issue.
+ del ins.name
+ detect_encoding(ins.readline)
+ with self.assertRaisesRegex(SyntaxError, '.*{}'.format(path)):
+ ins = Bunk(lines, path)
+ detect_encoding(ins.readline)
+
+
class TestTokenize(TestCase):
def test_tokenize(self):
@@ -941,6 +1036,82 @@ class TestTokenize(TestCase):
self.assertTrue(encoding_used, encoding)
+ def assertExactTypeEqual(self, opstr, *optypes):
+ tokens = list(tokenize(BytesIO(opstr.encode('utf-8')).readline))
+ num_optypes = len(optypes)
+ self.assertEqual(len(tokens), 2 + num_optypes)
+ self.assertEqual(token.tok_name[tokens[0].exact_type],
+ token.tok_name[ENCODING])
+ for i in range(num_optypes):
+ self.assertEqual(token.tok_name[tokens[i + 1].exact_type],
+ token.tok_name[optypes[i]])
+ self.assertEqual(token.tok_name[tokens[1 + num_optypes].exact_type],
+ token.tok_name[token.ENDMARKER])
+
+ def test_exact_type(self):
+ self.assertExactTypeEqual('()', token.LPAR, token.RPAR)
+ self.assertExactTypeEqual('[]', token.LSQB, token.RSQB)
+ self.assertExactTypeEqual(':', token.COLON)
+ self.assertExactTypeEqual(',', token.COMMA)
+ self.assertExactTypeEqual(';', token.SEMI)
+ self.assertExactTypeEqual('+', token.PLUS)
+ self.assertExactTypeEqual('-', token.MINUS)
+ self.assertExactTypeEqual('*', token.STAR)
+ self.assertExactTypeEqual('/', token.SLASH)
+ self.assertExactTypeEqual('|', token.VBAR)
+ self.assertExactTypeEqual('&', token.AMPER)
+ self.assertExactTypeEqual('<', token.LESS)
+ self.assertExactTypeEqual('>', token.GREATER)
+ self.assertExactTypeEqual('=', token.EQUAL)
+ self.assertExactTypeEqual('.', token.DOT)
+ self.assertExactTypeEqual('%', token.PERCENT)
+ self.assertExactTypeEqual('{}', token.LBRACE, token.RBRACE)
+ self.assertExactTypeEqual('==', token.EQEQUAL)
+ self.assertExactTypeEqual('!=', token.NOTEQUAL)
+ self.assertExactTypeEqual('<=', token.LESSEQUAL)
+ self.assertExactTypeEqual('>=', token.GREATEREQUAL)
+ self.assertExactTypeEqual('~', token.TILDE)
+ self.assertExactTypeEqual('^', token.CIRCUMFLEX)
+ self.assertExactTypeEqual('<<', token.LEFTSHIFT)
+ self.assertExactTypeEqual('>>', token.RIGHTSHIFT)
+ self.assertExactTypeEqual('**', token.DOUBLESTAR)
+ self.assertExactTypeEqual('+=', token.PLUSEQUAL)
+ self.assertExactTypeEqual('-=', token.MINEQUAL)
+ self.assertExactTypeEqual('*=', token.STAREQUAL)
+ self.assertExactTypeEqual('/=', token.SLASHEQUAL)
+ self.assertExactTypeEqual('%=', token.PERCENTEQUAL)
+ self.assertExactTypeEqual('&=', token.AMPEREQUAL)
+ self.assertExactTypeEqual('|=', token.VBAREQUAL)
+ self.assertExactTypeEqual('^=', token.CIRCUMFLEXEQUAL)
+ self.assertExactTypeEqual('^=', token.CIRCUMFLEXEQUAL)
+ self.assertExactTypeEqual('<<=', token.LEFTSHIFTEQUAL)
+ self.assertExactTypeEqual('>>=', token.RIGHTSHIFTEQUAL)
+ self.assertExactTypeEqual('**=', token.DOUBLESTAREQUAL)
+ self.assertExactTypeEqual('//', token.DOUBLESLASH)
+ self.assertExactTypeEqual('//=', token.DOUBLESLASHEQUAL)
+ self.assertExactTypeEqual('@', token.AT)
+
+ self.assertExactTypeEqual('a**2+b**2==c**2',
+ NAME, token.DOUBLESTAR, NUMBER,
+ token.PLUS,
+ NAME, token.DOUBLESTAR, NUMBER,
+ token.EQEQUAL,
+ NAME, token.DOUBLESTAR, NUMBER)
+ self.assertExactTypeEqual('{1, 2, 3}',
+ token.LBRACE,
+ token.NUMBER, token.COMMA,
+ token.NUMBER, token.COMMA,
+ token.NUMBER,
+ token.RBRACE)
+ self.assertExactTypeEqual('^(x & 0x1)',
+ token.CIRCUMFLEX,
+ token.LPAR,
+ token.NAME, token.AMPER, token.NUMBER,
+ token.RPAR)
+
+ def test_pathological_trailing_whitespace(self):
+ # See http://bugs.python.org/issue16152
+ self.assertExactTypeEqual('@ ', token.AT)
__test__ = {"doctests" : doctests, 'decistmt': decistmt}
diff --git a/Lib/test/test_tools.py b/Lib/test/test_tools.py
index 006220a..f971515 100644
--- a/Lib/test/test_tools.py
+++ b/Lib/test/test_tools.py
@@ -6,8 +6,9 @@ Tools directory of a Python checkout or tarball, such as reindent.py.
import os
import sys
-import imp
+import importlib.machinery
import unittest
+from unittest import mock
import shutil
import subprocess
import sysconfig
@@ -365,7 +366,7 @@ class TestSundryScripts(unittest.TestCase):
# added for a script it should be added to the whitelist below.
# scripts that have independent tests.
- whitelist = ['reindent.py']
+ whitelist = ['reindent.py', 'pdeps.py', 'gprof2html']
# scripts that can't be imported without running
blacklist = ['make_ctype.py']
# scripts that use windows-only modules
@@ -404,7 +405,8 @@ class PdepsTests(unittest.TestCase):
@classmethod
def setUpClass(self):
path = os.path.join(scriptsdir, 'pdeps.py')
- self.pdeps = imp.load_source('pdeps', path)
+ loader = importlib.machinery.SourceFileLoader('pdeps', path)
+ self.pdeps = loader.load_module()
@classmethod
def tearDownClass(self):
@@ -424,6 +426,35 @@ class PdepsTests(unittest.TestCase):
self.pdeps.inverse({'a': []})
+class Gprof2htmlTests(unittest.TestCase):
+
+ def setUp(self):
+ path = os.path.join(scriptsdir, 'gprof2html.py')
+ loader = importlib.machinery.SourceFileLoader('gprof2html', path)
+ self.gprof = loader.load_module()
+ oldargv = sys.argv
+ def fixup():
+ sys.argv = oldargv
+ self.addCleanup(fixup)
+ sys.argv = []
+
+ def test_gprof(self):
+ # Issue #14508: this used to fail with an NameError.
+ with mock.patch.object(self.gprof, 'webbrowser') as wmock, \
+ tempfile.TemporaryDirectory() as tmpdir:
+ fn = os.path.join(tmpdir, 'abc')
+ open(fn, 'w').close()
+ sys.argv = ['gprof2html', fn]
+ self.gprof.main()
+ self.assertTrue(wmock.open.called)
+
+
+# Run the tests in Tools/parser/test_unparse.py
+with support.DirsOnSysPath(os.path.join(basepath, 'parser')):
+ from test_unparse import UnparseTestCase
+ from test_unparse import DirectoryTestCase
+
+
def test_main():
support.run_unittest(*[obj for obj in globals().values()
if isinstance(obj, type)])
diff --git a/Lib/test/test_trace.py b/Lib/test/test_trace.py
index 461d1d8..1cec710 100644
--- a/Lib/test/test_trace.py
+++ b/Lib/test/test_trace.py
@@ -1,4 +1,5 @@
import os
+import io
import sys
from test.support import (run_unittest, TESTFN, rmtree, unlink,
captured_stdout)
@@ -102,6 +103,7 @@ class TracedClass(object):
class TestLineCounts(unittest.TestCase):
"""White-box testing of line-counting, via runfunc"""
def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
self.my_py_filename = fix_ext_py(__file__)
@@ -192,6 +194,7 @@ class TestRunExecCounts(unittest.TestCase):
"""A simple sanity test of line-counting, via runctx (exec)"""
def setUp(self):
self.my_py_filename = fix_ext_py(__file__)
+ self.addCleanup(sys.settrace, sys.gettrace())
def test_exec_counts(self):
self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
@@ -218,6 +221,7 @@ class TestRunExecCounts(unittest.TestCase):
class TestFuncs(unittest.TestCase):
"""White-box testing of funcs tracing"""
def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
self.tracer = Trace(count=0, trace=0, countfuncs=1)
self.filemod = my_file_and_modname()
@@ -242,6 +246,8 @@ class TestFuncs(unittest.TestCase):
}
self.assertEqual(self.tracer.results().calledfuncs, expected)
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'pre-existing trace function throws off measurements')
def test_inst_method_calling(self):
obj = TracedClass(20)
self.tracer.runfunc(obj.inst_method_calling, 1)
@@ -257,9 +263,12 @@ class TestFuncs(unittest.TestCase):
class TestCallers(unittest.TestCase):
"""White-box testing of callers tracing"""
def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
self.tracer = Trace(count=0, trace=0, countcallers=1)
self.filemod = my_file_and_modname()
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'pre-existing trace function throws off measurements')
def test_loop_caller_importing(self):
self.tracer.runfunc(traced_func_importing_caller, 1)
@@ -280,6 +289,9 @@ class TestCallers(unittest.TestCase):
# Created separately for issue #3821
class TestCoverage(unittest.TestCase):
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+
def tearDown(self):
rmtree(TESTFN)
unlink(TESTFN)
@@ -305,13 +317,13 @@ class TestCoverage(unittest.TestCase):
# Ignore all files, nothing should be traced nor printed
libpath = os.path.normpath(os.path.dirname(os.__file__))
# sys.prefix does not work when running from a checkout
- tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix, libpath],
- trace=0, count=1)
+ tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,
+ libpath], trace=0, count=1)
with captured_stdout() as stdout:
self._coverage(tracer)
if os.path.exists(TESTFN):
files = os.listdir(TESTFN)
- self.assertEqual(files, [])
+ self.assertEqual(files, ['_importlib.cover']) # Ignore __import__
def test_issue9936(self):
tracer = trace.Trace(trace=0, count=1)
@@ -350,6 +362,51 @@ class Test_Ignore(unittest.TestCase):
self.assertTrue(ignore.names(jn('bar', 'baz.py'), 'baz'))
+class TestDeprecatedMethods(unittest.TestCase):
+
+ def test_deprecated_usage(self):
+ sio = io.StringIO()
+ with self.assertWarns(DeprecationWarning):
+ trace.usage(sio)
+ self.assertIn('Usage:', sio.getvalue())
+
+ def test_deprecated_Ignore(self):
+ with self.assertWarns(DeprecationWarning):
+ trace.Ignore()
+
+ def test_deprecated_modname(self):
+ with self.assertWarns(DeprecationWarning):
+ self.assertEqual("spam", trace.modname("spam"))
+
+ def test_deprecated_fullmodname(self):
+ with self.assertWarns(DeprecationWarning):
+ self.assertEqual("spam", trace.fullmodname("spam"))
+
+ def test_deprecated_find_lines_from_code(self):
+ with self.assertWarns(DeprecationWarning):
+ def foo():
+ pass
+ trace.find_lines_from_code(foo.__code__, ["eggs"])
+
+ def test_deprecated_find_lines(self):
+ with self.assertWarns(DeprecationWarning):
+ def foo():
+ pass
+ trace.find_lines(foo.__code__, ["eggs"])
+
+ def test_deprecated_find_strings(self):
+ with open(TESTFN, 'w') as fd:
+ self.addCleanup(unlink, TESTFN)
+ with self.assertWarns(DeprecationWarning):
+ trace.find_strings(fd.name)
+
+ def test_deprecated_find_executable_linenos(self):
+ with open(TESTFN, 'w') as fd:
+ self.addCleanup(unlink, TESTFN)
+ with self.assertWarns(DeprecationWarning):
+ trace.find_executable_linenos(fd.name)
+
+
def test_main():
run_unittest(__name__)
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 4752d37..5bce2af 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -246,6 +246,21 @@ class BaseExceptionReportingTests:
self.check_zero_div(blocks[0])
self.assertIn('inner_raise() # Marker', blocks[2])
+ def test_context_suppression(self):
+ try:
+ try:
+ raise Exception
+ except:
+ raise ZeroDivisionError from None
+ except ZeroDivisionError as _:
+ e = _
+ lines = self.get_report(e).splitlines()
+ self.assertEqual(len(lines), 4)
+ self.assertTrue(lines[0].startswith('Traceback'))
+ self.assertTrue(lines[1].startswith(' File'))
+ self.assertIn('ZeroDivisionError from None', lines[2])
+ self.assertTrue(lines[3].startswith('ZeroDivisionError'))
+
def test_cause_and_context(self):
# When both a cause and a context are set, only the cause should be
# displayed and the context should be muted.
diff --git a/Lib/test/test_tuple.py b/Lib/test/test_tuple.py
index 6e934fb..e41711c 100644
--- a/Lib/test/test_tuple.py
+++ b/Lib/test/test_tuple.py
@@ -1,6 +1,7 @@
from test import support, seq_tests
import gc
+import pickle
class TupleTest(seq_tests.CommonTest):
type2test = tuple
@@ -164,6 +165,34 @@ class TupleTest(seq_tests.CommonTest):
check(10) # check our checking code
check(1000000)
+ def test_iterator_pickle(self):
+ # Userlist iterators don't support pickling yet since
+ # they are based on generators.
+ data = self.type2test([4, 5, 6, 7])
+ itorg = iter(data)
+ d = pickle.dumps(itorg)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(self.type2test(it), self.type2test(data))
+
+ it = pickle.loads(d)
+ next(it)
+ d = pickle.dumps(it)
+ self.assertEqual(self.type2test(it), self.type2test(data)[1:])
+
+ def test_reversed_pickle(self):
+ data = self.type2test([4, 5, 6, 7])
+ itorg = reversed(data)
+ d = pickle.dumps(itorg)
+ it = pickle.loads(d)
+ self.assertEqual(type(itorg), type(it))
+ self.assertEqual(self.type2test(it), self.type2test(reversed(data)))
+
+ it = pickle.loads(d)
+ next(it)
+ d = pickle.dumps(it)
+ self.assertEqual(self.type2test(it), self.type2test(reversed(data))[1:])
+
def test_no_comdat_folding(self):
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
# optimization causes failures in code that relies on distinct
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index 8a98a03..3ee4c6b 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -1,9 +1,11 @@
# Python test set -- part 6, built-in types
from test.support import run_unittest, run_with_locale
-import unittest
-import sys
+import collections
import locale
+import sys
+import types
+import unittest
class TypesTests(unittest.TestCase):
@@ -569,8 +571,583 @@ class TypesTests(unittest.TestCase):
self.assertGreater(tuple.__itemsize__, 0)
+class MappingProxyTests(unittest.TestCase):
+ mappingproxy = types.MappingProxyType
+
+ def test_constructor(self):
+ class userdict(dict):
+ pass
+
+ mapping = {'x': 1, 'y': 2}
+ self.assertEqual(self.mappingproxy(mapping), mapping)
+ mapping = userdict(x=1, y=2)
+ self.assertEqual(self.mappingproxy(mapping), mapping)
+ mapping = collections.ChainMap({'x': 1}, {'y': 2})
+ self.assertEqual(self.mappingproxy(mapping), mapping)
+
+ self.assertRaises(TypeError, self.mappingproxy, 10)
+ self.assertRaises(TypeError, self.mappingproxy, ("a", "tuple"))
+ self.assertRaises(TypeError, self.mappingproxy, ["a", "list"])
+
+ def test_methods(self):
+ attrs = set(dir(self.mappingproxy({}))) - set(dir(object()))
+ self.assertEqual(attrs, {
+ '__contains__',
+ '__getitem__',
+ '__iter__',
+ '__len__',
+ 'copy',
+ 'get',
+ 'items',
+ 'keys',
+ 'values',
+ })
+
+ def test_get(self):
+ view = self.mappingproxy({'a': 'A', 'b': 'B'})
+ self.assertEqual(view['a'], 'A')
+ self.assertEqual(view['b'], 'B')
+ self.assertRaises(KeyError, view.__getitem__, 'xxx')
+ self.assertEqual(view.get('a'), 'A')
+ self.assertIsNone(view.get('xxx'))
+ self.assertEqual(view.get('xxx', 42), 42)
+
+ def test_missing(self):
+ class dictmissing(dict):
+ def __missing__(self, key):
+ return "missing=%s" % key
+
+ view = self.mappingproxy(dictmissing(x=1))
+ self.assertEqual(view['x'], 1)
+ self.assertEqual(view['y'], 'missing=y')
+ self.assertEqual(view.get('x'), 1)
+ self.assertEqual(view.get('y'), None)
+ self.assertEqual(view.get('y', 42), 42)
+ self.assertTrue('x' in view)
+ self.assertFalse('y' in view)
+
+ def test_customdict(self):
+ class customdict(dict):
+ def __contains__(self, key):
+ if key == 'magic':
+ return True
+ else:
+ return dict.__contains__(self, key)
+
+ def __iter__(self):
+ return iter(('iter',))
+
+ def __len__(self):
+ return 500
+
+ def copy(self):
+ return 'copy'
+
+ def keys(self):
+ return 'keys'
+
+ def items(self):
+ return 'items'
+
+ def values(self):
+ return 'values'
+
+ def __getitem__(self, key):
+ return "getitem=%s" % dict.__getitem__(self, key)
+
+ def get(self, key, default=None):
+ return "get=%s" % dict.get(self, key, 'default=%r' % default)
+
+ custom = customdict({'key': 'value'})
+ view = self.mappingproxy(custom)
+ self.assertTrue('key' in view)
+ self.assertTrue('magic' in view)
+ self.assertFalse('xxx' in view)
+ self.assertEqual(view['key'], 'getitem=value')
+ self.assertRaises(KeyError, view.__getitem__, 'xxx')
+ self.assertEqual(tuple(view), ('iter',))
+ self.assertEqual(len(view), 500)
+ self.assertEqual(view.copy(), 'copy')
+ self.assertEqual(view.get('key'), 'get=value')
+ self.assertEqual(view.get('xxx'), 'get=default=None')
+ self.assertEqual(view.items(), 'items')
+ self.assertEqual(view.keys(), 'keys')
+ self.assertEqual(view.values(), 'values')
+
+ def test_chainmap(self):
+ d1 = {'x': 1}
+ d2 = {'y': 2}
+ mapping = collections.ChainMap(d1, d2)
+ view = self.mappingproxy(mapping)
+ self.assertTrue('x' in view)
+ self.assertTrue('y' in view)
+ self.assertFalse('z' in view)
+ self.assertEqual(view['x'], 1)
+ self.assertEqual(view['y'], 2)
+ self.assertRaises(KeyError, view.__getitem__, 'z')
+ self.assertEqual(tuple(sorted(view)), ('x', 'y'))
+ self.assertEqual(len(view), 2)
+ copy = view.copy()
+ self.assertIsNot(copy, mapping)
+ self.assertIsInstance(copy, collections.ChainMap)
+ self.assertEqual(copy, mapping)
+ self.assertEqual(view.get('x'), 1)
+ self.assertEqual(view.get('y'), 2)
+ self.assertIsNone(view.get('z'))
+ self.assertEqual(tuple(sorted(view.items())), (('x', 1), ('y', 2)))
+ self.assertEqual(tuple(sorted(view.keys())), ('x', 'y'))
+ self.assertEqual(tuple(sorted(view.values())), (1, 2))
+
+ def test_contains(self):
+ view = self.mappingproxy(dict.fromkeys('abc'))
+ self.assertTrue('a' in view)
+ self.assertTrue('b' in view)
+ self.assertTrue('c' in view)
+ self.assertFalse('xxx' in view)
+
+ def test_views(self):
+ mapping = {}
+ view = self.mappingproxy(mapping)
+ keys = view.keys()
+ values = view.values()
+ items = view.items()
+ self.assertEqual(list(keys), [])
+ self.assertEqual(list(values), [])
+ self.assertEqual(list(items), [])
+ mapping['key'] = 'value'
+ self.assertEqual(list(keys), ['key'])
+ self.assertEqual(list(values), ['value'])
+ self.assertEqual(list(items), [('key', 'value')])
+
+ def test_len(self):
+ for expected in range(6):
+ data = dict.fromkeys('abcde'[:expected])
+ self.assertEqual(len(data), expected)
+ view = self.mappingproxy(data)
+ self.assertEqual(len(view), expected)
+
+ def test_iterators(self):
+ keys = ('x', 'y')
+ values = (1, 2)
+ items = tuple(zip(keys, values))
+ view = self.mappingproxy(dict(items))
+ self.assertEqual(set(view), set(keys))
+ self.assertEqual(set(view.keys()), set(keys))
+ self.assertEqual(set(view.values()), set(values))
+ self.assertEqual(set(view.items()), set(items))
+
+ def test_copy(self):
+ original = {'key1': 27, 'key2': 51, 'key3': 93}
+ view = self.mappingproxy(original)
+ copy = view.copy()
+ self.assertEqual(type(copy), dict)
+ self.assertEqual(copy, original)
+ original['key1'] = 70
+ self.assertEqual(view['key1'], 70)
+ self.assertEqual(copy['key1'], 27)
+
+
+class ClassCreationTests(unittest.TestCase):
+
+ class Meta(type):
+ def __init__(cls, name, bases, ns, **kw):
+ super().__init__(name, bases, ns)
+ @staticmethod
+ def __new__(mcls, name, bases, ns, **kw):
+ return super().__new__(mcls, name, bases, ns)
+ @classmethod
+ def __prepare__(mcls, name, bases, **kw):
+ ns = super().__prepare__(name, bases)
+ ns["y"] = 1
+ ns.update(kw)
+ return ns
+
+ def test_new_class_basics(self):
+ C = types.new_class("C")
+ self.assertEqual(C.__name__, "C")
+ self.assertEqual(C.__bases__, (object,))
+
+ def test_new_class_subclass(self):
+ C = types.new_class("C", (int,))
+ self.assertTrue(issubclass(C, int))
+
+ def test_new_class_meta(self):
+ Meta = self.Meta
+ settings = {"metaclass": Meta, "z": 2}
+ # We do this twice to make sure the passed in dict isn't mutated
+ for i in range(2):
+ C = types.new_class("C" + str(i), (), settings)
+ self.assertIsInstance(C, Meta)
+ self.assertEqual(C.y, 1)
+ self.assertEqual(C.z, 2)
+
+ def test_new_class_exec_body(self):
+ Meta = self.Meta
+ def func(ns):
+ ns["x"] = 0
+ C = types.new_class("C", (), {"metaclass": Meta, "z": 2}, func)
+ self.assertIsInstance(C, Meta)
+ self.assertEqual(C.x, 0)
+ self.assertEqual(C.y, 1)
+ self.assertEqual(C.z, 2)
+
+ def test_new_class_metaclass_keywords(self):
+ #Test that keywords are passed to the metaclass:
+ def meta_func(name, bases, ns, **kw):
+ return name, bases, ns, kw
+ res = types.new_class("X",
+ (int, object),
+ dict(metaclass=meta_func, x=0))
+ self.assertEqual(res, ("X", (int, object), {}, {"x": 0}))
+
+ def test_new_class_defaults(self):
+ # Test defaults/keywords:
+ C = types.new_class("C", (), {}, None)
+ self.assertEqual(C.__name__, "C")
+ self.assertEqual(C.__bases__, (object,))
+
+ def test_new_class_meta_with_base(self):
+ Meta = self.Meta
+ def func(ns):
+ ns["x"] = 0
+ C = types.new_class(name="C",
+ bases=(int,),
+ kwds=dict(metaclass=Meta, z=2),
+ exec_body=func)
+ self.assertTrue(issubclass(C, int))
+ self.assertIsInstance(C, Meta)
+ self.assertEqual(C.x, 0)
+ self.assertEqual(C.y, 1)
+ self.assertEqual(C.z, 2)
+
+ # Many of the following tests are derived from test_descr.py
+ def test_prepare_class(self):
+ # Basic test of metaclass derivation
+ expected_ns = {}
+ class A(type):
+ def __new__(*args, **kwargs):
+ return type.__new__(*args, **kwargs)
+
+ def __prepare__(*args):
+ return expected_ns
+
+ B = types.new_class("B", (object,))
+ C = types.new_class("C", (object,), {"metaclass": A})
+
+ # The most derived metaclass of D is A rather than type.
+ meta, ns, kwds = types.prepare_class("D", (B, C), {"metaclass": type})
+ self.assertIs(meta, A)
+ self.assertIs(ns, expected_ns)
+ self.assertEqual(len(kwds), 0)
+
+ def test_metaclass_derivation(self):
+ # issue1294232: correct metaclass calculation
+ new_calls = [] # to check the order of __new__ calls
+ class AMeta(type):
+ def __new__(mcls, name, bases, ns):
+ new_calls.append('AMeta')
+ return super().__new__(mcls, name, bases, ns)
+ @classmethod
+ def __prepare__(mcls, name, bases):
+ return {}
+
+ class BMeta(AMeta):
+ def __new__(mcls, name, bases, ns):
+ new_calls.append('BMeta')
+ return super().__new__(mcls, name, bases, ns)
+ @classmethod
+ def __prepare__(mcls, name, bases):
+ ns = super().__prepare__(name, bases)
+ ns['BMeta_was_here'] = True
+ return ns
+
+ A = types.new_class("A", (), {"metaclass": AMeta})
+ self.assertEqual(new_calls, ['AMeta'])
+ new_calls.clear()
+
+ B = types.new_class("B", (), {"metaclass": BMeta})
+ # BMeta.__new__ calls AMeta.__new__ with super:
+ self.assertEqual(new_calls, ['BMeta', 'AMeta'])
+ new_calls.clear()
+
+ C = types.new_class("C", (A, B))
+ # The most derived metaclass is BMeta:
+ self.assertEqual(new_calls, ['BMeta', 'AMeta'])
+ new_calls.clear()
+ # BMeta.__prepare__ should've been called:
+ self.assertIn('BMeta_was_here', C.__dict__)
+
+ # The order of the bases shouldn't matter:
+ C2 = types.new_class("C2", (B, A))
+ self.assertEqual(new_calls, ['BMeta', 'AMeta'])
+ new_calls.clear()
+ self.assertIn('BMeta_was_here', C2.__dict__)
+
+ # Check correct metaclass calculation when a metaclass is declared:
+ D = types.new_class("D", (C,), {"metaclass": type})
+ self.assertEqual(new_calls, ['BMeta', 'AMeta'])
+ new_calls.clear()
+ self.assertIn('BMeta_was_here', D.__dict__)
+
+ E = types.new_class("E", (C,), {"metaclass": AMeta})
+ self.assertEqual(new_calls, ['BMeta', 'AMeta'])
+ new_calls.clear()
+ self.assertIn('BMeta_was_here', E.__dict__)
+
+ def test_metaclass_override_function(self):
+ # Special case: the given metaclass isn't a class,
+ # so there is no metaclass calculation.
+ class A(metaclass=self.Meta):
+ pass
+
+ marker = object()
+ def func(*args, **kwargs):
+ return marker
+
+ X = types.new_class("X", (), {"metaclass": func})
+ Y = types.new_class("Y", (object,), {"metaclass": func})
+ Z = types.new_class("Z", (A,), {"metaclass": func})
+ self.assertIs(marker, X)
+ self.assertIs(marker, Y)
+ self.assertIs(marker, Z)
+
+ def test_metaclass_override_callable(self):
+ # The given metaclass is a class,
+ # but not a descendant of type.
+ new_calls = [] # to check the order of __new__ calls
+ prepare_calls = [] # to track __prepare__ calls
+ class ANotMeta:
+ def __new__(mcls, *args, **kwargs):
+ new_calls.append('ANotMeta')
+ return super().__new__(mcls)
+ @classmethod
+ def __prepare__(mcls, name, bases):
+ prepare_calls.append('ANotMeta')
+ return {}
+
+ class BNotMeta(ANotMeta):
+ def __new__(mcls, *args, **kwargs):
+ new_calls.append('BNotMeta')
+ return super().__new__(mcls)
+ @classmethod
+ def __prepare__(mcls, name, bases):
+ prepare_calls.append('BNotMeta')
+ return super().__prepare__(name, bases)
+
+ A = types.new_class("A", (), {"metaclass": ANotMeta})
+ self.assertIs(ANotMeta, type(A))
+ self.assertEqual(prepare_calls, ['ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['ANotMeta'])
+ new_calls.clear()
+
+ B = types.new_class("B", (), {"metaclass": BNotMeta})
+ self.assertIs(BNotMeta, type(B))
+ self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
+ new_calls.clear()
+
+ C = types.new_class("C", (A, B))
+ self.assertIs(BNotMeta, type(C))
+ self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
+ new_calls.clear()
+
+ C2 = types.new_class("C2", (B, A))
+ self.assertIs(BNotMeta, type(C2))
+ self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
+ new_calls.clear()
+
+ # This is a TypeError, because of a metaclass conflict:
+ # BNotMeta is neither a subclass, nor a superclass of type
+ with self.assertRaises(TypeError):
+ D = types.new_class("D", (C,), {"metaclass": type})
+
+ E = types.new_class("E", (C,), {"metaclass": ANotMeta})
+ self.assertIs(BNotMeta, type(E))
+ self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
+ new_calls.clear()
+
+ F = types.new_class("F", (object(), C))
+ self.assertIs(BNotMeta, type(F))
+ self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
+ new_calls.clear()
+
+ F2 = types.new_class("F2", (C, object()))
+ self.assertIs(BNotMeta, type(F2))
+ self.assertEqual(prepare_calls, ['BNotMeta', 'ANotMeta'])
+ prepare_calls.clear()
+ self.assertEqual(new_calls, ['BNotMeta', 'ANotMeta'])
+ new_calls.clear()
+
+ # TypeError: BNotMeta is neither a
+ # subclass, nor a superclass of int
+ with self.assertRaises(TypeError):
+ X = types.new_class("X", (C, int()))
+ with self.assertRaises(TypeError):
+ X = types.new_class("X", (int(), C))
+
+
+class SimpleNamespaceTests(unittest.TestCase):
+
+ def test_constructor(self):
+ ns1 = types.SimpleNamespace()
+ ns2 = types.SimpleNamespace(x=1, y=2)
+ ns3 = types.SimpleNamespace(**dict(x=1, y=2))
+
+ with self.assertRaises(TypeError):
+ types.SimpleNamespace(1, 2, 3)
+
+ self.assertEqual(len(ns1.__dict__), 0)
+ self.assertEqual(vars(ns1), {})
+ self.assertEqual(len(ns2.__dict__), 2)
+ self.assertEqual(vars(ns2), {'y': 2, 'x': 1})
+ self.assertEqual(len(ns3.__dict__), 2)
+ self.assertEqual(vars(ns3), {'y': 2, 'x': 1})
+
+ def test_unbound(self):
+ ns1 = vars(types.SimpleNamespace())
+ ns2 = vars(types.SimpleNamespace(x=1, y=2))
+
+ self.assertEqual(ns1, {})
+ self.assertEqual(ns2, {'y': 2, 'x': 1})
+
+ def test_underlying_dict(self):
+ ns1 = types.SimpleNamespace()
+ ns2 = types.SimpleNamespace(x=1, y=2)
+ ns3 = types.SimpleNamespace(a=True, b=False)
+ mapping = ns3.__dict__
+ del ns3
+
+ self.assertEqual(ns1.__dict__, {})
+ self.assertEqual(ns2.__dict__, {'y': 2, 'x': 1})
+ self.assertEqual(mapping, dict(a=True, b=False))
+
+ def test_attrget(self):
+ ns = types.SimpleNamespace(x=1, y=2, w=3)
+
+ self.assertEqual(ns.x, 1)
+ self.assertEqual(ns.y, 2)
+ self.assertEqual(ns.w, 3)
+ with self.assertRaises(AttributeError):
+ ns.z
+
+ def test_attrset(self):
+ ns1 = types.SimpleNamespace()
+ ns2 = types.SimpleNamespace(x=1, y=2, w=3)
+ ns1.a = 'spam'
+ ns1.b = 'ham'
+ ns2.z = 4
+ ns2.theta = None
+
+ self.assertEqual(ns1.__dict__, dict(a='spam', b='ham'))
+ self.assertEqual(ns2.__dict__, dict(x=1, y=2, w=3, z=4, theta=None))
+
+ def test_attrdel(self):
+ ns1 = types.SimpleNamespace()
+ ns2 = types.SimpleNamespace(x=1, y=2, w=3)
+
+ with self.assertRaises(AttributeError):
+ del ns1.spam
+ with self.assertRaises(AttributeError):
+ del ns2.spam
+
+ del ns2.y
+ self.assertEqual(vars(ns2), dict(w=3, x=1))
+ ns2.y = 'spam'
+ self.assertEqual(vars(ns2), dict(w=3, x=1, y='spam'))
+ del ns2.y
+ self.assertEqual(vars(ns2), dict(w=3, x=1))
+
+ ns1.spam = 5
+ self.assertEqual(vars(ns1), dict(spam=5))
+ del ns1.spam
+ self.assertEqual(vars(ns1), {})
+
+ def test_repr(self):
+ ns1 = types.SimpleNamespace(x=1, y=2, w=3)
+ ns2 = types.SimpleNamespace()
+ ns2.x = "spam"
+ ns2._y = 5
+
+ self.assertEqual(repr(ns1), "namespace(w=3, x=1, y=2)")
+ self.assertEqual(repr(ns2), "namespace(_y=5, x='spam')")
+
+ def test_nested(self):
+ ns1 = types.SimpleNamespace(a=1, b=2)
+ ns2 = types.SimpleNamespace()
+ ns3 = types.SimpleNamespace(x=ns1)
+ ns2.spam = ns1
+ ns2.ham = '?'
+ ns2.spam = ns3
+
+ self.assertEqual(vars(ns1), dict(a=1, b=2))
+ self.assertEqual(vars(ns2), dict(spam=ns3, ham='?'))
+ self.assertEqual(ns2.spam, ns3)
+ self.assertEqual(vars(ns3), dict(x=ns1))
+ self.assertEqual(ns3.x.a, 1)
+
+ def test_recursive(self):
+ ns1 = types.SimpleNamespace(c='cookie')
+ ns2 = types.SimpleNamespace()
+ ns3 = types.SimpleNamespace(x=1)
+ ns1.spam = ns1
+ ns2.spam = ns3
+ ns3.spam = ns2
+
+ self.assertEqual(ns1.spam, ns1)
+ self.assertEqual(ns1.spam.spam, ns1)
+ self.assertEqual(ns1.spam.spam, ns1.spam)
+ self.assertEqual(ns2.spam, ns3)
+ self.assertEqual(ns3.spam, ns2)
+ self.assertEqual(ns2.spam.spam, ns2)
+
+ def test_recursive_repr(self):
+ ns1 = types.SimpleNamespace(c='cookie')
+ ns2 = types.SimpleNamespace()
+ ns3 = types.SimpleNamespace(x=1)
+ ns1.spam = ns1
+ ns2.spam = ns3
+ ns3.spam = ns2
+
+ self.assertEqual(repr(ns1),
+ "namespace(c='cookie', spam=namespace(...))")
+ self.assertEqual(repr(ns2),
+ "namespace(spam=namespace(spam=namespace(...), x=1))")
+
+ def test_as_dict(self):
+ ns = types.SimpleNamespace(spam='spamspamspam')
+
+ with self.assertRaises(TypeError):
+ len(ns)
+ with self.assertRaises(TypeError):
+ iter(ns)
+ with self.assertRaises(TypeError):
+ 'spam' in ns
+ with self.assertRaises(TypeError):
+ ns['spam']
+
+ def test_subclass(self):
+ class Spam(types.SimpleNamespace):
+ pass
+
+ spam = Spam(ham=8, eggs=9)
+
+ self.assertIs(type(spam), Spam)
+ self.assertEqual(vars(spam), {'ham': 8, 'eggs': 9})
+
+
def test_main():
- run_unittest(TypesTests)
+ run_unittest(TypesTests, MappingProxyTests, ClassCreationTests,
+ SimpleNamespaceTests)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_ucn.py b/Lib/test/test_ucn.py
index fbbe78b..2e63745 100644
--- a/Lib/test/test_ucn.py
+++ b/Lib/test/test_ucn.py
@@ -8,9 +8,12 @@ Modified for Python 2.0 by Fredrik Lundh (fredrik@pythonware.com)
"""#"
import unittest
+import unicodedata
import _testcapi
from test import support
+from http.client import HTTPException
+from test.test_normalization import check_version
class UnicodeNamesTest(unittest.TestCase):
@@ -60,8 +63,6 @@ class UnicodeNamesTest(unittest.TestCase):
)
def test_ascii_letters(self):
- import unicodedata
-
for char in "".join(map(chr, range(ord("a"), ord("z")))):
name = "LATIN SMALL LETTER %s" % char.upper()
code = unicodedata.lookup(name)
@@ -82,7 +83,6 @@ class UnicodeNamesTest(unittest.TestCase):
self.checkletter("HANGUL SYLLABLE HWEOK", "\ud6f8")
self.checkletter("HANGUL SYLLABLE HIH", "\ud7a3")
- import unicodedata
self.assertRaises(ValueError, unicodedata.name, "\ud7a4")
def test_cjk_unified_ideographs(self):
@@ -98,14 +98,11 @@ class UnicodeNamesTest(unittest.TestCase):
self.checkletter("CJK UNIFIED IDEOGRAPH-2B81D", "\U0002B81D")
def test_bmp_characters(self):
- import unicodedata
- count = 0
for code in range(0x10000):
char = chr(code)
name = unicodedata.name(char, None)
if name is not None:
self.assertEqual(unicodedata.lookup(name), char)
- count += 1
def test_misc_symbols(self):
self.checkletter("PILCROW SIGN", "\u00b6")
@@ -113,8 +110,85 @@ class UnicodeNamesTest(unittest.TestCase):
self.checkletter("HALFWIDTH KATAKANA SEMI-VOICED SOUND MARK", "\uFF9F")
self.checkletter("FULLWIDTH LATIN SMALL LETTER A", "\uFF41")
+ def test_aliases(self):
+ # Check that the aliases defined in the NameAliases.txt file work.
+ # This should be updated when new aliases are added or the file
+ # should be downloaded and parsed instead. See #12753.
+ aliases = [
+ ('LATIN CAPITAL LETTER GHA', 0x01A2),
+ ('LATIN SMALL LETTER GHA', 0x01A3),
+ ('KANNADA LETTER LLLA', 0x0CDE),
+ ('LAO LETTER FO FON', 0x0E9D),
+ ('LAO LETTER FO FAY', 0x0E9F),
+ ('LAO LETTER RO', 0x0EA3),
+ ('LAO LETTER LO', 0x0EA5),
+ ('TIBETAN MARK BKA- SHOG GI MGO RGYAN', 0x0FD0),
+ ('YI SYLLABLE ITERATION MARK', 0xA015),
+ ('PRESENTATION FORM FOR VERTICAL RIGHT WHITE LENTICULAR BRACKET', 0xFE18),
+ ('BYZANTINE MUSICAL SYMBOL FTHORA SKLIRON CHROMA VASIS', 0x1D0C5)
+ ]
+ for alias, codepoint in aliases:
+ self.checkletter(alias, chr(codepoint))
+ name = unicodedata.name(chr(codepoint))
+ self.assertNotEqual(name, alias)
+ self.assertEqual(unicodedata.lookup(alias),
+ unicodedata.lookup(name))
+ with self.assertRaises(KeyError):
+ unicodedata.ucd_3_2_0.lookup(alias)
+
+ def test_aliases_names_in_pua_range(self):
+ # We are storing aliases in the PUA 15, but their names shouldn't leak
+ for cp in range(0xf0000, 0xf0100):
+ with self.assertRaises(ValueError) as cm:
+ unicodedata.name(chr(cp))
+ self.assertEqual(str(cm.exception), 'no such name')
+
+ def test_named_sequences_names_in_pua_range(self):
+ # We are storing named seq in the PUA 15, but their names shouldn't leak
+ for cp in range(0xf0100, 0xf0fff):
+ with self.assertRaises(ValueError) as cm:
+ unicodedata.name(chr(cp))
+ self.assertEqual(str(cm.exception), 'no such name')
+
+ def test_named_sequences_sample(self):
+ # Check a few named sequences. See #12753.
+ sequences = [
+ ('LATIN SMALL LETTER R WITH TILDE', '\u0072\u0303'),
+ ('TAMIL SYLLABLE SAI', '\u0BB8\u0BC8'),
+ ('TAMIL SYLLABLE MOO', '\u0BAE\u0BCB'),
+ ('TAMIL SYLLABLE NNOO', '\u0BA3\u0BCB'),
+ ('TAMIL CONSONANT KSS', '\u0B95\u0BCD\u0BB7\u0BCD'),
+ ]
+ for seqname, codepoints in sequences:
+ self.assertEqual(unicodedata.lookup(seqname), codepoints)
+ with self.assertRaises(SyntaxError):
+ self.checkletter(seqname, None)
+ with self.assertRaises(KeyError):
+ unicodedata.ucd_3_2_0.lookup(seqname)
+
+ def test_named_sequences_full(self):
+ # Check all the named sequences
+ url = ("http://www.unicode.org/Public/%s/ucd/NamedSequences.txt" %
+ unicodedata.unidata_version)
+ try:
+ testdata = support.open_urlresource(url, encoding="utf-8",
+ check=check_version)
+ except (IOError, HTTPException):
+ self.skipTest("Could not retrieve " + url)
+ self.addCleanup(testdata.close)
+ for line in testdata:
+ line = line.strip()
+ if not line or line.startswith('#'):
+ continue
+ seqname, codepoints = line.split(';')
+ codepoints = ''.join(chr(int(cp, 16)) for cp in codepoints.split())
+ self.assertEqual(unicodedata.lookup(seqname), codepoints)
+ with self.assertRaises(SyntaxError):
+ self.checkletter(seqname, None)
+ with self.assertRaises(KeyError):
+ unicodedata.ucd_3_2_0.lookup(seqname)
+
def test_errors(self):
- import unicodedata
self.assertRaises(TypeError, unicodedata.name)
self.assertRaises(TypeError, unicodedata.name, 'xx')
self.assertRaises(TypeError, unicodedata.lookup)
@@ -145,7 +219,7 @@ class UnicodeNamesTest(unittest.TestCase):
@unittest.skipUnless(_testcapi.INT_MAX < _testcapi.PY_SSIZE_T_MAX,
"needs UINT_MAX < SIZE_MAX")
@support.bigmemtest(size=_testcapi.UINT_MAX + 1,
- memuse=2 + 4 // len('\U00010000'), dry_run=False)
+ memuse=2 + 1, dry_run=False)
def test_issue16335(self, size):
# very very long bogus character name
x = b'\\N{SPACE' + b'x' * (_testcapi.UINT_MAX + 1) + b'}'
diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py
index 47af8b9..f7d8686 100644
--- a/Lib/test/test_unicode.py
+++ b/Lib/test/test_unicode.py
@@ -5,17 +5,13 @@ Written by Marc-Andre Lemburg (mal@lemburg.com).
(c) Copyright CNRI, All Rights Reserved. NO WARRANTY.
"""#"
+import _string
import codecs
import struct
import sys
import unittest
import warnings
from test import support, string_tests
-import _string
-
-# decorator to skip tests on narrow builds
-requires_wide_build = unittest.skipIf(sys.maxunicode == 65535,
- 'requires wide build')
# Error handling (bad decoder return)
def search_function(encoding):
@@ -37,7 +33,8 @@ codecs.register(search_function)
class UnicodeTest(string_tests.CommonTest,
string_tests.MixinStrUnicodeUserStringTest,
- string_tests.MixinStrUnicodeTest):
+ string_tests.MixinStrUnicodeTest,
+ unittest.TestCase):
type2test = str
@@ -175,6 +172,15 @@ class UnicodeTest(string_tests.CommonTest,
def test_find(self):
string_tests.CommonTest.test_find(self)
+ # test implementation details of the memchr fast path
+ self.checkequal(100, 'a' * 100 + '\u0102', 'find', '\u0102')
+ self.checkequal(-1, 'a' * 100 + '\u0102', 'find', '\u0201')
+ self.checkequal(-1, 'a' * 100 + '\u0102', 'find', '\u0120')
+ self.checkequal(-1, 'a' * 100 + '\u0102', 'find', '\u0220')
+ self.checkequal(100, 'a' * 100 + '\U00100304', 'find', '\U00100304')
+ self.checkequal(-1, 'a' * 100 + '\U00100304', 'find', '\U00100204')
+ self.checkequal(-1, 'a' * 100 + '\U00100304', 'find', '\U00102004')
+ # check mixed argument types
self.checkequalnofix(0, 'abcdefghiabc', 'find', 'abc')
self.checkequalnofix(9, 'abcdefghiabc', 'find', 'abc', 1)
self.checkequalnofix(-1, 'abcdefghiabc', 'find', 'def', 4)
@@ -184,6 +190,14 @@ class UnicodeTest(string_tests.CommonTest,
def test_rfind(self):
string_tests.CommonTest.test_rfind(self)
+ # test implementation details of the memrchr fast path
+ self.checkequal(0, '\u0102' + 'a' * 100 , 'rfind', '\u0102')
+ self.checkequal(-1, '\u0102' + 'a' * 100 , 'rfind', '\u0201')
+ self.checkequal(-1, '\u0102' + 'a' * 100 , 'rfind', '\u0120')
+ self.checkequal(-1, '\u0102' + 'a' * 100 , 'rfind', '\u0220')
+ self.checkequal(0, '\U00100304' + 'a' * 100, 'rfind', '\U00100304')
+ self.checkequal(-1, '\U00100304' + 'a' * 100, 'rfind', '\U00100204')
+ self.checkequal(-1, '\U00100304' + 'a' * 100, 'rfind', '\U00102004')
# check mixed argument types
self.checkequalnofix(9, 'abcdefghiabc', 'rfind', 'abc')
self.checkequalnofix(12, 'abcdefghiabc', 'rfind', '')
@@ -280,6 +294,12 @@ class UnicodeTest(string_tests.CommonTest,
self.checkequalnofix('one@two!three!', 'one!two!three!', 'replace', '!', '@', 1)
self.assertRaises(TypeError, 'replace'.replace, "r", 42)
+ @support.cpython_only
+ def test_replace_id(self):
+ pattern = 'abc'
+ text = 'abc def'
+ self.assertIs(text.replace(pattern, pattern), text)
+
def test_bytes_comparison(self):
with support.check_warnings():
warnings.simplefilter('ignore', BytesWarning)
@@ -350,6 +370,8 @@ class UnicodeTest(string_tests.CommonTest,
def test_islower(self):
string_tests.MixinStrUnicodeUserStringTest.test_islower(self)
self.checkequalnofix(False, '\u1FFc', 'islower')
+ self.assertFalse('\u2167'.islower())
+ self.assertTrue('\u2177'.islower())
# non-BMP, uppercase
self.assertFalse('\U00010401'.islower())
self.assertFalse('\U00010427'.islower())
@@ -364,6 +386,8 @@ class UnicodeTest(string_tests.CommonTest,
string_tests.MixinStrUnicodeUserStringTest.test_isupper(self)
if not sys.platform.startswith('java'):
self.checkequalnofix(False, '\u1FFc', 'isupper')
+ self.assertTrue('\u2167'.isupper())
+ self.assertFalse('\u2177'.isupper())
# non-BMP, uppercase
self.assertTrue('\U00010401'.isupper())
self.assertTrue('\U00010427'.isupper())
@@ -520,7 +544,6 @@ class UnicodeTest(string_tests.CommonTest,
self.assertFalse(meth(s), '%a.%s() is False' % (s, meth_name))
- @requires_wide_build
def test_lower(self):
string_tests.CommonTest.test_lower(self)
self.assertEqual('\U00010427'.lower(), '\U0001044F')
@@ -530,8 +553,28 @@ class UnicodeTest(string_tests.CommonTest,
'\U0001044F\U0001044F')
self.assertEqual('X\U00010427x\U0001044F'.lower(),
'x\U0001044Fx\U0001044F')
+ self.assertEqual('fi'.lower(), 'fi')
+ self.assertEqual('\u0130'.lower(), '\u0069\u0307')
+ # Special case for GREEK CAPITAL LETTER SIGMA U+03A3
+ self.assertEqual('\u03a3'.lower(), '\u03c3')
+ self.assertEqual('\u0345\u03a3'.lower(), '\u0345\u03c3')
+ self.assertEqual('A\u0345\u03a3'.lower(), 'a\u0345\u03c2')
+ self.assertEqual('A\u0345\u03a3a'.lower(), 'a\u0345\u03c3a')
+ self.assertEqual('A\u0345\u03a3'.lower(), 'a\u0345\u03c2')
+ self.assertEqual('A\u03a3\u0345'.lower(), 'a\u03c2\u0345')
+ self.assertEqual('\u03a3\u0345 '.lower(), '\u03c3\u0345 ')
+ self.assertEqual('\U0008fffe'.lower(), '\U0008fffe')
+ self.assertEqual('\u2177'.lower(), '\u2177')
+
+ def test_casefold(self):
+ self.assertEqual('hello'.casefold(), 'hello')
+ self.assertEqual('hELlo'.casefold(), 'hello')
+ self.assertEqual('ß'.casefold(), 'ss')
+ self.assertEqual('fi'.casefold(), 'fi')
+ self.assertEqual('\u03a3'.casefold(), '\u03c3')
+ self.assertEqual('A\u0345\u03a3'.casefold(), 'a\u03b9\u03c3')
+ self.assertEqual('\u00b5'.casefold(), '\u03bc')
- @requires_wide_build
def test_upper(self):
string_tests.CommonTest.test_upper(self)
self.assertEqual('\U0001044F'.upper(), '\U00010427')
@@ -541,8 +584,14 @@ class UnicodeTest(string_tests.CommonTest,
'\U00010427\U00010427')
self.assertEqual('X\U00010427x\U0001044F'.upper(),
'X\U00010427X\U00010427')
+ self.assertEqual('fi'.upper(), 'FI')
+ self.assertEqual('\u0130'.upper(), '\u0130')
+ self.assertEqual('\u03a3'.upper(), '\u03a3')
+ self.assertEqual('ß'.upper(), 'SS')
+ self.assertEqual('\u1fd2'.upper(), '\u0399\u0308\u0300')
+ self.assertEqual('\U0008fffe'.upper(), '\U0008fffe')
+ self.assertEqual('\u2177'.upper(), '\u2167')
- @requires_wide_build
def test_capitalize(self):
string_tests.CommonTest.test_capitalize(self)
self.assertEqual('\U0001044F'.capitalize(), '\U00010427')
@@ -554,8 +603,12 @@ class UnicodeTest(string_tests.CommonTest,
'\U00010427\U0001044F')
self.assertEqual('X\U00010427x\U0001044F'.capitalize(),
'X\U0001044Fx\U0001044F')
+ self.assertEqual('h\u0130'.capitalize(), 'H\u0069\u0307')
+ exp = '\u0399\u0308\u0300\u0069\u0307'
+ self.assertEqual('\u1fd2\u0130'.capitalize(), exp)
+ self.assertEqual('finnish'.capitalize(), 'FInnish')
+ self.assertEqual('A\u0345\u03a3'.capitalize(), 'A\u0345\u03c2')
- @requires_wide_build
def test_title(self):
string_tests.MixinStrUnicodeUserStringTest.test_title(self)
self.assertEqual('\U0001044F'.title(), '\U00010427')
@@ -569,8 +622,10 @@ class UnicodeTest(string_tests.CommonTest,
'\U00010427\U0001044F \U00010427\U0001044F')
self.assertEqual('X\U00010427x\U0001044F X\U00010427x\U0001044F'.title(),
'X\U0001044Fx\U0001044F X\U0001044Fx\U0001044F')
+ self.assertEqual('fiNNISH'.title(), 'Finnish')
+ self.assertEqual('A\u03a3 \u1fa1xy'.title(), 'A\u03c2 \u1fa9xy')
+ self.assertEqual('A\u03a3A'.title(), 'A\u03c3a')
- @requires_wide_build
def test_swapcase(self):
string_tests.CommonTest.test_swapcase(self)
self.assertEqual('\U0001044F'.swapcase(), '\U00010427')
@@ -583,6 +638,19 @@ class UnicodeTest(string_tests.CommonTest,
'\U00010427\U0001044F')
self.assertEqual('X\U00010427x\U0001044F'.swapcase(),
'x\U0001044FX\U00010427')
+ self.assertEqual('fi'.swapcase(), 'FI')
+ self.assertEqual('\u0130'.swapcase(), '\u0069\u0307')
+ # Special case for GREEK CAPITAL LETTER SIGMA U+03A3
+ self.assertEqual('\u03a3'.swapcase(), '\u03c3')
+ self.assertEqual('\u0345\u03a3'.swapcase(), '\u0399\u03c3')
+ self.assertEqual('A\u0345\u03a3'.swapcase(), 'a\u0399\u03c2')
+ self.assertEqual('A\u0345\u03a3a'.swapcase(), 'a\u0399\u03c3A')
+ self.assertEqual('A\u0345\u03a3'.swapcase(), 'a\u0399\u03c2')
+ self.assertEqual('A\u03a3\u0345'.swapcase(), 'a\u03c2\u0399')
+ self.assertEqual('\u03a3\u0345 '.swapcase(), '\u03c3\u0399 ')
+ self.assertEqual('\u03a3'.swapcase(), '\u03c3')
+ self.assertEqual('ß'.swapcase(), 'SS')
+ self.assertEqual('\u1fd2'.swapcase(), '\u0399\u0308\u0300')
def test_contains(self):
# Testing Unicode contains method
@@ -776,7 +844,7 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual('{0!s}'.format(G('data')), 'string is data')
msg = 'object.__format__ with a non-empty format string is deprecated'
- with support.check_warnings((msg, PendingDeprecationWarning)):
+ with support.check_warnings((msg, DeprecationWarning)):
self.assertEqual('{0:^10}'.format(E('data')), ' E(data) ')
self.assertEqual('{0:^10s}'.format(E('data')), ' E(data) ')
self.assertEqual('{0:>15s}'.format(G('data')), ' string is data')
@@ -858,6 +926,14 @@ class UnicodeTest(string_tests.CommonTest,
self.assertRaises(ValueError, format, '', '#')
self.assertRaises(ValueError, format, '', '#20')
+ # Non-ASCII
+ self.assertEqual("{0:s}{1:s}".format("ABC", "\u0410\u0411\u0412"),
+ 'ABC\u0410\u0411\u0412')
+ self.assertEqual("{0:.3s}".format("ABC\u0410\u0411\u0412"),
+ 'ABC')
+ self.assertEqual("{0:.0s}".format("ABC\u0410\u0411\u0412"),
+ '')
+
def test_format_map(self):
self.assertEqual(''.format_map({}), '')
self.assertEqual('a'.format_map({}), 'a')
@@ -1005,6 +1081,10 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual('%f' % INF, 'inf')
self.assertEqual('%F' % INF, 'INF')
+ # PEP 393
+ self.assertEqual('%.1s' % "a\xe9\u20ac", 'a')
+ self.assertEqual('%.2s' % "a\xe9\u20ac", 'a\xe9')
+
@support.cpython_only
def test_formatting_huge_precision(self):
from _testcapi import INT_MAX
@@ -1041,10 +1121,13 @@ class UnicodeTest(string_tests.CommonTest,
class UnicodeSubclass(str):
pass
- self.assertEqual(
- str(UnicodeSubclass('unicode subclass becomes unicode')),
- 'unicode subclass becomes unicode'
- )
+ for text in ('ascii', '\xe9', '\u20ac', '\U0010FFFF'):
+ subclass = UnicodeSubclass(text)
+ self.assertEqual(str(subclass), text)
+ self.assertEqual(len(subclass), len(text))
+ if text == 'ascii':
+ self.assertEqual(subclass.encode('ascii'), b'ascii')
+ self.assertEqual(subclass.encode('utf-8'), b'ascii')
self.assertEqual(
str('strings are converted to unicode'),
@@ -1170,15 +1253,12 @@ class UnicodeTest(string_tests.CommonTest,
def test_codecs_utf8(self):
self.assertEqual(''.encode('utf-8'), b'')
self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac')
- if sys.maxunicode == 65535:
- self.assertEqual('\ud800\udc02'.encode('utf-8'), b'\xf0\x90\x80\x82')
- self.assertEqual('\ud84d\udc56'.encode('utf-8'), b'\xf0\xa3\x91\x96')
+ self.assertEqual('\U00010002'.encode('utf-8'), b'\xf0\x90\x80\x82')
+ self.assertEqual('\U00023456'.encode('utf-8'), b'\xf0\xa3\x91\x96')
self.assertEqual('\ud800'.encode('utf-8', 'surrogatepass'), b'\xed\xa0\x80')
self.assertEqual('\udc00'.encode('utf-8', 'surrogatepass'), b'\xed\xb0\x80')
- if sys.maxunicode == 65535:
- self.assertEqual(
- ('\ud800\udc02'*1000).encode('utf-8'),
- b'\xf0\x90\x80\x82'*1000)
+ self.assertEqual(('\U00010002'*10).encode('utf-8'),
+ b'\xf0\x90\x80\x82'*10)
self.assertEqual(
'\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f'
'\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00'
@@ -1290,7 +1370,7 @@ class UnicodeTest(string_tests.CommonTest,
# with start byte of a 2-byte sequence
(b'\xc2', FFFD), # only the start byte
(b'\xc2\xc2', FFFD*2), # 2 start bytes
- (b'\xc2\xc2\xc2', FFFD*3), # 2 start bytes
+ (b'\xc2\xc2\xc2', FFFD*3), # 3 start bytes
(b'\xc2\x41', FFFD+'A'), # invalid continuation byte
# with start byte of a 3-byte sequence
(b'\xe1', FFFD), # only the start byte
@@ -1360,6 +1440,226 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual(seq.decode('utf-8', 'ignore'),
res.replace('\uFFFD', ''))
+ def to_bytestring(self, seq):
+ return bytes(int(c, 16) for c in seq.split())
+
+ def assertCorrectUTF8Decoding(self, seq, res, err):
+ """
+ Check that an invalid UTF-8 sequence raises an UnicodeDecodeError when
+ 'strict' is used, returns res when 'replace' is used, and that doesn't
+ return anything when 'ignore' is used.
+ """
+ with self.assertRaises(UnicodeDecodeError) as cm:
+ seq.decode('utf-8')
+ exc = cm.exception
+
+ self.assertIn(err, str(exc))
+ self.assertEqual(seq.decode('utf-8', 'replace'), res)
+ self.assertEqual((b'aaaa' + seq + b'bbbb').decode('utf-8', 'replace'),
+ 'aaaa' + res + 'bbbb')
+ res = res.replace('\ufffd', '')
+ self.assertEqual(seq.decode('utf-8', 'ignore'), res)
+ self.assertEqual((b'aaaa' + seq + b'bbbb').decode('utf-8', 'ignore'),
+ 'aaaa' + res + 'bbbb')
+
+ def test_invalid_start_byte(self):
+ """
+ Test that an 'invalid start byte' error is raised when the first byte
+ is not in the ASCII range or is not a valid start byte of a 2-, 3-, or
+ 4-bytes sequence. The invalid start byte is replaced with a single
+ U+FFFD when errors='replace'.
+ E.g. <80> is a continuation byte and can appear only after a start byte.
+ """
+ FFFD = '\ufffd'
+ for byte in b'\x80\xA0\x9F\xBF\xC0\xC1\xF5\xFF':
+ self.assertCorrectUTF8Decoding(bytes([byte]), '\ufffd',
+ 'invalid start byte')
+
+ def test_unexpected_end_of_data(self):
+ """
+ Test that an 'unexpected end of data' error is raised when the string
+ ends after a start byte of a 2-, 3-, or 4-bytes sequence without having
+ enough continuation bytes. The incomplete sequence is replaced with a
+ single U+FFFD when errors='replace'.
+ E.g. in the sequence <F3 80 80>, F3 is the start byte of a 4-bytes
+ sequence, but it's followed by only 2 valid continuation bytes and the
+ last continuation bytes is missing.
+ Note: the continuation bytes must be all valid, if one of them is
+ invalid another error will be raised.
+ """
+ sequences = [
+ 'C2', 'DF',
+ 'E0 A0', 'E0 BF', 'E1 80', 'E1 BF', 'EC 80', 'EC BF',
+ 'ED 80', 'ED 9F', 'EE 80', 'EE BF', 'EF 80', 'EF BF',
+ 'F0 90', 'F0 BF', 'F0 90 80', 'F0 90 BF', 'F0 BF 80', 'F0 BF BF',
+ 'F1 80', 'F1 BF', 'F1 80 80', 'F1 80 BF', 'F1 BF 80', 'F1 BF BF',
+ 'F3 80', 'F3 BF', 'F3 80 80', 'F3 80 BF', 'F3 BF 80', 'F3 BF BF',
+ 'F4 80', 'F4 8F', 'F4 80 80', 'F4 80 BF', 'F4 8F 80', 'F4 8F BF'
+ ]
+ FFFD = '\ufffd'
+ for seq in sequences:
+ self.assertCorrectUTF8Decoding(self.to_bytestring(seq), '\ufffd',
+ 'unexpected end of data')
+
+ def test_invalid_cb_for_2bytes_seq(self):
+ """
+ Test that an 'invalid continuation byte' error is raised when the
+ continuation byte of a 2-bytes sequence is invalid. The start byte
+ is replaced by a single U+FFFD and the second byte is handled
+ separately when errors='replace'.
+ E.g. in the sequence <C2 41>, C2 is the start byte of a 2-bytes
+ sequence, but 41 is not a valid continuation byte because it's the
+ ASCII letter 'A'.
+ """
+ FFFD = '\ufffd'
+ FFFDx2 = FFFD * 2
+ sequences = [
+ ('C2 00', FFFD+'\x00'), ('C2 7F', FFFD+'\x7f'),
+ ('C2 C0', FFFDx2), ('C2 FF', FFFDx2),
+ ('DF 00', FFFD+'\x00'), ('DF 7F', FFFD+'\x7f'),
+ ('DF C0', FFFDx2), ('DF FF', FFFDx2),
+ ]
+ for seq, res in sequences:
+ self.assertCorrectUTF8Decoding(self.to_bytestring(seq), res,
+ 'invalid continuation byte')
+
+ def test_invalid_cb_for_3bytes_seq(self):
+ """
+ Test that an 'invalid continuation byte' error is raised when the
+ continuation byte(s) of a 3-bytes sequence are invalid. When
+ errors='replace', if the first continuation byte is valid, the first
+ two bytes (start byte + 1st cb) are replaced by a single U+FFFD and the
+ third byte is handled separately, otherwise only the start byte is
+ replaced with a U+FFFD and the other continuation bytes are handled
+ separately.
+ E.g. in the sequence <E1 80 41>, E1 is the start byte of a 3-bytes
+ sequence, 80 is a valid continuation byte, but 41 is not a valid cb
+ because it's the ASCII letter 'A'.
+ Note: when the start byte is E0 or ED, the valid ranges for the first
+ continuation byte are limited to A0..BF and 80..9F respectively.
+ Python 2 used to consider all the bytes in range 80..BF valid when the
+ start byte was ED. This is fixed in Python 3.
+ """
+ FFFD = '\ufffd'
+ FFFDx2 = FFFD * 2
+ sequences = [
+ ('E0 00', FFFD+'\x00'), ('E0 7F', FFFD+'\x7f'), ('E0 80', FFFDx2),
+ ('E0 9F', FFFDx2), ('E0 C0', FFFDx2), ('E0 FF', FFFDx2),
+ ('E0 A0 00', FFFD+'\x00'), ('E0 A0 7F', FFFD+'\x7f'),
+ ('E0 A0 C0', FFFDx2), ('E0 A0 FF', FFFDx2),
+ ('E0 BF 00', FFFD+'\x00'), ('E0 BF 7F', FFFD+'\x7f'),
+ ('E0 BF C0', FFFDx2), ('E0 BF FF', FFFDx2), ('E1 00', FFFD+'\x00'),
+ ('E1 7F', FFFD+'\x7f'), ('E1 C0', FFFDx2), ('E1 FF', FFFDx2),
+ ('E1 80 00', FFFD+'\x00'), ('E1 80 7F', FFFD+'\x7f'),
+ ('E1 80 C0', FFFDx2), ('E1 80 FF', FFFDx2),
+ ('E1 BF 00', FFFD+'\x00'), ('E1 BF 7F', FFFD+'\x7f'),
+ ('E1 BF C0', FFFDx2), ('E1 BF FF', FFFDx2), ('EC 00', FFFD+'\x00'),
+ ('EC 7F', FFFD+'\x7f'), ('EC C0', FFFDx2), ('EC FF', FFFDx2),
+ ('EC 80 00', FFFD+'\x00'), ('EC 80 7F', FFFD+'\x7f'),
+ ('EC 80 C0', FFFDx2), ('EC 80 FF', FFFDx2),
+ ('EC BF 00', FFFD+'\x00'), ('EC BF 7F', FFFD+'\x7f'),
+ ('EC BF C0', FFFDx2), ('EC BF FF', FFFDx2), ('ED 00', FFFD+'\x00'),
+ ('ED 7F', FFFD+'\x7f'),
+ ('ED A0', FFFDx2), ('ED BF', FFFDx2), # see note ^
+ ('ED C0', FFFDx2), ('ED FF', FFFDx2), ('ED 80 00', FFFD+'\x00'),
+ ('ED 80 7F', FFFD+'\x7f'), ('ED 80 C0', FFFDx2),
+ ('ED 80 FF', FFFDx2), ('ED 9F 00', FFFD+'\x00'),
+ ('ED 9F 7F', FFFD+'\x7f'), ('ED 9F C0', FFFDx2),
+ ('ED 9F FF', FFFDx2), ('EE 00', FFFD+'\x00'),
+ ('EE 7F', FFFD+'\x7f'), ('EE C0', FFFDx2), ('EE FF', FFFDx2),
+ ('EE 80 00', FFFD+'\x00'), ('EE 80 7F', FFFD+'\x7f'),
+ ('EE 80 C0', FFFDx2), ('EE 80 FF', FFFDx2),
+ ('EE BF 00', FFFD+'\x00'), ('EE BF 7F', FFFD+'\x7f'),
+ ('EE BF C0', FFFDx2), ('EE BF FF', FFFDx2), ('EF 00', FFFD+'\x00'),
+ ('EF 7F', FFFD+'\x7f'), ('EF C0', FFFDx2), ('EF FF', FFFDx2),
+ ('EF 80 00', FFFD+'\x00'), ('EF 80 7F', FFFD+'\x7f'),
+ ('EF 80 C0', FFFDx2), ('EF 80 FF', FFFDx2),
+ ('EF BF 00', FFFD+'\x00'), ('EF BF 7F', FFFD+'\x7f'),
+ ('EF BF C0', FFFDx2), ('EF BF FF', FFFDx2),
+ ]
+ for seq, res in sequences:
+ self.assertCorrectUTF8Decoding(self.to_bytestring(seq), res,
+ 'invalid continuation byte')
+
+ def test_invalid_cb_for_4bytes_seq(self):
+ """
+ Test that an 'invalid continuation byte' error is raised when the
+ continuation byte(s) of a 4-bytes sequence are invalid. When
+ errors='replace',the start byte and all the following valid
+ continuation bytes are replaced with a single U+FFFD, and all the bytes
+ starting from the first invalid continuation bytes (included) are
+ handled separately.
+ E.g. in the sequence <E1 80 41>, E1 is the start byte of a 3-bytes
+ sequence, 80 is a valid continuation byte, but 41 is not a valid cb
+ because it's the ASCII letter 'A'.
+ Note: when the start byte is E0 or ED, the valid ranges for the first
+ continuation byte are limited to A0..BF and 80..9F respectively.
+ However, when the start byte is ED, Python 2 considers all the bytes
+ in range 80..BF valid. This is fixed in Python 3.
+ """
+ FFFD = '\ufffd'
+ FFFDx2 = FFFD * 2
+ sequences = [
+ ('F0 00', FFFD+'\x00'), ('F0 7F', FFFD+'\x7f'), ('F0 80', FFFDx2),
+ ('F0 8F', FFFDx2), ('F0 C0', FFFDx2), ('F0 FF', FFFDx2),
+ ('F0 90 00', FFFD+'\x00'), ('F0 90 7F', FFFD+'\x7f'),
+ ('F0 90 C0', FFFDx2), ('F0 90 FF', FFFDx2),
+ ('F0 BF 00', FFFD+'\x00'), ('F0 BF 7F', FFFD+'\x7f'),
+ ('F0 BF C0', FFFDx2), ('F0 BF FF', FFFDx2),
+ ('F0 90 80 00', FFFD+'\x00'), ('F0 90 80 7F', FFFD+'\x7f'),
+ ('F0 90 80 C0', FFFDx2), ('F0 90 80 FF', FFFDx2),
+ ('F0 90 BF 00', FFFD+'\x00'), ('F0 90 BF 7F', FFFD+'\x7f'),
+ ('F0 90 BF C0', FFFDx2), ('F0 90 BF FF', FFFDx2),
+ ('F0 BF 80 00', FFFD+'\x00'), ('F0 BF 80 7F', FFFD+'\x7f'),
+ ('F0 BF 80 C0', FFFDx2), ('F0 BF 80 FF', FFFDx2),
+ ('F0 BF BF 00', FFFD+'\x00'), ('F0 BF BF 7F', FFFD+'\x7f'),
+ ('F0 BF BF C0', FFFDx2), ('F0 BF BF FF', FFFDx2),
+ ('F1 00', FFFD+'\x00'), ('F1 7F', FFFD+'\x7f'), ('F1 C0', FFFDx2),
+ ('F1 FF', FFFDx2), ('F1 80 00', FFFD+'\x00'),
+ ('F1 80 7F', FFFD+'\x7f'), ('F1 80 C0', FFFDx2),
+ ('F1 80 FF', FFFDx2), ('F1 BF 00', FFFD+'\x00'),
+ ('F1 BF 7F', FFFD+'\x7f'), ('F1 BF C0', FFFDx2),
+ ('F1 BF FF', FFFDx2), ('F1 80 80 00', FFFD+'\x00'),
+ ('F1 80 80 7F', FFFD+'\x7f'), ('F1 80 80 C0', FFFDx2),
+ ('F1 80 80 FF', FFFDx2), ('F1 80 BF 00', FFFD+'\x00'),
+ ('F1 80 BF 7F', FFFD+'\x7f'), ('F1 80 BF C0', FFFDx2),
+ ('F1 80 BF FF', FFFDx2), ('F1 BF 80 00', FFFD+'\x00'),
+ ('F1 BF 80 7F', FFFD+'\x7f'), ('F1 BF 80 C0', FFFDx2),
+ ('F1 BF 80 FF', FFFDx2), ('F1 BF BF 00', FFFD+'\x00'),
+ ('F1 BF BF 7F', FFFD+'\x7f'), ('F1 BF BF C0', FFFDx2),
+ ('F1 BF BF FF', FFFDx2), ('F3 00', FFFD+'\x00'),
+ ('F3 7F', FFFD+'\x7f'), ('F3 C0', FFFDx2), ('F3 FF', FFFDx2),
+ ('F3 80 00', FFFD+'\x00'), ('F3 80 7F', FFFD+'\x7f'),
+ ('F3 80 C0', FFFDx2), ('F3 80 FF', FFFDx2),
+ ('F3 BF 00', FFFD+'\x00'), ('F3 BF 7F', FFFD+'\x7f'),
+ ('F3 BF C0', FFFDx2), ('F3 BF FF', FFFDx2),
+ ('F3 80 80 00', FFFD+'\x00'), ('F3 80 80 7F', FFFD+'\x7f'),
+ ('F3 80 80 C0', FFFDx2), ('F3 80 80 FF', FFFDx2),
+ ('F3 80 BF 00', FFFD+'\x00'), ('F3 80 BF 7F', FFFD+'\x7f'),
+ ('F3 80 BF C0', FFFDx2), ('F3 80 BF FF', FFFDx2),
+ ('F3 BF 80 00', FFFD+'\x00'), ('F3 BF 80 7F', FFFD+'\x7f'),
+ ('F3 BF 80 C0', FFFDx2), ('F3 BF 80 FF', FFFDx2),
+ ('F3 BF BF 00', FFFD+'\x00'), ('F3 BF BF 7F', FFFD+'\x7f'),
+ ('F3 BF BF C0', FFFDx2), ('F3 BF BF FF', FFFDx2),
+ ('F4 00', FFFD+'\x00'), ('F4 7F', FFFD+'\x7f'), ('F4 90', FFFDx2),
+ ('F4 BF', FFFDx2), ('F4 C0', FFFDx2), ('F4 FF', FFFDx2),
+ ('F4 80 00', FFFD+'\x00'), ('F4 80 7F', FFFD+'\x7f'),
+ ('F4 80 C0', FFFDx2), ('F4 80 FF', FFFDx2),
+ ('F4 8F 00', FFFD+'\x00'), ('F4 8F 7F', FFFD+'\x7f'),
+ ('F4 8F C0', FFFDx2), ('F4 8F FF', FFFDx2),
+ ('F4 80 80 00', FFFD+'\x00'), ('F4 80 80 7F', FFFD+'\x7f'),
+ ('F4 80 80 C0', FFFDx2), ('F4 80 80 FF', FFFDx2),
+ ('F4 80 BF 00', FFFD+'\x00'), ('F4 80 BF 7F', FFFD+'\x7f'),
+ ('F4 80 BF C0', FFFDx2), ('F4 80 BF FF', FFFDx2),
+ ('F4 8F 80 00', FFFD+'\x00'), ('F4 8F 80 7F', FFFD+'\x7f'),
+ ('F4 8F 80 C0', FFFDx2), ('F4 8F 80 FF', FFFDx2),
+ ('F4 8F BF 00', FFFD+'\x00'), ('F4 8F BF 7F', FFFD+'\x7f'),
+ ('F4 8F BF C0', FFFDx2), ('F4 8F BF FF', FFFDx2)
+ ]
+ for seq, res in sequences:
+ self.assertCorrectUTF8Decoding(self.to_bytestring(seq), res,
+ 'invalid continuation byte')
+
def test_codecs_idna(self):
# Test whether trailing dot is preserved
self.assertEqual("www.python.org.".encode("idna"), b"www.python.org.")
@@ -1391,14 +1691,6 @@ class UnicodeTest(string_tests.CommonTest,
self.assertRaises(TypeError, str, b"hello", "test.unicode2")
self.assertRaises(TypeError, "hello".encode, "test.unicode1")
self.assertRaises(TypeError, "hello".encode, "test.unicode2")
- # executes PyUnicode_Encode()
- import imp
- self.assertRaises(
- ImportError,
- imp.find_module,
- "non-existing module",
- ["non-existing dir"]
- )
# Error handling (wrong arguments)
self.assertRaises(TypeError, "hello".encode, 42, 42, 42)
@@ -1416,18 +1708,25 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual('hello'.encode('ascii'), b'hello')
self.assertEqual('hello'.encode('utf-7'), b'hello')
self.assertEqual('hello'.encode('utf-8'), b'hello')
- self.assertEqual('hello'.encode('utf8'), b'hello')
+ self.assertEqual('hello'.encode('utf-8'), b'hello')
self.assertEqual('hello'.encode('utf-16-le'), b'h\000e\000l\000l\000o\000')
self.assertEqual('hello'.encode('utf-16-be'), b'\000h\000e\000l\000l\000o')
self.assertEqual('hello'.encode('latin-1'), b'hello')
+ # Default encoding is utf-8
+ self.assertEqual('\u2603'.encode(), b'\xe2\x98\x83')
+
# Roundtrip safety for BMP (just the first 1024 chars)
for c in range(1024):
u = chr(c)
for encoding in ('utf-7', 'utf-8', 'utf-16', 'utf-16-le',
'utf-16-be', 'raw_unicode_escape',
'unicode_escape', 'unicode_internal'):
- self.assertEqual(str(u.encode(encoding),encoding), u)
+ with warnings.catch_warnings():
+ # unicode-internal has been deprecated
+ warnings.simplefilter("ignore", DeprecationWarning)
+
+ self.assertEqual(str(u.encode(encoding),encoding), u)
# Roundtrip safety for BMP (just the first 256 chars)
for c in range(256):
@@ -1442,18 +1741,20 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual(str(u.encode(encoding),encoding), u)
# Roundtrip safety for non-BMP (just a few chars)
- u = '\U00010001\U00020002\U00030003\U00040004\U00050005'
- for encoding in ('utf-8', 'utf-16', 'utf-16-le', 'utf-16-be',
- #'raw_unicode_escape',
- 'unicode_escape', 'unicode_internal'):
- self.assertEqual(str(u.encode(encoding),encoding), u)
+ with warnings.catch_warnings():
+ # unicode-internal has been deprecated
+ warnings.simplefilter("ignore", DeprecationWarning)
+
+ u = '\U00010001\U00020002\U00030003\U00040004\U00050005'
+ for encoding in ('utf-8', 'utf-16', 'utf-16-le', 'utf-16-be',
+ 'raw_unicode_escape',
+ 'unicode_escape', 'unicode_internal'):
+ self.assertEqual(str(u.encode(encoding),encoding), u)
- # UTF-8 must be roundtrip safe for all UCS-2 code points
- # This excludes surrogates: in the full range, there would be
- # a surrogate pair (\udbff\udc00), which gets converted back
- # to a non-BMP character (\U0010fc00)
- u = ''.join(map(chr, list(range(0,0xd800)) +
- list(range(0xe000,0x10000))))
+ # UTF-8 must be roundtrip safe for all code points
+ # (except surrogates, which are forbidden).
+ u = ''.join(map(chr, list(range(0, 0xd800)) +
+ list(range(0xe000, 0x110000))))
for encoding in ('utf-8',):
self.assertEqual(str(u.encode(encoding),encoding), u)
@@ -1638,17 +1939,39 @@ class UnicodeTest(string_tests.CommonTest,
return
self.assertRaises(OverflowError, 't\tt\t'.expandtabs, sys.maxsize)
+ @support.cpython_only
+ def test_expandtabs_optimization(self):
+ s = 'abc'
+ self.assertIs(s.expandtabs(), s)
+
def test_raiseMemError(self):
- # Ensure that the freelist contains a consistent object, even
- # when a string allocation fails with a MemoryError.
- # This used to crash the interpreter,
- # or leak references when the number was smaller.
- charwidth = 4 if sys.maxunicode >= 0x10000 else 2
- # Note: sys.maxsize is half of the actual max allocation because of
- # the signedness of Py_ssize_t.
- alloc = lambda: "a" * (sys.maxsize // charwidth * 2)
- self.assertRaises(MemoryError, alloc)
- self.assertRaises(MemoryError, alloc)
+ if struct.calcsize('P') == 8:
+ # 64 bits pointers
+ ascii_struct_size = 48
+ compact_struct_size = 72
+ else:
+ # 32 bits pointers
+ ascii_struct_size = 24
+ compact_struct_size = 36
+
+ for char in ('a', '\xe9', '\u20ac', '\U0010ffff'):
+ code = ord(char)
+ if code < 0x100:
+ char_size = 1 # sizeof(Py_UCS1)
+ struct_size = ascii_struct_size
+ elif code < 0x10000:
+ char_size = 2 # sizeof(Py_UCS2)
+ struct_size = compact_struct_size
+ else:
+ char_size = 4 # sizeof(Py_UCS4)
+ struct_size = compact_struct_size
+ # Note: sys.maxsize is half of the actual max allocation because of
+ # the signedness of Py_ssize_t. Strings of maxlen-1 should in principle
+ # be allocatable, given enough memory.
+ maxlen = ((sys.maxsize - struct_size) // char_size)
+ alloc = lambda: char * maxlen
+ self.assertRaises(MemoryError, alloc)
+ self.assertRaises(MemoryError, alloc)
def test_format_subclass(self):
class S(str):
@@ -1661,11 +1984,10 @@ class UnicodeTest(string_tests.CommonTest,
# Test PyUnicode_FromFormat()
def test_from_format(self):
support.import_module('ctypes')
- from ctypes import pythonapi, py_object, c_int
- if sys.maxunicode == 65535:
- name = "PyUnicodeUCS2_FromFormat"
- else:
- name = "PyUnicodeUCS4_FromFormat"
+ from ctypes import (pythonapi, py_object,
+ c_int, c_long, c_longlong, c_ssize_t,
+ c_uint, c_ulong, c_ulonglong, c_size_t)
+ name = "PyUnicode_FromFormat"
_PyUnicode_FromFormat = getattr(pythonapi, name)
_PyUnicode_FromFormat.restype = py_object
@@ -1686,13 +2008,40 @@ class UnicodeTest(string_tests.CommonTest,
'string, got a non-ASCII byte: 0xe9$',
PyUnicode_FromFormat, b'unicode\xe9=%s', 'ascii')
+ # test "%c"
self.assertEqual(PyUnicode_FromFormat(b'%c', c_int(0xabcd)), '\uabcd')
self.assertEqual(PyUnicode_FromFormat(b'%c', c_int(0x10ffff)), '\U0010ffff')
- # other tests
+ # test "%"
+ self.assertEqual(PyUnicode_FromFormat(b'%'), '%')
+ self.assertEqual(PyUnicode_FromFormat(b'%%'), '%')
+ self.assertEqual(PyUnicode_FromFormat(b'%%s'), '%s')
+ self.assertEqual(PyUnicode_FromFormat(b'[%%]'), '[%]')
+ self.assertEqual(PyUnicode_FromFormat(b'%%%s', b'abc'), '%abc')
+
+ # test integer formats (%i, %d, %u)
+ self.assertEqual(PyUnicode_FromFormat(b'%03i', c_int(10)), '010')
+ self.assertEqual(PyUnicode_FromFormat(b'%0.4i', c_int(10)), '0010')
+ self.assertEqual(PyUnicode_FromFormat(b'%i', c_int(-123)), '-123')
+ self.assertEqual(PyUnicode_FromFormat(b'%li', c_long(-123)), '-123')
+ self.assertEqual(PyUnicode_FromFormat(b'%lli', c_longlong(-123)), '-123')
+ self.assertEqual(PyUnicode_FromFormat(b'%zi', c_ssize_t(-123)), '-123')
+
+ self.assertEqual(PyUnicode_FromFormat(b'%d', c_int(-123)), '-123')
+ self.assertEqual(PyUnicode_FromFormat(b'%ld', c_long(-123)), '-123')
+ self.assertEqual(PyUnicode_FromFormat(b'%lld', c_longlong(-123)), '-123')
+ self.assertEqual(PyUnicode_FromFormat(b'%zd', c_ssize_t(-123)), '-123')
+
+ self.assertEqual(PyUnicode_FromFormat(b'%u', c_uint(123)), '123')
+ self.assertEqual(PyUnicode_FromFormat(b'%lu', c_ulong(123)), '123')
+ self.assertEqual(PyUnicode_FromFormat(b'%llu', c_ulonglong(123)), '123')
+ self.assertEqual(PyUnicode_FromFormat(b'%zu', c_size_t(123)), '123')
+
+ # test %A
text = PyUnicode_FromFormat(b'%%A:%A', 'abc\xe9\uabcd\U0010ffff')
self.assertEqual(text, r"%A:'abc\xe9\uabcd\U0010ffff'")
+ # test %V
text = PyUnicode_FromFormat(b'repr=%V', 'abc', b'xyz')
self.assertEqual(text, 'repr=abc')
@@ -1706,6 +2055,13 @@ class UnicodeTest(string_tests.CommonTest,
text = PyUnicode_FromFormat(b'repr=%V', None, b'abc\xff')
self.assertEqual(text, 'repr=abc\ufffd')
+ # not supported: copy the raw format string. these tests are just here
+ # to check for crashs and should not be considered as specifications
+ self.assertEqual(PyUnicode_FromFormat(b'%1%s', b'abc'), '%s')
+ self.assertEqual(PyUnicode_FromFormat(b'%1abc'), '%1abc')
+ self.assertEqual(PyUnicode_FromFormat(b'%+i', c_int(10)), '%+i')
+ self.assertEqual(PyUnicode_FromFormat(b'%.%s', b'abc'), '%.%s')
+
# Test PyUnicode_AsWideChar()
def test_aswidechar(self):
from _testcapi import unicode_aswidechar
@@ -1766,6 +2122,66 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual(size, nchar)
self.assertEqual(wchar, nonbmp + '\0')
+ def test_subclass_add(self):
+ class S(str):
+ def __add__(self, o):
+ return "3"
+ self.assertEqual(S("4") + S("5"), "3")
+ class S(str):
+ def __iadd__(self, o):
+ return "3"
+ s = S("1")
+ s += "4"
+ self.assertEqual(s, "3")
+
+ def test_encode_decimal(self):
+ from _testcapi import unicode_encodedecimal
+ self.assertEqual(unicode_encodedecimal('123'),
+ b'123')
+ self.assertEqual(unicode_encodedecimal('\u0663.\u0661\u0664'),
+ b'3.14')
+ self.assertEqual(unicode_encodedecimal("\N{EM SPACE}3.14\N{EN SPACE}"),
+ b' 3.14 ')
+ self.assertRaises(UnicodeEncodeError,
+ unicode_encodedecimal, "123\u20ac", "strict")
+ self.assertRaisesRegex(
+ ValueError,
+ "^'decimal' codec can't encode character",
+ unicode_encodedecimal, "123\u20ac", "replace")
+
+ def test_transform_decimal(self):
+ from _testcapi import unicode_transformdecimaltoascii as transform_decimal
+ self.assertEqual(transform_decimal('123'),
+ '123')
+ self.assertEqual(transform_decimal('\u0663.\u0661\u0664'),
+ '3.14')
+ self.assertEqual(transform_decimal("\N{EM SPACE}3.14\N{EN SPACE}"),
+ "\N{EM SPACE}3.14\N{EN SPACE}")
+ self.assertEqual(transform_decimal('123\u20ac'),
+ '123\u20ac')
+
+ def test_getnewargs(self):
+ text = 'abc'
+ args = text.__getnewargs__()
+ self.assertIsNot(args[0], text)
+ self.assertEqual(args[0], text)
+ self.assertEqual(len(args), 1)
+
+ def test_resize(self):
+ for length in range(1, 100, 7):
+ # generate a fresh string (refcount=1)
+ text = 'a' * length + 'b'
+
+ # fill wstr internal field
+ abc = text.encode('unicode_internal')
+ self.assertEqual(abc.decode('unicode_internal'), text)
+
+ # resize text: wstr field must be cleared and then recomputed
+ text += 'c'
+ abcdef = text.encode('unicode_internal')
+ self.assertNotEqual(abc, abcdef)
+ self.assertEqual(abcdef.decode('unicode_internal'), text)
+
class StringModuleTest(unittest.TestCase):
def test_formatter_parser(self):
@@ -1817,45 +2233,6 @@ class StringModuleTest(unittest.TestCase):
]])
self.assertRaises(TypeError, _string.formatter_field_name_split, 1)
- def test_encode_decimal(self):
- from _testcapi import unicode_encodedecimal
- self.assertEqual(unicode_encodedecimal('123'),
- b'123')
- self.assertEqual(unicode_encodedecimal('\u0663.\u0661\u0664'),
- b'3.14')
- self.assertEqual(unicode_encodedecimal("\N{EM SPACE}3.14\N{EN SPACE}"),
- b' 3.14 ')
- self.assertRaises(UnicodeEncodeError,
- unicode_encodedecimal, "123\u20ac", "strict")
- self.assertEqual(unicode_encodedecimal("123\u20ac", "replace"),
- b'123?')
- self.assertEqual(unicode_encodedecimal("123\u20ac", "ignore"),
- b'123')
- self.assertEqual(unicode_encodedecimal("123\u20ac", "xmlcharrefreplace"),
- b'123&#8364;')
- self.assertEqual(unicode_encodedecimal("123\u20ac", "backslashreplace"),
- b'123\\u20ac')
- self.assertEqual(unicode_encodedecimal("123\u20ac\N{EM SPACE}", "replace"),
- b'123? ')
- self.assertEqual(unicode_encodedecimal("123\u20ac\u20ac", "replace"),
- b'123??')
- self.assertEqual(unicode_encodedecimal("123\u20ac\u0660", "replace"),
- b'123?0')
-
- def test_transform_decimal(self):
- from _testcapi import unicode_transformdecimaltoascii as transform_decimal
- self.assertEqual(transform_decimal('123'),
- '123')
- self.assertEqual(transform_decimal('\u0663.\u0661\u0664'),
- '3.14')
- self.assertEqual(transform_decimal("\N{EM SPACE}3.14\N{EN SPACE}"),
- "\N{EM SPACE}3.14\N{EN SPACE}")
- self.assertEqual(transform_decimal('123\u20ac'),
- '123\u20ac')
-
-
-def test_main():
- support.run_unittest(__name__)
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_unicode_file.py b/Lib/test/test_unicode_file.py
index 6c2011a..faa8da3 100644
--- a/Lib/test/test_unicode_file.py
+++ b/Lib/test/test_unicode_file.py
@@ -6,7 +6,7 @@ import unicodedata
import unittest
from test.support import (run_unittest, rmtree,
- TESTFN_ENCODING, TESTFN_UNICODE, TESTFN_UNENCODABLE)
+ TESTFN_ENCODING, TESTFN_UNICODE, TESTFN_UNENCODABLE, create_empty_file)
if not os.path.supports_unicode_filenames:
try:
@@ -56,16 +56,20 @@ class TestUnicodeFiles(unittest.TestCase):
# Should be able to rename the file using either name.
self.assertTrue(os.path.isfile(filename1)) # must exist.
os.rename(filename1, filename2 + ".new")
- self.assertTrue(os.path.isfile(filename1+".new"))
+ self.assertFalse(os.path.isfile(filename2))
+ self.assertTrue(os.path.isfile(filename1 + '.new'))
os.rename(filename1 + ".new", filename2)
+ self.assertFalse(os.path.isfile(filename1 + '.new'))
self.assertTrue(os.path.isfile(filename2))
shutil.copy(filename1, filename2 + ".new")
os.unlink(filename1 + ".new") # remove using equiv name.
# And a couple of moves, one using each name.
shutil.move(filename1, filename2 + ".new")
- self.assertTrue(not os.path.exists(filename2))
+ self.assertFalse(os.path.exists(filename2))
+ self.assertTrue(os.path.exists(filename1 + '.new'))
shutil.move(filename1 + ".new", filename2)
+ self.assertFalse(os.path.exists(filename2 + '.new'))
self.assertTrue(os.path.exists(filename1))
# Note - due to the implementation of shutil.move,
# it tries a rename first. This only fails on Windows when on
@@ -73,10 +77,12 @@ class TestUnicodeFiles(unittest.TestCase):
# So we test the shutil.copy2 function, which is the thing most
# likely to fail.
shutil.copy2(filename1, filename2 + ".new")
+ self.assertTrue(os.path.isfile(filename1 + '.new'))
os.unlink(filename1 + ".new")
+ self.assertFalse(os.path.exists(filename2 + '.new'))
def _do_directory(self, make_name, chdir_name):
- cwd = os.getcwdb()
+ cwd = os.getcwd()
if os.path.isdir(make_name):
rmtree(make_name)
os.mkdir(make_name)
@@ -99,8 +105,7 @@ class TestUnicodeFiles(unittest.TestCase):
# top-level 'test' functions would be if they could take params
def _test_single(self, filename):
remove_if_exists(filename)
- f = open(filename, "w")
- f.close()
+ create_empty_file(filename)
try:
self._do_single(filename)
finally:
diff --git a/Lib/test/test_unicodedata.py b/Lib/test/test_unicodedata.py
index 9744256..99aa003 100644
--- a/Lib/test/test_unicodedata.py
+++ b/Lib/test/test_unicodedata.py
@@ -21,7 +21,7 @@ errors = 'surrogatepass'
class UnicodeMethodsTest(unittest.TestCase):
# update this, if the database changes
- expectedchecksum = '21b90f1aed00081b81ca7942b22196af090015a0'
+ expectedchecksum = 'bf7a78f1a532421b5033600102e23a92044dbba9'
def test_method_checksum(self):
h = hashlib.sha1()
@@ -80,7 +80,7 @@ class UnicodeDatabaseTest(unittest.TestCase):
class UnicodeFunctionsTest(UnicodeDatabaseTest):
# update this, if the database changes
- expectedchecksum = 'c23dfc0b5eaf3ca2aad32d733de96bb182ccda50'
+ expectedchecksum = '17fe2f12b788e4fff5479b469c4404bb6ecf841f'
def test_function_checksum(self):
data = []
h = hashlib.sha1()
@@ -108,6 +108,7 @@ class UnicodeFunctionsTest(UnicodeDatabaseTest):
self.assertEqual(self.db.digit('\u215b', None), None)
self.assertEqual(self.db.digit('\u2468'), 9)
self.assertEqual(self.db.digit('\U00020000', None), None)
+ self.assertEqual(self.db.digit('\U0001D7FD'), 7)
self.assertRaises(TypeError, self.db.digit)
self.assertRaises(TypeError, self.db.digit, 'xx')
@@ -120,6 +121,7 @@ class UnicodeFunctionsTest(UnicodeDatabaseTest):
self.assertEqual(self.db.numeric('\u2468'), 9.0)
self.assertEqual(self.db.numeric('\ua627'), 7.0)
self.assertEqual(self.db.numeric('\U00020000', None), None)
+ self.assertEqual(self.db.numeric('\U0001012A'), 9000)
self.assertRaises(TypeError, self.db.numeric)
self.assertRaises(TypeError, self.db.numeric, 'xx')
@@ -131,6 +133,7 @@ class UnicodeFunctionsTest(UnicodeDatabaseTest):
self.assertEqual(self.db.decimal('\u215b', None), None)
self.assertEqual(self.db.decimal('\u2468', None), None)
self.assertEqual(self.db.decimal('\U00020000', None), None)
+ self.assertEqual(self.db.decimal('\U0001D7FD'), 7)
self.assertRaises(TypeError, self.db.decimal)
self.assertRaises(TypeError, self.db.decimal, 'xx')
@@ -141,6 +144,7 @@ class UnicodeFunctionsTest(UnicodeDatabaseTest):
self.assertEqual(self.db.category('a'), 'Ll')
self.assertEqual(self.db.category('A'), 'Lu')
self.assertEqual(self.db.category('\U00020000'), 'Lo')
+ self.assertEqual(self.db.category('\U0001012A'), 'No')
self.assertRaises(TypeError, self.db.category)
self.assertRaises(TypeError, self.db.category, 'xx')
@@ -308,14 +312,6 @@ class UnicodeMiscTest(UnicodeDatabaseTest):
self.assertEqual(len(lines), 1,
r"\u%.4x should not be a linebreak" % i)
- def test_UCS4(self):
- # unicodedata should work with code points outside the BMP
- # even on a narrow Unicode build
- self.assertEqual(self.db.category("\U0001012A"), "No")
- self.assertEqual(self.db.numeric("\U0001012A"), 9000)
- self.assertEqual(self.db.decimal("\U0001D7FD"), 7)
- self.assertEqual(self.db.digit("\U0001D7FD"), 7)
-
def test_main():
test.support.run_unittest(
UnicodeMiscTest,
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 3fc499e..52e7749 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -270,10 +270,9 @@ Content-Type: text/html; charset=iso-8859-1
def test_missing_localfile(self):
# Test for #10836
- with self.assertRaises(urllib.error.URLError) as e:
+ # 3.3 - URLError is not captured, explicit IOError is raised.
+ with self.assertRaises(IOError):
urlopen('file://localhost/a/file/which/doesnot/exists.py')
- self.assertTrue(e.exception.filename)
- self.assertTrue(e.exception.reason)
def test_file_notexists(self):
fd, tmp_file = tempfile.mkstemp()
@@ -286,21 +285,20 @@ Content-Type: text/html; charset=iso-8859-1
os.close(fd)
os.unlink(tmp_file)
self.assertFalse(os.path.exists(tmp_file))
- with self.assertRaises(urllib.error.URLError):
+ # 3.3 - IOError instead of URLError
+ with self.assertRaises(IOError):
urlopen(tmp_fileurl)
def test_ftp_nohost(self):
test_ftp_url = 'ftp:///path'
- with self.assertRaises(urllib.error.URLError) as e:
+ # 3.3 - IOError instead of URLError
+ with self.assertRaises(IOError):
urlopen(test_ftp_url)
- self.assertFalse(e.exception.filename)
- self.assertTrue(e.exception.reason)
def test_ftp_nonexisting(self):
- with self.assertRaises(urllib.error.URLError) as e:
+ # 3.3 - IOError instead of URLError
+ with self.assertRaises(IOError):
urlopen('ftp://localhost/a/file/which/doesnot/exists.py')
- self.assertFalse(e.exception.filename)
- self.assertTrue(e.exception.reason)
def test_userpass_inurl(self):
@@ -333,6 +331,10 @@ Content-Type: text/html; charset=iso-8859-1
finally:
self.unfakehttp()
+ def test_URLopener_deprecation(self):
+ with support.check_warnings(('',DeprecationWarning)):
+ urllib.request.URLopener()
+
class urlretrieve_FileTests(unittest.TestCase):
"""Test urllib.urlretrieve() on local files"""
@@ -366,7 +368,7 @@ class urlretrieve_FileTests(unittest.TestCase):
def constructLocalFileUrl(self, filePath):
filePath = os.path.abspath(filePath)
try:
- filePath.encode("utf8")
+ filePath.encode("utf-8")
except UnicodeEncodeError:
raise unittest.SkipTest("filePath is not encodable to utf8")
return "file://%s" % urllib.request.pathname2url(filePath)
@@ -419,11 +421,11 @@ class urlretrieve_FileTests(unittest.TestCase):
def test_reporthook(self):
# Make sure that the reporthook works.
- def hooktester(count, block_size, total_size, count_holder=[0]):
- self.assertIsInstance(count, int)
- self.assertIsInstance(block_size, int)
- self.assertIsInstance(total_size, int)
- self.assertEqual(count, count_holder[0])
+ def hooktester(block_count, block_read_size, file_size, count_holder=[0]):
+ self.assertIsInstance(block_count, int)
+ self.assertIsInstance(block_read_size, int)
+ self.assertIsInstance(file_size, int)
+ self.assertEqual(block_count, count_holder[0])
count_holder[0] = count_holder[0] + 1
second_temp = "%s.2" % support.TESTFN
self.registerFileForCleanUp(second_temp)
@@ -434,8 +436,8 @@ class urlretrieve_FileTests(unittest.TestCase):
def test_reporthook_0_bytes(self):
# Test on zero length file. Should call reporthook only 1 time.
report = []
- def hooktester(count, block_size, total_size, _report=report):
- _report.append((count, block_size, total_size))
+ def hooktester(block_count, block_read_size, file_size, _report=report):
+ _report.append((block_count, block_read_size, file_size))
srcFileName = self.createNewTempFile()
urllib.request.urlretrieve(self.constructLocalFileUrl(srcFileName),
support.TESTFN, hooktester)
@@ -445,31 +447,32 @@ class urlretrieve_FileTests(unittest.TestCase):
def test_reporthook_5_bytes(self):
# Test on 5 byte file. Should call reporthook only 2 times (once when
# the "network connection" is established and once when the block is
- # read). Since the block size is 8192 bytes, only one block read is
- # required to read the entire file.
+ # read).
report = []
- def hooktester(count, block_size, total_size, _report=report):
- _report.append((count, block_size, total_size))
+ def hooktester(block_count, block_read_size, file_size, _report=report):
+ _report.append((block_count, block_read_size, file_size))
srcFileName = self.createNewTempFile(b"x" * 5)
urllib.request.urlretrieve(self.constructLocalFileUrl(srcFileName),
support.TESTFN, hooktester)
self.assertEqual(len(report), 2)
- self.assertEqual(report[0][1], 8192)
self.assertEqual(report[0][2], 5)
+ self.assertEqual(report[1][2], 5)
def test_reporthook_8193_bytes(self):
# Test on 8193 byte file. Should call reporthook only 3 times (once
# when the "network connection" is established, once for the next 8192
# bytes, and once for the last byte).
report = []
- def hooktester(count, block_size, total_size, _report=report):
- _report.append((count, block_size, total_size))
+ def hooktester(block_count, block_read_size, file_size, _report=report):
+ _report.append((block_count, block_read_size, file_size))
srcFileName = self.createNewTempFile(b"x" * 8193)
urllib.request.urlretrieve(self.constructLocalFileUrl(srcFileName),
support.TESTFN, hooktester)
self.assertEqual(len(report), 3)
- self.assertEqual(report[0][1], 8192)
self.assertEqual(report[0][2], 8193)
+ self.assertEqual(report[0][1], 8192)
+ self.assertEqual(report[1][1], 8192)
+ self.assertEqual(report[2][1], 8192)
class urlretrieve_HttpTests(unittest.TestCase, FakeHTTPMixin):
@@ -1280,6 +1283,28 @@ class URLopener_Tests(unittest.TestCase):
# self.assertEqual(ftp.ftp.sock.gettimeout(), 30)
# ftp.close()
+class RequestTests(unittest.TestCase):
+ """Unit tests for urllib.request.Request."""
+
+ def test_default_values(self):
+ Request = urllib.request.Request
+ request = Request("http://www.python.org")
+ self.assertEqual(request.get_method(), 'GET')
+ request = Request("http://www.python.org", {})
+ self.assertEqual(request.get_method(), 'POST')
+
+ def test_with_method_arg(self):
+ Request = urllib.request.Request
+ request = Request("http://www.python.org", method='HEAD')
+ self.assertEqual(request.method, 'HEAD')
+ self.assertEqual(request.get_method(), 'HEAD')
+ request = Request("http://www.python.org", {}, method='HEAD')
+ self.assertEqual(request.method, 'HEAD')
+ self.assertEqual(request.get_method(), 'HEAD')
+ request = Request("http://www.python.org", method='GET')
+ self.assertEqual(request.get_method(), 'GET')
+ request.method = 'HEAD'
+ self.assertEqual(request.get_method(), 'HEAD')
def test_main():
@@ -1296,6 +1321,7 @@ def test_main():
Utility_Tests,
URLopener_Tests,
#FTPWrapperTests,
+ RequestTests,
)
diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py
index 5eab30a..ccd5419 100644
--- a/Lib/test/test_urllib2.py
+++ b/Lib/test/test_urllib2.py
@@ -5,6 +5,7 @@ import os
import io
import socket
import array
+import sys
import urllib.request
# The proxy bypass method imported below has logic specific to the OSX
@@ -18,6 +19,22 @@ import urllib.error
# parse_keqv_list, parse_http_list, HTTPDigestAuthHandler
class TrivialTests(unittest.TestCase):
+
+ def test___all__(self):
+ # Verify which names are exposed
+ for module in 'request', 'response', 'parse', 'error', 'robotparser':
+ context = {}
+ exec('from urllib.%s import *' % module, context)
+ del context['__builtins__']
+ if module == 'request' and os.name == 'nt':
+ u, p = context.pop('url2pathname'), context.pop('pathname2url')
+ self.assertEqual(u.__module__, 'nturl2path')
+ self.assertEqual(p.__module__, 'nturl2path')
+ for k, v in context.items():
+ self.assertEqual(v.__module__, 'urllib.%s' % module,
+ "%r is exposed in 'urllib.%s' but defined in %r" %
+ (k, module, v.__module__))
+
def test_trivial(self):
# A couple trivial tests
@@ -536,10 +553,6 @@ class OpenerDirectorTests(unittest.TestCase):
self.assertRaises(urllib.error.URLError, o.open, req)
self.assertEqual(o.calls, [(handlers[0], "http_open", (req,), {})])
-## def test_error(self):
-## # XXX this doesn't actually seem to be used in standard library,
-## # but should really be tested anyway...
-
def test_http_error(self):
# XXX http_error_default
# http errors are a special case
@@ -567,6 +580,7 @@ class OpenerDirectorTests(unittest.TestCase):
self.assertEqual((handler, method_name), got[:2])
self.assertEqual(args, got[2])
+
def test_processors(self):
# *_request / *_response methods get called appropriately
o = OpenerDirector()
@@ -602,10 +616,30 @@ class OpenerDirectorTests(unittest.TestCase):
self.assertTrue(args[1] is None or
isinstance(args[1], MockResponse))
+ def test_method_deprecations(self):
+ req = Request("http://www.example.com")
+
+ with self.assertWarns(DeprecationWarning):
+ req.add_data("data")
+ with self.assertWarns(DeprecationWarning):
+ req.get_data()
+ with self.assertWarns(DeprecationWarning):
+ req.has_data()
+ with self.assertWarns(DeprecationWarning):
+ req.get_host()
+ with self.assertWarns(DeprecationWarning):
+ req.get_selector()
+ with self.assertWarns(DeprecationWarning):
+ req.is_unverifiable()
+ with self.assertWarns(DeprecationWarning):
+ req.get_origin_req_host()
+ with self.assertWarns(DeprecationWarning):
+ req.get_type()
+
def sanepathname2url(path):
try:
- path.encode("utf8")
+ path.encode("utf-8")
except UnicodeEncodeError:
raise unittest.SkipTest("path is not encodable to utf8")
urlpath = urllib.request.pathname2url(path)
@@ -957,8 +991,8 @@ class HandlerTests(unittest.TestCase):
newreq = h.http_request(req)
self.assertIs(cj.ach_req, req)
self.assertIs(cj.ach_req, newreq)
- self.assertEqual(req.get_origin_req_host(), "example.com")
- self.assertFalse(req.is_unverifiable())
+ self.assertEqual(req.origin_req_host, "example.com")
+ self.assertFalse(req.unverifiable)
newr = h.http_response(req, r)
self.assertIs(cj.ec_req, req)
self.assertIs(cj.ec_r, r)
@@ -990,7 +1024,7 @@ class HandlerTests(unittest.TestCase):
try:
self.assertEqual(o.req.get_method(), "GET")
except AttributeError:
- self.assertFalse(o.req.has_data())
+ self.assertFalse(o.req.data)
# now it's a GET, there should not be headers regarding content
# (possibly dragged from before being a POST)
@@ -1106,9 +1140,9 @@ class HandlerTests(unittest.TestCase):
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("http://acme.example.com/")
- self.assertEqual(req.get_host(), "acme.example.com")
+ self.assertEqual(req.host, "acme.example.com")
r = o.open(req)
- self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual(req.host, "proxy.example.com:3128")
self.assertEqual([(handlers[0], "http_open")],
[tup[0:2] for tup in o.calls])
@@ -1119,13 +1153,13 @@ class HandlerTests(unittest.TestCase):
ph = urllib.request.ProxyHandler(dict(http="proxy.example.com"))
o.add_handler(ph)
req = Request("http://www.perl.org/")
- self.assertEqual(req.get_host(), "www.perl.org")
+ self.assertEqual(req.host, "www.perl.org")
r = o.open(req)
- self.assertEqual(req.get_host(), "proxy.example.com")
+ self.assertEqual(req.host, "proxy.example.com")
req = Request("http://www.python.org")
- self.assertEqual(req.get_host(), "www.python.org")
+ self.assertEqual(req.host, "www.python.org")
r = o.open(req)
- self.assertEqual(req.get_host(), "www.python.org")
+ self.assertEqual(req.host, "www.python.org")
del os.environ['no_proxy']
def test_proxy_no_proxy_all(self):
@@ -1134,9 +1168,9 @@ class HandlerTests(unittest.TestCase):
ph = urllib.request.ProxyHandler(dict(http="proxy.example.com"))
o.add_handler(ph)
req = Request("http://www.python.org")
- self.assertEqual(req.get_host(), "www.python.org")
+ self.assertEqual(req.host, "www.python.org")
r = o.open(req)
- self.assertEqual(req.get_host(), "www.python.org")
+ self.assertEqual(req.host, "www.python.org")
del os.environ['no_proxy']
@@ -1150,9 +1184,9 @@ class HandlerTests(unittest.TestCase):
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("https://www.example.com/")
- self.assertEqual(req.get_host(), "www.example.com")
+ self.assertEqual(req.host, "www.example.com")
r = o.open(req)
- self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual(req.host, "proxy.example.com:3128")
self.assertEqual([(handlers[0], "https_open")],
[tup[0:2] for tup in o.calls])
@@ -1165,7 +1199,7 @@ class HandlerTests(unittest.TestCase):
req = Request("https://www.example.com/")
req.add_header("Proxy-Authorization","FooBar")
req.add_header("User-Agent","Grail")
- self.assertEqual(req.get_host(), "www.example.com")
+ self.assertEqual(req.host, "www.example.com")
self.assertIsNone(req._tunnel_host)
r = o.open(req)
# Verify Proxy-Authorization gets tunneled to request.
@@ -1176,9 +1210,11 @@ class HandlerTests(unittest.TestCase):
self.assertIn(("User-Agent","Grail"),
https_handler.httpconn.req_headers)
self.assertIsNotNone(req._tunnel_host)
- self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual(req.host, "proxy.example.com:3128")
self.assertEqual(req.get_header("Proxy-authorization"),"FooBar")
+ # TODO: This should be only for OSX
+ @unittest.skipUnless(sys.platform == 'darwin', "only relevant for OSX")
def test_osx_proxy_bypass(self):
bypass = {
'exclude_simple': False,
@@ -1298,6 +1334,26 @@ class HandlerTests(unittest.TestCase):
# _test_basic_auth called .open() twice)
self.assertEqual(opener.recorded, ["digest", "basic"]*2)
+ def test_unsupported_auth_digest_handler(self):
+ opener = OpenerDirector()
+ # While using DigestAuthHandler
+ digest_auth_handler = urllib.request.HTTPDigestAuthHandler(None)
+ http_handler = MockHTTPHandler(
+ 401, 'WWW-Authenticate: Kerberos\r\n\r\n')
+ opener.add_handler(digest_auth_handler)
+ opener.add_handler(http_handler)
+ self.assertRaises(ValueError,opener.open,"http://www.example.com")
+
+ def test_unsupported_auth_basic_handler(self):
+ # While using BasicAuthHandler
+ opener = OpenerDirector()
+ basic_auth_handler = urllib.request.HTTPBasicAuthHandler(None)
+ http_handler = MockHTTPHandler(
+ 401, 'WWW-Authenticate: NTLM\r\n\r\n')
+ opener.add_handler(basic_auth_handler)
+ opener.add_handler(http_handler)
+ self.assertRaises(ValueError,opener.open,"http://www.example.com")
+
def _test_basic_auth(self, opener, auth_handler, auth_header,
realm, http_handler, password_manager,
request_url, protected_url):
@@ -1335,6 +1391,7 @@ class HandlerTests(unittest.TestCase):
self.assertEqual(len(http_handler.requests), 1)
self.assertFalse(http_handler.requests[0].has_header(auth_header))
+
class MiscTests(unittest.TestCase):
def test_build_opener(self):
@@ -1390,11 +1447,11 @@ class RequestTests(unittest.TestCase):
self.assertEqual("POST", self.post.get_method())
self.assertEqual("GET", self.get.get_method())
- def test_add_data(self):
- self.assertFalse(self.get.has_data())
+ def test_data(self):
+ self.assertFalse(self.get.data)
self.assertEqual("GET", self.get.get_method())
- self.get.add_data("spam")
- self.assertTrue(self.get.has_data())
+ self.get.data = "spam"
+ self.assertTrue(self.get.data)
self.assertEqual("POST", self.get.get_method())
def test_get_full_url(self):
@@ -1402,36 +1459,36 @@ class RequestTests(unittest.TestCase):
self.get.get_full_url())
def test_selector(self):
- self.assertEqual("/~jeremy/", self.get.get_selector())
+ self.assertEqual("/~jeremy/", self.get.selector)
req = Request("http://www.python.org/")
- self.assertEqual("/", req.get_selector())
+ self.assertEqual("/", req.selector)
def test_get_type(self):
- self.assertEqual("http", self.get.get_type())
+ self.assertEqual("http", self.get.type)
def test_get_host(self):
- self.assertEqual("www.python.org", self.get.get_host())
+ self.assertEqual("www.python.org", self.get.host)
def test_get_host_unquote(self):
req = Request("http://www.%70ython.org/")
- self.assertEqual("www.python.org", req.get_host())
+ self.assertEqual("www.python.org", req.host)
def test_proxy(self):
self.assertFalse(self.get.has_proxy())
self.get.set_proxy("www.perl.org", "http")
self.assertTrue(self.get.has_proxy())
- self.assertEqual("www.python.org", self.get.get_origin_req_host())
- self.assertEqual("www.perl.org", self.get.get_host())
+ self.assertEqual("www.python.org", self.get.origin_req_host)
+ self.assertEqual("www.perl.org", self.get.host)
def test_wrapped_url(self):
req = Request("<URL:http://www.python.org>")
- self.assertEqual("www.python.org", req.get_host())
+ self.assertEqual("www.python.org", req.host)
def test_url_fragment(self):
req = Request("http://www.python.org/?qs=query#fragment=true")
- self.assertEqual("/?qs=query", req.get_selector())
+ self.assertEqual("/?qs=query", req.selector)
req = Request("http://www.python.org/#fun=true")
- self.assertEqual("/", req.get_selector())
+ self.assertEqual("/", req.selector)
# Issue 11703: geturl() omits fragment in the original URL.
url = 'http://docs.python.org/library/urllib2.html#OK'
@@ -1443,7 +1500,9 @@ class RequestTests(unittest.TestCase):
Issue 13211 reveals that HTTPError didn't implement the URLError
interface even though HTTPError is a subclass of URLError.
- >>> err = urllib.error.HTTPError(msg='something bad happened', url=None, code=None, hdrs=None, fp=None)
+ >>> msg = 'something bad happened'
+ >>> url = code = hdrs = fp = None
+ >>> err = urllib.error.HTTPError(url, code, msg, hdrs, fp)
>>> assert hasattr(err, 'reason')
>>> err.reason
'something bad happened'
diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py
index f07471d..1eba14b 100644
--- a/Lib/test/test_urllib2_localnet.py
+++ b/Lib/test/test_urllib2_localnet.py
@@ -476,6 +476,13 @@ class TestUrlopen(unittest.TestCase):
self.urlopen("https://localhost:%s/bizarre" % handler.port,
cafile=CERT_fakehostname)
+ def test_https_with_cadefault(self):
+ handler = self.start_https_server(certfile=CERT_localhost)
+ # Self-signed cert should fail verification with system certificate store
+ with self.assertRaises(urllib.error.URLError) as cm:
+ self.urlopen("https://localhost:%s/bizarre" % handler.port,
+ cadefault=True)
+
def test_sending_headers(self):
handler = self.start_server()
req = urllib.request.Request("http://localhost:%s/" % handler.port,
diff --git a/Lib/test/test_urllib2net.py b/Lib/test/test_urllib2net.py
index e2b1d95..7f3c93a 100644
--- a/Lib/test/test_urllib2net.py
+++ b/Lib/test/test_urllib2net.py
@@ -83,12 +83,13 @@ class CloseSocketTest(unittest.TestCase):
def test_close(self):
# calling .close() on urllib2's response objects should close the
# underlying socket
-
- response = _urlopen_with_retry("http://www.python.org/")
- sock = response.fp
- self.assertTrue(not sock.closed)
- response.close()
- self.assertTrue(sock.closed)
+ url = "http://www.python.org/"
+ with support.transient_internet(url):
+ response = _urlopen_with_retry(url)
+ sock = response.fp
+ self.assertTrue(not sock.closed)
+ response.close()
+ self.assertTrue(sock.closed)
class OtherNetworkTests(unittest.TestCase):
def setUp(self):
diff --git a/Lib/test/test_urllibnet.py b/Lib/test/test_urllibnet.py
index 383b2af..d3fe69d 100644
--- a/Lib/test/test_urllibnet.py
+++ b/Lib/test/test_urllibnet.py
@@ -137,10 +137,10 @@ class urlretrieveNetworkTests(unittest.TestCase):
"""Tests urllib.request.urlretrieve using the network."""
@contextlib.contextmanager
- def urlretrieve(self, *args):
+ def urlretrieve(self, *args, **kwargs):
resource = args[0]
with support.transient_internet(resource):
- file_location, info = urllib.request.urlretrieve(*args)
+ file_location, info = urllib.request.urlretrieve(*args, **kwargs)
try:
yield file_location, info
finally:
@@ -170,9 +170,10 @@ class urlretrieveNetworkTests(unittest.TestCase):
self.assertIsInstance(info, email.message.Message,
"info is not an instance of email.message.Message")
+ logo = "http://www.python.org/community/logos/python-logo-master-v3-TM.png"
+
def test_data_header(self):
- logo = "http://www.python.org/community/logos/python-logo-master-v3-TM.png"
- with self.urlretrieve(logo) as (file_location, fileheaders):
+ with self.urlretrieve(self.logo) as (file_location, fileheaders):
datevalue = fileheaders.get('Date')
dateformat = '%a, %d %b %Y %H:%M:%S GMT'
try:
@@ -180,6 +181,31 @@ class urlretrieveNetworkTests(unittest.TestCase):
except ValueError:
self.fail('Date value not in %r format', dateformat)
+ def test_reporthook(self):
+ records = []
+ def recording_reporthook(blocks, block_size, total_size):
+ records.append((blocks, block_size, total_size))
+
+ with self.urlretrieve(self.logo, reporthook=recording_reporthook) as (
+ file_location, fileheaders):
+ expected_size = int(fileheaders['Content-Length'])
+
+ records_repr = repr(records) # For use in error messages.
+ self.assertGreater(len(records), 1, msg="There should always be two "
+ "calls; the first one before the transfer starts.")
+ self.assertEqual(records[0][0], 0)
+ self.assertGreater(records[0][1], 0,
+ msg="block size can't be 0 in %s" % records_repr)
+ self.assertEqual(records[0][2], expected_size)
+ self.assertEqual(records[-1][2], expected_size)
+
+ block_sizes = {block_size for _, block_size, _ in records}
+ self.assertEqual({records[0][1]}, block_sizes,
+ msg="block sizes in %s must be equal" % records_repr)
+ self.assertGreaterEqual(records[-1][0]*records[0][1], expected_size,
+ msg="number of blocks * block size must be"
+ " >= total size in %s" % records_repr)
+
def test_main():
support.requires('network')
diff --git a/Lib/test/test_userlist.py b/Lib/test/test_userlist.py
index 868ed24..6381070 100644
--- a/Lib/test/test_userlist.py
+++ b/Lib/test/test_userlist.py
@@ -52,6 +52,12 @@ class UserListTest(list_tests.CommonTest):
return str(key) + '!!!'
self.assertEqual(next(iter(T((1,2)))), "0!!!")
+ def test_userlist_copy(self):
+ u = self.type2test([6, 8, 1, 9, 1])
+ v = u.copy()
+ self.assertEqual(u, v)
+ self.assertEqual(type(u), type(v))
+
def test_main():
support.run_unittest(UserListTest)
diff --git a/Lib/test/test_userstring.py b/Lib/test/test_userstring.py
index 7a8b932..71fcac2 100755
--- a/Lib/test/test_userstring.py
+++ b/Lib/test/test_userstring.py
@@ -3,6 +3,7 @@
# UserString instances should behave similar to builtin string objects.
import string
+import unittest
from test import support, string_tests
from collections import UserString
@@ -10,6 +11,7 @@ from collections import UserString
class UserStringTest(
string_tests.CommonTest,
string_tests.MixinStrUnicodeUserStringTest,
+ unittest.TestCase
):
type2test = UserString
@@ -17,11 +19,11 @@ class UserStringTest(
# Overwrite the three testing methods, because UserString
# can't cope with arguments propagated to UserString
# (and we don't test with subclasses)
- def checkequal(self, result, object, methodname, *args):
+ def checkequal(self, result, object, methodname, *args, **kwargs):
result = self.fixtype(result)
object = self.fixtype(object)
# we don't fix the arguments, because UserString can't cope with it
- realresult = getattr(object, methodname)(*args)
+ realresult = getattr(object, methodname)(*args, **kwargs)
self.assertEqual(
result,
realresult
@@ -42,8 +44,5 @@ class UserStringTest(
getattr(object, methodname)(*args)
-def test_main():
- support.run_unittest(UserStringTest)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_uuid.py b/Lib/test/test_uuid.py
index 43fa656..7bc59ed 100644
--- a/Lib/test/test_uuid.py
+++ b/Lib/test/test_uuid.py
@@ -471,14 +471,14 @@ class TestUUID(TestCase):
if pid == 0:
os.close(fds[0])
value = uuid.uuid4()
- os.write(fds[1], value.hex.encode('latin1'))
+ os.write(fds[1], value.hex.encode('latin-1'))
os._exit(0)
else:
os.close(fds[1])
parent_value = uuid.uuid4().hex
os.waitpid(pid, 0)
- child_value = os.read(fds[0], 100).decode('latin1')
+ child_value = os.read(fds[0], 100).decode('latin-1')
self.assertNotEqual(parent_value, child_value)
diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py
new file mode 100644
index 0000000..9696844
--- /dev/null
+++ b/Lib/test/test_venv.py
@@ -0,0 +1,203 @@
+"""
+Test harness for the venv module.
+
+Copyright (C) 2011-2012 Vinay Sajip.
+Licensed to the PSF under a contributor agreement.
+"""
+
+import os
+import os.path
+import shutil
+import subprocess
+import sys
+import tempfile
+from test.support import (captured_stdout, captured_stderr, run_unittest,
+ can_symlink)
+import unittest
+import venv
+
+class BaseTest(unittest.TestCase):
+ """Base class for venv tests."""
+
+ def setUp(self):
+ self.env_dir = os.path.realpath(tempfile.mkdtemp())
+ if os.name == 'nt':
+ self.bindir = 'Scripts'
+ self.pydocname = 'pydoc.py'
+ self.lib = ('Lib',)
+ self.include = 'Include'
+ else:
+ self.bindir = 'bin'
+ self.pydocname = 'pydoc'
+ self.lib = ('lib', 'python%s' % sys.version[:3])
+ self.include = 'include'
+ if sys.platform == 'darwin' and '__PYVENV_LAUNCHER__' in os.environ:
+ executable = os.environ['__PYVENV_LAUNCHER__']
+ else:
+ executable = sys.executable
+ self.exe = os.path.split(executable)[-1]
+
+ def tearDown(self):
+ shutil.rmtree(self.env_dir)
+
+ def run_with_capture(self, func, *args, **kwargs):
+ with captured_stdout() as output:
+ with captured_stderr() as error:
+ func(*args, **kwargs)
+ return output.getvalue(), error.getvalue()
+
+ def get_env_file(self, *args):
+ return os.path.join(self.env_dir, *args)
+
+ def get_text_file_contents(self, *args):
+ with open(self.get_env_file(*args), 'r') as f:
+ result = f.read()
+ return result
+
+class BasicTest(BaseTest):
+ """Test venv module functionality."""
+
+ def isdir(self, *args):
+ fn = self.get_env_file(*args)
+ self.assertTrue(os.path.isdir(fn))
+
+ def test_defaults(self):
+ """
+ Test the create function with default arguments.
+ """
+ shutil.rmtree(self.env_dir)
+ self.run_with_capture(venv.create, self.env_dir)
+ self.isdir(self.bindir)
+ self.isdir(self.include)
+ self.isdir(*self.lib)
+ data = self.get_text_file_contents('pyvenv.cfg')
+ if sys.platform == 'darwin' and ('__PYVENV_LAUNCHER__'
+ in os.environ):
+ executable = os.environ['__PYVENV_LAUNCHER__']
+ else:
+ executable = sys.executable
+ path = os.path.dirname(executable)
+ self.assertIn('home = %s' % path, data)
+ data = self.get_text_file_contents(self.bindir, self.pydocname)
+ self.assertTrue(data.startswith('#!%s%s' % (self.env_dir, os.sep)))
+ fn = self.get_env_file(self.bindir, self.exe)
+ if not os.path.exists(fn): # diagnostics for Windows buildbot failures
+ bd = self.get_env_file(self.bindir)
+ print('Contents of %r:' % bd)
+ print(' %r' % os.listdir(bd))
+ self.assertTrue(os.path.exists(fn), 'File %r should exist.' % fn)
+
+ @unittest.skipIf(sys.prefix != sys.base_prefix, 'Test not appropriate '
+ 'in a venv')
+ def test_prefixes(self):
+ """
+ Test that the prefix values are as expected.
+ """
+ #check our prefixes
+ self.assertEqual(sys.base_prefix, sys.prefix)
+ self.assertEqual(sys.base_exec_prefix, sys.exec_prefix)
+
+ # check a venv's prefixes
+ shutil.rmtree(self.env_dir)
+ self.run_with_capture(venv.create, self.env_dir)
+ envpy = os.path.join(self.env_dir, self.bindir, self.exe)
+ cmd = [envpy, '-c', None]
+ for prefix, expected in (
+ ('prefix', self.env_dir),
+ ('prefix', self.env_dir),
+ ('base_prefix', sys.prefix),
+ ('base_exec_prefix', sys.exec_prefix)):
+ cmd[2] = 'import sys; print(sys.%s)' % prefix
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ out, err = p.communicate()
+ self.assertEqual(out.strip(), expected.encode())
+
+ def test_overwrite_existing(self):
+ """
+ Test control of overwriting an existing environment directory.
+ """
+ self.assertRaises(ValueError, venv.create, self.env_dir)
+ builder = venv.EnvBuilder(clear=True)
+ builder.create(self.env_dir)
+
+ def test_upgrade(self):
+ """
+ Test upgrading an existing environment directory.
+ """
+ builder = venv.EnvBuilder(upgrade=True)
+ self.run_with_capture(builder.create, self.env_dir)
+ self.isdir(self.bindir)
+ self.isdir(self.include)
+ self.isdir(*self.lib)
+ fn = self.get_env_file(self.bindir, self.exe)
+ if not os.path.exists(fn): # diagnostics for Windows buildbot failures
+ bd = self.get_env_file(self.bindir)
+ print('Contents of %r:' % bd)
+ print(' %r' % os.listdir(bd))
+ self.assertTrue(os.path.exists(fn), 'File %r should exist.' % fn)
+
+ def test_isolation(self):
+ """
+ Test isolation from system site-packages
+ """
+ for ssp, s in ((True, 'true'), (False, 'false')):
+ builder = venv.EnvBuilder(clear=True, system_site_packages=ssp)
+ builder.create(self.env_dir)
+ data = self.get_text_file_contents('pyvenv.cfg')
+ self.assertIn('include-system-site-packages = %s\n' % s, data)
+
+ @unittest.skipUnless(can_symlink(), 'Needs symlinks')
+ def test_symlinking(self):
+ """
+ Test symlinking works as expected
+ """
+ for usl in (False, True):
+ builder = venv.EnvBuilder(clear=True, symlinks=usl)
+ builder.create(self.env_dir)
+ fn = self.get_env_file(self.bindir, self.exe)
+ # Don't test when False, because e.g. 'python' is always
+ # symlinked to 'python3.3' in the env, even when symlinking in
+ # general isn't wanted.
+ if usl:
+ self.assertTrue(os.path.islink(fn))
+
+ # If a venv is created from a source build and that venv is used to
+ # run the test, the pyvenv.cfg in the venv created in the test will
+ # point to the venv being used to run the test, and we lose the link
+ # to the source build - so Python can't initialise properly.
+ @unittest.skipIf(sys.prefix != sys.base_prefix, 'Test not appropriate '
+ 'in a venv')
+ def test_executable(self):
+ """
+ Test that the sys.executable value is as expected.
+ """
+ shutil.rmtree(self.env_dir)
+ self.run_with_capture(venv.create, self.env_dir)
+ envpy = os.path.join(os.path.realpath(self.env_dir), self.bindir, self.exe)
+ cmd = [envpy, '-c', 'import sys; print(sys.executable)']
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ out, err = p.communicate()
+ self.assertEqual(out.strip(), envpy.encode())
+
+ @unittest.skipUnless(can_symlink(), 'Needs symlinks')
+ def test_executable_symlinks(self):
+ """
+ Test that the sys.executable value is as expected.
+ """
+ shutil.rmtree(self.env_dir)
+ builder = venv.EnvBuilder(clear=True, symlinks=True)
+ builder.create(self.env_dir)
+ envpy = os.path.join(os.path.realpath(self.env_dir), self.bindir, self.exe)
+ cmd = [envpy, '-c', 'import sys; print(sys.executable)']
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ out, err = p.communicate()
+ self.assertEqual(out.strip(), envpy.encode())
+
+def test_main():
+ run_unittest(BasicTest)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_wait3.py b/Lib/test/test_wait3.py
index 786e60b..bd06c8d 100644
--- a/Lib/test/test_wait3.py
+++ b/Lib/test/test_wait3.py
@@ -19,13 +19,16 @@ except AttributeError:
class Wait3Test(ForkWait):
def wait_impl(self, cpid):
- for i in range(10):
+ # This many iterations can be required, since some previously run
+ # tests (e.g. test_ctypes) could have spawned a lot of children
+ # very quickly.
+ for i in range(30):
# wait3() shouldn't hang, but some of the buildbots seem to hang
# in the forking tests. This is an attempt to fix the problem.
spid, status, rusage = os.wait3(os.WNOHANG)
if spid == cpid:
break
- time.sleep(1.0)
+ time.sleep(0.1)
self.assertEqual(spid, cpid)
self.assertEqual(status, 0, "cause = %d, exit = %d" % (status&0xff, status>>8))
diff --git a/Lib/test/test_warnings.py b/Lib/test/test_warnings.py
index 79be835..464ff40 100644
--- a/Lib/test/test_warnings.py
+++ b/Lib/test/test_warnings.py
@@ -40,7 +40,7 @@ def warnings_state(module):
module.filters = original_filters
-class BaseTest(unittest.TestCase):
+class BaseTest:
"""Basic bookkeeping required for testing."""
@@ -63,7 +63,7 @@ class BaseTest(unittest.TestCase):
super(BaseTest, self).tearDown()
-class FilterTests(object):
+class FilterTests(BaseTest):
"""Testing the filtering functionality."""
@@ -186,14 +186,14 @@ class FilterTests(object):
self.assertEqual(str(w[-1].message), text)
self.assertTrue(w[-1].category is UserWarning)
-class CFilterTests(BaseTest, FilterTests):
+class CFilterTests(FilterTests, unittest.TestCase):
module = c_warnings
-class PyFilterTests(BaseTest, FilterTests):
+class PyFilterTests(FilterTests, unittest.TestCase):
module = py_warnings
-class WarnTests(unittest.TestCase):
+class WarnTests(BaseTest):
"""Test warnings.warn() and warnings.warn_explicit()."""
@@ -360,7 +360,7 @@ class WarnTests(unittest.TestCase):
self.module.warn(BadStrWarning())
-class CWarnTests(BaseTest, WarnTests):
+class CWarnTests(WarnTests, unittest.TestCase):
module = c_warnings
# As an early adopter, we sanity check the
@@ -369,7 +369,7 @@ class CWarnTests(BaseTest, WarnTests):
self.assertFalse(original_warnings is self.module)
self.assertFalse(hasattr(self.module.warn, '__code__'))
-class PyWarnTests(BaseTest, WarnTests):
+class PyWarnTests(WarnTests, unittest.TestCase):
module = py_warnings
# As an early adopter, we sanity check the
@@ -379,7 +379,7 @@ class PyWarnTests(BaseTest, WarnTests):
self.assertTrue(hasattr(self.module.warn, '__code__'))
-class WCmdLineTests(unittest.TestCase):
+class WCmdLineTests(BaseTest):
def test_improper_input(self):
# Uses the private _setoption() function to test the parsing
@@ -410,14 +410,14 @@ class WCmdLineTests(unittest.TestCase):
self.assertFalse(out.strip())
self.assertNotIn(b'RuntimeWarning', err)
-class CWCmdLineTests(BaseTest, WCmdLineTests):
+class CWCmdLineTests(WCmdLineTests, unittest.TestCase):
module = c_warnings
-class PyWCmdLineTests(BaseTest, WCmdLineTests):
+class PyWCmdLineTests(WCmdLineTests, unittest.TestCase):
module = py_warnings
-class _WarningsTests(BaseTest):
+class _WarningsTests(BaseTest, unittest.TestCase):
"""Tests specific to the _warnings module."""
@@ -512,12 +512,11 @@ class _WarningsTests(BaseTest):
def test_showwarning_not_callable(self):
with original_warnings.catch_warnings(module=self.module):
self.module.filterwarnings("always", category=UserWarning)
- old_showwarning = self.module.showwarning
+ self.module.showwarning = print
+ with support.captured_output('stdout'):
+ self.module.warn('Warning!')
self.module.showwarning = 23
- try:
- self.assertRaises(TypeError, self.module.warn, "Warning!")
- finally:
- self.module.showwarning = old_showwarning
+ self.assertRaises(TypeError, self.module.warn, "Warning!")
def test_show_warning_output(self):
# With showarning() missing, make sure that output is okay.
@@ -547,15 +546,18 @@ class _WarningsTests(BaseTest):
globals_dict = globals()
oldfile = globals_dict['__file__']
try:
- with original_warnings.catch_warnings(module=self.module) as w:
+ catch = original_warnings.catch_warnings(record=True,
+ module=self.module)
+ with catch as w:
self.module.filterwarnings("always", category=UserWarning)
globals_dict['__file__'] = None
original_warnings.warn('test', UserWarning)
+ self.assertTrue(len(w))
finally:
globals_dict['__file__'] = oldfile
-class WarningsDisplayTests(unittest.TestCase):
+class WarningsDisplayTests(BaseTest):
"""Test the displaying of warnings and the ability to overload functions
related to displaying warnings."""
@@ -599,10 +601,10 @@ class WarningsDisplayTests(unittest.TestCase):
file_object, expected_file_line)
self.assertEqual(expect, file_object.getvalue())
-class CWarningsDisplayTests(BaseTest, WarningsDisplayTests):
+class CWarningsDisplayTests(WarningsDisplayTests, unittest.TestCase):
module = c_warnings
-class PyWarningsDisplayTests(BaseTest, WarningsDisplayTests):
+class PyWarningsDisplayTests(WarningsDisplayTests, unittest.TestCase):
module = py_warnings
@@ -708,10 +710,10 @@ class CatchWarningTests(BaseTest):
with support.check_warnings(('foo', RuntimeWarning)):
wmod.warn("foo")
-class CCatchWarningTests(CatchWarningTests):
+class CCatchWarningTests(CatchWarningTests, unittest.TestCase):
module = c_warnings
-class PyCatchWarningTests(CatchWarningTests):
+class PyCatchWarningTests(CatchWarningTests, unittest.TestCase):
module = py_warnings
@@ -760,10 +762,10 @@ class EnvironmentVariableTests(BaseTest):
"['ignore:DeprecaciónWarning']".encode('utf-8'))
self.assertEqual(p.wait(), 0)
-class CEnvironmentVariableTests(EnvironmentVariableTests):
+class CEnvironmentVariableTests(EnvironmentVariableTests, unittest.TestCase):
module = c_warnings
-class PyEnvironmentVariableTests(EnvironmentVariableTests):
+class PyEnvironmentVariableTests(EnvironmentVariableTests, unittest.TestCase):
module = py_warnings
@@ -786,20 +788,12 @@ class BootstrapTest(unittest.TestCase):
env=env)
self.assertEqual(retcode, 0)
-def test_main():
+
+def setUpModule():
py_warnings.onceregistry.clear()
c_warnings.onceregistry.clear()
- support.run_unittest(
- CFilterTests, PyFilterTests,
- CWarnTests, PyWarnTests,
- CWCmdLineTests, PyWCmdLineTests,
- _WarningsTests,
- CWarningsDisplayTests, PyWarningsDisplayTests,
- CCatchWarningTests, PyCatchWarningTests,
- CEnvironmentVariableTests, PyEnvironmentVariableTests,
- BootstrapTest,
- )
+tearDownModule = setUpModule
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_webbrowser.py b/Lib/test/test_webbrowser.py
new file mode 100644
index 0000000..34b9364
--- /dev/null
+++ b/Lib/test/test_webbrowser.py
@@ -0,0 +1,192 @@
+import webbrowser
+import unittest
+import subprocess
+from unittest import mock
+from test import support
+
+
+URL = 'http://www.example.com'
+CMD_NAME = 'test'
+
+
+class PopenMock(mock.MagicMock):
+
+ def poll(self):
+ return 0
+
+ def wait(self, seconds=None):
+ return 0
+
+
+class CommandTestMixin:
+
+ def _test(self, meth, *, args=[URL], kw={}, options, arguments):
+ """Given a web browser instance method name along with arguments and
+ keywords for same (which defaults to the single argument URL), creates
+ a browser instance from the class pointed to by self.browser, calls the
+ indicated instance method with the indicated arguments, and compares
+ the resulting options and arguments passed to Popen by the browser
+ instance against the 'options' and 'args' lists. Options are compared
+ in a position independent fashion, and the arguments are compared in
+ sequence order to whatever is left over after removing the options.
+
+ """
+ popen = PopenMock()
+ support.patch(self, subprocess, 'Popen', popen)
+ browser = self.browser_class(name=CMD_NAME)
+ getattr(browser, meth)(*args, **kw)
+ popen_args = subprocess.Popen.call_args[0][0]
+ self.assertEqual(popen_args[0], CMD_NAME)
+ popen_args.pop(0)
+ for option in options:
+ self.assertIn(option, popen_args)
+ popen_args.pop(popen_args.index(option))
+ self.assertEqual(popen_args, arguments)
+
+
+class GenericBrowserCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.GenericBrowser
+
+ def test_open(self):
+ self._test('open',
+ options=[],
+ arguments=[URL])
+
+
+class BackgroundBrowserCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.BackgroundBrowser
+
+ def test_open(self):
+ self._test('open',
+ options=[],
+ arguments=[URL])
+
+
+class ChromeCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.Chrome
+
+ def test_open(self):
+ self._test('open',
+ options=[],
+ arguments=[URL])
+
+ def test_open_with_autoraise_false(self):
+ self._test('open', kw=dict(autoraise=False),
+ options=[],
+ arguments=[URL])
+
+ def test_open_new(self):
+ self._test('open_new',
+ options=['--new-window'],
+ arguments=[URL])
+
+ def test_open_new_tab(self):
+ self._test('open_new_tab',
+ options=[],
+ arguments=[URL])
+
+
+class MozillaCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.Mozilla
+
+ def test_open(self):
+ self._test('open',
+ options=['-raise', '-remote'],
+ arguments=['openURL({})'.format(URL)])
+
+ def test_open_with_autoraise_false(self):
+ self._test('open', kw=dict(autoraise=False),
+ options=['-noraise', '-remote'],
+ arguments=['openURL({})'.format(URL)])
+
+ def test_open_new(self):
+ self._test('open_new',
+ options=['-raise', '-remote'],
+ arguments=['openURL({},new-window)'.format(URL)])
+
+ def test_open_new_tab(self):
+ self._test('open_new_tab',
+ options=['-raise', '-remote'],
+ arguments=['openURL({},new-tab)'.format(URL)])
+
+
+class GaleonCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.Galeon
+
+ def test_open(self):
+ self._test('open',
+ options=['-n'],
+ arguments=[URL])
+
+ def test_open_with_autoraise_false(self):
+ self._test('open', kw=dict(autoraise=False),
+ options=['-noraise', '-n'],
+ arguments=[URL])
+
+ def test_open_new(self):
+ self._test('open_new',
+ options=['-w'],
+ arguments=[URL])
+
+ def test_open_new_tab(self):
+ self._test('open_new_tab',
+ options=['-w'],
+ arguments=[URL])
+
+
+class OperaCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.Opera
+
+ def test_open(self):
+ self._test('open',
+ options=['-remote'],
+ arguments=['openURL({})'.format(URL)])
+
+ def test_open_with_autoraise_false(self):
+ self._test('open', kw=dict(autoraise=False),
+ options=['-remote', '-noraise'],
+ arguments=['openURL({})'.format(URL)])
+
+ def test_open_new(self):
+ self._test('open_new',
+ options=['-remote'],
+ arguments=['openURL({},new-window)'.format(URL)])
+
+ def test_open_new(self):
+ self._test('open_new_tab',
+ options=['-remote'],
+ arguments=['openURL({},new-page)'.format(URL)])
+
+
+class ELinksCommandTest(CommandTestMixin, unittest.TestCase):
+
+ browser_class = webbrowser.Elinks
+
+ def test_open(self):
+ self._test('open', options=['-remote'],
+ arguments=['openURL({})'.format(URL)])
+
+ def test_open_with_autoraise_false(self):
+ self._test('open',
+ options=['-remote'],
+ arguments=['openURL({})'.format(URL)])
+
+ def test_open_new(self):
+ self._test('open_new',
+ options=['-remote'],
+ arguments=['openURL({},new-window)'.format(URL)])
+
+ def test_open_new_tab(self):
+ self._test('open_new_tab',
+ options=['-remote'],
+ arguments=['openURL({},new-tab)'.format(URL)])
+
+
+if __name__=='__main__':
+ unittest.main()
diff --git a/Lib/test/test_winsound.py b/Lib/test/test_winsound.py
index 34c3dea..eb7f75f 100644
--- a/Lib/test/test_winsound.py
+++ b/Lib/test/test_winsound.py
@@ -16,16 +16,12 @@ def has_sound(sound):
try:
# Ask the mixer API for the number of devices it knows about.
# When there are no devices, PlaySound will fail.
- if ctypes.windll.winmm.mixerGetNumDevs() is 0:
+ if ctypes.windll.winmm.mixerGetNumDevs() == 0:
return False
key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER,
"AppEvents\Schemes\Apps\.Default\{0}\.Default".format(sound))
- value = winreg.EnumValue(key, 0)[1]
- if value is not "":
- return True
- else:
- return False
+ return winreg.EnumValue(key, 0)[1] != ""
except WindowsError:
return False
diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py
index 08f8d9a..05c1f4f 100644
--- a/Lib/test/test_wsgiref.py
+++ b/Lib/test/test_wsgiref.py
@@ -9,6 +9,8 @@ from wsgiref.simple_server import WSGIServer, WSGIRequestHandler, demo_app
from wsgiref.simple_server import make_server
from io import StringIO, BytesIO, BufferedReader
from socketserver import BaseServer
+from platform import python_implementation
+
import os
import re
import sys
@@ -100,9 +102,11 @@ def compare_generic_iter(make_it,match):
class IntegrationTests(TestCase):
def check_hello(self, out, has_length=True):
+ pyver = (python_implementation() + "/" +
+ sys.version.split()[0])
self.assertEqual(out,
("HTTP/1.0 200 OK\r\n"
- "Server: WSGIServer/0.2 Python/"+sys.version.split()[0]+"\r\n"
+ "Server: WSGIServer/0.2 " + pyver +"\r\n"
"Content-Type: text/plain\r\n"
"Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" +
(has_length and "Content-Length: 13\r\n" or "") +
@@ -156,9 +160,11 @@ class IntegrationTests(TestCase):
out, err = run_amock(validator(app))
self.assertTrue(err.endswith('"GET / HTTP/1.0" 200 4\n'))
ver = sys.version.split()[0].encode('ascii')
+ py = python_implementation().encode('ascii')
+ pyver = py + b"/" + ver
self.assertEqual(
b"HTTP/1.0 200 OK\r\n"
- b"Server: WSGIServer/0.2 Python/" + ver + b"\r\n"
+ b"Server: WSGIServer/0.2 "+ pyver + b"\r\n"
b"Content-Type: text/plain; charset=utf-8\r\n"
b"Date: Wed, 24 Dec 2008 13:29:32 GMT\r\n"
b"\r\n"
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index f61292f..61161b6 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -1,28 +1,41 @@
# xml.etree test. This file contains enough tests to make sure that
# all included components work as they should.
# Large parts are extracted from the upstream test suite.
-
-# IMPORTANT: the same doctests are run from "test_xml_etree_c" in
-# order to ensure consistency between the C implementation and the
-# Python implementation.
+#
+# PLEASE write all new tests using the standard unittest infrastructure and
+# not doctest.
+#
+# IMPORTANT: the same tests are run from "test_xml_etree_c" in order
+# to ensure consistency between the C implementation and the Python
+# implementation.
#
# For this purpose, the module-level "ET" symbol is temporarily
# monkey-patched when running the "test_xml_etree_c" test suite.
# Don't re-import "xml.etree.ElementTree" module in the docstring,
# except if the test is specific to the Python implementation.
-import sys
import html
+import io
+import operator
+import pickle
+import sys
import unittest
+import weakref
+from itertools import product
from test import support
-from test.support import findfile
+from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
-from xml.etree import ElementTree as ET
+# pyET is the pure-Python implementation.
+#
+# ET is pyET in test_xml_etree and is the C accelerated version in
+# test_xml_etree_c.
+pyET = None
+ET = None
SIMPLE_XMLFILE = findfile("simple.xml", subdir="xmltestdata")
try:
- SIMPLE_XMLFILE.encode("utf8")
+ SIMPLE_XMLFILE.encode("utf-8")
except UnicodeEncodeError:
raise unittest.SkipTest("filename is not encodable to utf8")
SIMPLE_NS_XMLFILE = findfile("simple-ns.xml", subdir="xmltestdata")
@@ -57,6 +70,22 @@ SAMPLE_XML_NS = """
</body>
"""
+SAMPLE_XML_NS_ELEMS = """
+<root>
+<h:table xmlns:h="hello">
+ <h:tr>
+ <h:td>Apples</h:td>
+ <h:td>Bananas</h:td>
+ </h:tr>
+</h:table>
+
+<f:table xmlns:f="foo">
+ <f:name>African Coffee Table</f:name>
+ <f:width>80</f:width>
+ <f:length>120</f:length>
+</f:table>
+</root>
+"""
def sanity():
"""
@@ -148,6 +177,38 @@ def check_element(element):
for elem in element:
check_element(elem)
+class ElementTestCase:
+ @classmethod
+ def setUpClass(cls):
+ cls.modules = {pyET, ET}
+
+ def pickleRoundTrip(self, obj, name, dumper, loader):
+ save_m = sys.modules[name]
+ try:
+ sys.modules[name] = dumper
+ temp = pickle.dumps(obj)
+ sys.modules[name] = loader
+ result = pickle.loads(temp)
+ except pickle.PicklingError as pe:
+ # pyET must be second, because pyET may be (equal to) ET.
+ human = dict([(ET, "cET"), (pyET, "pyET")])
+ raise support.TestFailed("Failed to round-trip %r from %r to %r"
+ % (obj,
+ human.get(dumper, dumper),
+ human.get(loader, loader))) from pe
+ finally:
+ sys.modules[name] = save_m
+ return result
+
+ def assertEqualElements(self, alice, bob):
+ self.assertIsInstance(alice, (ET.Element, pyET.Element))
+ self.assertIsInstance(bob, (ET.Element, pyET.Element))
+ self.assertEqual(len(list(alice)), len(list(bob)))
+ for x, y in zip(alice, bob):
+ self.assertEqualElements(x, y)
+ properties = operator.attrgetter('tag', 'tail', 'text', 'attrib')
+ self.assertEqual(properties(alice), properties(bob))
+
# --------------------------------------------------------------------
# element tree tests
@@ -188,10 +249,8 @@ def interface():
These methods return an iterable. See bug 6472.
- >>> check_method(element.iter("tag").__next__)
>>> check_method(element.iterfind("tag").__next__)
>>> check_method(element.iterfind("*").__next__)
- >>> check_method(tree.iter("tag").__next__)
>>> check_method(tree.iterfind("tag").__next__)
>>> check_method(tree.iterfind("*").__next__)
@@ -270,166 +329,6 @@ def cdata():
'<tag>hello</tag>'
"""
-# Only with Python implementation
-def simplefind():
- """
- Test find methods using the elementpath fallback.
-
- >>> from xml.etree import ElementTree
-
- >>> CurrentElementPath = ElementTree.ElementPath
- >>> ElementTree.ElementPath = ElementTree._SimpleElementPath()
- >>> elem = ElementTree.XML(SAMPLE_XML)
- >>> elem.find("tag").tag
- 'tag'
- >>> ElementTree.ElementTree(elem).find("tag").tag
- 'tag'
- >>> elem.findtext("tag")
- 'text'
- >>> elem.findtext("tog")
- >>> elem.findtext("tog", "default")
- 'default'
- >>> ElementTree.ElementTree(elem).findtext("tag")
- 'text'
- >>> summarize_list(elem.findall("tag"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall(".//tag"))
- ['tag', 'tag', 'tag']
-
- Path syntax doesn't work in this case.
-
- >>> elem.find("section/tag")
- >>> elem.findtext("section/tag")
- >>> summarize_list(elem.findall("section/tag"))
- []
-
- >>> ElementTree.ElementPath = CurrentElementPath
- """
-
-def find():
- """
- Test find methods (including xpath syntax).
-
- >>> elem = ET.XML(SAMPLE_XML)
- >>> elem.find("tag").tag
- 'tag'
- >>> ET.ElementTree(elem).find("tag").tag
- 'tag'
- >>> elem.find("section/tag").tag
- 'tag'
- >>> elem.find("./tag").tag
- 'tag'
- >>> ET.ElementTree(elem).find("./tag").tag
- 'tag'
- >>> ET.ElementTree(elem).find("/tag").tag
- 'tag'
- >>> elem[2] = ET.XML(SAMPLE_SECTION)
- >>> elem.find("section/nexttag").tag
- 'nexttag'
- >>> ET.ElementTree(elem).find("section/tag").tag
- 'tag'
- >>> ET.ElementTree(elem).find("tog")
- >>> ET.ElementTree(elem).find("tog/foo")
- >>> elem.findtext("tag")
- 'text'
- >>> elem.findtext("section/nexttag")
- ''
- >>> elem.findtext("section/nexttag", "default")
- ''
- >>> elem.findtext("tog")
- >>> elem.findtext("tog", "default")
- 'default'
- >>> ET.ElementTree(elem).findtext("tag")
- 'text'
- >>> ET.ElementTree(elem).findtext("tog/foo")
- >>> ET.ElementTree(elem).findtext("tog/foo", "default")
- 'default'
- >>> ET.ElementTree(elem).findtext("./tag")
- 'text'
- >>> ET.ElementTree(elem).findtext("/tag")
- 'text'
- >>> elem.findtext("section/tag")
- 'subtext'
- >>> ET.ElementTree(elem).findtext("section/tag")
- 'subtext'
- >>> ET.XML('<root><empty /></root>').findtext('empty')
- ''
- >>> summarize_list(elem.findall("."))
- ['body']
- >>> summarize_list(elem.findall("tag"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall("tog"))
- []
- >>> summarize_list(elem.findall("tog/foo"))
- []
- >>> summarize_list(elem.findall("*"))
- ['tag', 'tag', 'section']
- >>> summarize_list(elem.findall(".//tag"))
- ['tag', 'tag', 'tag', 'tag']
- >>> summarize_list(elem.findall("section/tag"))
- ['tag']
- >>> summarize_list(elem.findall("section//tag"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall("section/*"))
- ['tag', 'nexttag', 'nextsection']
- >>> summarize_list(elem.findall("section//*"))
- ['tag', 'nexttag', 'nextsection', 'tag']
- >>> summarize_list(elem.findall("section/.//*"))
- ['tag', 'nexttag', 'nextsection', 'tag']
- >>> summarize_list(elem.findall("*/*"))
- ['tag', 'nexttag', 'nextsection']
- >>> summarize_list(elem.findall("*//*"))
- ['tag', 'nexttag', 'nextsection', 'tag']
- >>> summarize_list(elem.findall("*/tag"))
- ['tag']
- >>> summarize_list(elem.findall("*/./tag"))
- ['tag']
- >>> summarize_list(elem.findall("./tag"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall(".//tag"))
- ['tag', 'tag', 'tag', 'tag']
- >>> summarize_list(elem.findall("././tag"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall(".//tag[@class]"))
- ['tag', 'tag', 'tag']
- >>> summarize_list(elem.findall(".//tag[@class='a']"))
- ['tag']
- >>> summarize_list(elem.findall(".//tag[@class='b']"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall(".//tag[@id]"))
- ['tag']
- >>> summarize_list(elem.findall(".//section[tag]"))
- ['section']
- >>> summarize_list(elem.findall(".//section[element]"))
- []
- >>> summarize_list(elem.findall("../tag"))
- []
- >>> summarize_list(elem.findall("section/../tag"))
- ['tag', 'tag']
- >>> summarize_list(ET.ElementTree(elem).findall("./tag"))
- ['tag', 'tag']
-
- Following example is invalid in 1.2.
- A leading '*' is assumed in 1.3.
-
- >>> elem.findall("section//") == elem.findall("section//*")
- True
-
- ET's Path module handles this case incorrectly; this gives
- a warning in 1.3, and the behaviour will be modified in 1.4.
-
- >>> summarize_list(ET.ElementTree(elem).findall("/tag"))
- ['tag', 'tag']
-
- >>> elem = ET.XML(SAMPLE_XML_NS)
- >>> summarize_list(elem.findall("tag"))
- []
- >>> summarize_list(elem.findall("{http://effbot.org/ns}tag"))
- ['{http://effbot.org/ns}tag', '{http://effbot.org/ns}tag']
- >>> summarize_list(elem.findall(".//{http://effbot.org/ns}tag"))
- ['{http://effbot.org/ns}tag', '{http://effbot.org/ns}tag', '{http://effbot.org/ns}tag']
- """
-
def file_init():
"""
>>> import io
@@ -448,31 +347,23 @@ def file_init():
'empty-element'
"""
-def bad_find():
- """
- Check bad or unsupported path expressions.
-
- >>> elem = ET.XML(SAMPLE_XML)
- >>> elem.findall("/tag")
- Traceback (most recent call last):
- SyntaxError: cannot use absolute path on element
- """
-
def path_cache():
"""
Check that the path cache behaves sanely.
+ >>> from xml.etree import ElementPath
+
>>> elem = ET.XML(SAMPLE_XML)
>>> for i in range(10): ET.ElementTree(elem).find('./'+str(i))
- >>> cache_len_10 = len(ET.ElementPath._cache)
+ >>> cache_len_10 = len(ElementPath._cache)
>>> for i in range(10): ET.ElementTree(elem).find('./'+str(i))
- >>> len(ET.ElementPath._cache) == cache_len_10
+ >>> len(ElementPath._cache) == cache_len_10
True
>>> for i in range(20): ET.ElementTree(elem).find('./'+str(i))
- >>> len(ET.ElementPath._cache) > cache_len_10
+ >>> len(ElementPath._cache) > cache_len_10
True
>>> for i in range(600): ET.ElementTree(elem).find('./'+str(i))
- >>> len(ET.ElementPath._cache) < 500
+ >>> len(ElementPath._cache) < 500
True
"""
@@ -904,65 +795,6 @@ def check_encoding(encoding):
"""
ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding)
-def encoding():
- r"""
- Test encoding issues.
-
- >>> elem = ET.Element("tag")
- >>> elem.text = "abc"
- >>> serialize(elem)
- '<tag>abc</tag>'
- >>> serialize(elem, encoding="utf-8")
- b'<tag>abc</tag>'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag>abc</tag>'
- >>> serialize(elem, encoding="iso-8859-1")
- b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>abc</tag>"
-
- >>> elem.text = "<&\"\'>"
- >>> serialize(elem)
- '<tag>&lt;&amp;"\'&gt;</tag>'
- >>> serialize(elem, encoding="utf-8")
- b'<tag>&lt;&amp;"\'&gt;</tag>'
- >>> serialize(elem, encoding="us-ascii") # cdata characters
- b'<tag>&lt;&amp;"\'&gt;</tag>'
- >>> serialize(elem, encoding="iso-8859-1")
- b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag>&lt;&amp;"\'&gt;</tag>'
-
- >>> elem.attrib["key"] = "<&\"\'>"
- >>> elem.text = None
- >>> serialize(elem)
- '<tag key="&lt;&amp;&quot;\'&gt;" />'
- >>> serialize(elem, encoding="utf-8")
- b'<tag key="&lt;&amp;&quot;\'&gt;" />'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag key="&lt;&amp;&quot;\'&gt;" />'
- >>> serialize(elem, encoding="iso-8859-1")
- b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="&lt;&amp;&quot;\'&gt;" />'
-
- >>> elem.text = '\xe5\xf6\xf6<>'
- >>> elem.attrib.clear()
- >>> serialize(elem)
- '<tag>\xe5\xf6\xf6&lt;&gt;</tag>'
- >>> serialize(elem, encoding="utf-8")
- b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>'
- >>> serialize(elem, encoding="iso-8859-1")
- b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>\xe5\xf6\xf6&lt;&gt;</tag>"
-
- >>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
- >>> elem.text = None
- >>> serialize(elem)
- '<tag key="\xe5\xf6\xf6&lt;&gt;" />'
- >>> serialize(elem, encoding="utf-8")
- b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag key="&#229;&#246;&#246;&lt;&gt;" />'
- >>> serialize(elem, encoding="iso-8859-1")
- b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="\xe5\xf6\xf6&lt;&gt;" />'
- """
-
def methods():
r"""
Test serialization methods.
@@ -981,36 +813,6 @@ def methods():
'1 < 2\n'
"""
-def iterators():
- """
- Test iterators.
-
- >>> e = ET.XML("<html><body>this is a <i>paragraph</i>.</body>..</html>")
- >>> summarize_list(e.iter())
- ['html', 'body', 'i']
- >>> summarize_list(e.find("body").iter())
- ['body', 'i']
- >>> summarize(next(e.iter()))
- 'html'
- >>> "".join(e.itertext())
- 'this is a paragraph...'
- >>> "".join(e.find("body").itertext())
- 'this is a paragraph.'
- >>> next(e.itertext())
- 'this is a '
-
- Method iterparse should return an iterator. See bug 6472.
-
- >>> sourcefile = serialize(e, to_string=False)
- >>> next(ET.iterparse(sourcefile)) # doctest: +ELLIPSIS
- ('end', <Element 'i' at 0x...>)
-
- >>> tree = ET.ElementTree(None)
- >>> tree.iter()
- Traceback (most recent call last):
- AttributeError: 'NoneType' object has no attribute 'iter'
- """
-
ENTITY_XML = """\
<!DOCTYPE points [
<!ENTITY % user-entities SYSTEM 'user-entities.xml'>
@@ -1051,26 +853,6 @@ def entity():
'<document>text</document>'
"""
-def error(xml):
- """
-
- Test error handling.
-
- >>> issubclass(ET.ParseError, SyntaxError)
- True
- >>> error("foo").position
- (1, 0)
- >>> error("<tag>&foo;</tag>").position
- (1, 5)
- >>> error("foobar<").position
- (1, 6)
-
- """
- try:
- ET.XML(xml)
- except ET.ParseError:
- return sys.exc_info()[1]
-
def namespace():
"""
Test namespace issues.
@@ -1258,8 +1040,8 @@ def processinginstruction():
>>> ET.tostring(ET.PI('test', '<testing&>'))
b'<?test <testing&>?>'
- >>> ET.tostring(ET.PI('test', '<testing&>\xe3'), 'latin1')
- b"<?xml version='1.0' encoding='latin1'?>\\n<?test <testing&>\\xe3?>"
+ >>> ET.tostring(ET.PI('test', '<testing&>\xe3'), 'latin-1')
+ b"<?xml version='1.0' encoding='latin-1'?>\\n<?test <testing&>\\xe3?>"
"""
#
@@ -1338,21 +1120,20 @@ XINCLUDE["default.xml"] = """\
</document>
""".format(html.escape(SIMPLE_XMLFILE, True))
+
def xinclude_loader(href, parse="xml", encoding=None):
try:
data = XINCLUDE[href]
except KeyError:
- raise IOError("resource not found")
+ raise OSError("resource not found")
if parse == "xml":
- from xml.etree.ElementTree import XML
- return XML(data)
+ data = ET.XML(data)
return data
def xinclude():
r"""
Basic inclusion example (XInclude C.1)
- >>> from xml.etree import ElementTree as ET
>>> from xml.etree import ElementInclude
>>> document = xinclude_loader("C1.xml")
@@ -1407,26 +1188,10 @@ def xinclude():
>>> document = xinclude_loader("C5.xml")
>>> ElementInclude.include(document, xinclude_loader)
Traceback (most recent call last):
- IOError: resource not found
+ OSError: resource not found
>>> # print(serialize(document)) # C5
"""
-def xinclude_default():
- """
- >>> from xml.etree import ElementInclude
-
- >>> document = xinclude_loader("default.xml")
- >>> ElementInclude.include(document)
- >>> print(serialize(document)) # default
- <document>
- <p>Example.</p>
- <root>
- <element key="value">text</element>
- <element>text</element>tail
- <empty-element />
- </root>
- </document>
- """
#
# badly formatted xi:include tags
@@ -1614,7 +1379,7 @@ def bug_xmltoolkit55():
class ExceptionFile:
def read(self, x):
- raise IOError
+ raise OSError
def xmltoolkit60():
"""
@@ -1622,7 +1387,7 @@ def xmltoolkit60():
Handle crash in stream source.
>>> tree = ET.parse(ExceptionFile())
Traceback (most recent call last):
- IOError
+ OSError
"""
@@ -1853,25 +1618,858 @@ def check_issue10777():
>>> ET.register_namespace('test10777', 'http://myuri/')
"""
-def check_html_empty_elems_serialization(self):
- # issue 15970
- # from http://www.w3.org/TR/html401/index/elements.html
- """
+# --------------------------------------------------------------------
- >>> empty_elems = ['AREA', 'BASE', 'BASEFONT', 'BR', 'COL', 'FRAME', 'HR',
- ... 'IMG', 'INPUT', 'ISINDEX', 'LINK', 'META', 'PARAM']
- >>> elems = ''.join('<%s />' % elem for elem in empty_elems)
- >>> serialize(ET.XML('<html>%s</html>' % elems), method='html')
- '<html><AREA><BASE><BASEFONT><BR><COL><FRAME><HR><IMG><INPUT><ISINDEX><LINK><META><PARAM></html>'
- >>> serialize(ET.XML('<html>%s</html>' % elems.lower()), method='html')
- '<html><area><base><basefont><br><col><frame><hr><img><input><isindex><link><meta><param></html>'
- >>> elems = ''.join('<%s></%s>' % (elem, elem) for elem in empty_elems)
- >>> serialize(ET.XML('<html>%s</html>' % elems), method='html')
- '<html><AREA><BASE><BASEFONT><BR><COL><FRAME><HR><IMG><INPUT><ISINDEX><LINK><META><PARAM></html>'
- >>> serialize(ET.XML('<html>%s</html>' % elems.lower()), method='html')
- '<html><area><base><basefont><br><col><frame><hr><img><input><isindex><link><meta><param></html>'
- """
+class BasicElementTest(ElementTestCase, unittest.TestCase):
+ def test_augmentation_type_errors(self):
+ e = ET.Element('joe')
+ self.assertRaises(TypeError, e.append, 'b')
+ self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo'])
+ self.assertRaises(TypeError, e.insert, 0, 'foo')
+
+ def test_cyclic_gc(self):
+ class Dummy:
+ pass
+
+ # Test the shortest cycle: d->element->d
+ d = Dummy()
+ d.dummyref = ET.Element('joe', attr=d)
+ wref = weakref.ref(d)
+ del d
+ gc_collect()
+ self.assertIsNone(wref())
+
+ # A longer cycle: d->e->e2->d
+ e = ET.Element('joe')
+ d = Dummy()
+ d.dummyref = e
+ wref = weakref.ref(d)
+ e2 = ET.SubElement(e, 'foo', attr=d)
+ del d, e, e2
+ gc_collect()
+ self.assertIsNone(wref())
+
+ # A cycle between Element objects as children of one another
+ # e1->e2->e3->e1
+ e1 = ET.Element('e1')
+ e2 = ET.Element('e2')
+ e3 = ET.Element('e3')
+ e1.append(e2)
+ e2.append(e2)
+ e3.append(e1)
+ wref = weakref.ref(e1)
+ del e1, e2, e3
+ gc_collect()
+ self.assertIsNone(wref())
+
+ def test_weakref(self):
+ flag = False
+ def wref_cb(w):
+ nonlocal flag
+ flag = True
+ e = ET.Element('e')
+ wref = weakref.ref(e, wref_cb)
+ self.assertEqual(wref().tag, 'e')
+ del e
+ self.assertEqual(flag, True)
+ self.assertEqual(wref(), None)
+
+ def test_get_keyword_args(self):
+ e1 = ET.Element('foo' , x=1, y=2, z=3)
+ self.assertEqual(e1.get('x', default=7), 1)
+ self.assertEqual(e1.get('w', default=7), 7)
+
+ def test_pickle(self):
+ # issue #16076: the C implementation wasn't pickleable.
+ for dumper, loader in product(self.modules, repeat=2):
+ e = dumper.Element('foo', bar=42)
+ e.text = "text goes here"
+ e.tail = "opposite of head"
+ dumper.SubElement(e, 'child').append(dumper.Element('grandchild'))
+ e.append(dumper.Element('child'))
+ e.findall('.//grandchild')[0].set('attr', 'other value')
+
+ e2 = self.pickleRoundTrip(e, 'xml.etree.ElementTree',
+ dumper, loader)
+
+ self.assertEqual(e2.tag, 'foo')
+ self.assertEqual(e2.attrib['bar'], 42)
+ self.assertEqual(len(e2), 2)
+ self.assertEqualElements(e, e2)
+
+class ElementTreeTest(unittest.TestCase):
+ def test_istype(self):
+ self.assertIsInstance(ET.ParseError, type)
+ self.assertIsInstance(ET.QName, type)
+ self.assertIsInstance(ET.ElementTree, type)
+ self.assertIsInstance(ET.Element, type)
+ self.assertIsInstance(ET.TreeBuilder, type)
+ self.assertIsInstance(ET.XMLParser, type)
+
+ def test_Element_subclass_trivial(self):
+ class MyElement(ET.Element):
+ pass
+
+ mye = MyElement('foo')
+ self.assertIsInstance(mye, ET.Element)
+ self.assertIsInstance(mye, MyElement)
+ self.assertEqual(mye.tag, 'foo')
+
+ # test that attribute assignment works (issue 14849)
+ mye.text = "joe"
+ self.assertEqual(mye.text, "joe")
+
+ def test_Element_subclass_constructor(self):
+ class MyElement(ET.Element):
+ def __init__(self, tag, attrib={}, **extra):
+ super(MyElement, self).__init__(tag + '__', attrib, **extra)
+
+ mye = MyElement('foo', {'a': 1, 'b': 2}, c=3, d=4)
+ self.assertEqual(mye.tag, 'foo__')
+ self.assertEqual(sorted(mye.items()),
+ [('a', 1), ('b', 2), ('c', 3), ('d', 4)])
+
+ def test_Element_subclass_new_method(self):
+ class MyElement(ET.Element):
+ def newmethod(self):
+ return self.tag
+
+ mye = MyElement('joe')
+ self.assertEqual(mye.newmethod(), 'joe')
+
+ def test_html_empty_elems_serialization(self):
+ # issue 15970
+ # from http://www.w3.org/TR/html401/index/elements.html
+ for element in ['AREA', 'BASE', 'BASEFONT', 'BR', 'COL', 'FRAME', 'HR',
+ 'IMG', 'INPUT', 'ISINDEX', 'LINK', 'META', 'PARAM']:
+ for elem in [element, element.lower()]:
+ expected = '<%s>' % elem
+ serialized = serialize(ET.XML('<%s />' % elem), method='html')
+ self.assertEqual(serialized, expected)
+ serialized = serialize(ET.XML('<%s></%s>' % (elem,elem)),
+ method='html')
+ self.assertEqual(serialized, expected)
+
+
+class ElementFindTest(unittest.TestCase):
+ def test_find_simple(self):
+ e = ET.XML(SAMPLE_XML)
+ self.assertEqual(e.find('tag').tag, 'tag')
+ self.assertEqual(e.find('section/tag').tag, 'tag')
+ self.assertEqual(e.find('./tag').tag, 'tag')
+
+ e[2] = ET.XML(SAMPLE_SECTION)
+ self.assertEqual(e.find('section/nexttag').tag, 'nexttag')
+
+ self.assertEqual(e.findtext('./tag'), 'text')
+ self.assertEqual(e.findtext('section/tag'), 'subtext')
+
+ # section/nexttag is found but has no text
+ self.assertEqual(e.findtext('section/nexttag'), '')
+ self.assertEqual(e.findtext('section/nexttag', 'default'), '')
+
+ # tog doesn't exist and 'default' kicks in
+ self.assertIsNone(e.findtext('tog'))
+ self.assertEqual(e.findtext('tog', 'default'), 'default')
+
+ # Issue #16922
+ self.assertEqual(ET.XML('<tag><empty /></tag>').findtext('empty'), '')
+
+ def test_find_xpath(self):
+ LINEAR_XML = '''
+ <body>
+ <tag class='a'/>
+ <tag class='b'/>
+ <tag class='c'/>
+ <tag class='d'/>
+ </body>'''
+ e = ET.XML(LINEAR_XML)
+
+ # Test for numeric indexing and last()
+ self.assertEqual(e.find('./tag[1]').attrib['class'], 'a')
+ self.assertEqual(e.find('./tag[2]').attrib['class'], 'b')
+ self.assertEqual(e.find('./tag[last()]').attrib['class'], 'd')
+ self.assertEqual(e.find('./tag[last()-1]').attrib['class'], 'c')
+ self.assertEqual(e.find('./tag[last()-2]').attrib['class'], 'b')
+
+ def test_findall(self):
+ e = ET.XML(SAMPLE_XML)
+ e[2] = ET.XML(SAMPLE_SECTION)
+ self.assertEqual(summarize_list(e.findall('.')), ['body'])
+ self.assertEqual(summarize_list(e.findall('tag')), ['tag', 'tag'])
+ self.assertEqual(summarize_list(e.findall('tog')), [])
+ self.assertEqual(summarize_list(e.findall('tog/foo')), [])
+ self.assertEqual(summarize_list(e.findall('*')),
+ ['tag', 'tag', 'section'])
+ self.assertEqual(summarize_list(e.findall('.//tag')),
+ ['tag'] * 4)
+ self.assertEqual(summarize_list(e.findall('section/tag')), ['tag'])
+ self.assertEqual(summarize_list(e.findall('section//tag')), ['tag'] * 2)
+ self.assertEqual(summarize_list(e.findall('section/*')),
+ ['tag', 'nexttag', 'nextsection'])
+ self.assertEqual(summarize_list(e.findall('section//*')),
+ ['tag', 'nexttag', 'nextsection', 'tag'])
+ self.assertEqual(summarize_list(e.findall('section/.//*')),
+ ['tag', 'nexttag', 'nextsection', 'tag'])
+ self.assertEqual(summarize_list(e.findall('*/*')),
+ ['tag', 'nexttag', 'nextsection'])
+ self.assertEqual(summarize_list(e.findall('*//*')),
+ ['tag', 'nexttag', 'nextsection', 'tag'])
+ self.assertEqual(summarize_list(e.findall('*/tag')), ['tag'])
+ self.assertEqual(summarize_list(e.findall('*/./tag')), ['tag'])
+ self.assertEqual(summarize_list(e.findall('./tag')), ['tag'] * 2)
+ self.assertEqual(summarize_list(e.findall('././tag')), ['tag'] * 2)
+
+ self.assertEqual(summarize_list(e.findall('.//tag[@class]')),
+ ['tag'] * 3)
+ self.assertEqual(summarize_list(e.findall('.//tag[@class="a"]')),
+ ['tag'])
+ self.assertEqual(summarize_list(e.findall('.//tag[@class="b"]')),
+ ['tag'] * 2)
+ self.assertEqual(summarize_list(e.findall('.//tag[@id]')),
+ ['tag'])
+ self.assertEqual(summarize_list(e.findall('.//section[tag]')),
+ ['section'])
+ self.assertEqual(summarize_list(e.findall('.//section[element]')), [])
+ self.assertEqual(summarize_list(e.findall('../tag')), [])
+ self.assertEqual(summarize_list(e.findall('section/../tag')),
+ ['tag'] * 2)
+ self.assertEqual(e.findall('section//'), e.findall('section//*'))
+
+ def test_test_find_with_ns(self):
+ e = ET.XML(SAMPLE_XML_NS)
+ self.assertEqual(summarize_list(e.findall('tag')), [])
+ self.assertEqual(
+ summarize_list(e.findall("{http://effbot.org/ns}tag")),
+ ['{http://effbot.org/ns}tag'] * 2)
+ self.assertEqual(
+ summarize_list(e.findall(".//{http://effbot.org/ns}tag")),
+ ['{http://effbot.org/ns}tag'] * 3)
+
+ def test_bad_find(self):
+ e = ET.XML(SAMPLE_XML)
+ with self.assertRaisesRegex(SyntaxError, 'cannot use absolute path'):
+ e.findall('/tag')
+
+ def test_find_through_ElementTree(self):
+ e = ET.XML(SAMPLE_XML)
+ self.assertEqual(ET.ElementTree(e).find('tag').tag, 'tag')
+ self.assertEqual(ET.ElementTree(e).findtext('tag'), 'text')
+ self.assertEqual(summarize_list(ET.ElementTree(e).findall('tag')),
+ ['tag'] * 2)
+ # this produces a warning
+ self.assertEqual(summarize_list(ET.ElementTree(e).findall('//tag')),
+ ['tag'] * 3)
+
+
+class ElementIterTest(unittest.TestCase):
+ def _ilist(self, elem, tag=None):
+ return summarize_list(elem.iter(tag))
+
+ def test_basic(self):
+ doc = ET.XML("<html><body>this is a <i>paragraph</i>.</body>..</html>")
+ self.assertEqual(self._ilist(doc), ['html', 'body', 'i'])
+ self.assertEqual(self._ilist(doc.find('body')), ['body', 'i'])
+ self.assertEqual(next(doc.iter()).tag, 'html')
+ self.assertEqual(''.join(doc.itertext()), 'this is a paragraph...')
+ self.assertEqual(''.join(doc.find('body').itertext()),
+ 'this is a paragraph.')
+ self.assertEqual(next(doc.itertext()), 'this is a ')
+
+ # iterparse should return an iterator
+ sourcefile = serialize(doc, to_string=False)
+ self.assertEqual(next(ET.iterparse(sourcefile))[0], 'end')
+
+ # With an explitit parser too (issue #9708)
+ sourcefile = serialize(doc, to_string=False)
+ parser = ET.XMLParser(target=ET.TreeBuilder())
+ self.assertEqual(next(ET.iterparse(sourcefile, parser=parser))[0],
+ 'end')
+
+ tree = ET.ElementTree(None)
+ self.assertRaises(AttributeError, tree.iter)
+
+ # Issue #16913
+ doc = ET.XML("<root>a&amp;<sub>b&amp;</sub>c&amp;</root>")
+ self.assertEqual(''.join(doc.itertext()), 'a&b&c&')
+
+ def test_corners(self):
+ # single root, no subelements
+ a = ET.Element('a')
+ self.assertEqual(self._ilist(a), ['a'])
+
+ # one child
+ b = ET.SubElement(a, 'b')
+ self.assertEqual(self._ilist(a), ['a', 'b'])
+
+ # one child and one grandchild
+ c = ET.SubElement(b, 'c')
+ self.assertEqual(self._ilist(a), ['a', 'b', 'c'])
+
+ # two children, only first with grandchild
+ d = ET.SubElement(a, 'd')
+ self.assertEqual(self._ilist(a), ['a', 'b', 'c', 'd'])
+
+ # replace first child by second
+ a[0] = a[1]
+ del a[1]
+ self.assertEqual(self._ilist(a), ['a', 'd'])
+
+ def test_iter_by_tag(self):
+ doc = ET.XML('''
+ <document>
+ <house>
+ <room>bedroom1</room>
+ <room>bedroom2</room>
+ </house>
+ <shed>nothing here
+ </shed>
+ <house>
+ <room>bedroom8</room>
+ </house>
+ </document>''')
+
+ self.assertEqual(self._ilist(doc, 'room'), ['room'] * 3)
+ self.assertEqual(self._ilist(doc, 'house'), ['house'] * 2)
+
+ # test that iter also accepts 'tag' as a keyword arg
+ self.assertEqual(
+ summarize_list(doc.iter(tag='room')),
+ ['room'] * 3)
+
+ # make sure both tag=None and tag='*' return all tags
+ all_tags = ['document', 'house', 'room', 'room',
+ 'shed', 'house', 'room']
+ self.assertEqual(self._ilist(doc), all_tags)
+ self.assertEqual(self._ilist(doc, '*'), all_tags)
+
+
+class TreeBuilderTest(unittest.TestCase):
+ sample1 = ('<!DOCTYPE html PUBLIC'
+ ' "-//W3C//DTD XHTML 1.0 Transitional//EN"'
+ ' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">'
+ '<html>text<div>subtext</div>tail</html>')
+
+ sample2 = '''<toplevel>sometext</toplevel>'''
+
+ def _check_sample1_element(self, e):
+ self.assertEqual(e.tag, 'html')
+ self.assertEqual(e.text, 'text')
+ self.assertEqual(e.tail, None)
+ self.assertEqual(e.attrib, {})
+ children = list(e)
+ self.assertEqual(len(children), 1)
+ child = children[0]
+ self.assertEqual(child.tag, 'div')
+ self.assertEqual(child.text, 'subtext')
+ self.assertEqual(child.tail, 'tail')
+ self.assertEqual(child.attrib, {})
+
+ def test_dummy_builder(self):
+ class BaseDummyBuilder:
+ def close(self):
+ return 42
+
+ class DummyBuilder(BaseDummyBuilder):
+ data = start = end = lambda *a: None
+
+ parser = ET.XMLParser(target=DummyBuilder())
+ parser.feed(self.sample1)
+ self.assertEqual(parser.close(), 42)
+
+ parser = ET.XMLParser(target=BaseDummyBuilder())
+ parser.feed(self.sample1)
+ self.assertEqual(parser.close(), 42)
+
+ parser = ET.XMLParser(target=object())
+ parser.feed(self.sample1)
+ self.assertIsNone(parser.close())
+
+ def test_subclass(self):
+ class MyTreeBuilder(ET.TreeBuilder):
+ def foobar(self, x):
+ return x * 2
+
+ tb = MyTreeBuilder()
+ self.assertEqual(tb.foobar(10), 20)
+
+ parser = ET.XMLParser(target=tb)
+ parser.feed(self.sample1)
+
+ e = parser.close()
+ self._check_sample1_element(e)
+
+ def test_element_factory(self):
+ lst = []
+ def myfactory(tag, attrib):
+ nonlocal lst
+ lst.append(tag)
+ return ET.Element(tag, attrib)
+
+ tb = ET.TreeBuilder(element_factory=myfactory)
+ parser = ET.XMLParser(target=tb)
+ parser.feed(self.sample2)
+ parser.close()
+
+ self.assertEqual(lst, ['toplevel'])
+
+ def _check_element_factory_class(self, cls):
+ tb = ET.TreeBuilder(element_factory=cls)
+
+ parser = ET.XMLParser(target=tb)
+ parser.feed(self.sample1)
+ e = parser.close()
+ self.assertIsInstance(e, cls)
+ self._check_sample1_element(e)
+
+ def test_element_factory_subclass(self):
+ class MyElement(ET.Element):
+ pass
+ self._check_element_factory_class(MyElement)
+
+ def test_element_factory_pure_python_subclass(self):
+ # Mimick SimpleTAL's behaviour (issue #16089): both versions of
+ # TreeBuilder should be able to cope with a subclass of the
+ # pure Python Element class.
+ base = ET._Element
+ # Not from a C extension
+ self.assertEqual(base.__module__, 'xml.etree.ElementTree')
+ # Force some multiple inheritance with a C class to make things
+ # more interesting.
+ class MyElement(base, ValueError):
+ pass
+ self._check_element_factory_class(MyElement)
+
+ def test_doctype(self):
+ class DoctypeParser:
+ _doctype = None
+
+ def doctype(self, name, pubid, system):
+ self._doctype = (name, pubid, system)
+
+ def close(self):
+ return self._doctype
+
+ parser = ET.XMLParser(target=DoctypeParser())
+ parser.feed(self.sample1)
+
+ self.assertEqual(parser.close(),
+ ('html', '-//W3C//DTD XHTML 1.0 Transitional//EN',
+ 'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd'))
+
+
+class XincludeTest(unittest.TestCase):
+ def _my_loader(self, href, parse):
+ # Used to avoid a test-dependency problem where the default loader
+ # of ElementInclude uses the pyET parser for cET tests.
+ if parse == 'xml':
+ with open(href, 'rb') as f:
+ return ET.parse(f).getroot()
+ else:
+ return None
+
+ def test_xinclude_default(self):
+ from xml.etree import ElementInclude
+ doc = xinclude_loader('default.xml')
+ ElementInclude.include(doc, self._my_loader)
+ s = serialize(doc)
+ self.assertEqual(s.strip(), '''<document>
+ <p>Example.</p>
+ <root>
+ <element key="value">text</element>
+ <element>text</element>tail
+ <empty-element />
+</root>
+</document>''')
+
+
+class XMLParserTest(unittest.TestCase):
+ sample1 = '<file><line>22</line></file>'
+ sample2 = ('<!DOCTYPE html PUBLIC'
+ ' "-//W3C//DTD XHTML 1.0 Transitional//EN"'
+ ' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">'
+ '<html>text</html>')
+
+ def _check_sample_element(self, e):
+ self.assertEqual(e.tag, 'file')
+ self.assertEqual(e[0].tag, 'line')
+ self.assertEqual(e[0].text, '22')
+
+ def test_constructor_args(self):
+ # Positional args. The first (html) is not supported, but should be
+ # nevertheless correctly accepted.
+ parser = ET.XMLParser(None, ET.TreeBuilder(), 'utf-8')
+ parser.feed(self.sample1)
+ self._check_sample_element(parser.close())
+
+ # Now as keyword args.
+ parser2 = ET.XMLParser(encoding='utf-8', html=[{}], target=ET.TreeBuilder())
+ parser2.feed(self.sample1)
+ self._check_sample_element(parser2.close())
+
+ def test_subclass(self):
+ class MyParser(ET.XMLParser):
+ pass
+ parser = MyParser()
+ parser.feed(self.sample1)
+ self._check_sample_element(parser.close())
+
+ def test_subclass_doctype(self):
+ _doctype = None
+ class MyParserWithDoctype(ET.XMLParser):
+ def doctype(self, name, pubid, system):
+ nonlocal _doctype
+ _doctype = (name, pubid, system)
+
+ parser = MyParserWithDoctype()
+ parser.feed(self.sample2)
+ parser.close()
+ self.assertEqual(_doctype,
+ ('html', '-//W3C//DTD XHTML 1.0 Transitional//EN',
+ 'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd'))
+
+
+class NamespaceParseTest(unittest.TestCase):
+ def test_find_with_namespace(self):
+ nsmap = {'h': 'hello', 'f': 'foo'}
+ doc = ET.fromstring(SAMPLE_XML_NS_ELEMS)
+
+ self.assertEqual(len(doc.findall('{hello}table', nsmap)), 1)
+ self.assertEqual(len(doc.findall('.//{hello}td', nsmap)), 2)
+ self.assertEqual(len(doc.findall('.//{foo}name', nsmap)), 1)
+
+
+class ElementSlicingTest(unittest.TestCase):
+ def _elem_tags(self, elemlist):
+ return [e.tag for e in elemlist]
+
+ def _subelem_tags(self, elem):
+ return self._elem_tags(list(elem))
+
+ def _make_elem_with_children(self, numchildren):
+ """Create an Element with a tag 'a', with the given amount of children
+ named 'a0', 'a1' ... and so on.
+
+ """
+ e = ET.Element('a')
+ for i in range(numchildren):
+ ET.SubElement(e, 'a%s' % i)
+ return e
+
+ def test_getslice_single_index(self):
+ e = self._make_elem_with_children(10)
+
+ self.assertEqual(e[1].tag, 'a1')
+ self.assertEqual(e[-2].tag, 'a8')
+
+ self.assertRaises(IndexError, lambda: e[12])
+
+ def test_getslice_range(self):
+ e = self._make_elem_with_children(6)
+
+ self.assertEqual(self._elem_tags(e[3:]), ['a3', 'a4', 'a5'])
+ self.assertEqual(self._elem_tags(e[3:6]), ['a3', 'a4', 'a5'])
+ self.assertEqual(self._elem_tags(e[3:16]), ['a3', 'a4', 'a5'])
+ self.assertEqual(self._elem_tags(e[3:5]), ['a3', 'a4'])
+ self.assertEqual(self._elem_tags(e[3:-1]), ['a3', 'a4'])
+ self.assertEqual(self._elem_tags(e[:2]), ['a0', 'a1'])
+
+ def test_getslice_steps(self):
+ e = self._make_elem_with_children(10)
+
+ self.assertEqual(self._elem_tags(e[8:10:1]), ['a8', 'a9'])
+ self.assertEqual(self._elem_tags(e[::3]), ['a0', 'a3', 'a6', 'a9'])
+ self.assertEqual(self._elem_tags(e[::8]), ['a0', 'a8'])
+ self.assertEqual(self._elem_tags(e[1::8]), ['a1', 'a9'])
+
+ def test_getslice_negative_steps(self):
+ e = self._make_elem_with_children(4)
+
+ self.assertEqual(self._elem_tags(e[::-1]), ['a3', 'a2', 'a1', 'a0'])
+ self.assertEqual(self._elem_tags(e[::-2]), ['a3', 'a1'])
+
+ def test_delslice(self):
+ e = self._make_elem_with_children(4)
+ del e[0:2]
+ self.assertEqual(self._subelem_tags(e), ['a2', 'a3'])
+
+ e = self._make_elem_with_children(4)
+ del e[0:]
+ self.assertEqual(self._subelem_tags(e), [])
+
+ e = self._make_elem_with_children(4)
+ del e[::-1]
+ self.assertEqual(self._subelem_tags(e), [])
+
+ e = self._make_elem_with_children(4)
+ del e[::-2]
+ self.assertEqual(self._subelem_tags(e), ['a0', 'a2'])
+
+ e = self._make_elem_with_children(4)
+ del e[1::2]
+ self.assertEqual(self._subelem_tags(e), ['a0', 'a2'])
+
+ e = self._make_elem_with_children(2)
+ del e[::2]
+ self.assertEqual(self._subelem_tags(e), ['a1'])
+
+
+class IOTest(unittest.TestCase):
+ def tearDown(self):
+ unlink(TESTFN)
+
+ def test_encoding(self):
+ # Test encoding issues.
+ elem = ET.Element("tag")
+ elem.text = "abc"
+ self.assertEqual(serialize(elem), '<tag>abc</tag>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag>abc</tag>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag>abc</tag>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag>abc</tag>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.text = "<&\"\'>"
+ self.assertEqual(serialize(elem), '<tag>&lt;&amp;"\'&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag>&lt;&amp;"\'&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag>&lt;&amp;"\'&gt;</tag>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag>&lt;&amp;\"'&gt;</tag>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.attrib["key"] = "<&\"\'>"
+ self.assertEqual(serialize(elem), '<tag key="&lt;&amp;&quot;\'&gt;" />')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag key=\"&lt;&amp;&quot;'&gt;\" />" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.text = '\xe5\xf6\xf6<>'
+ self.assertEqual(serialize(elem), '<tag>\xe5\xf6\xf6&lt;&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag>åöö&lt;&gt;</tag>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.attrib["key"] = '\xe5\xf6\xf6<>'
+ self.assertEqual(serialize(elem), '<tag key="\xe5\xf6\xf6&lt;&gt;" />')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag key="&#229;&#246;&#246;&lt;&gt;" />')
+ for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag key=\"åöö&lt;&gt;\" />" % enc).encode(enc))
+
+ def test_write_to_filename(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ tree.write(TESTFN)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''<site />''')
+
+ def test_write_to_text_file(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ with open(TESTFN, 'w', encoding='utf-8') as f:
+ tree.write(f, encoding='unicode')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''<site />''')
+
+ def test_write_to_binary_file(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ with open(TESTFN, 'wb') as f:
+ tree.write(f)
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''<site />''')
+
+ def test_write_to_binary_file_with_bom(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ # test BOM writing to buffered file
+ with open(TESTFN, 'wb') as f:
+ tree.write(f, encoding='utf-16')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ '''<?xml version='1.0' encoding='utf-16'?>\n'''
+ '''<site />'''.encode("utf-16"))
+ # test BOM writing to non-buffered file
+ with open(TESTFN, 'wb', buffering=0) as f:
+ tree.write(f, encoding='utf-16')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ '''<?xml version='1.0' encoding='utf-16'?>\n'''
+ '''<site />'''.encode("utf-16"))
+
+ def test_read_from_stringio(self):
+ tree = ET.ElementTree()
+ stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
+ tree.parse(stream)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_stringio(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ stream = io.StringIO()
+ tree.write(stream, encoding='unicode')
+ self.assertEqual(stream.getvalue(), '''<site />''')
+
+ def test_read_from_bytesio(self):
+ tree = ET.ElementTree()
+ raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+ tree.parse(raw)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_bytesio(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ raw = io.BytesIO()
+ tree.write(raw)
+ self.assertEqual(raw.getvalue(), b'''<site />''')
+
+ class dummy:
+ pass
+
+ def test_read_from_user_text_reader(self):
+ stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
+ reader = self.dummy()
+ reader.read = stream.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_user_text_writer(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ stream = io.StringIO()
+ writer = self.dummy()
+ writer.write = stream.write
+ tree.write(writer, encoding='unicode')
+ self.assertEqual(stream.getvalue(), '''<site />''')
+
+ def test_read_from_user_binary_reader(self):
+ raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+ reader = self.dummy()
+ reader.read = raw.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
+ self.assertEqual(tree.getroot().tag, 'site')
+ tree = ET.ElementTree()
+
+ def test_write_to_user_binary_writer(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ tree.write(writer)
+ self.assertEqual(raw.getvalue(), b'''<site />''')
+
+ def test_write_to_user_binary_writer_with_bom(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ writer.seekable = lambda: True
+ writer.tell = raw.tell
+ tree.write(writer, encoding="utf-16")
+ self.assertEqual(raw.getvalue(),
+ '''<?xml version='1.0' encoding='utf-16'?>\n'''
+ '''<site />'''.encode("utf-16"))
+
+ def test_tostringlist_invariant(self):
+ root = ET.fromstring('<tag>foo</tag>')
+ self.assertEqual(
+ ET.tostring(root, 'unicode'),
+ ''.join(ET.tostringlist(root, 'unicode')))
+ self.assertEqual(
+ ET.tostring(root, 'utf-16'),
+ b''.join(ET.tostringlist(root, 'utf-16')))
+
+
+class ParseErrorTest(unittest.TestCase):
+ def test_subclass(self):
+ self.assertIsInstance(ET.ParseError(), SyntaxError)
+
+ def _get_error(self, s):
+ try:
+ ET.fromstring(s)
+ except ET.ParseError as e:
+ return e
+
+ def test_error_position(self):
+ self.assertEqual(self._get_error('foo').position, (1, 0))
+ self.assertEqual(self._get_error('<tag>&foo;</tag>').position, (1, 5))
+ self.assertEqual(self._get_error('foobar<').position, (1, 6))
+
+ def test_error_code(self):
+ import xml.parsers.expat.errors as ERRORS
+ self.assertEqual(self._get_error('foo').code,
+ ERRORS.codes[ERRORS.XML_ERROR_SYNTAX])
+
+
+class KeywordArgsTest(unittest.TestCase):
+ # Test various issues with keyword arguments passed to ET.Element
+ # constructor and methods
+ def test_issue14818(self):
+ x = ET.XML("<a>foo</a>")
+ self.assertEqual(x.find('a', None),
+ x.find(path='a', namespaces=None))
+ self.assertEqual(x.findtext('a', None, None),
+ x.findtext(path='a', default=None, namespaces=None))
+ self.assertEqual(x.findall('a', None),
+ x.findall(path='a', namespaces=None))
+ self.assertEqual(list(x.iterfind('a', None)),
+ list(x.iterfind(path='a', namespaces=None)))
+
+ self.assertEqual(ET.Element('a').attrib, {})
+ elements = [
+ ET.Element('a', dict(href="#", id="foo")),
+ ET.Element('a', attrib=dict(href="#", id="foo")),
+ ET.Element('a', dict(href="#"), id="foo"),
+ ET.Element('a', href="#", id="foo"),
+ ET.Element('a', dict(href="#", id="foo"), href="#", id="foo"),
+ ]
+ for e in elements:
+ self.assertEqual(e.tag, 'a')
+ self.assertEqual(e.attrib, dict(href="#", id="foo"))
+
+ e2 = ET.SubElement(elements[0], 'foobar', attrib={'key1': 'value1'})
+ self.assertEqual(e2.attrib['key1'], 'value1')
+
+ with self.assertRaisesRegex(TypeError, 'must be dict, not str'):
+ ET.Element('a', "I'm not a dict")
+ with self.assertRaisesRegex(TypeError, 'must be dict, not str'):
+ ET.Element('a', attrib="I'm not a dict")
+
+# --------------------------------------------------------------------
+
+class NoAcceleratorTest(unittest.TestCase):
+ def setUp(self):
+ if not pyET:
+ raise unittest.SkipTest('only for the Python version')
+
+ # Test that the C accelerator was not imported for pyET
+ def test_correct_import_pyET(self):
+ self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree')
+ self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree')
# --------------------------------------------------------------------
@@ -1894,44 +2492,69 @@ class CleanContext(object):
("This method will be removed in future versions. "
"Use .+ instead.", DeprecationWarning),
("This method will be removed in future versions. "
- "Use .+ instead.", PendingDeprecationWarning),
- # XMLParser.doctype() is deprecated.
- ("This method of XMLParser is deprecated. Define doctype.. "
- "method on the TreeBuilder target.", DeprecationWarning))
+ "Use .+ instead.", PendingDeprecationWarning))
self.checkwarnings = support.check_warnings(*deprecations, quiet=quiet)
def __enter__(self):
- from xml.etree import ElementTree
- self._nsmap = ElementTree._namespace_map
- self._path_cache = ElementTree.ElementPath._cache
+ from xml.etree import ElementPath
+ self._nsmap = ET.register_namespace._namespace_map
# Copy the default namespace mapping
- ElementTree._namespace_map = self._nsmap.copy()
+ self._nsmap_copy = self._nsmap.copy()
# Copy the path cache (should be empty)
- ElementTree.ElementPath._cache = self._path_cache.copy()
+ self._path_cache = ElementPath._cache
+ ElementPath._cache = self._path_cache.copy()
self.checkwarnings.__enter__()
def __exit__(self, *args):
- from xml.etree import ElementTree
+ from xml.etree import ElementPath
# Restore mapping and path cache
- ElementTree._namespace_map = self._nsmap
- ElementTree.ElementPath._cache = self._path_cache
+ self._nsmap.clear()
+ self._nsmap.update(self._nsmap_copy)
+ ElementPath._cache = self._path_cache
self.checkwarnings.__exit__(*args)
-def test_main(module_name='xml.etree.ElementTree'):
- from test import test_xml_etree
+def test_main(module=None):
+ # When invoked without a module, runs the Python ET tests by loading pyET.
+ # Otherwise, uses the given module as the ET.
+ global pyET
+ pyET = import_fresh_module('xml.etree.ElementTree',
+ blocked=['_elementtree'])
+ if module is None:
+ module = pyET
+
+ global ET
+ ET = module
+
+ test_classes = [
+ ElementSlicingTest,
+ BasicElementTest,
+ IOTest,
+ ParseErrorTest,
+ XincludeTest,
+ ElementTreeTest,
+ ElementFindTest,
+ ElementIterTest,
+ TreeBuilderTest,
+ ]
+
+ # These tests will only run for the pure-Python version that doesn't import
+ # _elementtree. We can't use skipUnless here, because pyET is filled in only
+ # after the module is loaded.
+ if pyET is not ET:
+ test_classes.extend([
+ NoAcceleratorTest,
+ ])
- use_py_module = (module_name == 'xml.etree.ElementTree')
-
- # The same doctests are used for both the Python and the C implementations
- assert test_xml_etree.ET.__name__ == module_name
-
- # XXX the C module should give the same warnings as the Python module
- with CleanContext(quiet=not use_py_module):
- support.run_doctest(test_xml_etree, verbosity=True)
+ try:
+ # XXX the C module should give the same warnings as the Python module
+ with CleanContext(quiet=(pyET is not ET)):
+ support.run_unittest(*test_classes)
+ support.run_doctest(sys.modules[__name__], verbosity=True)
+ finally:
+ # don't interfere with subsequent tests
+ ET = pyET = None
- # The module should not be changed by the tests
- assert test_xml_etree.ET.__name__ == module_name
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py
index 2ff118f..e44c6ca 100644
--- a/Lib/test/test_xml_etree_c.py
+++ b/Lib/test/test_xml_etree_c.py
@@ -1,36 +1,11 @@
# xml.etree test for cElementTree
-
+import sys, struct
from test import support
-from test.support import bigmemtest, _2G
+from test.support import import_fresh_module
import unittest
-cET = support.import_module('xml.etree.cElementTree')
-
-
-# cElementTree specific tests
-
-def sanity():
- r"""
- Import sanity.
-
- >>> from xml.etree import cElementTree
-
- Issue #6697.
-
- >>> e = cElementTree.Element('a')
- >>> getattr(e, '\uD800') # doctest: +ELLIPSIS
- Traceback (most recent call last):
- ...
- UnicodeEncodeError: ...
-
- >>> p = cElementTree.XMLParser()
- >>> p.version.split()[0]
- 'Expat'
- >>> getattr(p, '\uD800')
- Traceback (most recent call last):
- ...
- AttributeError: 'XMLParser' object has no attribute '\ud800'
- """
+cET = import_fresh_module('xml.etree.ElementTree', fresh=['_elementtree'])
+cET_alias = import_fresh_module('xml.etree.cElementTree', fresh=['_elementtree', 'xml.etree'])
class MiscTests(unittest.TestCase):
@@ -47,27 +22,64 @@ class MiscTests(unittest.TestCase):
data = None
+@unittest.skipUnless(cET, 'requires _elementtree')
+class TestAliasWorking(unittest.TestCase):
+ # Test that the cET alias module is alive
+ def test_alias_working(self):
+ e = cET_alias.Element('foo')
+ self.assertEqual(e.tag, 'foo')
+
+
+@unittest.skipUnless(cET, 'requires _elementtree')
+class TestAcceleratorImported(unittest.TestCase):
+ # Test that the C accelerator was imported, as expected
+ def test_correct_import_cET(self):
+ self.assertEqual(cET.SubElement.__module__, '_elementtree')
+
+ def test_correct_import_cET_alias(self):
+ self.assertEqual(cET_alias.SubElement.__module__, '_elementtree')
+
+
+@unittest.skipUnless(cET, 'requires _elementtree')
+@support.cpython_only
+class SizeofTest(unittest.TestCase):
+ def setUp(self):
+ self.elementsize = support.calcobjsize('5P')
+ # extra
+ self.extra = struct.calcsize('PiiP4P')
+
+ check_sizeof = support.check_sizeof
+
+ def test_element(self):
+ e = cET.Element('a')
+ self.check_sizeof(e, self.elementsize)
+
+ def test_element_with_attrib(self):
+ e = cET.Element('a', href='about:')
+ self.check_sizeof(e, self.elementsize + self.extra)
+
+ def test_element_with_children(self):
+ e = cET.Element('a')
+ for i in range(5):
+ cET.SubElement(e, 'span')
+ # should have space for 8 children now
+ self.check_sizeof(e, self.elementsize + self.extra +
+ struct.calcsize('8P'))
+
def test_main():
from test import test_xml_etree, test_xml_etree_c
# Run the tests specific to the C implementation
- support.run_doctest(test_xml_etree_c, verbosity=True)
-
- support.run_unittest(MiscTests)
-
- # Assign the C implementation before running the doctests
- # Patch the __name__, to prevent confusion with the pure Python test
- pyET = test_xml_etree.ET
- py__name__ = test_xml_etree.__name__
- test_xml_etree.ET = cET
- if __name__ != '__main__':
- test_xml_etree.__name__ = __name__
- try:
- # Run the same test suite as xml.etree.ElementTree
- test_xml_etree.test_main(module_name='xml.etree.cElementTree')
- finally:
- test_xml_etree.ET = pyET
- test_xml_etree.__name__ = py__name__
+ support.run_unittest(
+ MiscTests,
+ TestAliasWorking,
+ TestAcceleratorImported,
+ SizeofTest,
+ )
+
+ # Run the same test suite as the Python module
+ test_xml_etree.test_main(module=cET)
+
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py
index 3814191..16f85c5 100644
--- a/Lib/test/test_xmlrpc.py
+++ b/Lib/test/test_xmlrpc.py
@@ -24,6 +24,8 @@ alist = [{'astring': 'foo@bar.baz.spam',
'ashortlong': 2,
'anotherlist': ['.zyx.41'],
'abase64': xmlrpclib.Binary(b"my dog has fleas"),
+ 'b64bytes': b"my dog has fleas",
+ 'b64bytearray': bytearray(b"my dog has fleas"),
'boolean': False,
'unicode': '\u4000\u6000\u8000',
'ukey\u4000': 'regular value',
@@ -44,27 +46,54 @@ class XMLRPCTestCase(unittest.TestCase):
def test_dump_bare_datetime(self):
# This checks that an unwrapped datetime.date object can be handled
# by the marshalling code. This can't be done via test_dump_load()
- # since with use_datetime set to 1 the unmarshaller would create
+ # since with use_builtin_types set to 1 the unmarshaller would create
# datetime objects for the 'datetime[123]' keys as well
dt = datetime.datetime(2005, 2, 10, 11, 41, 23)
+ self.assertEqual(dt, xmlrpclib.DateTime('20050210T11:41:23'))
s = xmlrpclib.dumps((dt,))
- (newdt,), m = xmlrpclib.loads(s, use_datetime=1)
+
+ result, m = xmlrpclib.loads(s, use_builtin_types=True)
+ (newdt,) = result
self.assertEqual(newdt, dt)
- self.assertEqual(m, None)
+ self.assertIs(type(newdt), datetime.datetime)
+ self.assertIsNone(m)
+
+ result, m = xmlrpclib.loads(s, use_builtin_types=False)
+ (newdt,) = result
+ self.assertEqual(newdt, dt)
+ self.assertIs(type(newdt), xmlrpclib.DateTime)
+ self.assertIsNone(m)
+
+ result, m = xmlrpclib.loads(s, use_datetime=True)
+ (newdt,) = result
+ self.assertEqual(newdt, dt)
+ self.assertIs(type(newdt), datetime.datetime)
+ self.assertIsNone(m)
+
+ result, m = xmlrpclib.loads(s, use_datetime=False)
+ (newdt,) = result
+ self.assertEqual(newdt, dt)
+ self.assertIs(type(newdt), xmlrpclib.DateTime)
+ self.assertIsNone(m)
- (newdt,), m = xmlrpclib.loads(s, use_datetime=0)
- self.assertEqual(newdt, xmlrpclib.DateTime('20050210T11:41:23'))
def test_datetime_before_1900(self):
# same as before but with a date before 1900
dt = datetime.datetime(1, 2, 10, 11, 41, 23)
+ self.assertEqual(dt, xmlrpclib.DateTime('00010210T11:41:23'))
s = xmlrpclib.dumps((dt,))
- (newdt,), m = xmlrpclib.loads(s, use_datetime=1)
+
+ result, m = xmlrpclib.loads(s, use_builtin_types=True)
+ (newdt,) = result
self.assertEqual(newdt, dt)
- self.assertEqual(m, None)
+ self.assertIs(type(newdt), datetime.datetime)
+ self.assertIsNone(m)
- (newdt,), m = xmlrpclib.loads(s, use_datetime=0)
- self.assertEqual(newdt, xmlrpclib.DateTime('00010210T11:41:23'))
+ result, m = xmlrpclib.loads(s, use_builtin_types=False)
+ (newdt,) = result
+ self.assertEqual(newdt, dt)
+ self.assertIs(type(newdt), xmlrpclib.DateTime)
+ self.assertIsNone(m)
def test_bug_1164912 (self):
d = xmlrpclib.DateTime()
@@ -125,6 +154,22 @@ class XMLRPCTestCase(unittest.TestCase):
self.assertRaises(OverflowError, m.dump_int,
xmlrpclib.MININT-1, dummy_write)
+ def test_dump_double(self):
+ xmlrpclib.dumps((float(2 ** 34),))
+ xmlrpclib.dumps((float(xmlrpclib.MAXINT),
+ float(xmlrpclib.MININT)))
+ xmlrpclib.dumps((float(xmlrpclib.MAXINT + 42),
+ float(xmlrpclib.MININT - 42)))
+
+ def dummy_write(s):
+ pass
+
+ m = xmlrpclib.Marshaller()
+ m.dump_double(xmlrpclib.MAXINT, dummy_write)
+ m.dump_double(xmlrpclib.MININT, dummy_write)
+ m.dump_double(xmlrpclib.MAXINT + 42, dummy_write)
+ m.dump_double(xmlrpclib.MININT - 42, dummy_write)
+
def test_dump_none(self):
value = alist + [None]
arg1 = (alist + [None],)
@@ -133,6 +178,25 @@ class XMLRPCTestCase(unittest.TestCase):
xmlrpclib.loads(strg)[0][0])
self.assertRaises(TypeError, xmlrpclib.dumps, (arg1,))
+ def test_dump_bytes(self):
+ sample = b"my dog has fleas"
+ self.assertEqual(sample, xmlrpclib.Binary(sample))
+ for type_ in bytes, bytearray, xmlrpclib.Binary:
+ value = type_(sample)
+ s = xmlrpclib.dumps((value,))
+
+ result, m = xmlrpclib.loads(s, use_builtin_types=True)
+ (newvalue,) = result
+ self.assertEqual(newvalue, sample)
+ self.assertIs(type(newvalue), bytes)
+ self.assertIsNone(m)
+
+ result, m = xmlrpclib.loads(s, use_builtin_types=False)
+ (newvalue,) = result
+ self.assertEqual(newvalue, sample)
+ self.assertIs(type(newvalue), xmlrpclib.Binary)
+ self.assertIsNone(m)
+
def test_get_host_info(self):
# see bug #3613, this raised a TypeError
transp = xmlrpc.client.Transport()
@@ -140,9 +204,6 @@ class XMLRPCTestCase(unittest.TestCase):
('host.tld',
[('Authorization', 'Basic dXNlcg==')], {}))
- def test_dump_bytes(self):
- self.assertRaises(TypeError, xmlrpclib.dumps, (b"my dog has fleas",))
-
def test_ssl_presence(self):
try:
import ssl
@@ -980,10 +1041,44 @@ class CGIHandlerTestCase(unittest.TestCase):
len(content))
+class UseBuiltinTypesTestCase(unittest.TestCase):
+
+ def test_use_builtin_types(self):
+ # SimpleXMLRPCDispatcher.__init__ accepts use_builtin_types, which
+ # makes all dispatch of binary data as bytes instances, and all
+ # dispatch of datetime argument as datetime.datetime instances.
+ self.log = []
+ expected_bytes = b"my dog has fleas"
+ expected_date = datetime.datetime(2008, 5, 26, 18, 25, 12)
+ marshaled = xmlrpclib.dumps((expected_bytes, expected_date), 'foobar')
+ def foobar(*args):
+ self.log.extend(args)
+ handler = xmlrpc.server.SimpleXMLRPCDispatcher(
+ allow_none=True, encoding=None, use_builtin_types=True)
+ handler.register_function(foobar)
+ handler._marshaled_dispatch(marshaled)
+ self.assertEqual(len(self.log), 2)
+ mybytes, mydate = self.log
+ self.assertEqual(self.log, [expected_bytes, expected_date])
+ self.assertIs(type(mydate), datetime.datetime)
+ self.assertIs(type(mybytes), bytes)
+
+ def test_cgihandler_has_use_builtin_types_flag(self):
+ handler = xmlrpc.server.CGIXMLRPCRequestHandler(use_builtin_types=True)
+ self.assertTrue(handler.use_builtin_types)
+
+ def test_xmlrpcserver_has_use_builtin_types_flag(self):
+ server = xmlrpc.server.SimpleXMLRPCServer(("localhost", 0),
+ use_builtin_types=True)
+ server.server_close()
+ self.assertTrue(server.use_builtin_types)
+
+
@support.reap_threads
def test_main():
xmlrpc_tests = [XMLRPCTestCase, HelperTestCase, DateTimeTestCase,
BinaryTestCase, FaultTestCase]
+ xmlrpc_tests.append(UseBuiltinTypesTestCase)
xmlrpc_tests.append(SimpleServerTestCase)
xmlrpc_tests.append(KeepaliveServerTestCase1)
xmlrpc_tests.append(KeepaliveServerTestCase2)
diff --git a/Lib/test/test_xmlrpc_net.py b/Lib/test/test_xmlrpc_net.py
index d72f8ac..dfb5f9a 100644
--- a/Lib/test/test_xmlrpc_net.py
+++ b/Lib/test/test_xmlrpc_net.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-import collections
+import collections.abc
import errno
import socket
import sys
@@ -49,7 +49,7 @@ class CurrentTimeTest(unittest.TestCase):
# Perform a minimal sanity check on the result, just to be sure
# the request means what we think it means.
- self.assertIsInstance(builders, collections.Sequence)
+ self.assertIsInstance(builders, collections.abc.Sequence)
self.assertTrue([x for x in builders if "3.x" in x], builders)
diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py
index c00c83e..5e837cd 100644
--- a/Lib/test/test_zipfile.py
+++ b/Lib/test/test_zipfile.py
@@ -1,9 +1,3 @@
-# We can test part of the module without zlib.
-try:
- import zlib
-except ImportError:
- zlib = None
-
import io
import os
import sys
@@ -20,7 +14,8 @@ from random import randint, random
from unittest import skipUnless
from test.support import (TESTFN, run_unittest, findfile, unlink,
- captured_stdout)
+ requires_zlib, requires_bz2, requires_lzma,
+ captured_stdout)
TESTFN2 = TESTFN + "2"
TESTFNDIR = TESTFN + "d"
@@ -270,44 +265,44 @@ class TestsWithSourceFile(unittest.TestCase):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_iterlines_test(f, zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_open_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_open_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_random_open_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_random_open_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_readline_read_deflated(self):
# Issue #7610: calls to readline() interleaved with calls to read().
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_readline_read_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_readline_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_readline_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_readlines_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_readlines_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_iterlines_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_iterlines_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_low_compression(self):
"""Check for cases where compressed data is larger than original."""
# Create the ZIP archive
@@ -320,6 +315,103 @@ class TestsWithSourceFile(unittest.TestCase):
self.assertEqual(openobj.read(1), b'1')
self.assertEqual(openobj.read(1), b'2')
+ @requires_bz2
+ def test_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_open_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_open_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_random_open_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_random_open_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_readline_read_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_readline_read_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_readline_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_readline_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_readlines_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_readlines_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_iterlines_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_iterlines_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_low_compression_bzip2(self):
+ """Check for cases where compressed data is larger than original."""
+ # Create the ZIP archive
+ with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_BZIP2) as zipfp:
+ zipfp.writestr("strfile", '12')
+
+ # Get an open object for strfile
+ with zipfile.ZipFile(TESTFN2, "r", zipfile.ZIP_BZIP2) as zipfp:
+ with zipfp.open("strfile") as openobj:
+ self.assertEqual(openobj.read(1), b'1')
+ self.assertEqual(openobj.read(1), b'2')
+
+ @requires_lzma
+ def test_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_open_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_open_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_random_open_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_random_open_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_readline_read_lzma(self):
+ # Issue #7610: calls to readline() interleaved with calls to read().
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_readline_read_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_readline_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_readline_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_readlines_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_readlines_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_iterlines_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_iterlines_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_low_compression_lzma(self):
+ """Check for cases where compressed data is larger than original."""
+ # Create the ZIP archive
+ with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_LZMA) as zipfp:
+ zipfp.writestr("strfile", '12')
+
+ # Get an open object for strfile
+ with zipfile.ZipFile(TESTFN2, "r", zipfile.ZIP_LZMA) as zipfp:
+ with zipfp.open("strfile") as openobj:
+ self.assertEqual(openobj.read(1), b'1')
+ self.assertEqual(openobj.read(1), b'2')
+
def test_absolute_arcnames(self):
with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_STORED) as zipfp:
zipfp.write(TESTFN, "/absolute")
@@ -378,7 +470,7 @@ class TestsWithSourceFile(unittest.TestCase):
with open(TESTFN, "rb") as f:
self.assertEqual(zipfp.read(TESTFN), f.read())
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_per_file_compression(self):
"""Check that files within a Zip archive can have different
compression options."""
@@ -446,8 +538,15 @@ class TestsWithSourceFile(unittest.TestCase):
with open(filename, 'rb') as f:
self.assertEqual(f.read(), content)
- def test_extract_hackers_arcnames(self):
- hacknames = [
+ def test_sanitize_windows_name(self):
+ san = zipfile.ZipFile._sanitize_windows_name
+ # Passing pathsep in allows this test to work regardless of platform.
+ self.assertEqual(san(r',,?,C:,foo,bar/z', ','), r'_,C_,foo,bar/z')
+ self.assertEqual(san(r'a\b,c<d>e|f"g?h*i', ','), r'a\b,c_d_e_f_g_h_i')
+ self.assertEqual(san('../../foo../../ba..r', '/'), r'foo/ba..r')
+
+ def test_extract_hackers_arcnames_common_cases(self):
+ common_hacknames = [
('../foo/bar', 'foo/bar'),
('foo/../bar', 'foo/bar'),
('foo/../../bar', 'foo/bar'),
@@ -457,8 +556,12 @@ class TestsWithSourceFile(unittest.TestCase):
('/foo/../bar', 'foo/bar'),
('/foo/../../bar', 'foo/bar'),
]
- if os.path.sep == '\\': # Windows.
- hacknames.extend([
+ self._test_extract_hackers_arcnames(common_hacknames)
+
+ @unittest.skipIf(os.path.sep != '\\', 'Requires \\ as path separator.')
+ def test_extract_hackers_arcnames_windows_only(self):
+ """Test combination of path fixing and windows name sanitization."""
+ windows_hacknames = [
(r'..\foo\bar', 'foo/bar'),
(r'..\/foo\/bar', 'foo/bar'),
(r'foo/\..\/bar', 'foo/bar'),
@@ -478,14 +581,19 @@ class TestsWithSourceFile(unittest.TestCase):
(r'C:/../C:/foo/bar', 'C_/foo/bar'),
(r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'),
('../../foo../../ba..r', 'foo/ba..r'),
- ])
- else: # Unix
- hacknames.extend([
- ('//foo/bar', 'foo/bar'),
- ('../../foo../../ba..r', 'foo../ba..r'),
- (r'foo/..\bar', r'foo/..\bar'),
- ])
+ ]
+ self._test_extract_hackers_arcnames(windows_hacknames)
+
+ @unittest.skipIf(os.path.sep != '/', r'Requires / as path separator.')
+ def test_extract_hackers_arcnames_posix_only(self):
+ posix_hacknames = [
+ ('//foo/bar', 'foo/bar'),
+ ('../../foo../../ba..r', 'foo../ba..r'),
+ (r'foo/..\bar', r'foo/..\bar'),
+ ]
+ self._test_extract_hackers_arcnames(posix_hacknames)
+ def _test_extract_hackers_arcnames(self, hacknames):
for arcname, fixedname in hacknames:
content = b'foobar' + arcname.encode()
with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_STORED) as zipfp:
@@ -502,7 +610,8 @@ class TestsWithSourceFile(unittest.TestCase):
with zipfile.ZipFile(TESTFN2, 'r') as zipfp:
writtenfile = zipfp.extract(arcname, targetpath)
self.assertEqual(writtenfile, correctfile,
- msg="extract %r" % arcname)
+ msg='extract %r: %r != %r' %
+ (arcname, writtenfile, correctfile))
self.check_file(correctfile, content)
shutil.rmtree('target')
@@ -527,19 +636,32 @@ class TestsWithSourceFile(unittest.TestCase):
os.remove(TESTFN2)
- def test_writestr_compression(self):
+ def test_writestr_compression_stored(self):
zipfp = zipfile.ZipFile(TESTFN2, "w")
zipfp.writestr("a.txt", "hello world", compress_type=zipfile.ZIP_STORED)
- if zlib:
- zipfp.writestr("b.txt", "hello world", compress_type=zipfile.ZIP_DEFLATED)
-
info = zipfp.getinfo('a.txt')
self.assertEqual(info.compress_type, zipfile.ZIP_STORED)
- if zlib:
- info = zipfp.getinfo('b.txt')
- self.assertEqual(info.compress_type, zipfile.ZIP_DEFLATED)
+ @requires_zlib
+ def test_writestr_compression_deflated(self):
+ zipfp = zipfile.ZipFile(TESTFN2, "w")
+ zipfp.writestr("b.txt", "hello world", compress_type=zipfile.ZIP_DEFLATED)
+ info = zipfp.getinfo('b.txt')
+ self.assertEqual(info.compress_type, zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_writestr_compression_bzip2(self):
+ zipfp = zipfile.ZipFile(TESTFN2, "w")
+ zipfp.writestr("b.txt", "hello world", compress_type=zipfile.ZIP_BZIP2)
+ info = zipfp.getinfo('b.txt')
+ self.assertEqual(info.compress_type, zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_writestr_compression_lzma(self):
+ zipfp = zipfile.ZipFile(TESTFN2, "w")
+ zipfp.writestr("b.txt", "hello world", compress_type=zipfile.ZIP_LZMA)
+ info = zipfp.getinfo('b.txt')
+ self.assertEqual(info.compress_type, zipfile.ZIP_LZMA)
def zip_test_writestr_permissions(self, f, compression):
# Make sure that writestr creates files with mode 0600,
@@ -595,7 +717,12 @@ class TestsWithSourceFile(unittest.TestCase):
self.assertRaises(ValueError, zipfp.write, TESTFN)
- @skipUnless(zlib, "requires zlib")
+
+
+
+
+
+ @requires_zlib
def test_unicode_filenames(self):
# bug #10801
fname = findfile('zip_cp437_header.zip')
@@ -704,11 +831,21 @@ class TestZip64InSmallFiles(unittest.TestCase):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_test(f, zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_test(f, zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_test(f, zipfile.ZIP_LZMA)
+
def test_absolute_arcnames(self):
with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_STORED,
allowZip64=True) as zipfp:
@@ -859,8 +996,40 @@ class OtherTests(unittest.TestCase):
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00'
b'\x00afilePK\x05\x06\x00\x00\x00\x00\x01\x00'
b'\x01\x003\x00\x00\x003\x00\x00\x00\x00\x00'),
+ zipfile.ZIP_BZIP2: (
+ b'PK\x03\x04\x14\x03\x00\x00\x0c\x00nu\x0c=FA'
+ b'KE8\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
+ b'ileBZh91AY&SY\xd4\xa8\xca'
+ b'\x7f\x00\x00\x0f\x11\x80@\x00\x06D\x90\x80 \x00 \xa5'
+ b'P\xd9!\x03\x03\x13\x13\x13\x89\xa9\xa9\xc2u5:\x9f'
+ b'\x8b\xb9"\x9c(HjTe?\x80PK\x01\x02\x14'
+ b'\x03\x14\x03\x00\x00\x0c\x00nu\x0c=FAKE8'
+ b'\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00'
+ b'\x00 \x80\x80\x81\x00\x00\x00\x00afilePK'
+ b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00[\x00'
+ b'\x00\x00\x00\x00'),
+ zipfile.ZIP_LZMA: (
+ b'PK\x03\x04\x14\x03\x00\x00\x0e\x00nu\x0c=FA'
+ b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
+ b'ile\t\x04\x05\x00]\x00\x00\x00\x04\x004\x19I'
+ b'\xee\x8d\xe9\x17\x89:3`\tq!.8\x00PK'
+ b'\x01\x02\x14\x03\x14\x03\x00\x00\x0e\x00nu\x0c=FA'
+ b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00'
+ b'\x00\x00\x00\x00 \x80\x80\x81\x00\x00\x00\x00afil'
+ b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
+ b'\x00>\x00\x00\x00\x00\x00'),
}
+ def test_unsupported_version(self):
+ # File has an extract_version of 120
+ data = (b'PK\x03\x04x\x00\x00\x00\x00\x00!p\xa1@\x00\x00\x00\x00\x00\x00'
+ b'\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00xPK\x01\x02x\x03x\x00\x00\x00\x00'
+ b'\x00!p\xa1@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00'
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00xPK\x05\x06'
+ b'\x00\x00\x00\x00\x01\x00\x01\x00/\x00\x00\x00\x1f\x00\x00\x00\x00\x00')
+ self.assertRaises(NotImplementedError, zipfile.ZipFile,
+ io.BytesIO(data), 'r')
+
def test_unicode_filenames(self):
with zipfile.ZipFile(TESTFN, "w") as zf:
zf.writestr("foo.txt", "Test for unicode filename")
@@ -1151,10 +1320,18 @@ class OtherTests(unittest.TestCase):
def test_testzip_with_bad_crc_stored(self):
self.check_testzip_with_bad_crc(zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_testzip_with_bad_crc_deflated(self):
self.check_testzip_with_bad_crc(zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_testzip_with_bad_crc_bzip2(self):
+ self.check_testzip_with_bad_crc(zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_testzip_with_bad_crc_lzma(self):
+ self.check_testzip_with_bad_crc(zipfile.ZIP_LZMA)
+
def check_read_with_bad_crc(self, compression):
"""Tests that files with bad CRCs raise a BadZipFile exception when read."""
zipdata = self.zips_with_bad_crc[compression]
@@ -1179,10 +1356,18 @@ class OtherTests(unittest.TestCase):
def test_read_with_bad_crc_stored(self):
self.check_read_with_bad_crc(zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_read_with_bad_crc_deflated(self):
self.check_read_with_bad_crc(zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_read_with_bad_crc_bzip2(self):
+ self.check_read_with_bad_crc(zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_read_with_bad_crc_lzma(self):
+ self.check_read_with_bad_crc(zipfile.ZIP_LZMA)
+
def check_read_return_size(self, compression):
# Issue #9837: ZipExtFile.read() shouldn't return more bytes
# than requested.
@@ -1199,10 +1384,18 @@ class OtherTests(unittest.TestCase):
def test_read_return_size_stored(self):
self.check_read_return_size(zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_read_return_size_deflated(self):
self.check_read_return_size(zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_read_return_size_bzip2(self):
+ self.check_read_return_size(zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_read_return_size_lzma(self):
+ self.check_read_return_size(zipfile.ZIP_LZMA)
+
def test_empty_zipfile(self):
# Check that creating a file in 'w' or 'a' mode and closing without
# adding any files to the archives creates a valid empty ZIP file
@@ -1289,7 +1482,7 @@ class DecryptionTests(unittest.TestCase):
self.zip2.setpassword(b"perl")
self.assertRaises(RuntimeError, self.zip2.read, "zero")
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_good_password(self):
self.zip.setpassword(b"python")
self.assertEqual(self.zip.read("test.txt"), self.plain)
@@ -1339,11 +1532,21 @@ class TestsWithRandomBinaryFiles(unittest.TestCase):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_test(f, zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_test(f, zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_test(f, zipfile.ZIP_LZMA)
+
def zip_open_test(self, f, compression):
self.make_test_archive(f, compression)
@@ -1379,11 +1582,21 @@ class TestsWithRandomBinaryFiles(unittest.TestCase):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_open_test(f, zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_open_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_open_test(f, zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_open_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_open_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_open_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_open_test(f, zipfile.ZIP_LZMA)
+
def zip_random_open_test(self, f, compression):
self.make_test_archive(f, compression)
@@ -1407,13 +1620,23 @@ class TestsWithRandomBinaryFiles(unittest.TestCase):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_random_open_test(f, zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_random_open_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.zip_random_open_test(f, zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_random_open_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_random_open_test(f, zipfile.ZIP_BZIP2)
-@skipUnless(zlib, "requires zlib")
+ @requires_lzma
+ def test_random_open_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.zip_random_open_test(f, zipfile.ZIP_LZMA)
+
+
+@requires_zlib
class TestsWithMultipleOpens(unittest.TestCase):
def setUp(self):
# Create the ZIP archive
@@ -1605,32 +1828,82 @@ class UniversalNewlineTests(unittest.TestCase):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.iterlines_test(f, zipfile.ZIP_STORED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_read_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.read_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_readline_read_deflated(self):
# Issue #7610: calls to readline() interleaved with calls to read().
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.readline_read_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_readline_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.readline_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_readlines_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.readlines_test(f, zipfile.ZIP_DEFLATED)
- @skipUnless(zlib, "requires zlib")
+ @requires_zlib
def test_iterlines_deflated(self):
for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
self.iterlines_test(f, zipfile.ZIP_DEFLATED)
+ @requires_bz2
+ def test_read_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.read_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_readline_read_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.readline_read_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_readline_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.readline_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_readlines_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.readlines_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_bz2
+ def test_iterlines_bzip2(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.iterlines_test(f, zipfile.ZIP_BZIP2)
+
+ @requires_lzma
+ def test_read_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.read_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_readline_read_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.readline_read_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_readline_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.readline_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_readlines_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.readlines_test(f, zipfile.ZIP_LZMA)
+
+ @requires_lzma
+ def test_iterlines_lzma(self):
+ for f in (TESTFN2, TemporaryFile(), io.BytesIO()):
+ self.iterlines_test(f, zipfile.ZIP_LZMA)
+
def tearDown(self):
for sep, fn in self.arcfiles.items():
os.remove(fn)
diff --git a/Lib/test/test_zipfile64.py b/Lib/test/test_zipfile64.py
index 0e7d73f..a8fb7ab 100644
--- a/Lib/test/test_zipfile64.py
+++ b/Lib/test/test_zipfile64.py
@@ -11,12 +11,6 @@ support.requires(
'test requires loads of disk-space bytes and a long time to run'
)
-# We can test part of the module without zlib.
-try:
- import zlib
-except ImportError:
- zlib = None
-
import zipfile, os, unittest
import time
import sys
@@ -24,7 +18,7 @@ import sys
from io import StringIO
from tempfile import TemporaryFile
-from test.support import TESTFN, run_unittest
+from test.support import TESTFN, run_unittest, requires_zlib
TESTFN2 = TESTFN + "2"
@@ -81,12 +75,12 @@ class TestsWithSourceFile(unittest.TestCase):
for f in TemporaryFile(), TESTFN2:
self.zipTest(f, zipfile.ZIP_STORED)
- if zlib:
- def testDeflated(self):
- # Try the temp file first. If we do TESTFN2 first, then it hogs
- # gigabytes of disk space for the duration of the test.
- for f in TemporaryFile(), TESTFN2:
- self.zipTest(f, zipfile.ZIP_DEFLATED)
+ @requires_zlib
+ def testDeflated(self):
+ # Try the temp file first. If we do TESTFN2 first, then it hogs
+ # gigabytes of disk space for the duration of the test.
+ for f in TemporaryFile(), TESTFN2:
+ self.zipTest(f, zipfile.ZIP_DEFLATED)
def tearDown(self):
for fname in TESTFN, TESTFN2:
diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py
index df5ff9d..f7cb8b9 100644
--- a/Lib/test/test_zipimport.py
+++ b/Lib/test/test_zipimport.py
@@ -9,12 +9,6 @@ import unittest
from test import support
from test.test_importhooks import ImportHooksBaseTestCase, test_src, test_co
-# some tests can be ran even without zlib
-try:
- import zlib
-except ImportError:
- zlib = None
-
from zipfile import ZipFile, ZipInfo, ZIP_STORED, ZIP_DEFLATED
import zipimport
@@ -25,7 +19,7 @@ import io
from traceback import extract_tb, extract_stack, print_tb
raise_src = 'def do_raise(): raise TypeError\n'
-def make_pyc(co, mtime):
+def make_pyc(co, mtime, size):
data = marshal.dumps(co)
if type(mtime) is type(0.0):
# Mac mtimes need a bit of special casing
@@ -33,14 +27,14 @@ def make_pyc(co, mtime):
mtime = int(mtime)
else:
mtime = int(-0x100000000 + int(mtime))
- pyc = imp.get_magic() + struct.pack("<i", int(mtime)) + data
+ pyc = imp.get_magic() + struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data
return pyc
def module_path_to_dotted_name(path):
return path.replace(os.sep, '.')
NOW = time.time()
-test_pyc = make_pyc(test_co, NOW)
+test_pyc = make_pyc(test_co, NOW, len(test_src))
TESTMOD = "ziptestmodule"
@@ -211,6 +205,10 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
mod = zi.load_module(TESTPACK)
self.assertEqual(zi.get_filename(TESTPACK), mod.__file__)
+ existing_pack_path = __import__(TESTPACK).__path__[0]
+ expected_path_path = os.path.join(TEMP_ZIP, TESTPACK)
+ self.assertEqual(existing_pack_path, expected_path_path)
+
self.assertEqual(zi.is_package(packdir + '__init__'), False)
self.assertEqual(zi.is_package(packdir + TESTPACK2), True)
self.assertEqual(zi.is_package(packdir2 + TESTMOD), False)
@@ -299,7 +297,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
return __file__
if __loader__.get_data("some.data") != b"some data":
raise AssertionError("bad data")\n"""
- pyc = make_pyc(compile(src, "<???>", "exec"), NOW)
+ pyc = make_pyc(compile(src, "<???>", "exec"), NOW, len(src))
files = {TESTMOD + pyc_ext: (NOW, pyc),
"some.data": (NOW, "some data")}
self.doTest(pyc_ext, files, TESTMOD)
@@ -319,7 +317,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
self.doTest(".py", files, TESTMOD, call=self.assertModuleSource)
def testGetCompiledSource(self):
- pyc = make_pyc(compile(test_src, "<???>", "exec"), NOW)
+ pyc = make_pyc(compile(test_src, "<???>", "exec"), NOW, len(test_src))
files = {TESTMOD + ".py": (NOW, test_src),
TESTMOD + pyc_ext: (NOW, pyc)}
self.doTest(pyc_ext, files, TESTMOD, call=self.assertModuleSource)
@@ -392,7 +390,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
os.remove(filename)
-@unittest.skipUnless(zlib, "requires zlib")
+@support.requires_zlib
class CompressedZipImportTestCase(UncompressedZipImportTestCase):
compression = ZIP_DEFLATED
@@ -417,7 +415,7 @@ class BadFileZipImportTestCase(unittest.TestCase):
def testEmptyFile(self):
support.unlink(TESTMOD)
- open(TESTMOD, 'w+').close()
+ support.create_empty_file(TESTMOD)
self.assertZipFailure(TESTMOD)
def testFileUnreadable(self):
diff --git a/Lib/test/test_zipimport_support.py b/Lib/test/test_zipimport_support.py
index 060108f..f7f3398 100644
--- a/Lib/test/test_zipimport_support.py
+++ b/Lib/test/test_zipimport_support.py
@@ -167,20 +167,19 @@ class ZipSupportTests(unittest.TestCase):
test_zipped_doctest.test_DocTestRunner.verbose_flag,
test_zipped_doctest.test_Example,
test_zipped_doctest.test_debug,
- test_zipped_doctest.test_pdb_set_trace,
- test_zipped_doctest.test_pdb_set_trace_nested,
test_zipped_doctest.test_testsource,
test_zipped_doctest.test_trailing_space_in_test,
test_zipped_doctest.test_DocTestSuite,
test_zipped_doctest.test_DocTestFinder,
]
- # These remaining tests are the ones which need access
+ # These tests are the ones which need access
# to the data files, so we don't run them
fail_due_to_missing_data_files = [
test_zipped_doctest.test_DocFileSuite,
test_zipped_doctest.test_testfile,
test_zipped_doctest.test_unittest_reportflags,
]
+
for obj in known_good_tests:
_run_object_doctest(obj, test_zipped_doctest)
finally:
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
index 4661c1d..2f6f840 100644
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -7,10 +7,16 @@ from test.support import bigmemtest, _1G, _4G
zlib = support.import_module('zlib')
-try:
- import mmap
-except ImportError:
- mmap = None
+
+class VersionTestCase(unittest.TestCase):
+
+ def test_library_version(self):
+ # Test that the major version of the actual library in use matches the
+ # major version that we were compiled against. We can't guarantee that
+ # the minor versions will match (even on the machine on which the module
+ # was compiled), and the API is stable between minor versions, so
+ # testing only the major versions avoids spurious failures.
+ self.assertEqual(zlib.ZLIB_RUNTIME_VERSION[0], zlib.ZLIB_VERSION[0])
class ChecksumTestCase(unittest.TestCase):
@@ -173,10 +179,8 @@ class CompressTestCase(BaseCompressTestCase, unittest.TestCase):
def test_big_decompress_buffer(self, size):
self.check_big_decompress_buffer(size, zlib.decompress)
- @bigmemtest(size=_4G + 100, memuse=1)
+ @bigmemtest(size=_4G + 100, memuse=1, dry_run=False)
def test_length_overflow(self, size):
- if size < _4G + 100:
- self.skipTest("not enough free memory, need at least 4 GB")
data = b'x' * size
try:
self.assertRaises(OverflowError, zlib.compress, data, 1)
@@ -421,6 +425,35 @@ class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
dco = zlib.decompressobj()
self.assertEqual(dco.flush(), b"") # Returns nothing
+ def test_dictionary(self):
+ h = HAMLET_SCENE
+ # Build a simulated dictionary out of the words in HAMLET.
+ words = h.split()
+ random.shuffle(words)
+ zdict = b''.join(words)
+ # Use it to compress HAMLET.
+ co = zlib.compressobj(zdict=zdict)
+ cd = co.compress(h) + co.flush()
+ # Verify that it will decompress with the dictionary.
+ dco = zlib.decompressobj(zdict=zdict)
+ self.assertEqual(dco.decompress(cd) + dco.flush(), h)
+ # Verify that it fails when not given the dictionary.
+ dco = zlib.decompressobj()
+ self.assertRaises(zlib.error, dco.decompress, cd)
+
+ def test_dictionary_streaming(self):
+ # This simulates the reuse of a compressor object for compressing
+ # several separate data streams.
+ co = zlib.compressobj(zdict=HAMLET_SCENE)
+ do = zlib.decompressobj(zdict=HAMLET_SCENE)
+ piece = HAMLET_SCENE[1000:1500]
+ d0 = co.compress(piece) + co.flush(zlib.Z_SYNC_FLUSH)
+ d1 = co.compress(piece[100:]) + co.flush(zlib.Z_SYNC_FLUSH)
+ d2 = co.compress(piece[:-100]) + co.flush(zlib.Z_SYNC_FLUSH)
+ self.assertEqual(do.decompress(d0), piece)
+ self.assertEqual(do.decompress(d1), piece[100:])
+ self.assertEqual(do.decompress(d2), piece[:-100])
+
def test_decompress_incomplete_stream(self):
# This is 'foo', deflated
x = b'x\x9cK\xcb\xcf\x07\x00\x02\x82\x01E'
@@ -434,6 +467,26 @@ class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
y += dco.flush()
self.assertEqual(y, b'foo')
+ def test_decompress_eof(self):
+ x = b'x\x9cK\xcb\xcf\x07\x00\x02\x82\x01E' # 'foo'
+ dco = zlib.decompressobj()
+ self.assertFalse(dco.eof)
+ dco.decompress(x[:-5])
+ self.assertFalse(dco.eof)
+ dco.decompress(x[-5:])
+ self.assertTrue(dco.eof)
+ dco.flush()
+ self.assertTrue(dco.eof)
+
+ def test_decompress_eof_incomplete_stream(self):
+ x = b'x\x9cK\xcb\xcf\x07\x00\x02\x82\x01E' # 'foo'
+ dco = zlib.decompressobj()
+ self.assertFalse(dco.eof)
+ dco.decompress(x[:-5])
+ self.assertFalse(dco.eof)
+ dco.flush()
+ self.assertFalse(dco.eof)
+
def test_decompress_unused_data(self):
# Repeated calls to decompress() after EOF should accumulate data in
# dco.unused_data, instead of just storing the arg to the last call.
@@ -455,6 +508,7 @@ class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
data += dco.decompress(
dco.unconsumed_tail + x[i : i + step], maxlen)
data += dco.flush()
+ self.assertTrue(dco.eof)
self.assertEqual(data, source)
self.assertEqual(dco.unconsumed_tail, b'')
self.assertEqual(dco.unused_data, remainder)
@@ -547,10 +601,8 @@ class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
decompress = lambda s: d.decompress(s) + d.flush()
self.check_big_decompress_buffer(size, decompress)
- @bigmemtest(size=_4G + 100, memuse=1)
+ @bigmemtest(size=_4G + 100, memuse=1, dry_run=False)
def test_length_overflow(self, size):
- if size < _4G + 100:
- self.skipTest("not enough free memory, need at least 4 GB")
data = b'x' * size
c = zlib.compressobj(1)
d = zlib.decompressobj()
@@ -651,6 +703,7 @@ LAERTES
def test_main():
support.run_unittest(
+ VersionTestCase,
ChecksumTestCase,
ChecksumBigBufferTestCase,
ExceptionTestCase,
diff --git a/Lib/test/testbz2_bigmem.bz2 b/Lib/test/testbz2_bigmem.bz2
deleted file mode 100644
index c9a4616..0000000
--- a/Lib/test/testbz2_bigmem.bz2
+++ /dev/null
Binary files differ
diff --git a/Lib/test/threaded_import_hangers.py b/Lib/test/threaded_import_hangers.py
index adf03e3..5484e60 100644
--- a/Lib/test/threaded_import_hangers.py
+++ b/Lib/test/threaded_import_hangers.py
@@ -35,8 +35,11 @@ for name, func, args in [
("os.path.abspath", os.path.abspath, ('.',)),
]:
- t = Worker(func, args)
- t.start()
- t.join(TIMEOUT)
- if t.is_alive():
- errors.append("%s appeared to hang" % name)
+ try:
+ t = Worker(func, args)
+ t.start()
+ t.join(TIMEOUT)
+ if t.is_alive():
+ errors.append("%s appeared to hang" % name)
+ finally:
+ del t
diff --git a/Lib/test/tokenize_tests.txt b/Lib/test/tokenize_tests.txt
index 06c83b0..2c5fb10 100644
--- a/Lib/test/tokenize_tests.txt
+++ b/Lib/test/tokenize_tests.txt
@@ -114,8 +114,12 @@ x = b'abc' + B'ABC'
y = b"abc" + B"ABC"
x = br'abc' + Br'ABC' + bR'ABC' + BR'ABC'
y = br"abc" + Br"ABC" + bR"ABC" + BR"ABC"
+x = rb'abc' + rB'ABC' + Rb'ABC' + RB'ABC'
+y = rb"abc" + rB"ABC" + Rb"ABC" + RB"ABC"
x = br'\\' + BR'\\'
+x = rb'\\' + RB'\\'
x = br'\'' + ''
+x = rb'\'' + ''
y = br'''
foo bar \\
baz''' + BR'''
@@ -124,6 +128,10 @@ y = Br"""foo
bar \\ baz
""" + bR'''spam
'''
+y = rB"""foo
+bar \\ baz
+""" + Rb'''spam
+'''
# Indentation
if 1: