Commit 1542df504ef321511df84f15c16a15aed252c58c

Authored by Philippe Lagadec
Committed by GitHub
2 parents 26b43390 911b2732

Merge pull request #308 from christian-intra2net/central-logger-json

Unified logging with json option
oletools/common/log_helper/__init__.py 0 → 100644
  1 +from . import log_helper as log_helper_
  2 +
  3 +log_helper = log_helper_.LogHelper()
  4 +
  5 +__all__ = ['log_helper']
... ...
oletools/common/log_helper/_json_formatter.py 0 → 100644
  1 +import logging
  2 +import json
  3 +
  4 +
  5 +class JsonFormatter(logging.Formatter):
  6 + """
  7 + Format every message to be logged as a JSON object
  8 + """
  9 + _is_first_line = True
  10 +
  11 + def format(self, record):
  12 + """
  13 + Since we don't buffer messages, we always prepend messages with a comma to make
  14 + the output JSON-compatible. The only exception is when printing the first line,
  15 + so we need to keep track of it.
  16 + """
  17 + json_dict = dict(msg=record.msg, level=record.levelname)
  18 + formatted_message = ' ' + json.dumps(json_dict)
  19 +
  20 + if self._is_first_line:
  21 + self._is_first_line = False
  22 + return formatted_message
  23 +
  24 + return ', ' + formatted_message
... ...
oletools/common/log_helper/_logger_adapter.py 0 → 100644
  1 +import logging
  2 +from . import _root_logger_wrapper
  3 +
  4 +
  5 +class OletoolsLoggerAdapter(logging.LoggerAdapter):
  6 + """
  7 + Adapter class for all loggers returned by the logging module.
  8 + """
  9 + _json_enabled = None
  10 +
  11 + def print_str(self, message):
  12 + """
  13 + This function replaces normal print() calls so we can format them as JSON
  14 + when needed or just print them right away otherwise.
  15 + """
  16 + if self._json_enabled and self._json_enabled():
  17 + # Messages from this function should always be printed,
  18 + # so when using JSON we log using the same level that set
  19 + self.log(_root_logger_wrapper.level(), message)
  20 + else:
  21 + print(message)
  22 +
  23 + def set_json_enabled_function(self, json_enabled):
  24 + """
  25 + Set a function to be called to check whether JSON output is enabled.
  26 + """
  27 + self._json_enabled = json_enabled
  28 +
  29 + def level(self):
  30 + return self.logger.level
... ...
oletools/common/log_helper/_root_logger_wrapper.py 0 → 100644
  1 +import logging
  2 +
  3 +
  4 +def is_logging_initialized():
  5 + """
  6 + We use the same strategy as the logging module when checking if
  7 + the logging was initialized - look for handlers in the root logger
  8 + """
  9 + return len(logging.root.handlers) > 0
  10 +
  11 +
  12 +def set_formatter(fmt):
  13 + """
  14 + Set the formatter to be used by every handler of the root logger.
  15 + """
  16 + if not is_logging_initialized():
  17 + return
  18 +
  19 + for handler in logging.root.handlers:
  20 + handler.setFormatter(fmt)
  21 +
  22 +
  23 +def level():
  24 + return logging.root.level
... ...
oletools/common/log_helper/log_helper.py 0 → 100644
  1 +"""
  2 +log_helper.py
  3 +
  4 +General logging helpers
  5 +
  6 +.. codeauthor:: Intra2net AG <info@intra2net>
  7 +"""
  8 +
  9 +# === LICENSE =================================================================
  10 +#
  11 +# Redistribution and use in source and binary forms, with or without
  12 +# modification, are permitted provided that the following conditions are met:
  13 +#
  14 +# * Redistributions of source code must retain the above copyright notice,
  15 +# this list of conditions and the following disclaimer.
  16 +# * Redistributions in binary form must reproduce the above copyright notice,
  17 +# this list of conditions and the following disclaimer in the documentation
  18 +# and/or other materials provided with the distribution.
  19 +#
  20 +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21 +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22 +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  23 +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  24 +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  25 +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  26 +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  27 +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  28 +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  29 +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  30 +# POSSIBILITY OF SUCH DAMAGE.
  31 +
  32 +# -----------------------------------------------------------------------------
  33 +# CHANGELOG:
  34 +# 2017-12-07 v0.01 CH: - first version
  35 +# 2018-02-05 v0.02 SA: - fixed log level selection and reformatted code
  36 +# 2018-02-06 v0.03 SA: - refactored code to deal with NullHandlers
  37 +# 2018-02-07 v0.04 SA: - fixed control of handlers propagation
  38 +# 2018-04-23 v0.05 SA: - refactored the whole logger to use an OOP approach
  39 +
  40 +# -----------------------------------------------------------------------------
  41 +# TODO:
  42 +
  43 +
  44 +from ._json_formatter import JsonFormatter
  45 +from ._logger_adapter import OletoolsLoggerAdapter
  46 +from . import _root_logger_wrapper
  47 +import logging
  48 +import sys
  49 +
  50 +
  51 +LOG_LEVELS = {
  52 + 'debug': logging.DEBUG,
  53 + 'info': logging.INFO,
  54 + 'warning': logging.WARNING,
  55 + 'error': logging.ERROR,
  56 + 'critical': logging.CRITICAL
  57 +}
  58 +
  59 +DEFAULT_LOGGER_NAME = 'oletools'
  60 +DEFAULT_MESSAGE_FORMAT = '%(levelname)-8s %(message)s'
  61 +
  62 +
  63 +class LogHelper:
  64 + def __init__(self):
  65 + self._all_names = set() # set so we do not have duplicates
  66 + self._use_json = False
  67 + self._is_enabled = False
  68 +
  69 + def get_or_create_silent_logger(self, name=DEFAULT_LOGGER_NAME, level=logging.CRITICAL + 1):
  70 + """
  71 + Get a logger or create one if it doesn't exist, setting a NullHandler
  72 + as the handler (to avoid printing to the console).
  73 + By default we also use a higher logging level so every message will
  74 + be ignored.
  75 + This will prevent oletools from logging unnecessarily when being imported
  76 + from external tools.
  77 + """
  78 + return self._get_or_create_logger(name, level, logging.NullHandler())
  79 +
  80 + def enable_logging(self, use_json, level, log_format=DEFAULT_MESSAGE_FORMAT, stream=None):
  81 + """
  82 + This function initializes the root logger and enables logging.
  83 + We set the level of the root logger to the one passed by calling logging.basicConfig.
  84 + We also set the level of every logger we created to 0 (logging.NOTSET), meaning that
  85 + the level of the root logger will be used to tell if messages should be logged.
  86 + Additionally, since our loggers use the NullHandler, they won't log anything themselves,
  87 + but due to having propagation enabled they will pass messages to the root logger,
  88 + which in turn will log to the stream set in this function.
  89 + Since the root logger is the one doing the work, when using JSON we set its formatter
  90 + so that every message logged is JSON-compatible.
  91 + """
  92 + if self._is_enabled:
  93 + raise ValueError('re-enabling logging. Not sure whether that is ok...')
  94 +
  95 + log_level = LOG_LEVELS[level]
  96 + logging.basicConfig(level=log_level, format=log_format, stream=stream)
  97 + self._is_enabled = True
  98 +
  99 + self._use_json = use_json
  100 + sys.excepthook = self._get_except_hook(sys.excepthook)
  101 +
  102 + # since there could be loggers already created we go through all of them
  103 + # and set their levels to 0 so they will use the root logger's level
  104 + for name in self._all_names:
  105 + logger = self.get_or_create_silent_logger(name)
  106 + self._set_logger_level(logger, logging.NOTSET)
  107 +
  108 + # add a JSON formatter to the root logger, which will be used by every logger
  109 + if self._use_json:
  110 + _root_logger_wrapper.set_formatter(JsonFormatter())
  111 + print('[')
  112 +
  113 + def end_logging(self):
  114 + """
  115 + Must be called at the end of the main function if the caller wants
  116 + json-compatible output
  117 + """
  118 + if not self._is_enabled:
  119 + return
  120 + self._is_enabled = False
  121 +
  122 + # end logging
  123 + self._all_names = set()
  124 + logging.shutdown()
  125 +
  126 + # end json list
  127 + if self._use_json:
  128 + print(']')
  129 + self._use_json = False
  130 +
  131 + def _get_except_hook(self, old_hook):
  132 + """
  133 + Global hook for exceptions so we can always end logging.
  134 + We wrap any hook currently set to avoid overwriting global hooks set by oletools.
  135 + Note that this is only called by enable_logging, which in turn is called by
  136 + the main() function in oletools' scripts. When scripts are being imported this
  137 + code won't execute and won't affect global hooks.
  138 + """
  139 + def hook(exctype, value, traceback):
  140 + self.end_logging()
  141 + old_hook(exctype, value, traceback)
  142 +
  143 + return hook
  144 +
  145 + def _get_or_create_logger(self, name, level, handler=None):
  146 + """
  147 + Get or create a new logger. This newly created logger will have the
  148 + handler and level that was passed, but if it already exists it's not changed.
  149 + We also wrap the logger in an adapter so we can easily extend its functionality.
  150 + """
  151 +
  152 + # logging.getLogger creates a logger if it doesn't exist,
  153 + # so we need to check before calling it
  154 + if handler and not self._log_exists(name):
  155 + logger = logging.getLogger(name)
  156 + logger.addHandler(handler)
  157 + self._set_logger_level(logger, level)
  158 + else:
  159 + logger = logging.getLogger(name)
  160 +
  161 + # Keep track of every logger we created so we can easily change
  162 + # their levels whenever needed
  163 + self._all_names.add(name)
  164 +
  165 + adapted_logger = OletoolsLoggerAdapter(logger, None)
  166 + adapted_logger.set_json_enabled_function(lambda: self._use_json)
  167 +
  168 + return adapted_logger
  169 +
  170 + @staticmethod
  171 + def _set_logger_level(logger, level):
  172 + """
  173 + If the logging is already initialized, we set the level of our logger
  174 + to 0, meaning that it will reuse the level of the root logger.
  175 + That means that if the root logger level changes, we will keep using
  176 + its level and not logging unnecessarily.
  177 + """
  178 +
  179 + # if this log was wrapped, unwrap it to set the level
  180 + if isinstance(logger, OletoolsLoggerAdapter):
  181 + logger = logger.logger
  182 +
  183 + if _root_logger_wrapper.is_logging_initialized():
  184 + logger.setLevel(logging.NOTSET)
  185 + else:
  186 + logger.setLevel(level)
  187 +
  188 + @staticmethod
  189 + def _log_exists(name):
  190 + """
  191 + We check the log manager instead of our global _all_names variable
  192 + since the logger could have been created outside of the helper
  193 + """
  194 + return name in logging.Logger.manager.loggerDict
... ...
oletools/msodde.py
... ... @@ -53,8 +53,6 @@ import argparse
53 53 import os
54 54 from os.path import abspath, dirname
55 55 import sys
56   -import json
57   -import logging
58 56 import re
59 57 import csv
60 58  
... ... @@ -63,6 +61,7 @@ import olefile
63 61 from oletools import ooxml
64 62 from oletools import xls_parser
65 63 from oletools import rtfobj
  64 +from oletools.common.log_helper import log_helper
66 65  
67 66 # -----------------------------------------------------------------------------
68 67 # CHANGELOG:
... ... @@ -212,63 +211,12 @@ THIS IS WORK IN PROGRESS - Check updates regularly!
212 211 Please report any issue at https://github.com/decalage2/oletools/issues
213 212 """ % __version__
214 213  
215   -BANNER_JSON = dict(type='meta', version=__version__, name='msodde',
216   - link='http://decalage.info/python/oletools',
217   - message='THIS IS WORK IN PROGRESS - Check updates regularly! '
218   - 'Please report any issue at '
219   - 'https://github.com/decalage2/oletools/issues')
220   -
221 214 # === LOGGING =================================================================
222 215  
223 216 DEFAULT_LOG_LEVEL = "warning" # Default log level
224   -LOG_LEVELS = {
225   - 'debug': logging.DEBUG,
226   - 'info': logging.INFO,
227   - 'warning': logging.WARNING,
228   - 'error': logging.ERROR,
229   - 'critical': logging.CRITICAL
230   -}
231   -
232   -
233   -class NullHandler(logging.Handler):
234   - """
235   - Log Handler without output, to avoid printing messages if logging is not
236   - configured by the main application.
237   - Python 2.7 has logging.NullHandler, but this is necessary for 2.6:
238   - see https://docs.python.org/2.6/library/logging.html#configuring-logging-for-a-library
239   - """
240   - def emit(self, record):
241   - pass
242   -
243   -
244   -def get_logger(name, level=logging.CRITICAL+1):
245   - """
246   - Create a suitable logger object for this module.
247   - The goal is not to change settings of the root logger, to avoid getting
248   - other modules' logs on the screen.
249   - If a logger exists with same name, reuse it. (Else it would have duplicate
250   - handlers and messages would be doubled.)
251   - The level is set to CRITICAL+1 by default, to avoid any logging.
252   - """
253   - # First, test if there is already a logger with the same name, else it
254   - # will generate duplicate messages (due to duplicate handlers):
255   - if name in logging.Logger.manager.loggerDict:
256   - # NOTE: another less intrusive but more "hackish" solution would be to
257   - # use getLogger then test if its effective level is not default.
258   - logger = logging.getLogger(name)
259   - # make sure level is OK:
260   - logger.setLevel(level)
261   - return logger
262   - # get a new logger:
263   - logger = logging.getLogger(name)
264   - # only add a NullHandler for this logger, it is up to the application
265   - # to configure its own logging:
266   - logger.addHandler(NullHandler())
267   - logger.setLevel(level)
268   - return logger
269 217  
270 218 # a global logger object used for debugging:
271   -log = get_logger('msodde')
  219 +logger = log_helper.get_or_create_silent_logger('msodde')
272 220  
273 221  
274 222 # === UNICODE IN PY2 =========================================================
... ... @@ -312,7 +260,7 @@ def ensure_stdout_handles_unicode():
312 260 encoding = 'utf8'
313 261  
314 262 # logging is probably not initialized yet, but just in case
315   - log.debug('wrapping sys.stdout with encoder using {0}'.format(encoding))
  263 + logger.debug('wrapping sys.stdout with encoder using {0}'.format(encoding))
316 264  
317 265 wrapper = codecs.getwriter(encoding)
318 266 sys.stdout = wrapper(sys.stdout)
... ... @@ -396,7 +344,7 @@ def process_doc_field(data):
396 344 """ check if field instructions start with DDE
397 345  
398 346 expects unicode input, returns unicode output (empty if not dde) """
399   - log.debug('processing field {0}'.format(data))
  347 + logger.debug('processing field {0}'.format(data))
400 348  
401 349 if data.lstrip().lower().startswith(u'dde'):
402 350 return data
... ... @@ -434,7 +382,7 @@ def process_doc_stream(stream):
434 382  
435 383 if char == OLE_FIELD_START:
436 384 if have_start and max_size_exceeded:
437   - log.debug('big field was not a field after all')
  385 + logger.debug('big field was not a field after all')
438 386 have_start = True
439 387 have_sep = False
440 388 max_size_exceeded = False
... ... @@ -446,7 +394,7 @@ def process_doc_stream(stream):
446 394 # now we are after start char but not at end yet
447 395 if char == OLE_FIELD_SEP:
448 396 if have_sep:
449   - log.debug('unexpected field: has multiple separators!')
  397 + logger.debug('unexpected field: has multiple separators!')
450 398 have_sep = True
451 399 elif char == OLE_FIELD_END:
452 400 # have complete field now, process it
... ... @@ -464,7 +412,7 @@ def process_doc_stream(stream):
464 412 if max_size_exceeded:
465 413 pass
466 414 elif len(field_contents) > OLE_FIELD_MAX_SIZE:
467   - log.debug('field exceeds max size of {0}. Ignore rest'
  415 + logger.debug('field exceeds max size of {0}. Ignore rest'
468 416 .format(OLE_FIELD_MAX_SIZE))
469 417 max_size_exceeded = True
470 418  
... ... @@ -482,9 +430,9 @@ def process_doc_stream(stream):
482 430 field_contents += u'?'
483 431  
484 432 if max_size_exceeded:
485   - log.debug('big field was not a field after all')
  433 + logger.debug('big field was not a field after all')
486 434  
487   - log.debug('Checked {0} characters, found {1} fields'
  435 + logger.debug('Checked {0} characters, found {1} fields'
488 436 .format(idx, len(result_parts)))
489 437  
490 438 return result_parts
... ... @@ -498,7 +446,7 @@ def process_doc(filepath):
498 446 empty if none were found. dde-links will still begin with the dde[auto] key
499 447 word (possibly after some whitespace)
500 448 """
501   - log.debug('process_doc')
  449 + logger.debug('process_doc')
502 450 ole = olefile.OleFileIO(filepath, path_encoding=None)
503 451  
504 452 links = []
... ... @@ -508,7 +456,7 @@ def process_doc(filepath):
508 456 # this direntry is not part of the tree --> unused or orphan
509 457 direntry = ole._load_direntry(sid)
510 458 is_stream = direntry.entry_type == olefile.STGTY_STREAM
511   - log.debug('direntry {:2d} {}: {}'
  459 + logger.debug('direntry {:2d} {}: {}'
512 460 .format(sid, '[orphan]' if is_orphan else direntry.name,
513 461 'is stream of size {}'.format(direntry.size)
514 462 if is_stream else
... ... @@ -593,7 +541,7 @@ def process_docx(filepath, field_filter_mode=None):
593 541 ddetext += unquote(elem.text)
594 542  
595 543 # apply field command filter
596   - log.debug('filtering with mode "{0}"'.format(field_filter_mode))
  544 + logger.debug('filtering with mode "{0}"'.format(field_filter_mode))
597 545 if field_filter_mode in (FIELD_FILTER_ALL, None):
598 546 clean_fields = all_fields
599 547 elif field_filter_mode == FIELD_FILTER_DDE:
... ... @@ -652,7 +600,7 @@ def field_is_blacklisted(contents):
652 600 index = FIELD_BLACKLIST_CMDS.index(words[0].lower())
653 601 except ValueError: # first word is no blacklisted command
654 602 return False
655   - log.debug('trying to match "{0}" to blacklist command {1}'
  603 + logger.debug('trying to match "{0}" to blacklist command {1}'
656 604 .format(contents, FIELD_BLACKLIST[index]))
657 605 _, nargs_required, nargs_optional, sw_with_arg, sw_solo, sw_format \
658 606 = FIELD_BLACKLIST[index]
... ... @@ -664,11 +612,11 @@ def field_is_blacklisted(contents):
664 612 break
665 613 nargs += 1
666 614 if nargs < nargs_required:
667   - log.debug('too few args: found {0}, but need at least {1} in "{2}"'
  615 + logger.debug('too few args: found {0}, but need at least {1} in "{2}"'
668 616 .format(nargs, nargs_required, contents))
669 617 return False
670 618 elif nargs > nargs_required + nargs_optional:
671   - log.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"'
  619 + logger.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"'
672 620 .format(nargs, nargs_required, nargs_optional, contents))
673 621 return False
674 622  
... ... @@ -678,14 +626,14 @@ def field_is_blacklisted(contents):
678 626 for word in words[1+nargs:]:
679 627 if expect_arg: # this is an argument for the last switch
680 628 if arg_choices and (word not in arg_choices):
681   - log.debug('Found invalid switch argument "{0}" in "{1}"'
  629 + logger.debug('Found invalid switch argument "{0}" in "{1}"'
682 630 .format(word, contents))
683 631 return False
684 632 expect_arg = False
685 633 arg_choices = [] # in general, do not enforce choices
686 634 continue # "no further questions, your honor"
687 635 elif not FIELD_SWITCH_REGEX.match(word):
688   - log.debug('expected switch, found "{0}" in "{1}"'
  636 + logger.debug('expected switch, found "{0}" in "{1}"'
689 637 .format(word, contents))
690 638 return False
691 639 # we want a switch and we got a valid one
... ... @@ -707,7 +655,7 @@ def field_is_blacklisted(contents):
707 655 if 'numeric' in sw_format:
708 656 arg_choices = [] # too many choices to list them here
709 657 else:
710   - log.debug('unexpected switch {0} in "{1}"'
  658 + logger.debug('unexpected switch {0} in "{1}"'
711 659 .format(switch, contents))
712 660 return False
713 661  
... ... @@ -733,11 +681,11 @@ def process_xlsx(filepath):
733 681 # binary parts, e.g. contained in .xlsb
734 682 for subfile, content_type, handle in parser.iter_non_xml():
735 683 try:
736   - logging.info('Parsing non-xml subfile {0} with content type {1}'
  684 + logger.info('Parsing non-xml subfile {0} with content type {1}'
737 685 .format(subfile, content_type))
738 686 for record in xls_parser.parse_xlsb_part(handle, content_type,
739 687 subfile):
740   - logging.debug('{0}: {1}'.format(subfile, record))
  688 + logger.debug('{0}: {1}'.format(subfile, record))
741 689 if isinstance(record, xls_parser.XlsbBeginSupBook) and \
742 690 record.link_type == \
743 691 xls_parser.XlsbBeginSupBook.LINK_TYPE_DDE:
... ... @@ -747,14 +695,14 @@ def process_xlsx(filepath):
747 695 if content_type.startswith('application/vnd.ms-excel.') or \
748 696 content_type.startswith('application/vnd.ms-office.'): # pylint: disable=bad-indentation
749 697 # should really be able to parse these either as xml or records
750   - log_func = logging.warning
  698 + log_func = logger.warning
751 699 elif content_type.startswith('image/') or content_type == \
752 700 'application/vnd.openxmlformats-officedocument.' + \
753 701 'spreadsheetml.printerSettings':
754 702 # understandable that these are not record-base
755   - log_func = logging.debug
  703 + log_func = logger.debug
756 704 else: # default
757   - log_func = logging.info
  705 + log_func = logger.info
758 706 log_func('Failed to parse {0} of content type {1}'
759 707 .format(subfile, content_type))
760 708 # in any case: continue with next
... ... @@ -774,15 +722,15 @@ class RtfFieldParser(rtfobj.RtfParser):
774 722  
775 723 def open_destination(self, destination):
776 724 if destination.cword == b'fldinst':
777   - log.debug('*** Start field data at index %Xh' % destination.start)
  725 + logger.debug('*** Start field data at index %Xh' % destination.start)
778 726  
779 727 def close_destination(self, destination):
780 728 if destination.cword == b'fldinst':
781   - log.debug('*** Close field data at index %Xh' % self.index)
782   - log.debug('Field text: %r' % destination.data)
  729 + logger.debug('*** Close field data at index %Xh' % self.index)
  730 + logger.debug('Field text: %r' % destination.data)
783 731 # remove extra spaces and newline chars:
784 732 field_clean = destination.data.translate(None, b'\r\n').strip()
785   - log.debug('Cleaned Field text: %r' % field_clean)
  733 + logger.debug('Cleaned Field text: %r' % field_clean)
786 734 self.fields.append(field_clean)
787 735  
788 736 def control_symbol(self, matchobject):
... ... @@ -804,7 +752,7 @@ def process_rtf(file_handle, field_filter_mode=None):
804 752 rtfparser.parse()
805 753 all_fields = [field.decode('ascii') for field in rtfparser.fields]
806 754 # apply field command filter
807   - log.debug('found {1} fields, filtering with mode "{0}"'
  755 + logger.debug('found {1} fields, filtering with mode "{0}"'
808 756 .format(field_filter_mode, len(all_fields)))
809 757 if field_filter_mode in (FIELD_FILTER_ALL, None):
810 758 clean_fields = all_fields
... ... @@ -853,7 +801,7 @@ def process_csv(filepath):
853 801  
854 802 if is_small and not results:
855 803 # easy to mis-sniff small files. Try different delimiters
856   - log.debug('small file, no results; try all delimiters')
  804 + logger.debug('small file, no results; try all delimiters')
857 805 file_handle.seek(0)
858 806 other_delim = CSV_DELIMITERS.replace(dialect.delimiter, '')
859 807 for delim in other_delim:
... ... @@ -861,12 +809,12 @@ def process_csv(filepath):
861 809 file_handle.seek(0)
862 810 results, _ = process_csv_dialect(file_handle, delim)
863 811 except csv.Error: # e.g. sniffing fails
864   - log.debug('failed to csv-parse with delimiter {0!r}'
  812 + logger.debug('failed to csv-parse with delimiter {0!r}'
865 813 .format(delim))
866 814  
867 815 if is_small and not results:
868 816 # try whole file as single cell, since sniffing fails in this case
869   - log.debug('last attempt: take whole file as single unquoted cell')
  817 + logger.debug('last attempt: take whole file as single unquoted cell')
870 818 file_handle.seek(0)
871 819 match = CSV_DDE_FORMAT.match(file_handle.read(CSV_SMALL_THRESH))
872 820 if match:
... ... @@ -882,7 +830,7 @@ def process_csv_dialect(file_handle, delimiters):
882 830 dialect = csv.Sniffer().sniff(file_handle.read(CSV_SMALL_THRESH),
883 831 delimiters=delimiters)
884 832 dialect.strict = False # microsoft is never strict
885   - log.debug('sniffed csv dialect with delimiter {0!r} '
  833 + logger.debug('sniffed csv dialect with delimiter {0!r} '
886 834 'and quote char {1!r}'
887 835 .format(dialect.delimiter, dialect.quotechar))
888 836  
... ... @@ -924,7 +872,7 @@ def process_excel_xml(filepath):
924 872 break
925 873 if formula is None:
926 874 continue
927   - log.debug('found cell with formula {0}'.format(formula))
  875 + logger.debug('found cell with formula {0}'.format(formula))
928 876 match = re.match(XML_DDE_FORMAT, formula)
929 877 if match:
930 878 dde_links.append(u' '.join(match.groups()[:2]))
... ... @@ -934,40 +882,40 @@ def process_excel_xml(filepath):
934 882 def process_file(filepath, field_filter_mode=None):
935 883 """ decides which of the process_* functions to call """
936 884 if olefile.isOleFile(filepath):
937   - log.debug('Is OLE. Checking streams to see whether this is xls')
  885 + logger.debug('Is OLE. Checking streams to see whether this is xls')
938 886 if xls_parser.is_xls(filepath):
939   - log.debug('Process file as excel 2003 (xls)')
  887 + logger.debug('Process file as excel 2003 (xls)')
940 888 return process_xls(filepath)
941 889 else:
942   - log.debug('Process file as word 2003 (doc)')
  890 + logger.debug('Process file as word 2003 (doc)')
943 891 return process_doc(filepath)
944 892  
945 893 with open(filepath, 'rb') as file_handle:
946 894 if file_handle.read(4) == RTF_START:
947   - log.debug('Process file as rtf')
  895 + logger.debug('Process file as rtf')
948 896 return process_rtf(file_handle, field_filter_mode)
949 897  
950 898 try:
951 899 doctype = ooxml.get_type(filepath)
952   - log.debug('Detected file type: {0}'.format(doctype))
  900 + logger.debug('Detected file type: {0}'.format(doctype))
953 901 except Exception as exc:
954   - log.debug('Exception trying to xml-parse file: {0}'.format(exc))
  902 + logger.debug('Exception trying to xml-parse file: {0}'.format(exc))
955 903 doctype = None
956 904  
957 905 if doctype == ooxml.DOCTYPE_EXCEL:
958   - log.debug('Process file as excel 2007+ (xlsx)')
  906 + logger.debug('Process file as excel 2007+ (xlsx)')
959 907 return process_xlsx(filepath)
960 908 elif doctype in (ooxml.DOCTYPE_EXCEL_XML, ooxml.DOCTYPE_EXCEL_XML2003):
961   - log.debug('Process file as xml from excel 2003/2007+')
  909 + logger.debug('Process file as xml from excel 2003/2007+')
962 910 return process_excel_xml(filepath)
963 911 elif doctype in (ooxml.DOCTYPE_WORD_XML, ooxml.DOCTYPE_WORD_XML2003):
964   - log.debug('Process file as xml from word 2003/2007+')
  912 + logger.debug('Process file as xml from word 2003/2007+')
965 913 return process_docx(filepath)
966 914 elif doctype is None:
967   - log.debug('Process file as csv')
  915 + logger.debug('Process file as csv')
968 916 return process_csv(filepath)
969 917 else: # could be docx; if not: this is the old default code path
970   - log.debug('Process file as word 2007+ (docx)')
  918 + logger.debug('Process file as word 2007+ (docx)')
971 919 return process_docx(filepath, field_filter_mode)
972 920  
973 921  
... ... @@ -985,27 +933,14 @@ def main(cmd_line_args=None):
985 933 # Setup logging to the console:
986 934 # here we use stdout instead of stderr by default, so that the output
987 935 # can be redirected properly.
988   - logging.basicConfig(level=LOG_LEVELS[args.loglevel], stream=sys.stdout,
989   - format='%(levelname)-8s %(message)s')
990   - # enable logging in the modules:
991   - log.setLevel(logging.NOTSET)
992   -
993   - if args.json and args.loglevel.lower() == 'debug':
994   - log.warning('Debug log output will not be json-compatible!')
  936 + log_helper.enable_logging(args.json, args.loglevel, stream=sys.stdout)
995 937  
996 938 if args.nounquote:
997 939 global NO_QUOTES
998 940 NO_QUOTES = True
999 941  
1000   - if args.json:
1001   - jout = []
1002   - jout.append(BANNER_JSON)
1003   - else:
1004   - # print banner with version
1005   - print(BANNER)
1006   -
1007   - if not args.json:
1008   - print('Opening file: %s' % args.filepath)
  942 + logger.print_str(BANNER)
  943 + logger.print_str('Opening file: %s' % args.filepath)
1009 944  
1010 945 text = ''
1011 946 return_code = 1
... ... @@ -1013,22 +948,12 @@ def main(cmd_line_args=None):
1013 948 text = process_file(args.filepath, args.field_filter_mode)
1014 949 return_code = 0
1015 950 except Exception as exc:
1016   - if args.json:
1017   - jout.append(dict(type='error', error=type(exc).__name__,
1018   - message=str(exc)))
1019   - else:
1020   - raise # re-raise last known exception, keeping trace intact
1021   -
1022   - if args.json:
1023   - for line in text.splitlines():
1024   - if line.strip():
1025   - jout.append(dict(type='dde-link', link=line.strip()))
1026   - json.dump(jout, sys.stdout, check_circular=False, indent=4)
1027   - print() # add a newline after closing "]"
1028   - return return_code # required if we catch an exception in json-mode
1029   - else:
1030   - print ('DDE Links:')
1031   - print(text)
  951 + logger.exception(exc.message)
  952 +
  953 + logger.print_str('DDE Links:')
  954 + logger.print_str(text)
  955 +
  956 + log_helper.end_logging()
1032 957  
1033 958 return return_code
1034 959  
... ...
oletools/ooxml.py
... ... @@ -14,7 +14,7 @@ TODO: may have to tell apart single xml types: office2003 looks much different
14 14 """
15 15  
16 16 import sys
17   -import logging
  17 +from oletools.common.log_helper import log_helper
18 18 from zipfile import ZipFile, BadZipfile, is_zipfile
19 19 from os.path import splitext
20 20 import io
... ... @@ -27,6 +27,7 @@ try:
27 27 except ImportError:
28 28 import xml.etree.cElementTree as ET
29 29  
  30 +logger = log_helper.get_or_create_silent_logger('ooxml')
30 31  
31 32 #: subfiles that have to be part of every ooxml file
32 33 FILE_CONTENT_TYPES = '[Content_Types].xml'
... ... @@ -142,7 +143,7 @@ def get_type(filename):
142 143 is_xls = False
143 144 is_ppt = False
144 145 for _, elem, _ in parser.iter_xml(FILE_CONTENT_TYPES):
145   - logging.debug(u' ' + debug_str(elem))
  146 + logger.debug(u' ' + debug_str(elem))
146 147 try:
147 148 content_type = elem.attrib['ContentType']
148 149 except KeyError: # ContentType not an attr
... ... @@ -160,7 +161,7 @@ def get_type(filename):
160 161 if not is_doc and not is_xls and not is_ppt:
161 162 return DOCTYPE_NONE
162 163 else:
163   - logging.warning('Encountered contradictory content types')
  164 + logger.warning('Encountered contradictory content types')
164 165 return DOCTYPE_MIXED
165 166  
166 167  
... ... @@ -220,7 +221,7 @@ class ZipSubFile(object):
220 221 self.name = filename
221 222 if size is None:
222 223 self.size = container.getinfo(filename).file_size
223   - logging.debug('zip stream has size {0}'.format(self.size))
  224 + logger.debug('zip stream has size {0}'.format(self.size))
224 225 else:
225 226 self.size = size
226 227 if 'w' in mode.lower():
... ... @@ -484,10 +485,10 @@ class XmlParser(object):
484 485 want_tags = []
485 486 elif isstr(tags):
486 487 want_tags = [tags, ]
487   - logging.debug('looking for tags: {0}'.format(tags))
  488 + logger.debug('looking for tags: {0}'.format(tags))
488 489 else:
489 490 want_tags = tags
490   - logging.debug('looking for tags: {0}'.format(tags))
  491 + logger.debug('looking for tags: {0}'.format(tags))
491 492  
492 493 for subfile, handle in self.iter_files(subfiles):
493 494 events = ('start', 'end')
... ... @@ -499,7 +500,7 @@ class XmlParser(object):
499 500 continue
500 501 if event == 'start':
501 502 if elem.tag in want_tags:
502   - logging.debug('remember start of tag {0} at {1}'
  503 + logger.debug('remember start of tag {0} at {1}'
503 504 .format(elem.tag, depth))
504 505 inside_tags.append((elem.tag, depth))
505 506 depth += 1
... ... @@ -515,18 +516,18 @@ class XmlParser(object):
515 516 if inside_tags[-1] == curr_tag:
516 517 inside_tags.pop()
517 518 else:
518   - logging.error('found end for wanted tag {0} '
  519 + logger.error('found end for wanted tag {0} '
519 520 'but last start tag {1} does not'
520 521 ' match'.format(curr_tag,
521 522 inside_tags[-1]))
522 523 # try to recover: close all deeper tags
523 524 while inside_tags and \
524 525 inside_tags[-1][1] >= depth:
525   - logging.debug('recover: pop {0}'
  526 + logger.debug('recover: pop {0}'
526 527 .format(inside_tags[-1]))
527 528 inside_tags.pop()
528 529 except IndexError: # no inside_tag[-1]
529   - logging.error('found end of {0} at depth {1} but '
  530 + logger.error('found end of {0} at depth {1} but '
530 531 'no start event')
531 532 # yield element
532 533 if is_wanted or not want_tags:
... ... @@ -543,12 +544,12 @@ class XmlParser(object):
543 544 if subfile is None: # this is no zip subfile but single xml
544 545 raise BadOOXML(self.filename, 'is neither zip nor xml')
545 546 elif subfile.endswith('.xml'):
546   - logger = logging.warning
  547 + log = logger.warning
547 548 else:
548   - logger = logging.debug
549   - logger(' xml-parsing for {0} failed ({1}). '
550   - .format(subfile, err) +
551   - 'Run iter_non_xml to investigate.')
  549 + log = logger.debug
  550 + log(' xml-parsing for {0} failed ({1}). '
  551 + .format(subfile, err) +
  552 + 'Run iter_non_xml to investigate.')
552 553 assert(depth == 0)
553 554  
554 555 def get_content_types(self):
... ... @@ -571,14 +572,14 @@ class XmlParser(object):
571 572 if extension.startswith('.'):
572 573 extension = extension[1:]
573 574 defaults.append((extension, elem.attrib['ContentType']))
574   - logging.debug('found content type for extension {0[0]}: {0[1]}'
  575 + logger.debug('found content type for extension {0[0]}: {0[1]}'
575 576 .format(defaults[-1]))
576 577 elif elem.tag.endswith('Override'):
577 578 subfile = elem.attrib['PartName']
578 579 if subfile.startswith('/'):
579 580 subfile = subfile[1:]
580 581 files.append((subfile, elem.attrib['ContentType']))
581   - logging.debug('found content type for subfile {0[0]}: {0[1]}'
  582 + logger.debug('found content type for subfile {0[0]}: {0[1]}'
582 583 .format(files[-1]))
583 584 return dict(files), dict(defaults)
584 585  
... ... @@ -595,7 +596,7 @@ class XmlParser(object):
595 596 To handle binary parts of an xlsb file, use xls_parser.parse_xlsb_part
596 597 """
597 598 if not self.did_iter_all:
598   - logging.warning('Did not iterate through complete file. '
  599 + logger.warning('Did not iterate through complete file. '
599 600 'Should run iter_xml() without args, first.')
600 601 if not self.subfiles_no_xml:
601 602 return
... ... @@ -628,7 +629,7 @@ def test():
628 629  
629 630 see module doc for more info
630 631 """
631   - logging.basicConfig(level=logging.DEBUG)
  632 + log_helper.enable_logging(False, logger.DEBUG)
632 633 if len(sys.argv) != 2:
633 634 print(u'To test this code, give me a single file as arg')
634 635 return 2
... ... @@ -647,6 +648,9 @@ def test():
647 648 if index > 100:
648 649 print(u'...')
649 650 break
  651 +
  652 + log_helper.end_logging()
  653 +
650 654 return 0
651 655  
652 656  
... ...
tests/json/__init__.py renamed to tests/common/__init__.py
tests/common/log_helper/__init__.py 0 → 100644
tests/common/log_helper/log_helper_test_imported.py 0 → 100644
  1 +"""
  2 +Dummy file that logs messages, meant to be imported
  3 +by the main test file
  4 +"""
  5 +
  6 +from oletools.common.log_helper import log_helper
  7 +import logging
  8 +
  9 +DEBUG_MESSAGE = 'imported: debug log'
  10 +INFO_MESSAGE = 'imported: info log'
  11 +WARNING_MESSAGE = 'imported: warning log'
  12 +ERROR_MESSAGE = 'imported: error log'
  13 +CRITICAL_MESSAGE = 'imported: critical log'
  14 +
  15 +logger = log_helper.get_or_create_silent_logger('test_imported', logging.ERROR)
  16 +
  17 +
  18 +def log():
  19 + logger.debug(DEBUG_MESSAGE)
  20 + logger.info(INFO_MESSAGE)
  21 + logger.warning(WARNING_MESSAGE)
  22 + logger.error(ERROR_MESSAGE)
  23 + logger.critical(CRITICAL_MESSAGE)
... ...
tests/common/log_helper/log_helper_test_main.py 0 → 100644
  1 +""" Test log_helpers """
  2 +
  3 +import sys
  4 +from tests.common.log_helper import log_helper_test_imported
  5 +from oletools.common.log_helper import log_helper
  6 +
  7 +DEBUG_MESSAGE = 'main: debug log'
  8 +INFO_MESSAGE = 'main: info log'
  9 +WARNING_MESSAGE = 'main: warning log'
  10 +ERROR_MESSAGE = 'main: error log'
  11 +CRITICAL_MESSAGE = 'main: critical log'
  12 +
  13 +logger = log_helper.get_or_create_silent_logger('test_main')
  14 +
  15 +
  16 +def init_logging_and_log(args):
  17 + """
  18 + Try to cover possible logging scenarios. For each scenario covered, here's the expected args and outcome:
  19 + - Log without enabling: ['<level>']
  20 + * logging when being imported - should never print
  21 + - Log as JSON without enabling: ['as-json', '<level>']
  22 + * logging as JSON when being imported - should never print
  23 + - Enable and log: ['enable', '<level>']
  24 + * logging when being run as script - should log messages
  25 + - Enable and log as JSON: ['as-json', 'enable', '<level>']
  26 + * logging as JSON when being run as script - should log messages as JSON
  27 + - Enable, log as JSON and throw: ['enable', 'as-json', 'throw', '<level>']
  28 + * should produce JSON-compatible output, even after an unhandled exception
  29 + """
  30 +
  31 + # the level should always be the last argument passed
  32 + level = args[-1]
  33 + use_json = 'as-json' in args
  34 + throw = 'throw' in args
  35 +
  36 + if 'enable' in args:
  37 + log_helper.enable_logging(use_json, level, stream=sys.stdout)
  38 +
  39 + _log()
  40 +
  41 + if throw:
  42 + raise Exception('An exception occurred before ending the logging')
  43 +
  44 + log_helper.end_logging()
  45 +
  46 +
  47 +def _log():
  48 + logger.debug(DEBUG_MESSAGE)
  49 + logger.info(INFO_MESSAGE)
  50 + logger.warning(WARNING_MESSAGE)
  51 + logger.error(ERROR_MESSAGE)
  52 + logger.critical(CRITICAL_MESSAGE)
  53 + log_helper_test_imported.log()
  54 +
  55 +
  56 +if __name__ == '__main__':
  57 + init_logging_and_log(sys.argv[1:])
... ...
tests/common/log_helper/test_log_helper.py 0 → 100644
  1 +""" Test the log helper
  2 +
  3 +This tests the generic log helper.
  4 +Check if it handles imported modules correctly
  5 +and that the default silent logger won't log when nothing is enabled
  6 +"""
  7 +
  8 +import unittest
  9 +import sys
  10 +import json
  11 +import subprocess
  12 +from tests.common.log_helper import log_helper_test_main
  13 +from tests.common.log_helper import log_helper_test_imported
  14 +from os.path import dirname, join, relpath, abspath
  15 +
  16 +# this is the common base of "tests" and "oletools" dirs
  17 +ROOT_DIRECTORY = abspath(join(__file__, '..', '..', '..', '..'))
  18 +TEST_FILE = relpath(join(dirname(__file__), 'log_helper_test_main.py'), ROOT_DIRECTORY)
  19 +PYTHON_EXECUTABLE = sys.executable
  20 +
  21 +MAIN_LOG_MESSAGES = [
  22 + log_helper_test_main.DEBUG_MESSAGE,
  23 + log_helper_test_main.INFO_MESSAGE,
  24 + log_helper_test_main.WARNING_MESSAGE,
  25 + log_helper_test_main.ERROR_MESSAGE,
  26 + log_helper_test_main.CRITICAL_MESSAGE
  27 +]
  28 +
  29 +
  30 +class TestLogHelper(unittest.TestCase):
  31 + def test_it_doesnt_log_when_not_enabled(self):
  32 + output = self._run_test(['debug'])
  33 + self.assertTrue(len(output) == 0)
  34 +
  35 + def test_it_doesnt_log_json_when_not_enabled(self):
  36 + output = self._run_test(['as-json', 'debug'])
  37 + self.assertTrue(len(output) == 0)
  38 +
  39 + def test_logs_when_enabled(self):
  40 + output = self._run_test(['enable', 'warning'])
  41 +
  42 + expected_messages = [
  43 + log_helper_test_main.WARNING_MESSAGE,
  44 + log_helper_test_main.ERROR_MESSAGE,
  45 + log_helper_test_main.CRITICAL_MESSAGE,
  46 + log_helper_test_imported.WARNING_MESSAGE,
  47 + log_helper_test_imported.ERROR_MESSAGE,
  48 + log_helper_test_imported.CRITICAL_MESSAGE
  49 + ]
  50 +
  51 + for msg in expected_messages:
  52 + self.assertIn(msg, output)
  53 +
  54 + def test_logs_json_when_enabled(self):
  55 + output = self._run_test(['enable', 'as-json', 'critical'])
  56 +
  57 + self._assert_json_messages(output, [
  58 + log_helper_test_main.CRITICAL_MESSAGE,
  59 + log_helper_test_imported.CRITICAL_MESSAGE
  60 + ])
  61 +
  62 + def test_json_correct_on_exceptions(self):
  63 + """
  64 + Test that even on unhandled exceptions our JSON is always correct
  65 + """
  66 + output = self._run_test(['enable', 'as-json', 'throw', 'critical'], False)
  67 + self._assert_json_messages(output, [
  68 + log_helper_test_main.CRITICAL_MESSAGE,
  69 + log_helper_test_imported.CRITICAL_MESSAGE
  70 + ])
  71 +
  72 + def _assert_json_messages(self, output, messages):
  73 + try:
  74 + json_data = json.loads(output)
  75 + self.assertEquals(len(json_data), len(messages))
  76 +
  77 + for i in range(len(messages)):
  78 + self.assertEquals(messages[i], json_data[i]['msg'])
  79 + except ValueError:
  80 + self.fail('Invalid json:\n' + output)
  81 +
  82 + self.assertNotEqual(len(json_data), 0, msg='Output was empty')
  83 +
  84 + def _run_test(self, args, should_succeed=True):
  85 + """
  86 + Use subprocess to better simulate the real scenario and avoid
  87 + logging conflicts when running multiple tests (since logging depends on singletons,
  88 + we might get errors or false positives between sequential tests runs)
  89 + """
  90 + child = subprocess.Popen(
  91 + [PYTHON_EXECUTABLE, TEST_FILE] + args,
  92 + shell=False,
  93 + env={'PYTHONPATH': ROOT_DIRECTORY},
  94 + universal_newlines=True,
  95 + cwd=ROOT_DIRECTORY,
  96 + stdin=None,
  97 + stdout=subprocess.PIPE,
  98 + stderr=subprocess.PIPE
  99 + )
  100 + (output, output_err) = child.communicate()
  101 +
  102 + if not isinstance(output, str):
  103 + output = output.decode('utf-8')
  104 +
  105 + self.assertEquals(child.returncode == 0, should_succeed)
  106 +
  107 + return output.strip()
  108 +
  109 +
  110 +# just in case somebody calls this file as a script
  111 +if __name__ == '__main__':
  112 + unittest.main()
... ...
tests/json/test_output.py deleted
1   -""" Test validity of json output
2   -
3   -Some scripts have a json output flag. Verify that at default log levels output
4   -can be captured as-is and parsed by a json parser -- checking the return code
5   -if desired.
6   -"""
7   -
8   -import unittest
9   -import sys
10   -import json
11   -import os
12   -from os.path import join
13   -from oletools import msodde
14   -from tests.test_utils import OutputCapture, DATA_BASE_DIR
15   -
16   -if sys.version_info[0] <= 2:
17   - from oletools import olevba
18   -else:
19   - from oletools import olevba3 as olevba
20   -
21   -
22   -class TestValidJson(unittest.TestCase):
23   - """
24   - Ensure that script output is valid json.
25   - If check_return_code is True we also ignore the output
26   - of runs that didn't succeed.
27   - """
28   -
29   - @staticmethod
30   - def iter_test_files():
31   - """ Iterate over all test files in DATA_BASE_DIR """
32   - for dirpath, _, filenames in os.walk(DATA_BASE_DIR):
33   - for filename in filenames:
34   - yield join(dirpath, filename)
35   -
36   - def run_and_parse(self, program, args, print_output=False, check_return_code=True):
37   - """ run single program with single file and parse output """
38   - with OutputCapture() as capturer: # capture stdout
39   - try:
40   - return_code = program(args)
41   - except Exception:
42   - return_code = 1 # would result in non-zero exit
43   - except SystemExit as se:
44   - return_code = se.code or 0 # se.code can be None
45   - if check_return_code and return_code is not 0:
46   - if print_output:
47   - print('Command failed ({0}) -- not parsing output'
48   - .format(return_code))
49   - return [] # no need to test
50   -
51   - self.assertNotEqual(return_code, None,
52   - msg='self-test fail: return_code not set')
53   -
54   - # now test output
55   - if print_output:
56   - print(capturer.get_data())
57   - try:
58   - json_data = json.loads(capturer.get_data())
59   - except ValueError:
60   - self.fail('Invalid json:\n' + capturer.get_data())
61   - self.assertNotEqual(len(json_data), 0, msg='Output was empty')
62   - return json_data
63   -
64   - def run_all_files(self, program, args_without_filename, print_output=False):
65   - """ run test for a single program over all test files """
66   - n_files = 0
67   - for testfile in self.iter_test_files(): # loop over all input
68   - args = args_without_filename + [testfile, ]
69   - self.run_and_parse(program, args, print_output)
70   - n_files += 1
71   - self.assertNotEqual(n_files, 0,
72   - msg='self-test fail: No test files found')
73   -
74   - def test_msodde(self):
75   - """ Test msodde.py """
76   - self.run_all_files(msodde.main, ['-j', ])
77   -
78   - def test_olevba(self):
79   - """ Test olevba.py with default args """
80   - self.run_all_files(olevba.main, ['-j', ])
81   -
82   - def test_olevba_analysis(self):
83   - """ Test olevba.py with -a """
84   - self.run_all_files(olevba.main, ['-j', '-a', ])
85   -
86   - def test_olevba_recurse(self):
87   - """ Test olevba.py with -r """
88   - json_data = self.run_and_parse(olevba.main,
89   - ['-j', '-r', join(DATA_BASE_DIR, '*')],
90   - check_return_code=False)
91   - self.assertNotEqual(len(json_data), 0,
92   - msg='olevba[3] returned non-zero or no output')
93   - self.assertNotEqual(json_data[-1]['n_processed'], 0,
94   - msg='self-test fail: No test files found!')
95   -
96   -
97   -# just in case somebody calls this file as a script
98   -if __name__ == '__main__':
99   - unittest.main()
tests/msodde/test_basic.py
... ... @@ -10,15 +10,13 @@ from __future__ import print_function
10 10  
11 11 import unittest
12 12 from oletools import msodde
13   -from tests.test_utils import OutputCapture, DATA_BASE_DIR as BASE_DIR
14   -import shlex
  13 +from tests.test_utils import DATA_BASE_DIR as BASE_DIR
15 14 from os.path import join
16 15 from traceback import print_exc
17 16  
18 17  
19 18 class TestReturnCode(unittest.TestCase):
20 19 """ check return codes and exception behaviour (not text output) """
21   -
22 20 def test_valid_doc(self):
23 21 """ check that a valid doc file leads to 0 exit status """
24 22 for filename in (
... ... @@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase):
59 57  
60 58 def do_test_validity(self, args, expect_error=False):
61 59 """ helper for test_valid_doc[x] """
62   - args = shlex.split(args)
63   - return_code = -1
64 60 have_exception = False
65 61 try:
66   - return_code = msodde.main(args)
  62 + msodde.process_file(args, msodde.FIELD_FILTER_BLACKLIST)
67 63 except Exception:
68 64 have_exception = True
69 65 print_exc()
70 66 except SystemExit as exc: # sys.exit() was called
71   - return_code = exc.code
  67 + have_exception = True
72 68 if exc.code is None:
73   - return_code = 0
  69 + have_exception = False
74 70  
75   - self.assertEqual(expect_error, have_exception or (return_code != 0),
76   - msg='Args={0}, expect={1}, exc={2}, return={3}'
77   - .format(args, expect_error, have_exception,
78   - return_code))
  71 + self.assertEqual(expect_error, have_exception,
  72 + msg='Args={0}, expect={1}, exc={2}'
  73 + .format(args, expect_error, have_exception))
79 74  
80 75  
81 76 class TestDdeLinks(unittest.TestCase):
82 77 """ capture output of msodde and check dde-links are found correctly """
83 78  
84   - def get_dde_from_output(self, capturer):
  79 + @staticmethod
  80 + def get_dde_from_output(output):
85 81 """ helper to read dde links from captured output
86   -
87   - duplicate in tests/msodde/test_csv
88 82 """
89   - have_start_line = False
90   - result = []
91   - for line in capturer:
92   - if not line.strip():
93   - continue # skip empty lines
94   - if have_start_line:
95   - result.append(line)
96   - elif line == 'DDE Links:':
97   - have_start_line = True
98   -
99   - self.assertTrue(have_start_line) # ensure output was complete
100   - return result
  83 + return [o for o in output.splitlines()]
101 84  
102 85 def test_with_dde(self):
103 86 """ check that dde links appear on stdout """
104 87 filename = 'dde-test-from-office2003.doc'
105   - with OutputCapture() as capturer:
106   - msodde.main([join(BASE_DIR, 'msodde', filename)])
107   - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0,
  88 + output = msodde.process_file(
  89 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  90 + self.assertNotEqual(len(self.get_dde_from_output(output)), 0,
108 91 msg='Found no dde links in output of ' + filename)
109 92  
110 93 def test_no_dde(self):
111 94 """ check that no dde links appear on stdout """
112 95 filename = 'harmless-clean.doc'
113   - with OutputCapture() as capturer:
114   - msodde.main([join(BASE_DIR, 'msodde', filename)])
115   - self.assertEqual(len(self.get_dde_from_output(capturer)), 0,
  96 + output = msodde.process_file(
  97 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  98 + self.assertEqual(len(self.get_dde_from_output(output)), 0,
116 99 msg='Found dde links in output of ' + filename)
117 100  
118 101 def test_with_dde_utf16le(self):
119 102 """ check that dde links appear on stdout """
120 103 filename = 'dde-test-from-office2013-utf_16le-korean.doc'
121   - with OutputCapture() as capturer:
122   - msodde.main([join(BASE_DIR, 'msodde', filename)])
123   - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0,
  104 + output = msodde.process_file(
  105 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  106 + self.assertNotEqual(len(self.get_dde_from_output(output)), 0,
124 107 msg='Found no dde links in output of ' + filename)
125 108  
126 109 def test_excel(self):
127 110 """ check that dde links are found in excel 2007+ files """
128 111 expect = ['DDE-Link cmd /c calc.exe', ]
129 112 for extn in 'xlsx', 'xlsm', 'xlsb':
130   - with OutputCapture() as capturer:
131   - msodde.main([join(BASE_DIR, 'msodde', 'dde-test.' + extn), ])
132   - self.assertEqual(expect, self.get_dde_from_output(capturer),
  113 + output = msodde.process_file(
  114 + join(BASE_DIR, 'msodde', 'dde-test.' + extn), msodde.FIELD_FILTER_BLACKLIST)
  115 +
  116 + self.assertEqual(expect, self.get_dde_from_output(output),
133 117 msg='unexpected output for dde-test.{0}: {1}'
134   - .format(extn, capturer.get_data()))
  118 + .format(extn, output))
135 119  
136 120 def test_xml(self):
137 121 """ check that dde in xml from word / excel is found """
138 122 for name_part in 'excel2003', 'word2003', 'word2007':
139 123 filename = 'dde-in-' + name_part + '.xml'
140   - with OutputCapture() as capturer:
141   - msodde.main([join(BASE_DIR, 'msodde', filename), ])
142   - links = self.get_dde_from_output(capturer)
  124 + output = msodde.process_file(
  125 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  126 + links = self.get_dde_from_output(output)
143 127 self.assertEqual(len(links), 1, 'found {0} dde-links in {1}'
144 128 .format(len(links), filename))
145 129 self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}'
... ... @@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase):
150 134 def test_clean_rtf_blacklist(self):
151 135 """ find a lot of hyperlinks in rtf spec """
152 136 filename = 'RTF-Spec-1.7.rtf'
153   - with OutputCapture() as capturer:
154   - msodde.main([join(BASE_DIR, 'msodde', filename)])
155   - self.assertEqual(len(self.get_dde_from_output(capturer)), 1413)
  137 + output = msodde.process_file(
  138 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  139 + self.assertEqual(len(self.get_dde_from_output(output)), 1413)
156 140  
157 141 def test_clean_rtf_ddeonly(self):
158 142 """ find no dde links in rtf spec """
159 143 filename = 'RTF-Spec-1.7.rtf'
160   - with OutputCapture() as capturer:
161   - msodde.main(['-d', join(BASE_DIR, 'msodde', filename)])
162   - self.assertEqual(len(self.get_dde_from_output(capturer)), 0,
  144 + output = msodde.process_file(
  145 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_DDE)
  146 + self.assertEqual(len(self.get_dde_from_output(output)), 0,
163 147 msg='Found dde links in output of ' + filename)
164 148  
165 149  
... ...
tests/msodde/test_csv.py
... ... @@ -9,7 +9,7 @@ import os
9 9 from os.path import join
10 10  
11 11 from oletools import msodde
12   -from tests.test_utils import OutputCapture, DATA_BASE_DIR
  12 +from tests.test_utils import DATA_BASE_DIR
13 13  
14 14  
15 15 class TestCSV(unittest.TestCase):
... ... @@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase):
69 69 def test_file(self):
70 70 """ test simple small example file """
71 71 filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv')
72   - with OutputCapture() as capturer:
73   - capturer.reload_module(msodde) # re-create logger
74   - ret_code = msodde.main([filename, ])
75   - self.assertEqual(ret_code, 0)
76   - links = self.get_dde_from_output(capturer)
  72 + output = msodde.process_file(filename, msodde.FIELD_FILTER_BLACKLIST)
  73 + links = self.get_dde_from_output(output)
77 74 self.assertEqual(len(links), 1)
78 75 self.assertEqual(links[0],
79 76 r"cmd '/k \..\..\..\Windows\System32\calc.exe'")
... ... @@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase):
91 88 if self.DO_DEBUG:
92 89 args += ['-l', 'debug']
93 90  
94   - with OutputCapture() as capturer:
95   - capturer.reload_module(msodde) # re-create logger
96   - ret_code = msodde.main(args)
97   - self.assertEqual(ret_code, 0, 'checking sample resulted in '
98   - 'error:\n' + sample_text)
99   - return capturer
  91 + processed_args = msodde.process_args(args)
  92 +
  93 + return msodde.process_file(
  94 + processed_args.filepath, processed_args.field_filter_mode)
100 95  
101 96 except Exception:
102 97 raise
... ... @@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase):
111 106 os.remove(filename)
112 107 filename = None # just in case
113 108  
114   - def get_dde_from_output(self, capturer):
  109 + @staticmethod
  110 + def get_dde_from_output(output):
115 111 """ helper to read dde links from captured output
116   -
117   - duplicate in tests/msodde/test_basic
118 112 """
119   - have_start_line = False
120   - result = []
121   - for line in capturer:
122   - if self.DO_DEBUG:
123   - print('captured: ' + line)
124   - if not line.strip():
125   - continue # skip empty lines
126   - if have_start_line:
127   - result.append(line)
128   - elif line == 'DDE Links:':
129   - have_start_line = True
130   -
131   - self.assertTrue(have_start_line) # ensure output was complete
132   - return result
  113 + return [o for o in output.splitlines()]
133 114  
134 115 def test_regex(self):
135 116 """ check that regex captures other ways to include dde commands
... ...
tests/test_utils/__init__.py
1   -from .output_capture import OutputCapture
2   -
3 1 from os.path import dirname, join
4 2  
5 3 # Directory with test data, independent of current working directory
... ...
tests/test_utils/output_capture.py deleted
1   -""" class OutputCapture to test what scripts print to stdout """
2   -
3   -from __future__ import print_function
4   -import sys
5   -import logging
6   -
7   -
8   -# python 2/3 version conflict:
9   -if sys.version_info.major <= 2:
10   - from StringIO import StringIO
11   - # reload is a builtin
12   -else:
13   - from io import StringIO
14   - if sys.version_info.minor < 4:
15   - from imp import reload
16   - else:
17   - from importlib import reload
18   -
19   -
20   -class OutputCapture:
21   - """ context manager that captures stdout
22   -
23   - use as follows::
24   -
25   - with OutputCapture() as capturer:
26   - run_my_script(some_args)
27   -
28   - # either test line-by-line ...
29   - for line in capturer:
30   - some_test(line)
31   - # ...or test all output in one go
32   - some_test(capturer.get_data())
33   -
34   - In order to solve issues with old logger instances still remembering closed
35   - StringIO instances as "their" stdout, logging is shutdown and restarted
36   - upon entering this Context Manager. This means that you may have to reload
37   - your module, as well.
38   - """
39   -
40   - def __init__(self):
41   - self.buffer = StringIO()
42   - self.orig_stdout = None
43   - self.data = None
44   -
45   - def __enter__(self):
46   - # Avoid problems with old logger instances that still remember an old
47   - # closed StringIO as their sys.stdout
48   - logging.shutdown()
49   - reload(logging)
50   -
51   - # replace sys.stdout with own buffer.
52   - self.orig_stdout = sys.stdout
53   - sys.stdout = self.buffer
54   - return self
55   -
56   - def __exit__(self, exc_type, exc_value, traceback):
57   - sys.stdout = self.orig_stdout # re-set to original
58   - self.data = self.buffer.getvalue()
59   - self.buffer.close() # close buffer
60   - self.buffer = None
61   -
62   - if exc_type: # there has been an error
63   - print('Got error during output capture!')
64   - print('Print captured output and re-raise:')
65   - for line in self.data.splitlines():
66   - print(line.rstrip()) # print output before re-raising
67   -
68   - def get_data(self):
69   - """ retrieve all the captured data """
70   - if self.buffer is not None:
71   - return self.buffer.getvalue()
72   - elif self.data is not None:
73   - return self.data
74   - else: # should not be possible
75   - raise RuntimeError('programming error or someone messed with data!')
76   -
77   - def __iter__(self):
78   - for line in self.get_data().splitlines():
79   - yield line
80   -
81   - def reload_module(self, mod):
82   - """ Wrapper around reload function for different python versions """
83   - return reload(mod)