From afb9dc887c6e8ae17b6a54c6124399e8bdc82253 Mon Sep 17 00:00:00 2001
From: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>
Date: Mon, 13 Jan 2025 13:05:02 +0200
Subject: gh-128595: Add test class helper to force no terminal colour
 (#128687)

Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
---
 Lib/test/support/__init__.py          | 47 +++++++++++++++++++++++------------
 Lib/test/test_code_module.py          |  3 ++-
 Lib/test/test_exceptions.py           |  1 +
 Lib/test/test_traceback.py            |  9 ++++++-
 Lib/test/test_unittest/test_result.py | 11 +++++---
 5 files changed, 49 insertions(+), 22 deletions(-)

diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index 42e7b87..ee9520a8 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -60,6 +60,7 @@ __all__ = [
     "skip_on_s390x",
     "without_optimizer",
     "force_not_colorized",
+    "force_not_colorized_test_class",
     "BrokenIter",
     "in_systemd_nspawn_sync_suppressed",
     "run_no_yield_async_fn", "run_yielding_async_fn", "async_yield",
@@ -2832,30 +2833,44 @@ def iter_slot_wrappers(cls):
             yield name, True
 
 
+@contextlib.contextmanager
+def no_color():
+    import _colorize
+    from .os_helper import EnvironmentVarGuard
+
+    with (
+        swap_attr(_colorize, "can_colorize", lambda: False),
+        EnvironmentVarGuard() as env,
+    ):
+        for var in {"FORCE_COLOR", "NO_COLOR", "PYTHON_COLORS"}:
+            env.unset(var)
+        env.set("NO_COLOR", "1")
+        yield
+
+
 def force_not_colorized(func):
     """Force the terminal not to be colorized."""
     @functools.wraps(func)
     def wrapper(*args, **kwargs):
-        import _colorize
-        original_fn = _colorize.can_colorize
-        variables: dict[str, str | None] = {
-            "PYTHON_COLORS": None, "FORCE_COLOR": None, "NO_COLOR": None
-        }
-        try:
-            for key in variables:
-                variables[key] = os.environ.pop(key, None)
-            os.environ["NO_COLOR"] = "1"
-            _colorize.can_colorize = lambda: False
+        with no_color():
             return func(*args, **kwargs)
-        finally:
-            _colorize.can_colorize = original_fn
-            del os.environ["NO_COLOR"]
-            for key, value in variables.items():
-                if value is not None:
-                    os.environ[key] = value
     return wrapper
 
 
+def force_not_colorized_test_class(cls):
+    """Force the terminal not to be colorized for the entire test class."""
+    original_setUpClass = cls.setUpClass
+
+    @classmethod
+    @functools.wraps(cls.setUpClass)
+    def new_setUpClass(cls):
+        cls.enterClassContext(no_color())
+        original_setUpClass()
+
+    cls.setUpClass = new_setUpClass
+    return cls
+
+
 def initialized_with_pyrepl():
     """Detect whether PyREPL was used during Python initialization."""
     # If the main module has a __file__ attribute it's a Python module, which means PyREPL.
diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py
index 37c7bc7..20b960c 100644
--- a/Lib/test/test_code_module.py
+++ b/Lib/test/test_code_module.py
@@ -5,9 +5,9 @@ import unittest
 from textwrap import dedent
 from contextlib import ExitStack
 from unittest import mock
+from test.support import force_not_colorized_test_class
 from test.support import import_helper
 
-
 code = import_helper.import_module('code')
 
 
@@ -30,6 +30,7 @@ class MockSys:
         del self.sysmod.ps2
 
 
+@force_not_colorized_test_class
 class TestInteractiveConsole(unittest.TestCase, MockSys):
     maxDiff = None
 
diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py
index 6ccfa95..206e22e 100644
--- a/Lib/test/test_exceptions.py
+++ b/Lib/test/test_exceptions.py
@@ -2274,6 +2274,7 @@ class SyntaxErrorTests(unittest.TestCase):
                     self.assertIn(expected, err.getvalue())
                     the_exception = exc
 
+    @force_not_colorized
     def test_subclass(self):
         class MySyntaxError(SyntaxError):
             pass
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 31f0a61d..abdfc46 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -21,7 +21,7 @@ from test.support import (Error, captured_output, cpython_only, ALWAYS_EQ,
 from test.support.os_helper import TESTFN, unlink
 from test.support.script_helper import assert_python_ok, assert_python_failure
 from test.support.import_helper import forget
-from test.support import force_not_colorized
+from test.support import force_not_colorized, force_not_colorized_test_class
 
 import json
 import textwrap
@@ -1712,6 +1712,7 @@ class TracebackErrorLocationCaretTestBase:
 
 
 @requires_debug_ranges()
+@force_not_colorized_test_class
 class PurePythonTracebackErrorCaretTests(
     PurePythonExceptionFormattingMixin,
     TracebackErrorLocationCaretTestBase,
@@ -1725,6 +1726,7 @@ class PurePythonTracebackErrorCaretTests(
 
 @cpython_only
 @requires_debug_ranges()
+@force_not_colorized_test_class
 class CPythonTracebackErrorCaretTests(
     CAPIExceptionFormattingMixin,
     TracebackErrorLocationCaretTestBase,
@@ -1736,6 +1738,7 @@ class CPythonTracebackErrorCaretTests(
 
 @cpython_only
 @requires_debug_ranges()
+@force_not_colorized_test_class
 class CPythonTracebackLegacyErrorCaretTests(
     CAPIExceptionFormattingLegacyMixin,
     TracebackErrorLocationCaretTestBase,
@@ -2149,10 +2152,12 @@ context_message = (
 boundaries = re.compile(
     '(%s|%s)' % (re.escape(cause_message), re.escape(context_message)))
 
+@force_not_colorized_test_class
 class TestTracebackFormat(unittest.TestCase, TracebackFormatMixin):
     pass
 
 @cpython_only
+@force_not_colorized_test_class
 class TestFallbackTracebackFormat(unittest.TestCase, TracebackFormatMixin):
     DEBUG_RANGES = False
     def setUp(self) -> None:
@@ -2940,6 +2945,7 @@ class BaseExceptionReportingTests:
         self.assertEqual(report, expected)
 
 
+@force_not_colorized_test_class
 class PyExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
     #
     # This checks reporting through the 'traceback' module, with both
@@ -2956,6 +2962,7 @@ class PyExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
         return s
 
 
+@force_not_colorized_test_class
 class CExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
     #
     # This checks built-in reporting by the interpreter.
diff --git a/Lib/test/test_unittest/test_result.py b/Lib/test/test_unittest/test_result.py
index 746b9fa..ad6f52d 100644
--- a/Lib/test/test_unittest/test_result.py
+++ b/Lib/test/test_unittest/test_result.py
@@ -1,13 +1,15 @@
 import io
 import sys
 import textwrap
-
-from test.support import warnings_helper, captured_stdout
-
 import traceback
 import unittest
 from unittest.util import strclass
-from test.support import force_not_colorized
+from test.support import warnings_helper
+from test.support import (
+    captured_stdout,
+    force_not_colorized,
+    force_not_colorized_test_class,
+)
 from test.test_unittest.support import BufferedWriter
 
 
@@ -772,6 +774,7 @@ class Test_OldTestResult(unittest.TestCase):
         runner.run(Test('testFoo'))
 
 
+@force_not_colorized_test_class
 class TestOutputBuffering(unittest.TestCase):
 
     def setUp(self):
-- 
cgit v0.12