diff --git a/oletools/rtfobj.py b/oletools/rtfobj.py index ac2a3e5..5873fb0 100644 --- a/oletools/rtfobj.py +++ b/oletools/rtfobj.py @@ -690,35 +690,36 @@ def is_rtf(arg, treat_str_as_data=False): magic_len = len(RTF_MAGIC) if isinstance(arg, UNICODE_TYPE): with open(arg, 'rb') as reader: - return reader.read(len(RTF_MAGIC)).lower() == RTF_MAGIC + return reader.read(len(RTF_MAGIC)) == RTF_MAGIC if isinstance(arg, bytes) and not isinstance(arg, str): # only in PY3 - return arg[:magic_len].lower() == RTF_MAGIC + return arg[:magic_len] == RTF_MAGIC if isinstance(arg, bytearray): - return arg[:magic_len].lower() == RTF_MAGIC + return arg[:magic_len] == RTF_MAGIC if isinstance(arg, str): # could be bytes, but we assume file name if treat_str_as_data: try: - return arg[:magic_len].encode('ascii', errors='strict').lower()\ + return arg[:magic_len].encode('ascii', errors='strict')\ == RTF_MAGIC except UnicodeError: return False else: with open(arg, 'rb') as reader: - return reader.read(len(RTF_MAGIC)).lower() == RTF_MAGIC + return reader.read(len(RTF_MAGIC)) == RTF_MAGIC if hasattr(arg, 'read'): # a stream (i.e. file-like object) - return arg.read(len(RTF_MAGIC)).lower() == RTF_MAGIC + return arg.read(len(RTF_MAGIC)) == RTF_MAGIC if isinstance(arg, (list, tuple)): iter_arg = iter(arg) else: iter_arg = arg # check iterable - for magic_byte, upper_cased in zip(RTF_MAGIC, RTF_MAGIC.upper()): + for magic_byte in zip(RTF_MAGIC): try: - if next(iter_arg) not in (magic_byte, upper_cased): + if next(iter_arg) not in magic_byte: return False except StopIteration: return False + return True # checked the complete magic without returning False --> match diff --git a/tests/rtfobj/test_is_rtf.py b/tests/rtfobj/test_is_rtf.py index 3f00186..43ea40f 100644 --- a/tests/rtfobj/test_is_rtf.py +++ b/tests/rtfobj/test_is_rtf.py @@ -18,13 +18,13 @@ class TestIsRtf(unittest.TestCase): def test_bytearray(self): """ test that is_rtf works with bytearray """ self.assertTrue(is_rtf(bytearray(RTF_MAGIC + b'asdfasdfasdfasdfasdf'))) - self.assertTrue(is_rtf(bytearray(RTF_MAGIC.upper() + b'asdfasdasdff'))) + self.assertFalse(is_rtf(bytearray(RTF_MAGIC.upper() + b'asdfasdasdff'))) self.assertFalse(is_rtf(bytearray(b'asdfasdfasdfasdfasdfasdfsdfsdfa'))) def test_bytes(self): """ test that is_rtf works with bytearray """ self.assertTrue(is_rtf(RTF_MAGIC + b'asasdffdfasdfasdfasdfasdf', True)) - self.assertTrue(is_rtf(RTF_MAGIC.upper() + b'asdffasdfasdasdff', True)) + self.assertFalse(is_rtf(RTF_MAGIC.upper() + b'asdffasdfasdasdff', True)) self.assertFalse(is_rtf(b'asdfasdfasdfasdfasdfasdasdfffsdfsdfa', True)) def test_tuple(self): @@ -33,7 +33,7 @@ class TestIsRtf(unittest.TestCase): self.assertTrue(is_rtf(data)) data = tuple(byte_char for byte_char in RTF_MAGIC.upper() + b'asfasdf') - self.assertTrue(is_rtf(data)) + self.assertFalse(is_rtf(data)) data = tuple(byte_char for byte_char in b'asdfasfassdfsdsfeereasdfwdf') self.assertFalse(is_rtf(data)) @@ -44,7 +44,7 @@ class TestIsRtf(unittest.TestCase): self.assertTrue(is_rtf(data)) data = (byte_char for byte_char in RTF_MAGIC.upper() + b'asdfassfasdf') - self.assertTrue(is_rtf(data)) + self.assertFalse(is_rtf(data)) data = (byte_char for byte_char in b'asdfasfasasdfasdfasdfsdfdwerwedf') self.assertFalse(is_rtf(data))