Commit 1542df504ef321511df84f15c16a15aed252c58c

Authored by Philippe Lagadec
Committed by GitHub
2 parents 26b43390 911b2732

Merge pull request #308 from christian-intra2net/central-logger-json

Unified logging with json option
oletools/common/log_helper/__init__.py 0 → 100644
  1 +from . import log_helper as log_helper_
  2 +
  3 +log_helper = log_helper_.LogHelper()
  4 +
  5 +__all__ = ['log_helper']
oletools/common/log_helper/_json_formatter.py 0 → 100644
  1 +import logging
  2 +import json
  3 +
  4 +
  5 +class JsonFormatter(logging.Formatter):
  6 + """
  7 + Format every message to be logged as a JSON object
  8 + """
  9 + _is_first_line = True
  10 +
  11 + def format(self, record):
  12 + """
  13 + Since we don't buffer messages, we always prepend messages with a comma to make
  14 + the output JSON-compatible. The only exception is when printing the first line,
  15 + so we need to keep track of it.
  16 + """
  17 + json_dict = dict(msg=record.msg, level=record.levelname)
  18 + formatted_message = ' ' + json.dumps(json_dict)
  19 +
  20 + if self._is_first_line:
  21 + self._is_first_line = False
  22 + return formatted_message
  23 +
  24 + return ', ' + formatted_message
oletools/common/log_helper/_logger_adapter.py 0 → 100644
  1 +import logging
  2 +from . import _root_logger_wrapper
  3 +
  4 +
  5 +class OletoolsLoggerAdapter(logging.LoggerAdapter):
  6 + """
  7 + Adapter class for all loggers returned by the logging module.
  8 + """
  9 + _json_enabled = None
  10 +
  11 + def print_str(self, message):
  12 + """
  13 + This function replaces normal print() calls so we can format them as JSON
  14 + when needed or just print them right away otherwise.
  15 + """
  16 + if self._json_enabled and self._json_enabled():
  17 + # Messages from this function should always be printed,
  18 + # so when using JSON we log using the same level that set
  19 + self.log(_root_logger_wrapper.level(), message)
  20 + else:
  21 + print(message)
  22 +
  23 + def set_json_enabled_function(self, json_enabled):
  24 + """
  25 + Set a function to be called to check whether JSON output is enabled.
  26 + """
  27 + self._json_enabled = json_enabled
  28 +
  29 + def level(self):
  30 + return self.logger.level
oletools/common/log_helper/_root_logger_wrapper.py 0 → 100644
  1 +import logging
  2 +
  3 +
  4 +def is_logging_initialized():
  5 + """
  6 + We use the same strategy as the logging module when checking if
  7 + the logging was initialized - look for handlers in the root logger
  8 + """
  9 + return len(logging.root.handlers) > 0
  10 +
  11 +
  12 +def set_formatter(fmt):
  13 + """
  14 + Set the formatter to be used by every handler of the root logger.
  15 + """
  16 + if not is_logging_initialized():
  17 + return
  18 +
  19 + for handler in logging.root.handlers:
  20 + handler.setFormatter(fmt)
  21 +
  22 +
  23 +def level():
  24 + return logging.root.level
oletools/common/log_helper/log_helper.py 0 → 100644
  1 +"""
  2 +log_helper.py
  3 +
  4 +General logging helpers
  5 +
  6 +.. codeauthor:: Intra2net AG <info@intra2net>
  7 +"""
  8 +
  9 +# === LICENSE =================================================================
  10 +#
  11 +# Redistribution and use in source and binary forms, with or without
  12 +# modification, are permitted provided that the following conditions are met:
  13 +#
  14 +# * Redistributions of source code must retain the above copyright notice,
  15 +# this list of conditions and the following disclaimer.
  16 +# * Redistributions in binary form must reproduce the above copyright notice,
  17 +# this list of conditions and the following disclaimer in the documentation
  18 +# and/or other materials provided with the distribution.
  19 +#
  20 +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21 +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22 +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  23 +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  24 +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  25 +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  26 +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  27 +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  28 +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  29 +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  30 +# POSSIBILITY OF SUCH DAMAGE.
  31 +
  32 +# -----------------------------------------------------------------------------
  33 +# CHANGELOG:
  34 +# 2017-12-07 v0.01 CH: - first version
  35 +# 2018-02-05 v0.02 SA: - fixed log level selection and reformatted code
  36 +# 2018-02-06 v0.03 SA: - refactored code to deal with NullHandlers
  37 +# 2018-02-07 v0.04 SA: - fixed control of handlers propagation
  38 +# 2018-04-23 v0.05 SA: - refactored the whole logger to use an OOP approach
  39 +
  40 +# -----------------------------------------------------------------------------
  41 +# TODO:
  42 +
  43 +
  44 +from ._json_formatter import JsonFormatter
  45 +from ._logger_adapter import OletoolsLoggerAdapter
  46 +from . import _root_logger_wrapper
  47 +import logging
  48 +import sys
  49 +
  50 +
  51 +LOG_LEVELS = {
  52 + 'debug': logging.DEBUG,
  53 + 'info': logging.INFO,
  54 + 'warning': logging.WARNING,
  55 + 'error': logging.ERROR,
  56 + 'critical': logging.CRITICAL
  57 +}
  58 +
  59 +DEFAULT_LOGGER_NAME = 'oletools'
  60 +DEFAULT_MESSAGE_FORMAT = '%(levelname)-8s %(message)s'
  61 +
  62 +
  63 +class LogHelper:
  64 + def __init__(self):
  65 + self._all_names = set() # set so we do not have duplicates
  66 + self._use_json = False
  67 + self._is_enabled = False
  68 +
  69 + def get_or_create_silent_logger(self, name=DEFAULT_LOGGER_NAME, level=logging.CRITICAL + 1):
  70 + """
  71 + Get a logger or create one if it doesn't exist, setting a NullHandler
  72 + as the handler (to avoid printing to the console).
  73 + By default we also use a higher logging level so every message will
  74 + be ignored.
  75 + This will prevent oletools from logging unnecessarily when being imported
  76 + from external tools.
  77 + """
  78 + return self._get_or_create_logger(name, level, logging.NullHandler())
  79 +
  80 + def enable_logging(self, use_json, level, log_format=DEFAULT_MESSAGE_FORMAT, stream=None):
  81 + """
  82 + This function initializes the root logger and enables logging.
  83 + We set the level of the root logger to the one passed by calling logging.basicConfig.
  84 + We also set the level of every logger we created to 0 (logging.NOTSET), meaning that
  85 + the level of the root logger will be used to tell if messages should be logged.
  86 + Additionally, since our loggers use the NullHandler, they won't log anything themselves,
  87 + but due to having propagation enabled they will pass messages to the root logger,
  88 + which in turn will log to the stream set in this function.
  89 + Since the root logger is the one doing the work, when using JSON we set its formatter
  90 + so that every message logged is JSON-compatible.
  91 + """
  92 + if self._is_enabled:
  93 + raise ValueError('re-enabling logging. Not sure whether that is ok...')
  94 +
  95 + log_level = LOG_LEVELS[level]
  96 + logging.basicConfig(level=log_level, format=log_format, stream=stream)
  97 + self._is_enabled = True
  98 +
  99 + self._use_json = use_json
  100 + sys.excepthook = self._get_except_hook(sys.excepthook)
  101 +
  102 + # since there could be loggers already created we go through all of them
  103 + # and set their levels to 0 so they will use the root logger's level
  104 + for name in self._all_names:
  105 + logger = self.get_or_create_silent_logger(name)
  106 + self._set_logger_level(logger, logging.NOTSET)
  107 +
  108 + # add a JSON formatter to the root logger, which will be used by every logger
  109 + if self._use_json:
  110 + _root_logger_wrapper.set_formatter(JsonFormatter())
  111 + print('[')
  112 +
  113 + def end_logging(self):
  114 + """
  115 + Must be called at the end of the main function if the caller wants
  116 + json-compatible output
  117 + """
  118 + if not self._is_enabled:
  119 + return
  120 + self._is_enabled = False
  121 +
  122 + # end logging
  123 + self._all_names = set()
  124 + logging.shutdown()
  125 +
  126 + # end json list
  127 + if self._use_json:
  128 + print(']')
  129 + self._use_json = False
  130 +
  131 + def _get_except_hook(self, old_hook):
  132 + """
  133 + Global hook for exceptions so we can always end logging.
  134 + We wrap any hook currently set to avoid overwriting global hooks set by oletools.
  135 + Note that this is only called by enable_logging, which in turn is called by
  136 + the main() function in oletools' scripts. When scripts are being imported this
  137 + code won't execute and won't affect global hooks.
  138 + """
  139 + def hook(exctype, value, traceback):
  140 + self.end_logging()
  141 + old_hook(exctype, value, traceback)
  142 +
  143 + return hook
  144 +
  145 + def _get_or_create_logger(self, name, level, handler=None):
  146 + """
  147 + Get or create a new logger. This newly created logger will have the
  148 + handler and level that was passed, but if it already exists it's not changed.
  149 + We also wrap the logger in an adapter so we can easily extend its functionality.
  150 + """
  151 +
  152 + # logging.getLogger creates a logger if it doesn't exist,
  153 + # so we need to check before calling it
  154 + if handler and not self._log_exists(name):
  155 + logger = logging.getLogger(name)
  156 + logger.addHandler(handler)
  157 + self._set_logger_level(logger, level)
  158 + else:
  159 + logger = logging.getLogger(name)
  160 +
  161 + # Keep track of every logger we created so we can easily change
  162 + # their levels whenever needed
  163 + self._all_names.add(name)
  164 +
  165 + adapted_logger = OletoolsLoggerAdapter(logger, None)
  166 + adapted_logger.set_json_enabled_function(lambda: self._use_json)
  167 +
  168 + return adapted_logger
  169 +
  170 + @staticmethod
  171 + def _set_logger_level(logger, level):
  172 + """
  173 + If the logging is already initialized, we set the level of our logger
  174 + to 0, meaning that it will reuse the level of the root logger.
  175 + That means that if the root logger level changes, we will keep using
  176 + its level and not logging unnecessarily.
  177 + """
  178 +
  179 + # if this log was wrapped, unwrap it to set the level
  180 + if isinstance(logger, OletoolsLoggerAdapter):
  181 + logger = logger.logger
  182 +
  183 + if _root_logger_wrapper.is_logging_initialized():
  184 + logger.setLevel(logging.NOTSET)
  185 + else:
  186 + logger.setLevel(level)
  187 +
  188 + @staticmethod
  189 + def _log_exists(name):
  190 + """
  191 + We check the log manager instead of our global _all_names variable
  192 + since the logger could have been created outside of the helper
  193 + """
  194 + return name in logging.Logger.manager.loggerDict
oletools/msodde.py
@@ -53,8 +53,6 @@ import argparse @@ -53,8 +53,6 @@ import argparse
53 import os 53 import os
54 from os.path import abspath, dirname 54 from os.path import abspath, dirname
55 import sys 55 import sys
56 -import json  
57 -import logging  
58 import re 56 import re
59 import csv 57 import csv
60 58
@@ -63,6 +61,7 @@ import olefile @@ -63,6 +61,7 @@ import olefile
63 from oletools import ooxml 61 from oletools import ooxml
64 from oletools import xls_parser 62 from oletools import xls_parser
65 from oletools import rtfobj 63 from oletools import rtfobj
  64 +from oletools.common.log_helper import log_helper
66 65
67 # ----------------------------------------------------------------------------- 66 # -----------------------------------------------------------------------------
68 # CHANGELOG: 67 # CHANGELOG:
@@ -212,63 +211,12 @@ THIS IS WORK IN PROGRESS - Check updates regularly! @@ -212,63 +211,12 @@ THIS IS WORK IN PROGRESS - Check updates regularly!
212 Please report any issue at https://github.com/decalage2/oletools/issues 211 Please report any issue at https://github.com/decalage2/oletools/issues
213 """ % __version__ 212 """ % __version__
214 213
215 -BANNER_JSON = dict(type='meta', version=__version__, name='msodde',  
216 - link='http://decalage.info/python/oletools',  
217 - message='THIS IS WORK IN PROGRESS - Check updates regularly! '  
218 - 'Please report any issue at '  
219 - 'https://github.com/decalage2/oletools/issues')  
220 -  
221 # === LOGGING ================================================================= 214 # === LOGGING =================================================================
222 215
223 DEFAULT_LOG_LEVEL = "warning" # Default log level 216 DEFAULT_LOG_LEVEL = "warning" # Default log level
224 -LOG_LEVELS = {  
225 - 'debug': logging.DEBUG,  
226 - 'info': logging.INFO,  
227 - 'warning': logging.WARNING,  
228 - 'error': logging.ERROR,  
229 - 'critical': logging.CRITICAL  
230 -}  
231 -  
232 -  
233 -class NullHandler(logging.Handler):  
234 - """  
235 - Log Handler without output, to avoid printing messages if logging is not  
236 - configured by the main application.  
237 - Python 2.7 has logging.NullHandler, but this is necessary for 2.6:  
238 - see https://docs.python.org/2.6/library/logging.html#configuring-logging-for-a-library  
239 - """  
240 - def emit(self, record):  
241 - pass  
242 -  
243 -  
244 -def get_logger(name, level=logging.CRITICAL+1):  
245 - """  
246 - Create a suitable logger object for this module.  
247 - The goal is not to change settings of the root logger, to avoid getting  
248 - other modules' logs on the screen.  
249 - If a logger exists with same name, reuse it. (Else it would have duplicate  
250 - handlers and messages would be doubled.)  
251 - The level is set to CRITICAL+1 by default, to avoid any logging.  
252 - """  
253 - # First, test if there is already a logger with the same name, else it  
254 - # will generate duplicate messages (due to duplicate handlers):  
255 - if name in logging.Logger.manager.loggerDict:  
256 - # NOTE: another less intrusive but more "hackish" solution would be to  
257 - # use getLogger then test if its effective level is not default.  
258 - logger = logging.getLogger(name)  
259 - # make sure level is OK:  
260 - logger.setLevel(level)  
261 - return logger  
262 - # get a new logger:  
263 - logger = logging.getLogger(name)  
264 - # only add a NullHandler for this logger, it is up to the application  
265 - # to configure its own logging:  
266 - logger.addHandler(NullHandler())  
267 - logger.setLevel(level)  
268 - return logger  
269 217
270 # a global logger object used for debugging: 218 # a global logger object used for debugging:
271 -log = get_logger('msodde') 219 +logger = log_helper.get_or_create_silent_logger('msodde')
272 220
273 221
274 # === UNICODE IN PY2 ========================================================= 222 # === UNICODE IN PY2 =========================================================
@@ -312,7 +260,7 @@ def ensure_stdout_handles_unicode(): @@ -312,7 +260,7 @@ def ensure_stdout_handles_unicode():
312 encoding = 'utf8' 260 encoding = 'utf8'
313 261
314 # logging is probably not initialized yet, but just in case 262 # logging is probably not initialized yet, but just in case
315 - log.debug('wrapping sys.stdout with encoder using {0}'.format(encoding)) 263 + logger.debug('wrapping sys.stdout with encoder using {0}'.format(encoding))
316 264
317 wrapper = codecs.getwriter(encoding) 265 wrapper = codecs.getwriter(encoding)
318 sys.stdout = wrapper(sys.stdout) 266 sys.stdout = wrapper(sys.stdout)
@@ -396,7 +344,7 @@ def process_doc_field(data): @@ -396,7 +344,7 @@ def process_doc_field(data):
396 """ check if field instructions start with DDE 344 """ check if field instructions start with DDE
397 345
398 expects unicode input, returns unicode output (empty if not dde) """ 346 expects unicode input, returns unicode output (empty if not dde) """
399 - log.debug('processing field {0}'.format(data)) 347 + logger.debug('processing field {0}'.format(data))
400 348
401 if data.lstrip().lower().startswith(u'dde'): 349 if data.lstrip().lower().startswith(u'dde'):
402 return data 350 return data
@@ -434,7 +382,7 @@ def process_doc_stream(stream): @@ -434,7 +382,7 @@ def process_doc_stream(stream):
434 382
435 if char == OLE_FIELD_START: 383 if char == OLE_FIELD_START:
436 if have_start and max_size_exceeded: 384 if have_start and max_size_exceeded:
437 - log.debug('big field was not a field after all') 385 + logger.debug('big field was not a field after all')
438 have_start = True 386 have_start = True
439 have_sep = False 387 have_sep = False
440 max_size_exceeded = False 388 max_size_exceeded = False
@@ -446,7 +394,7 @@ def process_doc_stream(stream): @@ -446,7 +394,7 @@ def process_doc_stream(stream):
446 # now we are after start char but not at end yet 394 # now we are after start char but not at end yet
447 if char == OLE_FIELD_SEP: 395 if char == OLE_FIELD_SEP:
448 if have_sep: 396 if have_sep:
449 - log.debug('unexpected field: has multiple separators!') 397 + logger.debug('unexpected field: has multiple separators!')
450 have_sep = True 398 have_sep = True
451 elif char == OLE_FIELD_END: 399 elif char == OLE_FIELD_END:
452 # have complete field now, process it 400 # have complete field now, process it
@@ -464,7 +412,7 @@ def process_doc_stream(stream): @@ -464,7 +412,7 @@ def process_doc_stream(stream):
464 if max_size_exceeded: 412 if max_size_exceeded:
465 pass 413 pass
466 elif len(field_contents) > OLE_FIELD_MAX_SIZE: 414 elif len(field_contents) > OLE_FIELD_MAX_SIZE:
467 - log.debug('field exceeds max size of {0}. Ignore rest' 415 + logger.debug('field exceeds max size of {0}. Ignore rest'
468 .format(OLE_FIELD_MAX_SIZE)) 416 .format(OLE_FIELD_MAX_SIZE))
469 max_size_exceeded = True 417 max_size_exceeded = True
470 418
@@ -482,9 +430,9 @@ def process_doc_stream(stream): @@ -482,9 +430,9 @@ def process_doc_stream(stream):
482 field_contents += u'?' 430 field_contents += u'?'
483 431
484 if max_size_exceeded: 432 if max_size_exceeded:
485 - log.debug('big field was not a field after all') 433 + logger.debug('big field was not a field after all')
486 434
487 - log.debug('Checked {0} characters, found {1} fields' 435 + logger.debug('Checked {0} characters, found {1} fields'
488 .format(idx, len(result_parts))) 436 .format(idx, len(result_parts)))
489 437
490 return result_parts 438 return result_parts
@@ -498,7 +446,7 @@ def process_doc(filepath): @@ -498,7 +446,7 @@ def process_doc(filepath):
498 empty if none were found. dde-links will still begin with the dde[auto] key 446 empty if none were found. dde-links will still begin with the dde[auto] key
499 word (possibly after some whitespace) 447 word (possibly after some whitespace)
500 """ 448 """
501 - log.debug('process_doc') 449 + logger.debug('process_doc')
502 ole = olefile.OleFileIO(filepath, path_encoding=None) 450 ole = olefile.OleFileIO(filepath, path_encoding=None)
503 451
504 links = [] 452 links = []
@@ -508,7 +456,7 @@ def process_doc(filepath): @@ -508,7 +456,7 @@ def process_doc(filepath):
508 # this direntry is not part of the tree --> unused or orphan 456 # this direntry is not part of the tree --> unused or orphan
509 direntry = ole._load_direntry(sid) 457 direntry = ole._load_direntry(sid)
510 is_stream = direntry.entry_type == olefile.STGTY_STREAM 458 is_stream = direntry.entry_type == olefile.STGTY_STREAM
511 - log.debug('direntry {:2d} {}: {}' 459 + logger.debug('direntry {:2d} {}: {}'
512 .format(sid, '[orphan]' if is_orphan else direntry.name, 460 .format(sid, '[orphan]' if is_orphan else direntry.name,
513 'is stream of size {}'.format(direntry.size) 461 'is stream of size {}'.format(direntry.size)
514 if is_stream else 462 if is_stream else
@@ -593,7 +541,7 @@ def process_docx(filepath, field_filter_mode=None): @@ -593,7 +541,7 @@ def process_docx(filepath, field_filter_mode=None):
593 ddetext += unquote(elem.text) 541 ddetext += unquote(elem.text)
594 542
595 # apply field command filter 543 # apply field command filter
596 - log.debug('filtering with mode "{0}"'.format(field_filter_mode)) 544 + logger.debug('filtering with mode "{0}"'.format(field_filter_mode))
597 if field_filter_mode in (FIELD_FILTER_ALL, None): 545 if field_filter_mode in (FIELD_FILTER_ALL, None):
598 clean_fields = all_fields 546 clean_fields = all_fields
599 elif field_filter_mode == FIELD_FILTER_DDE: 547 elif field_filter_mode == FIELD_FILTER_DDE:
@@ -652,7 +600,7 @@ def field_is_blacklisted(contents): @@ -652,7 +600,7 @@ def field_is_blacklisted(contents):
652 index = FIELD_BLACKLIST_CMDS.index(words[0].lower()) 600 index = FIELD_BLACKLIST_CMDS.index(words[0].lower())
653 except ValueError: # first word is no blacklisted command 601 except ValueError: # first word is no blacklisted command
654 return False 602 return False
655 - log.debug('trying to match "{0}" to blacklist command {1}' 603 + logger.debug('trying to match "{0}" to blacklist command {1}'
656 .format(contents, FIELD_BLACKLIST[index])) 604 .format(contents, FIELD_BLACKLIST[index]))
657 _, nargs_required, nargs_optional, sw_with_arg, sw_solo, sw_format \ 605 _, nargs_required, nargs_optional, sw_with_arg, sw_solo, sw_format \
658 = FIELD_BLACKLIST[index] 606 = FIELD_BLACKLIST[index]
@@ -664,11 +612,11 @@ def field_is_blacklisted(contents): @@ -664,11 +612,11 @@ def field_is_blacklisted(contents):
664 break 612 break
665 nargs += 1 613 nargs += 1
666 if nargs < nargs_required: 614 if nargs < nargs_required:
667 - log.debug('too few args: found {0}, but need at least {1} in "{2}"' 615 + logger.debug('too few args: found {0}, but need at least {1} in "{2}"'
668 .format(nargs, nargs_required, contents)) 616 .format(nargs, nargs_required, contents))
669 return False 617 return False
670 elif nargs > nargs_required + nargs_optional: 618 elif nargs > nargs_required + nargs_optional:
671 - log.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"' 619 + logger.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"'
672 .format(nargs, nargs_required, nargs_optional, contents)) 620 .format(nargs, nargs_required, nargs_optional, contents))
673 return False 621 return False
674 622
@@ -678,14 +626,14 @@ def field_is_blacklisted(contents): @@ -678,14 +626,14 @@ def field_is_blacklisted(contents):
678 for word in words[1+nargs:]: 626 for word in words[1+nargs:]:
679 if expect_arg: # this is an argument for the last switch 627 if expect_arg: # this is an argument for the last switch
680 if arg_choices and (word not in arg_choices): 628 if arg_choices and (word not in arg_choices):
681 - log.debug('Found invalid switch argument "{0}" in "{1}"' 629 + logger.debug('Found invalid switch argument "{0}" in "{1}"'
682 .format(word, contents)) 630 .format(word, contents))
683 return False 631 return False
684 expect_arg = False 632 expect_arg = False
685 arg_choices = [] # in general, do not enforce choices 633 arg_choices = [] # in general, do not enforce choices
686 continue # "no further questions, your honor" 634 continue # "no further questions, your honor"
687 elif not FIELD_SWITCH_REGEX.match(word): 635 elif not FIELD_SWITCH_REGEX.match(word):
688 - log.debug('expected switch, found "{0}" in "{1}"' 636 + logger.debug('expected switch, found "{0}" in "{1}"'
689 .format(word, contents)) 637 .format(word, contents))
690 return False 638 return False
691 # we want a switch and we got a valid one 639 # we want a switch and we got a valid one
@@ -707,7 +655,7 @@ def field_is_blacklisted(contents): @@ -707,7 +655,7 @@ def field_is_blacklisted(contents):
707 if 'numeric' in sw_format: 655 if 'numeric' in sw_format:
708 arg_choices = [] # too many choices to list them here 656 arg_choices = [] # too many choices to list them here
709 else: 657 else:
710 - log.debug('unexpected switch {0} in "{1}"' 658 + logger.debug('unexpected switch {0} in "{1}"'
711 .format(switch, contents)) 659 .format(switch, contents))
712 return False 660 return False
713 661
@@ -733,11 +681,11 @@ def process_xlsx(filepath): @@ -733,11 +681,11 @@ def process_xlsx(filepath):
733 # binary parts, e.g. contained in .xlsb 681 # binary parts, e.g. contained in .xlsb
734 for subfile, content_type, handle in parser.iter_non_xml(): 682 for subfile, content_type, handle in parser.iter_non_xml():
735 try: 683 try:
736 - logging.info('Parsing non-xml subfile {0} with content type {1}' 684 + logger.info('Parsing non-xml subfile {0} with content type {1}'
737 .format(subfile, content_type)) 685 .format(subfile, content_type))
738 for record in xls_parser.parse_xlsb_part(handle, content_type, 686 for record in xls_parser.parse_xlsb_part(handle, content_type,
739 subfile): 687 subfile):
740 - logging.debug('{0}: {1}'.format(subfile, record)) 688 + logger.debug('{0}: {1}'.format(subfile, record))
741 if isinstance(record, xls_parser.XlsbBeginSupBook) and \ 689 if isinstance(record, xls_parser.XlsbBeginSupBook) and \
742 record.link_type == \ 690 record.link_type == \
743 xls_parser.XlsbBeginSupBook.LINK_TYPE_DDE: 691 xls_parser.XlsbBeginSupBook.LINK_TYPE_DDE:
@@ -747,14 +695,14 @@ def process_xlsx(filepath): @@ -747,14 +695,14 @@ def process_xlsx(filepath):
747 if content_type.startswith('application/vnd.ms-excel.') or \ 695 if content_type.startswith('application/vnd.ms-excel.') or \
748 content_type.startswith('application/vnd.ms-office.'): # pylint: disable=bad-indentation 696 content_type.startswith('application/vnd.ms-office.'): # pylint: disable=bad-indentation
749 # should really be able to parse these either as xml or records 697 # should really be able to parse these either as xml or records
750 - log_func = logging.warning 698 + log_func = logger.warning
751 elif content_type.startswith('image/') or content_type == \ 699 elif content_type.startswith('image/') or content_type == \
752 'application/vnd.openxmlformats-officedocument.' + \ 700 'application/vnd.openxmlformats-officedocument.' + \
753 'spreadsheetml.printerSettings': 701 'spreadsheetml.printerSettings':
754 # understandable that these are not record-base 702 # understandable that these are not record-base
755 - log_func = logging.debug 703 + log_func = logger.debug
756 else: # default 704 else: # default
757 - log_func = logging.info 705 + log_func = logger.info
758 log_func('Failed to parse {0} of content type {1}' 706 log_func('Failed to parse {0} of content type {1}'
759 .format(subfile, content_type)) 707 .format(subfile, content_type))
760 # in any case: continue with next 708 # in any case: continue with next
@@ -774,15 +722,15 @@ class RtfFieldParser(rtfobj.RtfParser): @@ -774,15 +722,15 @@ class RtfFieldParser(rtfobj.RtfParser):
774 722
775 def open_destination(self, destination): 723 def open_destination(self, destination):
776 if destination.cword == b'fldinst': 724 if destination.cword == b'fldinst':
777 - log.debug('*** Start field data at index %Xh' % destination.start) 725 + logger.debug('*** Start field data at index %Xh' % destination.start)
778 726
779 def close_destination(self, destination): 727 def close_destination(self, destination):
780 if destination.cword == b'fldinst': 728 if destination.cword == b'fldinst':
781 - log.debug('*** Close field data at index %Xh' % self.index)  
782 - log.debug('Field text: %r' % destination.data) 729 + logger.debug('*** Close field data at index %Xh' % self.index)
  730 + logger.debug('Field text: %r' % destination.data)
783 # remove extra spaces and newline chars: 731 # remove extra spaces and newline chars:
784 field_clean = destination.data.translate(None, b'\r\n').strip() 732 field_clean = destination.data.translate(None, b'\r\n').strip()
785 - log.debug('Cleaned Field text: %r' % field_clean) 733 + logger.debug('Cleaned Field text: %r' % field_clean)
786 self.fields.append(field_clean) 734 self.fields.append(field_clean)
787 735
788 def control_symbol(self, matchobject): 736 def control_symbol(self, matchobject):
@@ -804,7 +752,7 @@ def process_rtf(file_handle, field_filter_mode=None): @@ -804,7 +752,7 @@ def process_rtf(file_handle, field_filter_mode=None):
804 rtfparser.parse() 752 rtfparser.parse()
805 all_fields = [field.decode('ascii') for field in rtfparser.fields] 753 all_fields = [field.decode('ascii') for field in rtfparser.fields]
806 # apply field command filter 754 # apply field command filter
807 - log.debug('found {1} fields, filtering with mode "{0}"' 755 + logger.debug('found {1} fields, filtering with mode "{0}"'
808 .format(field_filter_mode, len(all_fields))) 756 .format(field_filter_mode, len(all_fields)))
809 if field_filter_mode in (FIELD_FILTER_ALL, None): 757 if field_filter_mode in (FIELD_FILTER_ALL, None):
810 clean_fields = all_fields 758 clean_fields = all_fields
@@ -853,7 +801,7 @@ def process_csv(filepath): @@ -853,7 +801,7 @@ def process_csv(filepath):
853 801
854 if is_small and not results: 802 if is_small and not results:
855 # easy to mis-sniff small files. Try different delimiters 803 # easy to mis-sniff small files. Try different delimiters
856 - log.debug('small file, no results; try all delimiters') 804 + logger.debug('small file, no results; try all delimiters')
857 file_handle.seek(0) 805 file_handle.seek(0)
858 other_delim = CSV_DELIMITERS.replace(dialect.delimiter, '') 806 other_delim = CSV_DELIMITERS.replace(dialect.delimiter, '')
859 for delim in other_delim: 807 for delim in other_delim:
@@ -861,12 +809,12 @@ def process_csv(filepath): @@ -861,12 +809,12 @@ def process_csv(filepath):
861 file_handle.seek(0) 809 file_handle.seek(0)
862 results, _ = process_csv_dialect(file_handle, delim) 810 results, _ = process_csv_dialect(file_handle, delim)
863 except csv.Error: # e.g. sniffing fails 811 except csv.Error: # e.g. sniffing fails
864 - log.debug('failed to csv-parse with delimiter {0!r}' 812 + logger.debug('failed to csv-parse with delimiter {0!r}'
865 .format(delim)) 813 .format(delim))
866 814
867 if is_small and not results: 815 if is_small and not results:
868 # try whole file as single cell, since sniffing fails in this case 816 # try whole file as single cell, since sniffing fails in this case
869 - log.debug('last attempt: take whole file as single unquoted cell') 817 + logger.debug('last attempt: take whole file as single unquoted cell')
870 file_handle.seek(0) 818 file_handle.seek(0)
871 match = CSV_DDE_FORMAT.match(file_handle.read(CSV_SMALL_THRESH)) 819 match = CSV_DDE_FORMAT.match(file_handle.read(CSV_SMALL_THRESH))
872 if match: 820 if match:
@@ -882,7 +830,7 @@ def process_csv_dialect(file_handle, delimiters): @@ -882,7 +830,7 @@ def process_csv_dialect(file_handle, delimiters):
882 dialect = csv.Sniffer().sniff(file_handle.read(CSV_SMALL_THRESH), 830 dialect = csv.Sniffer().sniff(file_handle.read(CSV_SMALL_THRESH),
883 delimiters=delimiters) 831 delimiters=delimiters)
884 dialect.strict = False # microsoft is never strict 832 dialect.strict = False # microsoft is never strict
885 - log.debug('sniffed csv dialect with delimiter {0!r} ' 833 + logger.debug('sniffed csv dialect with delimiter {0!r} '
886 'and quote char {1!r}' 834 'and quote char {1!r}'
887 .format(dialect.delimiter, dialect.quotechar)) 835 .format(dialect.delimiter, dialect.quotechar))
888 836
@@ -924,7 +872,7 @@ def process_excel_xml(filepath): @@ -924,7 +872,7 @@ def process_excel_xml(filepath):
924 break 872 break
925 if formula is None: 873 if formula is None:
926 continue 874 continue
927 - log.debug('found cell with formula {0}'.format(formula)) 875 + logger.debug('found cell with formula {0}'.format(formula))
928 match = re.match(XML_DDE_FORMAT, formula) 876 match = re.match(XML_DDE_FORMAT, formula)
929 if match: 877 if match:
930 dde_links.append(u' '.join(match.groups()[:2])) 878 dde_links.append(u' '.join(match.groups()[:2]))
@@ -934,40 +882,40 @@ def process_excel_xml(filepath): @@ -934,40 +882,40 @@ def process_excel_xml(filepath):
934 def process_file(filepath, field_filter_mode=None): 882 def process_file(filepath, field_filter_mode=None):
935 """ decides which of the process_* functions to call """ 883 """ decides which of the process_* functions to call """
936 if olefile.isOleFile(filepath): 884 if olefile.isOleFile(filepath):
937 - log.debug('Is OLE. Checking streams to see whether this is xls') 885 + logger.debug('Is OLE. Checking streams to see whether this is xls')
938 if xls_parser.is_xls(filepath): 886 if xls_parser.is_xls(filepath):
939 - log.debug('Process file as excel 2003 (xls)') 887 + logger.debug('Process file as excel 2003 (xls)')
940 return process_xls(filepath) 888 return process_xls(filepath)
941 else: 889 else:
942 - log.debug('Process file as word 2003 (doc)') 890 + logger.debug('Process file as word 2003 (doc)')
943 return process_doc(filepath) 891 return process_doc(filepath)
944 892
945 with open(filepath, 'rb') as file_handle: 893 with open(filepath, 'rb') as file_handle:
946 if file_handle.read(4) == RTF_START: 894 if file_handle.read(4) == RTF_START:
947 - log.debug('Process file as rtf') 895 + logger.debug('Process file as rtf')
948 return process_rtf(file_handle, field_filter_mode) 896 return process_rtf(file_handle, field_filter_mode)
949 897
950 try: 898 try:
951 doctype = ooxml.get_type(filepath) 899 doctype = ooxml.get_type(filepath)
952 - log.debug('Detected file type: {0}'.format(doctype)) 900 + logger.debug('Detected file type: {0}'.format(doctype))
953 except Exception as exc: 901 except Exception as exc:
954 - log.debug('Exception trying to xml-parse file: {0}'.format(exc)) 902 + logger.debug('Exception trying to xml-parse file: {0}'.format(exc))
955 doctype = None 903 doctype = None
956 904
957 if doctype == ooxml.DOCTYPE_EXCEL: 905 if doctype == ooxml.DOCTYPE_EXCEL:
958 - log.debug('Process file as excel 2007+ (xlsx)') 906 + logger.debug('Process file as excel 2007+ (xlsx)')
959 return process_xlsx(filepath) 907 return process_xlsx(filepath)
960 elif doctype in (ooxml.DOCTYPE_EXCEL_XML, ooxml.DOCTYPE_EXCEL_XML2003): 908 elif doctype in (ooxml.DOCTYPE_EXCEL_XML, ooxml.DOCTYPE_EXCEL_XML2003):
961 - log.debug('Process file as xml from excel 2003/2007+') 909 + logger.debug('Process file as xml from excel 2003/2007+')
962 return process_excel_xml(filepath) 910 return process_excel_xml(filepath)
963 elif doctype in (ooxml.DOCTYPE_WORD_XML, ooxml.DOCTYPE_WORD_XML2003): 911 elif doctype in (ooxml.DOCTYPE_WORD_XML, ooxml.DOCTYPE_WORD_XML2003):
964 - log.debug('Process file as xml from word 2003/2007+') 912 + logger.debug('Process file as xml from word 2003/2007+')
965 return process_docx(filepath) 913 return process_docx(filepath)
966 elif doctype is None: 914 elif doctype is None:
967 - log.debug('Process file as csv') 915 + logger.debug('Process file as csv')
968 return process_csv(filepath) 916 return process_csv(filepath)
969 else: # could be docx; if not: this is the old default code path 917 else: # could be docx; if not: this is the old default code path
970 - log.debug('Process file as word 2007+ (docx)') 918 + logger.debug('Process file as word 2007+ (docx)')
971 return process_docx(filepath, field_filter_mode) 919 return process_docx(filepath, field_filter_mode)
972 920
973 921
@@ -985,27 +933,14 @@ def main(cmd_line_args=None): @@ -985,27 +933,14 @@ def main(cmd_line_args=None):
985 # Setup logging to the console: 933 # Setup logging to the console:
986 # here we use stdout instead of stderr by default, so that the output 934 # here we use stdout instead of stderr by default, so that the output
987 # can be redirected properly. 935 # can be redirected properly.
988 - logging.basicConfig(level=LOG_LEVELS[args.loglevel], stream=sys.stdout,  
989 - format='%(levelname)-8s %(message)s')  
990 - # enable logging in the modules:  
991 - log.setLevel(logging.NOTSET)  
992 -  
993 - if args.json and args.loglevel.lower() == 'debug':  
994 - log.warning('Debug log output will not be json-compatible!') 936 + log_helper.enable_logging(args.json, args.loglevel, stream=sys.stdout)
995 937
996 if args.nounquote: 938 if args.nounquote:
997 global NO_QUOTES 939 global NO_QUOTES
998 NO_QUOTES = True 940 NO_QUOTES = True
999 941
1000 - if args.json:  
1001 - jout = []  
1002 - jout.append(BANNER_JSON)  
1003 - else:  
1004 - # print banner with version  
1005 - print(BANNER)  
1006 -  
1007 - if not args.json:  
1008 - print('Opening file: %s' % args.filepath) 942 + logger.print_str(BANNER)
  943 + logger.print_str('Opening file: %s' % args.filepath)
1009 944
1010 text = '' 945 text = ''
1011 return_code = 1 946 return_code = 1
@@ -1013,22 +948,12 @@ def main(cmd_line_args=None): @@ -1013,22 +948,12 @@ def main(cmd_line_args=None):
1013 text = process_file(args.filepath, args.field_filter_mode) 948 text = process_file(args.filepath, args.field_filter_mode)
1014 return_code = 0 949 return_code = 0
1015 except Exception as exc: 950 except Exception as exc:
1016 - if args.json:  
1017 - jout.append(dict(type='error', error=type(exc).__name__,  
1018 - message=str(exc)))  
1019 - else:  
1020 - raise # re-raise last known exception, keeping trace intact  
1021 -  
1022 - if args.json:  
1023 - for line in text.splitlines():  
1024 - if line.strip():  
1025 - jout.append(dict(type='dde-link', link=line.strip()))  
1026 - json.dump(jout, sys.stdout, check_circular=False, indent=4)  
1027 - print() # add a newline after closing "]"  
1028 - return return_code # required if we catch an exception in json-mode  
1029 - else:  
1030 - print ('DDE Links:')  
1031 - print(text) 951 + logger.exception(exc.message)
  952 +
  953 + logger.print_str('DDE Links:')
  954 + logger.print_str(text)
  955 +
  956 + log_helper.end_logging()
1032 957
1033 return return_code 958 return return_code
1034 959
oletools/ooxml.py
@@ -14,7 +14,7 @@ TODO: may have to tell apart single xml types: office2003 looks much different @@ -14,7 +14,7 @@ TODO: may have to tell apart single xml types: office2003 looks much different
14 """ 14 """
15 15
16 import sys 16 import sys
17 -import logging 17 +from oletools.common.log_helper import log_helper
18 from zipfile import ZipFile, BadZipfile, is_zipfile 18 from zipfile import ZipFile, BadZipfile, is_zipfile
19 from os.path import splitext 19 from os.path import splitext
20 import io 20 import io
@@ -27,6 +27,7 @@ try: @@ -27,6 +27,7 @@ try:
27 except ImportError: 27 except ImportError:
28 import xml.etree.cElementTree as ET 28 import xml.etree.cElementTree as ET
29 29
  30 +logger = log_helper.get_or_create_silent_logger('ooxml')
30 31
31 #: subfiles that have to be part of every ooxml file 32 #: subfiles that have to be part of every ooxml file
32 FILE_CONTENT_TYPES = '[Content_Types].xml' 33 FILE_CONTENT_TYPES = '[Content_Types].xml'
@@ -142,7 +143,7 @@ def get_type(filename): @@ -142,7 +143,7 @@ def get_type(filename):
142 is_xls = False 143 is_xls = False
143 is_ppt = False 144 is_ppt = False
144 for _, elem, _ in parser.iter_xml(FILE_CONTENT_TYPES): 145 for _, elem, _ in parser.iter_xml(FILE_CONTENT_TYPES):
145 - logging.debug(u' ' + debug_str(elem)) 146 + logger.debug(u' ' + debug_str(elem))
146 try: 147 try:
147 content_type = elem.attrib['ContentType'] 148 content_type = elem.attrib['ContentType']
148 except KeyError: # ContentType not an attr 149 except KeyError: # ContentType not an attr
@@ -160,7 +161,7 @@ def get_type(filename): @@ -160,7 +161,7 @@ def get_type(filename):
160 if not is_doc and not is_xls and not is_ppt: 161 if not is_doc and not is_xls and not is_ppt:
161 return DOCTYPE_NONE 162 return DOCTYPE_NONE
162 else: 163 else:
163 - logging.warning('Encountered contradictory content types') 164 + logger.warning('Encountered contradictory content types')
164 return DOCTYPE_MIXED 165 return DOCTYPE_MIXED
165 166
166 167
@@ -220,7 +221,7 @@ class ZipSubFile(object): @@ -220,7 +221,7 @@ class ZipSubFile(object):
220 self.name = filename 221 self.name = filename
221 if size is None: 222 if size is None:
222 self.size = container.getinfo(filename).file_size 223 self.size = container.getinfo(filename).file_size
223 - logging.debug('zip stream has size {0}'.format(self.size)) 224 + logger.debug('zip stream has size {0}'.format(self.size))
224 else: 225 else:
225 self.size = size 226 self.size = size
226 if 'w' in mode.lower(): 227 if 'w' in mode.lower():
@@ -484,10 +485,10 @@ class XmlParser(object): @@ -484,10 +485,10 @@ class XmlParser(object):
484 want_tags = [] 485 want_tags = []
485 elif isstr(tags): 486 elif isstr(tags):
486 want_tags = [tags, ] 487 want_tags = [tags, ]
487 - logging.debug('looking for tags: {0}'.format(tags)) 488 + logger.debug('looking for tags: {0}'.format(tags))
488 else: 489 else:
489 want_tags = tags 490 want_tags = tags
490 - logging.debug('looking for tags: {0}'.format(tags)) 491 + logger.debug('looking for tags: {0}'.format(tags))
491 492
492 for subfile, handle in self.iter_files(subfiles): 493 for subfile, handle in self.iter_files(subfiles):
493 events = ('start', 'end') 494 events = ('start', 'end')
@@ -499,7 +500,7 @@ class XmlParser(object): @@ -499,7 +500,7 @@ class XmlParser(object):
499 continue 500 continue
500 if event == 'start': 501 if event == 'start':
501 if elem.tag in want_tags: 502 if elem.tag in want_tags:
502 - logging.debug('remember start of tag {0} at {1}' 503 + logger.debug('remember start of tag {0} at {1}'
503 .format(elem.tag, depth)) 504 .format(elem.tag, depth))
504 inside_tags.append((elem.tag, depth)) 505 inside_tags.append((elem.tag, depth))
505 depth += 1 506 depth += 1
@@ -515,18 +516,18 @@ class XmlParser(object): @@ -515,18 +516,18 @@ class XmlParser(object):
515 if inside_tags[-1] == curr_tag: 516 if inside_tags[-1] == curr_tag:
516 inside_tags.pop() 517 inside_tags.pop()
517 else: 518 else:
518 - logging.error('found end for wanted tag {0} ' 519 + logger.error('found end for wanted tag {0} '
519 'but last start tag {1} does not' 520 'but last start tag {1} does not'
520 ' match'.format(curr_tag, 521 ' match'.format(curr_tag,
521 inside_tags[-1])) 522 inside_tags[-1]))
522 # try to recover: close all deeper tags 523 # try to recover: close all deeper tags
523 while inside_tags and \ 524 while inside_tags and \
524 inside_tags[-1][1] >= depth: 525 inside_tags[-1][1] >= depth:
525 - logging.debug('recover: pop {0}' 526 + logger.debug('recover: pop {0}'
526 .format(inside_tags[-1])) 527 .format(inside_tags[-1]))
527 inside_tags.pop() 528 inside_tags.pop()
528 except IndexError: # no inside_tag[-1] 529 except IndexError: # no inside_tag[-1]
529 - logging.error('found end of {0} at depth {1} but ' 530 + logger.error('found end of {0} at depth {1} but '
530 'no start event') 531 'no start event')
531 # yield element 532 # yield element
532 if is_wanted or not want_tags: 533 if is_wanted or not want_tags:
@@ -543,12 +544,12 @@ class XmlParser(object): @@ -543,12 +544,12 @@ class XmlParser(object):
543 if subfile is None: # this is no zip subfile but single xml 544 if subfile is None: # this is no zip subfile but single xml
544 raise BadOOXML(self.filename, 'is neither zip nor xml') 545 raise BadOOXML(self.filename, 'is neither zip nor xml')
545 elif subfile.endswith('.xml'): 546 elif subfile.endswith('.xml'):
546 - logger = logging.warning 547 + log = logger.warning
547 else: 548 else:
548 - logger = logging.debug  
549 - logger(' xml-parsing for {0} failed ({1}). '  
550 - .format(subfile, err) +  
551 - 'Run iter_non_xml to investigate.') 549 + log = logger.debug
  550 + log(' xml-parsing for {0} failed ({1}). '
  551 + .format(subfile, err) +
  552 + 'Run iter_non_xml to investigate.')
552 assert(depth == 0) 553 assert(depth == 0)
553 554
554 def get_content_types(self): 555 def get_content_types(self):
@@ -571,14 +572,14 @@ class XmlParser(object): @@ -571,14 +572,14 @@ class XmlParser(object):
571 if extension.startswith('.'): 572 if extension.startswith('.'):
572 extension = extension[1:] 573 extension = extension[1:]
573 defaults.append((extension, elem.attrib['ContentType'])) 574 defaults.append((extension, elem.attrib['ContentType']))
574 - logging.debug('found content type for extension {0[0]}: {0[1]}' 575 + logger.debug('found content type for extension {0[0]}: {0[1]}'
575 .format(defaults[-1])) 576 .format(defaults[-1]))
576 elif elem.tag.endswith('Override'): 577 elif elem.tag.endswith('Override'):
577 subfile = elem.attrib['PartName'] 578 subfile = elem.attrib['PartName']
578 if subfile.startswith('/'): 579 if subfile.startswith('/'):
579 subfile = subfile[1:] 580 subfile = subfile[1:]
580 files.append((subfile, elem.attrib['ContentType'])) 581 files.append((subfile, elem.attrib['ContentType']))
581 - logging.debug('found content type for subfile {0[0]}: {0[1]}' 582 + logger.debug('found content type for subfile {0[0]}: {0[1]}'
582 .format(files[-1])) 583 .format(files[-1]))
583 return dict(files), dict(defaults) 584 return dict(files), dict(defaults)
584 585
@@ -595,7 +596,7 @@ class XmlParser(object): @@ -595,7 +596,7 @@ class XmlParser(object):
595 To handle binary parts of an xlsb file, use xls_parser.parse_xlsb_part 596 To handle binary parts of an xlsb file, use xls_parser.parse_xlsb_part
596 """ 597 """
597 if not self.did_iter_all: 598 if not self.did_iter_all:
598 - logging.warning('Did not iterate through complete file. ' 599 + logger.warning('Did not iterate through complete file. '
599 'Should run iter_xml() without args, first.') 600 'Should run iter_xml() without args, first.')
600 if not self.subfiles_no_xml: 601 if not self.subfiles_no_xml:
601 return 602 return
@@ -628,7 +629,7 @@ def test(): @@ -628,7 +629,7 @@ def test():
628 629
629 see module doc for more info 630 see module doc for more info
630 """ 631 """
631 - logging.basicConfig(level=logging.DEBUG) 632 + log_helper.enable_logging(False, logger.DEBUG)
632 if len(sys.argv) != 2: 633 if len(sys.argv) != 2:
633 print(u'To test this code, give me a single file as arg') 634 print(u'To test this code, give me a single file as arg')
634 return 2 635 return 2
@@ -647,6 +648,9 @@ def test(): @@ -647,6 +648,9 @@ def test():
647 if index > 100: 648 if index > 100:
648 print(u'...') 649 print(u'...')
649 break 650 break
  651 +
  652 + log_helper.end_logging()
  653 +
650 return 0 654 return 0
651 655
652 656
tests/json/__init__.py renamed to tests/common/__init__.py
tests/common/log_helper/__init__.py 0 → 100644
tests/common/log_helper/log_helper_test_imported.py 0 → 100644
  1 +"""
  2 +Dummy file that logs messages, meant to be imported
  3 +by the main test file
  4 +"""
  5 +
  6 +from oletools.common.log_helper import log_helper
  7 +import logging
  8 +
  9 +DEBUG_MESSAGE = 'imported: debug log'
  10 +INFO_MESSAGE = 'imported: info log'
  11 +WARNING_MESSAGE = 'imported: warning log'
  12 +ERROR_MESSAGE = 'imported: error log'
  13 +CRITICAL_MESSAGE = 'imported: critical log'
  14 +
  15 +logger = log_helper.get_or_create_silent_logger('test_imported', logging.ERROR)
  16 +
  17 +
  18 +def log():
  19 + logger.debug(DEBUG_MESSAGE)
  20 + logger.info(INFO_MESSAGE)
  21 + logger.warning(WARNING_MESSAGE)
  22 + logger.error(ERROR_MESSAGE)
  23 + logger.critical(CRITICAL_MESSAGE)
tests/common/log_helper/log_helper_test_main.py 0 → 100644
  1 +""" Test log_helpers """
  2 +
  3 +import sys
  4 +from tests.common.log_helper import log_helper_test_imported
  5 +from oletools.common.log_helper import log_helper
  6 +
  7 +DEBUG_MESSAGE = 'main: debug log'
  8 +INFO_MESSAGE = 'main: info log'
  9 +WARNING_MESSAGE = 'main: warning log'
  10 +ERROR_MESSAGE = 'main: error log'
  11 +CRITICAL_MESSAGE = 'main: critical log'
  12 +
  13 +logger = log_helper.get_or_create_silent_logger('test_main')
  14 +
  15 +
  16 +def init_logging_and_log(args):
  17 + """
  18 + Try to cover possible logging scenarios. For each scenario covered, here's the expected args and outcome:
  19 + - Log without enabling: ['<level>']
  20 + * logging when being imported - should never print
  21 + - Log as JSON without enabling: ['as-json', '<level>']
  22 + * logging as JSON when being imported - should never print
  23 + - Enable and log: ['enable', '<level>']
  24 + * logging when being run as script - should log messages
  25 + - Enable and log as JSON: ['as-json', 'enable', '<level>']
  26 + * logging as JSON when being run as script - should log messages as JSON
  27 + - Enable, log as JSON and throw: ['enable', 'as-json', 'throw', '<level>']
  28 + * should produce JSON-compatible output, even after an unhandled exception
  29 + """
  30 +
  31 + # the level should always be the last argument passed
  32 + level = args[-1]
  33 + use_json = 'as-json' in args
  34 + throw = 'throw' in args
  35 +
  36 + if 'enable' in args:
  37 + log_helper.enable_logging(use_json, level, stream=sys.stdout)
  38 +
  39 + _log()
  40 +
  41 + if throw:
  42 + raise Exception('An exception occurred before ending the logging')
  43 +
  44 + log_helper.end_logging()
  45 +
  46 +
  47 +def _log():
  48 + logger.debug(DEBUG_MESSAGE)
  49 + logger.info(INFO_MESSAGE)
  50 + logger.warning(WARNING_MESSAGE)
  51 + logger.error(ERROR_MESSAGE)
  52 + logger.critical(CRITICAL_MESSAGE)
  53 + log_helper_test_imported.log()
  54 +
  55 +
  56 +if __name__ == '__main__':
  57 + init_logging_and_log(sys.argv[1:])
tests/common/log_helper/test_log_helper.py 0 → 100644
  1 +""" Test the log helper
  2 +
  3 +This tests the generic log helper.
  4 +Check if it handles imported modules correctly
  5 +and that the default silent logger won't log when nothing is enabled
  6 +"""
  7 +
  8 +import unittest
  9 +import sys
  10 +import json
  11 +import subprocess
  12 +from tests.common.log_helper import log_helper_test_main
  13 +from tests.common.log_helper import log_helper_test_imported
  14 +from os.path import dirname, join, relpath, abspath
  15 +
  16 +# this is the common base of "tests" and "oletools" dirs
  17 +ROOT_DIRECTORY = abspath(join(__file__, '..', '..', '..', '..'))
  18 +TEST_FILE = relpath(join(dirname(__file__), 'log_helper_test_main.py'), ROOT_DIRECTORY)
  19 +PYTHON_EXECUTABLE = sys.executable
  20 +
  21 +MAIN_LOG_MESSAGES = [
  22 + log_helper_test_main.DEBUG_MESSAGE,
  23 + log_helper_test_main.INFO_MESSAGE,
  24 + log_helper_test_main.WARNING_MESSAGE,
  25 + log_helper_test_main.ERROR_MESSAGE,
  26 + log_helper_test_main.CRITICAL_MESSAGE
  27 +]
  28 +
  29 +
  30 +class TestLogHelper(unittest.TestCase):
  31 + def test_it_doesnt_log_when_not_enabled(self):
  32 + output = self._run_test(['debug'])
  33 + self.assertTrue(len(output) == 0)
  34 +
  35 + def test_it_doesnt_log_json_when_not_enabled(self):
  36 + output = self._run_test(['as-json', 'debug'])
  37 + self.assertTrue(len(output) == 0)
  38 +
  39 + def test_logs_when_enabled(self):
  40 + output = self._run_test(['enable', 'warning'])
  41 +
  42 + expected_messages = [
  43 + log_helper_test_main.WARNING_MESSAGE,
  44 + log_helper_test_main.ERROR_MESSAGE,
  45 + log_helper_test_main.CRITICAL_MESSAGE,
  46 + log_helper_test_imported.WARNING_MESSAGE,
  47 + log_helper_test_imported.ERROR_MESSAGE,
  48 + log_helper_test_imported.CRITICAL_MESSAGE
  49 + ]
  50 +
  51 + for msg in expected_messages:
  52 + self.assertIn(msg, output)
  53 +
  54 + def test_logs_json_when_enabled(self):
  55 + output = self._run_test(['enable', 'as-json', 'critical'])
  56 +
  57 + self._assert_json_messages(output, [
  58 + log_helper_test_main.CRITICAL_MESSAGE,
  59 + log_helper_test_imported.CRITICAL_MESSAGE
  60 + ])
  61 +
  62 + def test_json_correct_on_exceptions(self):
  63 + """
  64 + Test that even on unhandled exceptions our JSON is always correct
  65 + """
  66 + output = self._run_test(['enable', 'as-json', 'throw', 'critical'], False)
  67 + self._assert_json_messages(output, [
  68 + log_helper_test_main.CRITICAL_MESSAGE,
  69 + log_helper_test_imported.CRITICAL_MESSAGE
  70 + ])
  71 +
  72 + def _assert_json_messages(self, output, messages):
  73 + try:
  74 + json_data = json.loads(output)
  75 + self.assertEquals(len(json_data), len(messages))
  76 +
  77 + for i in range(len(messages)):
  78 + self.assertEquals(messages[i], json_data[i]['msg'])
  79 + except ValueError:
  80 + self.fail('Invalid json:\n' + output)
  81 +
  82 + self.assertNotEqual(len(json_data), 0, msg='Output was empty')
  83 +
  84 + def _run_test(self, args, should_succeed=True):
  85 + """
  86 + Use subprocess to better simulate the real scenario and avoid
  87 + logging conflicts when running multiple tests (since logging depends on singletons,
  88 + we might get errors or false positives between sequential tests runs)
  89 + """
  90 + child = subprocess.Popen(
  91 + [PYTHON_EXECUTABLE, TEST_FILE] + args,
  92 + shell=False,
  93 + env={'PYTHONPATH': ROOT_DIRECTORY},
  94 + universal_newlines=True,
  95 + cwd=ROOT_DIRECTORY,
  96 + stdin=None,
  97 + stdout=subprocess.PIPE,
  98 + stderr=subprocess.PIPE
  99 + )
  100 + (output, output_err) = child.communicate()
  101 +
  102 + if not isinstance(output, str):
  103 + output = output.decode('utf-8')
  104 +
  105 + self.assertEquals(child.returncode == 0, should_succeed)
  106 +
  107 + return output.strip()
  108 +
  109 +
  110 +# just in case somebody calls this file as a script
  111 +if __name__ == '__main__':
  112 + unittest.main()
tests/json/test_output.py deleted
1 -""" Test validity of json output  
2 -  
3 -Some scripts have a json output flag. Verify that at default log levels output  
4 -can be captured as-is and parsed by a json parser -- checking the return code  
5 -if desired.  
6 -"""  
7 -  
8 -import unittest  
9 -import sys  
10 -import json  
11 -import os  
12 -from os.path import join  
13 -from oletools import msodde  
14 -from tests.test_utils import OutputCapture, DATA_BASE_DIR  
15 -  
16 -if sys.version_info[0] <= 2:  
17 - from oletools import olevba  
18 -else:  
19 - from oletools import olevba3 as olevba  
20 -  
21 -  
22 -class TestValidJson(unittest.TestCase):  
23 - """  
24 - Ensure that script output is valid json.  
25 - If check_return_code is True we also ignore the output  
26 - of runs that didn't succeed.  
27 - """  
28 -  
29 - @staticmethod  
30 - def iter_test_files():  
31 - """ Iterate over all test files in DATA_BASE_DIR """  
32 - for dirpath, _, filenames in os.walk(DATA_BASE_DIR):  
33 - for filename in filenames:  
34 - yield join(dirpath, filename)  
35 -  
36 - def run_and_parse(self, program, args, print_output=False, check_return_code=True):  
37 - """ run single program with single file and parse output """  
38 - with OutputCapture() as capturer: # capture stdout  
39 - try:  
40 - return_code = program(args)  
41 - except Exception:  
42 - return_code = 1 # would result in non-zero exit  
43 - except SystemExit as se:  
44 - return_code = se.code or 0 # se.code can be None  
45 - if check_return_code and return_code is not 0:  
46 - if print_output:  
47 - print('Command failed ({0}) -- not parsing output'  
48 - .format(return_code))  
49 - return [] # no need to test  
50 -  
51 - self.assertNotEqual(return_code, None,  
52 - msg='self-test fail: return_code not set')  
53 -  
54 - # now test output  
55 - if print_output:  
56 - print(capturer.get_data())  
57 - try:  
58 - json_data = json.loads(capturer.get_data())  
59 - except ValueError:  
60 - self.fail('Invalid json:\n' + capturer.get_data())  
61 - self.assertNotEqual(len(json_data), 0, msg='Output was empty')  
62 - return json_data  
63 -  
64 - def run_all_files(self, program, args_without_filename, print_output=False):  
65 - """ run test for a single program over all test files """  
66 - n_files = 0  
67 - for testfile in self.iter_test_files(): # loop over all input  
68 - args = args_without_filename + [testfile, ]  
69 - self.run_and_parse(program, args, print_output)  
70 - n_files += 1  
71 - self.assertNotEqual(n_files, 0,  
72 - msg='self-test fail: No test files found')  
73 -  
74 - def test_msodde(self):  
75 - """ Test msodde.py """  
76 - self.run_all_files(msodde.main, ['-j', ])  
77 -  
78 - def test_olevba(self):  
79 - """ Test olevba.py with default args """  
80 - self.run_all_files(olevba.main, ['-j', ])  
81 -  
82 - def test_olevba_analysis(self):  
83 - """ Test olevba.py with -a """  
84 - self.run_all_files(olevba.main, ['-j', '-a', ])  
85 -  
86 - def test_olevba_recurse(self):  
87 - """ Test olevba.py with -r """  
88 - json_data = self.run_and_parse(olevba.main,  
89 - ['-j', '-r', join(DATA_BASE_DIR, '*')],  
90 - check_return_code=False)  
91 - self.assertNotEqual(len(json_data), 0,  
92 - msg='olevba[3] returned non-zero or no output')  
93 - self.assertNotEqual(json_data[-1]['n_processed'], 0,  
94 - msg='self-test fail: No test files found!')  
95 -  
96 -  
97 -# just in case somebody calls this file as a script  
98 -if __name__ == '__main__':  
99 - unittest.main()  
tests/msodde/test_basic.py
@@ -10,15 +10,13 @@ from __future__ import print_function @@ -10,15 +10,13 @@ from __future__ import print_function
10 10
11 import unittest 11 import unittest
12 from oletools import msodde 12 from oletools import msodde
13 -from tests.test_utils import OutputCapture, DATA_BASE_DIR as BASE_DIR  
14 -import shlex 13 +from tests.test_utils import DATA_BASE_DIR as BASE_DIR
15 from os.path import join 14 from os.path import join
16 from traceback import print_exc 15 from traceback import print_exc
17 16
18 17
19 class TestReturnCode(unittest.TestCase): 18 class TestReturnCode(unittest.TestCase):
20 """ check return codes and exception behaviour (not text output) """ 19 """ check return codes and exception behaviour (not text output) """
21 -  
22 def test_valid_doc(self): 20 def test_valid_doc(self):
23 """ check that a valid doc file leads to 0 exit status """ 21 """ check that a valid doc file leads to 0 exit status """
24 for filename in ( 22 for filename in (
@@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase): @@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase):
59 57
60 def do_test_validity(self, args, expect_error=False): 58 def do_test_validity(self, args, expect_error=False):
61 """ helper for test_valid_doc[x] """ 59 """ helper for test_valid_doc[x] """
62 - args = shlex.split(args)  
63 - return_code = -1  
64 have_exception = False 60 have_exception = False
65 try: 61 try:
66 - return_code = msodde.main(args) 62 + msodde.process_file(args, msodde.FIELD_FILTER_BLACKLIST)
67 except Exception: 63 except Exception:
68 have_exception = True 64 have_exception = True
69 print_exc() 65 print_exc()
70 except SystemExit as exc: # sys.exit() was called 66 except SystemExit as exc: # sys.exit() was called
71 - return_code = exc.code 67 + have_exception = True
72 if exc.code is None: 68 if exc.code is None:
73 - return_code = 0 69 + have_exception = False
74 70
75 - self.assertEqual(expect_error, have_exception or (return_code != 0),  
76 - msg='Args={0}, expect={1}, exc={2}, return={3}'  
77 - .format(args, expect_error, have_exception,  
78 - return_code)) 71 + self.assertEqual(expect_error, have_exception,
  72 + msg='Args={0}, expect={1}, exc={2}'
  73 + .format(args, expect_error, have_exception))
79 74
80 75
81 class TestDdeLinks(unittest.TestCase): 76 class TestDdeLinks(unittest.TestCase):
82 """ capture output of msodde and check dde-links are found correctly """ 77 """ capture output of msodde and check dde-links are found correctly """
83 78
84 - def get_dde_from_output(self, capturer): 79 + @staticmethod
  80 + def get_dde_from_output(output):
85 """ helper to read dde links from captured output 81 """ helper to read dde links from captured output
86 -  
87 - duplicate in tests/msodde/test_csv  
88 """ 82 """
89 - have_start_line = False  
90 - result = []  
91 - for line in capturer:  
92 - if not line.strip():  
93 - continue # skip empty lines  
94 - if have_start_line:  
95 - result.append(line)  
96 - elif line == 'DDE Links:':  
97 - have_start_line = True  
98 -  
99 - self.assertTrue(have_start_line) # ensure output was complete  
100 - return result 83 + return [o for o in output.splitlines()]
101 84
102 def test_with_dde(self): 85 def test_with_dde(self):
103 """ check that dde links appear on stdout """ 86 """ check that dde links appear on stdout """
104 filename = 'dde-test-from-office2003.doc' 87 filename = 'dde-test-from-office2003.doc'
105 - with OutputCapture() as capturer:  
106 - msodde.main([join(BASE_DIR, 'msodde', filename)])  
107 - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, 88 + output = msodde.process_file(
  89 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  90 + self.assertNotEqual(len(self.get_dde_from_output(output)), 0,
108 msg='Found no dde links in output of ' + filename) 91 msg='Found no dde links in output of ' + filename)
109 92
110 def test_no_dde(self): 93 def test_no_dde(self):
111 """ check that no dde links appear on stdout """ 94 """ check that no dde links appear on stdout """
112 filename = 'harmless-clean.doc' 95 filename = 'harmless-clean.doc'
113 - with OutputCapture() as capturer:  
114 - msodde.main([join(BASE_DIR, 'msodde', filename)])  
115 - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, 96 + output = msodde.process_file(
  97 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  98 + self.assertEqual(len(self.get_dde_from_output(output)), 0,
116 msg='Found dde links in output of ' + filename) 99 msg='Found dde links in output of ' + filename)
117 100
118 def test_with_dde_utf16le(self): 101 def test_with_dde_utf16le(self):
119 """ check that dde links appear on stdout """ 102 """ check that dde links appear on stdout """
120 filename = 'dde-test-from-office2013-utf_16le-korean.doc' 103 filename = 'dde-test-from-office2013-utf_16le-korean.doc'
121 - with OutputCapture() as capturer:  
122 - msodde.main([join(BASE_DIR, 'msodde', filename)])  
123 - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, 104 + output = msodde.process_file(
  105 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  106 + self.assertNotEqual(len(self.get_dde_from_output(output)), 0,
124 msg='Found no dde links in output of ' + filename) 107 msg='Found no dde links in output of ' + filename)
125 108
126 def test_excel(self): 109 def test_excel(self):
127 """ check that dde links are found in excel 2007+ files """ 110 """ check that dde links are found in excel 2007+ files """
128 expect = ['DDE-Link cmd /c calc.exe', ] 111 expect = ['DDE-Link cmd /c calc.exe', ]
129 for extn in 'xlsx', 'xlsm', 'xlsb': 112 for extn in 'xlsx', 'xlsm', 'xlsb':
130 - with OutputCapture() as capturer:  
131 - msodde.main([join(BASE_DIR, 'msodde', 'dde-test.' + extn), ])  
132 - self.assertEqual(expect, self.get_dde_from_output(capturer), 113 + output = msodde.process_file(
  114 + join(BASE_DIR, 'msodde', 'dde-test.' + extn), msodde.FIELD_FILTER_BLACKLIST)
  115 +
  116 + self.assertEqual(expect, self.get_dde_from_output(output),
133 msg='unexpected output for dde-test.{0}: {1}' 117 msg='unexpected output for dde-test.{0}: {1}'
134 - .format(extn, capturer.get_data())) 118 + .format(extn, output))
135 119
136 def test_xml(self): 120 def test_xml(self):
137 """ check that dde in xml from word / excel is found """ 121 """ check that dde in xml from word / excel is found """
138 for name_part in 'excel2003', 'word2003', 'word2007': 122 for name_part in 'excel2003', 'word2003', 'word2007':
139 filename = 'dde-in-' + name_part + '.xml' 123 filename = 'dde-in-' + name_part + '.xml'
140 - with OutputCapture() as capturer:  
141 - msodde.main([join(BASE_DIR, 'msodde', filename), ])  
142 - links = self.get_dde_from_output(capturer) 124 + output = msodde.process_file(
  125 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  126 + links = self.get_dde_from_output(output)
143 self.assertEqual(len(links), 1, 'found {0} dde-links in {1}' 127 self.assertEqual(len(links), 1, 'found {0} dde-links in {1}'
144 .format(len(links), filename)) 128 .format(len(links), filename))
145 self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}' 129 self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}'
@@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase): @@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase):
150 def test_clean_rtf_blacklist(self): 134 def test_clean_rtf_blacklist(self):
151 """ find a lot of hyperlinks in rtf spec """ 135 """ find a lot of hyperlinks in rtf spec """
152 filename = 'RTF-Spec-1.7.rtf' 136 filename = 'RTF-Spec-1.7.rtf'
153 - with OutputCapture() as capturer:  
154 - msodde.main([join(BASE_DIR, 'msodde', filename)])  
155 - self.assertEqual(len(self.get_dde_from_output(capturer)), 1413) 137 + output = msodde.process_file(
  138 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST)
  139 + self.assertEqual(len(self.get_dde_from_output(output)), 1413)
156 140
157 def test_clean_rtf_ddeonly(self): 141 def test_clean_rtf_ddeonly(self):
158 """ find no dde links in rtf spec """ 142 """ find no dde links in rtf spec """
159 filename = 'RTF-Spec-1.7.rtf' 143 filename = 'RTF-Spec-1.7.rtf'
160 - with OutputCapture() as capturer:  
161 - msodde.main(['-d', join(BASE_DIR, 'msodde', filename)])  
162 - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, 144 + output = msodde.process_file(
  145 + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_DDE)
  146 + self.assertEqual(len(self.get_dde_from_output(output)), 0,
163 msg='Found dde links in output of ' + filename) 147 msg='Found dde links in output of ' + filename)
164 148
165 149
tests/msodde/test_csv.py
@@ -9,7 +9,7 @@ import os @@ -9,7 +9,7 @@ import os
9 from os.path import join 9 from os.path import join
10 10
11 from oletools import msodde 11 from oletools import msodde
12 -from tests.test_utils import OutputCapture, DATA_BASE_DIR 12 +from tests.test_utils import DATA_BASE_DIR
13 13
14 14
15 class TestCSV(unittest.TestCase): 15 class TestCSV(unittest.TestCase):
@@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase): @@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase):
69 def test_file(self): 69 def test_file(self):
70 """ test simple small example file """ 70 """ test simple small example file """
71 filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv') 71 filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv')
72 - with OutputCapture() as capturer:  
73 - capturer.reload_module(msodde) # re-create logger  
74 - ret_code = msodde.main([filename, ])  
75 - self.assertEqual(ret_code, 0)  
76 - links = self.get_dde_from_output(capturer) 72 + output = msodde.process_file(filename, msodde.FIELD_FILTER_BLACKLIST)
  73 + links = self.get_dde_from_output(output)
77 self.assertEqual(len(links), 1) 74 self.assertEqual(len(links), 1)
78 self.assertEqual(links[0], 75 self.assertEqual(links[0],
79 r"cmd '/k \..\..\..\Windows\System32\calc.exe'") 76 r"cmd '/k \..\..\..\Windows\System32\calc.exe'")
@@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase): @@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase):
91 if self.DO_DEBUG: 88 if self.DO_DEBUG:
92 args += ['-l', 'debug'] 89 args += ['-l', 'debug']
93 90
94 - with OutputCapture() as capturer:  
95 - capturer.reload_module(msodde) # re-create logger  
96 - ret_code = msodde.main(args)  
97 - self.assertEqual(ret_code, 0, 'checking sample resulted in '  
98 - 'error:\n' + sample_text)  
99 - return capturer 91 + processed_args = msodde.process_args(args)
  92 +
  93 + return msodde.process_file(
  94 + processed_args.filepath, processed_args.field_filter_mode)
100 95
101 except Exception: 96 except Exception:
102 raise 97 raise
@@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase): @@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase):
111 os.remove(filename) 106 os.remove(filename)
112 filename = None # just in case 107 filename = None # just in case
113 108
114 - def get_dde_from_output(self, capturer): 109 + @staticmethod
  110 + def get_dde_from_output(output):
115 """ helper to read dde links from captured output 111 """ helper to read dde links from captured output
116 -  
117 - duplicate in tests/msodde/test_basic  
118 """ 112 """
119 - have_start_line = False  
120 - result = []  
121 - for line in capturer:  
122 - if self.DO_DEBUG:  
123 - print('captured: ' + line)  
124 - if not line.strip():  
125 - continue # skip empty lines  
126 - if have_start_line:  
127 - result.append(line)  
128 - elif line == 'DDE Links:':  
129 - have_start_line = True  
130 -  
131 - self.assertTrue(have_start_line) # ensure output was complete  
132 - return result 113 + return [o for o in output.splitlines()]
133 114
134 def test_regex(self): 115 def test_regex(self):
135 """ check that regex captures other ways to include dde commands 116 """ check that regex captures other ways to include dde commands
tests/test_utils/__init__.py
1 -from .output_capture import OutputCapture  
2 -  
3 from os.path import dirname, join 1 from os.path import dirname, join
4 2
5 # Directory with test data, independent of current working directory 3 # Directory with test data, independent of current working directory
tests/test_utils/output_capture.py deleted
1 -""" class OutputCapture to test what scripts print to stdout """  
2 -  
3 -from __future__ import print_function  
4 -import sys  
5 -import logging  
6 -  
7 -  
8 -# python 2/3 version conflict:  
9 -if sys.version_info.major <= 2:  
10 - from StringIO import StringIO  
11 - # reload is a builtin  
12 -else:  
13 - from io import StringIO  
14 - if sys.version_info.minor < 4:  
15 - from imp import reload  
16 - else:  
17 - from importlib import reload  
18 -  
19 -  
20 -class OutputCapture:  
21 - """ context manager that captures stdout  
22 -  
23 - use as follows::  
24 -  
25 - with OutputCapture() as capturer:  
26 - run_my_script(some_args)  
27 -  
28 - # either test line-by-line ...  
29 - for line in capturer:  
30 - some_test(line)  
31 - # ...or test all output in one go  
32 - some_test(capturer.get_data())  
33 -  
34 - In order to solve issues with old logger instances still remembering closed  
35 - StringIO instances as "their" stdout, logging is shutdown and restarted  
36 - upon entering this Context Manager. This means that you may have to reload  
37 - your module, as well.  
38 - """  
39 -  
40 - def __init__(self):  
41 - self.buffer = StringIO()  
42 - self.orig_stdout = None  
43 - self.data = None  
44 -  
45 - def __enter__(self):  
46 - # Avoid problems with old logger instances that still remember an old  
47 - # closed StringIO as their sys.stdout  
48 - logging.shutdown()  
49 - reload(logging)  
50 -  
51 - # replace sys.stdout with own buffer.  
52 - self.orig_stdout = sys.stdout  
53 - sys.stdout = self.buffer  
54 - return self  
55 -  
56 - def __exit__(self, exc_type, exc_value, traceback):  
57 - sys.stdout = self.orig_stdout # re-set to original  
58 - self.data = self.buffer.getvalue()  
59 - self.buffer.close() # close buffer  
60 - self.buffer = None  
61 -  
62 - if exc_type: # there has been an error  
63 - print('Got error during output capture!')  
64 - print('Print captured output and re-raise:')  
65 - for line in self.data.splitlines():  
66 - print(line.rstrip()) # print output before re-raising  
67 -  
68 - def get_data(self):  
69 - """ retrieve all the captured data """  
70 - if self.buffer is not None:  
71 - return self.buffer.getvalue()  
72 - elif self.data is not None:  
73 - return self.data  
74 - else: # should not be possible  
75 - raise RuntimeError('programming error or someone messed with data!')  
76 -  
77 - def __iter__(self):  
78 - for line in self.get_data().splitlines():  
79 - yield line  
80 -  
81 - def reload_module(self, mod):  
82 - """ Wrapper around reload function for different python versions """  
83 - return reload(mod)