diff options
Diffstat (limited to 'Lib/wave.py')
-rw-r--r-- | Lib/wave.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/Lib/wave.py b/Lib/wave.py index 9a45574..e446174 100644 --- a/Lib/wave.py +++ b/Lib/wave.py @@ -83,6 +83,7 @@ class Error(Exception): pass WAVE_FORMAT_PCM = 0x0001 +WAVE_FORMAT_EXTENSIBLE = 0xFFFE _array_fmts = None, 'b', 'h', None, 'i' @@ -377,16 +378,23 @@ class Wave_read: wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14)) except struct.error: raise EOFError from None - if wFormatTag == WAVE_FORMAT_PCM: + if wFormatTag != WAVE_FORMAT_PCM and wFormatTag != WAVE_FORMAT_EXTENSIBLE: + raise Error('unknown format: %r' % (wFormatTag,)) + try: + sampwidth = struct.unpack_from('<H', chunk.read(2))[0] + except struct.error: + raise EOFError from None + if wFormatTag == WAVE_FORMAT_EXTENSIBLE: try: - sampwidth = struct.unpack_from('<H', chunk.read(2))[0] + # Only the first 2 bytes (of 16) of SubFormat are needed. + cbSize, wValidBitsPerSample, dwChannelMask, SubFormatFmt = struct.unpack_from('<HHLH', chunk.read(10)) except struct.error: raise EOFError from None - self._sampwidth = (sampwidth + 7) // 8 - if not self._sampwidth: - raise Error('bad sample width') - else: - raise Error('unknown format: %r' % (wFormatTag,)) + if SubFormatFmt != WAVE_FORMAT_PCM: + raise Error(f'unknown format: {SubFormatFmt}') + self._sampwidth = (sampwidth + 7) // 8 + if not self._sampwidth: + raise Error('bad sample width') if not self._nchannels: raise Error('bad # of channels') self._framesize = self._nchannels * self._sampwidth |