diff options
Diffstat (limited to 'Lib')
| -rw-r--r-- | Lib/_pydatetime.py | 35 | ||||
| -rw-r--r-- | Lib/test/datetimetester.py | 64 |
2 files changed, 47 insertions, 52 deletions
diff --git a/Lib/_pydatetime.py b/Lib/_pydatetime.py index 54c12d3..b7d569c 100644 --- a/Lib/_pydatetime.py +++ b/Lib/_pydatetime.py @@ -556,10 +556,6 @@ def _check_tzinfo_arg(tz): if tz is not None and not isinstance(tz, tzinfo): raise TypeError("tzinfo argument must be None or of a tzinfo subclass") -def _cmperror(x, y): - raise TypeError("can't compare '%s' to '%s'" % ( - type(x).__name__, type(y).__name__)) - def _divide_and_round(a, b): """divide a by b and round result to the nearest integer @@ -1113,32 +1109,33 @@ class date: # Comparisons of date objects with other. def __eq__(self, other): - if isinstance(other, date): + if isinstance(other, date) and not isinstance(other, datetime): return self._cmp(other) == 0 return NotImplemented def __le__(self, other): - if isinstance(other, date): + if isinstance(other, date) and not isinstance(other, datetime): return self._cmp(other) <= 0 return NotImplemented def __lt__(self, other): - if isinstance(other, date): + if isinstance(other, date) and not isinstance(other, datetime): return self._cmp(other) < 0 return NotImplemented def __ge__(self, other): - if isinstance(other, date): + if isinstance(other, date) and not isinstance(other, datetime): return self._cmp(other) >= 0 return NotImplemented def __gt__(self, other): - if isinstance(other, date): + if isinstance(other, date) and not isinstance(other, datetime): return self._cmp(other) > 0 return NotImplemented def _cmp(self, other): assert isinstance(other, date) + assert not isinstance(other, datetime) y, m, d = self._year, self._month, self._day y2, m2, d2 = other._year, other._month, other._day return _cmp((y, m, d), (y2, m2, d2)) @@ -2137,42 +2134,32 @@ class datetime(date): def __eq__(self, other): if isinstance(other, datetime): return self._cmp(other, allow_mixed=True) == 0 - elif not isinstance(other, date): - return NotImplemented else: - return False + return NotImplemented def __le__(self, other): if isinstance(other, datetime): return self._cmp(other) <= 0 - elif not isinstance(other, date): - return NotImplemented else: - _cmperror(self, other) + return NotImplemented def __lt__(self, other): if isinstance(other, datetime): return self._cmp(other) < 0 - elif not isinstance(other, date): - return NotImplemented else: - _cmperror(self, other) + return NotImplemented def __ge__(self, other): if isinstance(other, datetime): return self._cmp(other) >= 0 - elif not isinstance(other, date): - return NotImplemented else: - _cmperror(self, other) + return NotImplemented def __gt__(self, other): if isinstance(other, datetime): return self._cmp(other) > 0 - elif not isinstance(other, date): - return NotImplemented else: - _cmperror(self, other) + return NotImplemented def _cmp(self, other, allow_mixed=False): assert isinstance(other, datetime) diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py index 53ad5e5..980a8e6 100644 --- a/Lib/test/datetimetester.py +++ b/Lib/test/datetimetester.py @@ -5435,42 +5435,50 @@ class TestTimezoneConversions(unittest.TestCase): class Oddballs(unittest.TestCase): - def test_bug_1028306(self): + def test_date_datetime_comparison(self): + # bpo-1028306, bpo-5516 (gh-49766) # Trying to compare a date to a datetime should act like a mixed- # type comparison, despite that datetime is a subclass of date. as_date = date.today() as_datetime = datetime.combine(as_date, time()) - self.assertTrue(as_date != as_datetime) - self.assertTrue(as_datetime != as_date) - self.assertFalse(as_date == as_datetime) - self.assertFalse(as_datetime == as_date) - self.assertRaises(TypeError, lambda: as_date < as_datetime) - self.assertRaises(TypeError, lambda: as_datetime < as_date) - self.assertRaises(TypeError, lambda: as_date <= as_datetime) - self.assertRaises(TypeError, lambda: as_datetime <= as_date) - self.assertRaises(TypeError, lambda: as_date > as_datetime) - self.assertRaises(TypeError, lambda: as_datetime > as_date) - self.assertRaises(TypeError, lambda: as_date >= as_datetime) - self.assertRaises(TypeError, lambda: as_datetime >= as_date) - - # Nevertheless, comparison should work with the base-class (date) - # projection if use of a date method is forced. - self.assertEqual(as_date.__eq__(as_datetime), True) - different_day = (as_date.day + 1) % 20 + 1 - as_different = as_datetime.replace(day= different_day) - self.assertEqual(as_date.__eq__(as_different), False) + date_sc = SubclassDate(as_date.year, as_date.month, as_date.day) + datetime_sc = SubclassDatetime(as_date.year, as_date.month, + as_date.day, 0, 0, 0) + for d in (as_date, date_sc): + for dt in (as_datetime, datetime_sc): + for x, y in (d, dt), (dt, d): + self.assertTrue(x != y) + self.assertFalse(x == y) + self.assertRaises(TypeError, lambda: x < y) + self.assertRaises(TypeError, lambda: x <= y) + self.assertRaises(TypeError, lambda: x > y) + self.assertRaises(TypeError, lambda: x >= y) # And date should compare with other subclasses of date. If a # subclass wants to stop this, it's up to the subclass to do so. - date_sc = SubclassDate(as_date.year, as_date.month, as_date.day) - self.assertEqual(as_date, date_sc) - self.assertEqual(date_sc, as_date) - # Ditto for datetimes. - datetime_sc = SubclassDatetime(as_datetime.year, as_datetime.month, - as_date.day, 0, 0, 0) - self.assertEqual(as_datetime, datetime_sc) - self.assertEqual(datetime_sc, as_datetime) + for x, y in ((as_date, date_sc), + (date_sc, as_date), + (as_datetime, datetime_sc), + (datetime_sc, as_datetime)): + self.assertTrue(x == y) + self.assertFalse(x != y) + self.assertFalse(x < y) + self.assertFalse(x > y) + self.assertTrue(x <= y) + self.assertTrue(x >= y) + + # Nevertheless, comparison should work if other object is an instance + # of date or datetime class with overridden comparison operators. + # So special methods should return NotImplemented, as if + # date and datetime were independent classes. + for x, y in (as_date, as_datetime), (as_datetime, as_date): + self.assertEqual(x.__eq__(y), NotImplemented) + self.assertEqual(x.__ne__(y), NotImplemented) + self.assertEqual(x.__lt__(y), NotImplemented) + self.assertEqual(x.__gt__(y), NotImplemented) + self.assertEqual(x.__gt__(y), NotImplemented) + self.assertEqual(x.__ge__(y), NotImplemented) def test_extra_attributes(self): with self.assertWarns(DeprecationWarning): |
