diff --git a/stego_lsb/WavSteg.py b/stego_lsb/WavSteg.py index 0ade20c..cd6c8ef 100644 --- a/stego_lsb/WavSteg.py +++ b/stego_lsb/WavSteg.py @@ -51,8 +51,8 @@ def hide_data(sound_path: str, file_path: str, output_path: str, num_lsb: int) - required_lsb = math.ceil(file_size * 8 / num_samples) raise ValueError(f"Input file too large to hide, requires {required_lsb} LSBs, using {num_lsb}") - if sample_width != 1 and sample_width != 2: - # Python's wave module doesn't support higher sample widths + if sample_width < 1 or sample_width > 4: + # WavSteg doesn't support higher sample widths, see setsampwidth() in cpython/Libwave.py raise ValueError("File has an unsupported bit-depth") start = time() @@ -83,8 +83,8 @@ def recover_data(sound_path: str, output_path: str, num_lsb: int, bytes_to_recov sound_frames = sound.readframes(num_frames) log.debug(f"{'Files read':<30} in {time() - start:.2f}s") - if sample_width != 1 and sample_width != 2: - # Python's wave module doesn't support higher sample widths + if sample_width < 1 or sample_width > 4: + # WavSteg doesn't support higher sample widths, see setsampwidth() in cpython/Libwave.py raise ValueError("File has an unsupported bit-depth") start = time() diff --git a/tests/test_wavsteg.py b/tests/test_wavsteg.py index 1d83fe6..f926237 100644 --- a/tests/test_wavsteg.py +++ b/tests/test_wavsteg.py @@ -11,8 +11,8 @@ class TestWavSteg(unittest.TestCase): def write_random_wav(self, filename: str, num_channels: int, sample_width: int, framerate: int, num_frames: int) -> None: - if sample_width != 1 and sample_width != 2: - # WavSteg doesn't support higher sample widths + if sample_width < 1 or sample_width > 4: + # WavSteg doesn't support higher sample widths, see setsampwidth() in cpython/Libwave.py raise ValueError("File has an unsupported bit-depth") with wave.open(filename, "w") as file: @@ -23,8 +23,10 @@ def write_random_wav(self, filename: str, num_channels: int, sample_width: int, dtype: Type[np.unsignedinteger[Any]] if sample_width == 1: dtype = np.uint8 - else: + elif sample_width == 2: dtype = np.uint16 + else: + dtype = np.uint32 data = np.random.randint(0, 2 ** (8 * sample_width), dtype=dtype, size=num_frames * num_channels) # note: typing does not recognize that "writeframes() accepts any bytes-like object" (see documentation) @@ -75,6 +77,12 @@ def test_consistency_8bit(self) -> None: def test_consistency_16bit(self) -> None: self.check_random_interleaving(byte_depth=2) + def test_consistency_24bit(self) -> None: + self.check_random_interleaving(byte_depth=3) + + def test_consistency_32bit(self) -> None: + self.check_random_interleaving(byte_depth=4) + if __name__ == "__main__": unittest.main()