diff --git a/oletools/rtfobj.py b/oletools/rtfobj.py index ac6374e..d81c9e5 100644 --- a/oletools/rtfobj.py +++ b/oletools/rtfobj.py @@ -303,11 +303,15 @@ if sys.version_info[0] <= 2: BACKSLASH = '\\' BRACE_OPEN = '{' BRACE_CLOSE = '}' + UNICODE_TYPE = unicode else: # Python 3.x - Integers BACKSLASH = ord('\\') BRACE_OPEN = ord('{') BRACE_CLOSE = ord('}') + UNICODE_TYPE = str + +RTF_MAGIC = b'\x7b\\rt' # \x7b == b'{' but does not mess up auto-indent #=== CLASSES ================================================================= @@ -673,7 +677,56 @@ def rtf_iter_objects(filename, min_size=32): yield obj.start, orig_len, obj.rawdata +def is_rtf(arg, treat_str_as_data=False): + """ determine whether given file / stream / array represents an rtf file + + arg can be either a file name, a byte stream (located at start), a + list/tuple or a an iterable that contains bytes. + For str it is not clear whether data is a file name or the data read from + it (at least for py2-str which is bytes). Argument treat_str_as_data + clarifies. + """ + magic_len = len(RTF_MAGIC) + if isinstance(arg, UNICODE_TYPE): + print('test file name') + with open(arg, 'rb') as reader: + return reader.read(len(RTF_MAGIC)).lower() == RTF_MAGIC + if isinstance(arg, bytes) and not isinstance(arg, str): # only in PY3 + print('test byte array') + return arg[:magic_len].lower() == RTF_MAGIC + if isinstance(arg, bytearray): + print('test byte array') + return arg[:magic_len].lower() == RTF_MAGIC + if isinstance(arg, str): # could be bytes, but we assume file name + if treat_str_as_data: + try: + return arg[:magic_len].encode('ascii', error='strict').lower()\ + == RTF_MAGIC + except UnicodeError: + return False + else: + print('test file name') + with open(arg, 'rb') as reader: + return reader.read(len(RTF_MAGIC)).lower() == RTF_MAGIC + if hasattr(arg, 'read'): # a stream (i.e. file-like object) + print('test stream') + return arg.read(len(RTF_MAGIC)).lower() == RTF_MAGIC + if isinstance(arg, (list, tuple)): + print('test list/tuple') + iter_arg = iter(arg) + else: + print('test iterable') + iter_arg = arg + + # check iterable + for magic_byte, upper_cased in zip(RTF_MAGIC, RTF_MAGIC.upper()): + try: + if next(iter_arg) not in (magic_byte, upper_cased): + return False + except StopIteration: + return False + return True # checked the complete magic without returning False --> match def sanitize_filename(filename, replacement='_', max_length=200):