summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_secrets.py
blob: a3d1a8cc10f5e7f4fd50acbb7299951ddbc1f2e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Test the secrets module.

As most of the functions in secrets are thin wrappers around functions
defined elsewhere, we don't need to test them exhaustively.
"""


import secrets
import unittest
import string


# === Unit tests ===

class Compare_Digest_Tests(unittest.TestCase):
    """Test secrets.compare_digest function."""

    def test_equal(self):
        # Test compare_digest functionality with equal (byte/text) strings.
        for s in ("a", "bcd", "xyz123"):
            a = s*100
            b = s*100
            self.assertTrue(secrets.compare_digest(a, b))
            self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))

    def test_unequal(self):
        # Test compare_digest functionality with unequal (byte/text) strings.
        self.assertFalse(secrets.compare_digest("abc", "abcd"))
        self.assertFalse(secrets.compare_digest(b"abc", b"abcd"))
        for s in ("x", "mn", "a1b2c3"):
            a = s*100 + "q"
            b = s*100 + "k"
            self.assertFalse(secrets.compare_digest(a, b))
            self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))

    def test_bad_types(self):
        # Test that compare_digest raises with mixed types.
        a = 'abcde'
        b = a.encode('utf-8')
        assert isinstance(a, str)
        assert isinstance(b, bytes)
        self.assertRaises(TypeError, secrets.compare_digest, a, b)
        self.assertRaises(TypeError, secrets.compare_digest, b, a)

    def test_bool(self):
        # Test that compare_digest returns a bool.
        self.assertTrue(isinstance(secrets.compare_digest("abc", "abc"), bool))
        self.assertTrue(isinstance(secrets.compare_digest("abc", "xyz"), bool))


class Random_Tests(unittest.TestCase):
    """Test wrappers around SystemRandom methods."""

    def test_randbits(self):
        # Test randbits.
        errmsg = "randbits(%d) returned %d"
        for numbits in (3, 12, 30):
            for i in range(6):
                n = secrets.randbits(numbits)
                self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n))

    def test_choice(self):
        # Test choice.
        items = [1, 2, 4, 8, 16, 32, 64]
        for i in range(10):
            self.assertTrue(secrets.choice(items) in items)

    def test_randbelow(self):
        # Test randbelow.
        errmsg = "randbelow(%d) returned %d"
        for i in range(2, 10):
            n = secrets.randbelow(i)
            self.assertTrue(n in range(i), errmsg % (i, n))
        self.assertRaises(ValueError, secrets.randbelow, 0)


class Token_Tests(unittest.TestCase):
    """Test token functions."""

    def test_token_defaults(self):
        # Test that token_* functions handle default size correctly.
        for func in (secrets.token_bytes, secrets.token_hex,
                     secrets.token_urlsafe):
            name = func.__name__
            try:
                func()
            except TypeError:
                self.fail("%s cannot be called with no argument" % name)
            try:
                func(None)
            except TypeError:
                self.fail("%s cannot be called with None" % name)
        size = secrets.DEFAULT_ENTROPY
        self.assertEqual(len(secrets.token_bytes(None)), size)
        self.assertEqual(len(secrets.token_hex(None)), 2*size)

    def test_token_bytes(self):
        # Test token_bytes.
        self.assertTrue(isinstance(secrets.token_bytes(11), bytes))
        for n in (1, 8, 17, 100):
            self.assertEqual(len(secrets.token_bytes(n)), n)

    def test_token_hex(self):
        # Test token_hex.
        self.assertTrue(isinstance(secrets.token_hex(7), str))
        for n in (1, 12, 25, 90):
            s = secrets.token_hex(n)
            self.assertEqual(len(s), 2*n)
            self.assertTrue(all(c in string.hexdigits for c in s))

    def test_token_urlsafe(self):
        # Test token_urlsafe.
        self.assertTrue(isinstance(secrets.token_urlsafe(9), str))
        legal = string.ascii_letters + string.digits + '-_'
        for n in (1, 11, 28, 76):
            self.assertTrue(all(c in legal for c in secrets.token_urlsafe(n)))


if __name__ == '__main__':
    unittest.main()