summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_imghdr.py
blob: b54daf8e2ca1a77aea49cbff6db2ee2c010c9d32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import imghdr
import io
import os
import unittest
import warnings
from test.support import findfile, TESTFN, unlink

TEST_FILES = (
    ('python.png', 'png'),
    ('python.gif', 'gif'),
    ('python.bmp', 'bmp'),
    ('python.ppm', 'ppm'),
    ('python.pgm', 'pgm'),
    ('python.pbm', 'pbm'),
    ('python.jpg', 'jpeg'),
    ('python.ras', 'rast'),
    ('python.sgi', 'rgb'),
    ('python.tiff', 'tiff'),
    ('python.xbm', 'xbm'),
    ('python.webp', 'webp'),
    ('python.exr', 'exr'),
)

class UnseekableIO(io.FileIO):
    def tell(self):
        raise io.UnsupportedOperation

    def seek(self, *args, **kwargs):
        raise io.UnsupportedOperation

class TestImghdr(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.testfile = findfile('python.png', subdir='imghdrdata')
        with open(cls.testfile, 'rb') as stream:
            cls.testdata = stream.read()

    def tearDown(self):
        unlink(TESTFN)

    def test_data(self):
        for filename, expected in TEST_FILES:
            filename = findfile(filename, subdir='imghdrdata')
            self.assertEqual(imghdr.what(filename), expected)
            with open(filename, 'rb') as stream:
                self.assertEqual(imghdr.what(stream), expected)
            with open(filename, 'rb') as stream:
                data = stream.read()
            self.assertEqual(imghdr.what(None, data), expected)
            self.assertEqual(imghdr.what(None, bytearray(data)), expected)

    def test_register_test(self):
        def test_jumbo(h, file):
            if h.startswith(b'eggs'):
                return 'ham'
        imghdr.tests.append(test_jumbo)
        self.addCleanup(imghdr.tests.pop)
        self.assertEqual(imghdr.what(None, b'eggs'), 'ham')

    def test_file_pos(self):
        with open(TESTFN, 'wb') as stream:
            stream.write(b'ababagalamaga')
            pos = stream.tell()
            stream.write(self.testdata)
        with open(TESTFN, 'rb') as stream:
            stream.seek(pos)
            self.assertEqual(imghdr.what(stream), 'png')
            self.assertEqual(stream.tell(), pos)

    def test_bad_args(self):
        with self.assertRaises(TypeError):
            imghdr.what()
        with self.assertRaises(AttributeError):
            imghdr.what(None)
        with self.assertRaises(TypeError):
            imghdr.what(self.testfile, 1)
        with self.assertRaises(AttributeError):
            imghdr.what(os.fsencode(self.testfile))
        with open(self.testfile, 'rb') as f:
            with self.assertRaises(AttributeError):
                imghdr.what(f.fileno())

    def test_invalid_headers(self):
        for header in (b'\211PN\r\n',
                       b'\001\331',
                       b'\x59\xA6',
                       b'cutecat',
                       b'000000JFI',
                       b'GIF80'):
            self.assertIsNone(imghdr.what(None, header))

    def test_string_data(self):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", BytesWarning)
            for filename, _ in TEST_FILES:
                filename = findfile(filename, subdir='imghdrdata')
                with open(filename, 'rb') as stream:
                    data = stream.read().decode('latin1')
                with self.assertRaises(TypeError):
                    imghdr.what(io.StringIO(data))
                with self.assertRaises(TypeError):
                    imghdr.what(None, data)

    def test_missing_file(self):
        with self.assertRaises(FileNotFoundError):
            imghdr.what('missing')

    def test_closed_file(self):
        stream = open(self.testfile, 'rb')
        stream.close()
        with self.assertRaises(ValueError) as cm:
            imghdr.what(stream)
        stream = io.BytesIO(self.testdata)
        stream.close()
        with self.assertRaises(ValueError) as cm:
            imghdr.what(stream)

    def test_unseekable(self):
        with open(TESTFN, 'wb') as stream:
            stream.write(self.testdata)
        with UnseekableIO(TESTFN, 'rb') as stream:
            with self.assertRaises(io.UnsupportedOperation):
                imghdr.what(stream)

    def test_output_stream(self):
        with open(TESTFN, 'wb') as stream:
            stream.write(self.testdata)
            stream.seek(0)
            with self.assertRaises(OSError) as cm:
                imghdr.what(stream)

if __name__ == '__main__':
    unittest.main()