diff --git a/oletools/common/log_helper/__init__.py b/oletools/common/log_helper/__init__.py new file mode 100644 index 0000000..7a027c2 --- /dev/null +++ b/oletools/common/log_helper/__init__.py @@ -0,0 +1,5 @@ +from . import log_helper as log_helper_ + +log_helper = log_helper_.LogHelper() + +__all__ = ['log_helper'] diff --git a/oletools/common/log_helper/_json_formatter.py b/oletools/common/log_helper/_json_formatter.py new file mode 100644 index 0000000..4c5e337 --- /dev/null +++ b/oletools/common/log_helper/_json_formatter.py @@ -0,0 +1,24 @@ +import logging +import json + + +class JsonFormatter(logging.Formatter): + """ + Format every message to be logged as a JSON object + """ + _is_first_line = True + + def format(self, record): + """ + Since we don't buffer messages, we always prepend messages with a comma to make + the output JSON-compatible. The only exception is when printing the first line, + so we need to keep track of it. + """ + json_dict = dict(msg=record.msg, level=record.levelname) + formatted_message = ' ' + json.dumps(json_dict) + + if self._is_first_line: + self._is_first_line = False + return formatted_message + + return ', ' + formatted_message diff --git a/oletools/common/log_helper/_logger_adapter.py b/oletools/common/log_helper/_logger_adapter.py new file mode 100644 index 0000000..75e331a --- /dev/null +++ b/oletools/common/log_helper/_logger_adapter.py @@ -0,0 +1,30 @@ +import logging +from . import _root_logger_wrapper + + +class OletoolsLoggerAdapter(logging.LoggerAdapter): + """ + Adapter class for all loggers returned by the logging module. + """ + _json_enabled = None + + def print_str(self, message): + """ + This function replaces normal print() calls so we can format them as JSON + when needed or just print them right away otherwise. + """ + if self._json_enabled and self._json_enabled(): + # Messages from this function should always be printed, + # so when using JSON we log using the same level that set + self.log(_root_logger_wrapper.level(), message) + else: + print(message) + + def set_json_enabled_function(self, json_enabled): + """ + Set a function to be called to check whether JSON output is enabled. + """ + self._json_enabled = json_enabled + + def level(self): + return self.logger.level diff --git a/oletools/common/log_helper/_root_logger_wrapper.py b/oletools/common/log_helper/_root_logger_wrapper.py new file mode 100644 index 0000000..273d5c6 --- /dev/null +++ b/oletools/common/log_helper/_root_logger_wrapper.py @@ -0,0 +1,24 @@ +import logging + + +def is_logging_initialized(): + """ + We use the same strategy as the logging module when checking if + the logging was initialized - look for handlers in the root logger + """ + return len(logging.root.handlers) > 0 + + +def set_formatter(fmt): + """ + Set the formatter to be used by every handler of the root logger. + """ + if not is_logging_initialized(): + return + + for handler in logging.root.handlers: + handler.setFormatter(fmt) + + +def level(): + return logging.root.level diff --git a/oletools/common/log_helper/log_helper.py b/oletools/common/log_helper/log_helper.py new file mode 100644 index 0000000..7a7fb02 --- /dev/null +++ b/oletools/common/log_helper/log_helper.py @@ -0,0 +1,194 @@ +""" +log_helper.py + +General logging helpers + +.. codeauthor:: Intra2net AG +""" + +# === LICENSE ================================================================= +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +# ----------------------------------------------------------------------------- +# CHANGELOG: +# 2017-12-07 v0.01 CH: - first version +# 2018-02-05 v0.02 SA: - fixed log level selection and reformatted code +# 2018-02-06 v0.03 SA: - refactored code to deal with NullHandlers +# 2018-02-07 v0.04 SA: - fixed control of handlers propagation +# 2018-04-23 v0.05 SA: - refactored the whole logger to use an OOP approach + +# ----------------------------------------------------------------------------- +# TODO: + + +from ._json_formatter import JsonFormatter +from ._logger_adapter import OletoolsLoggerAdapter +from . import _root_logger_wrapper +import logging +import sys + + +LOG_LEVELS = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'critical': logging.CRITICAL +} + +DEFAULT_LOGGER_NAME = 'oletools' +DEFAULT_MESSAGE_FORMAT = '%(levelname)-8s %(message)s' + + +class LogHelper: + def __init__(self): + self._all_names = set() # set so we do not have duplicates + self._use_json = False + self._is_enabled = False + + def get_or_create_silent_logger(self, name=DEFAULT_LOGGER_NAME, level=logging.CRITICAL + 1): + """ + Get a logger or create one if it doesn't exist, setting a NullHandler + as the handler (to avoid printing to the console). + By default we also use a higher logging level so every message will + be ignored. + This will prevent oletools from logging unnecessarily when being imported + from external tools. + """ + return self._get_or_create_logger(name, level, logging.NullHandler()) + + def enable_logging(self, use_json, level, log_format=DEFAULT_MESSAGE_FORMAT, stream=None): + """ + This function initializes the root logger and enables logging. + We set the level of the root logger to the one passed by calling logging.basicConfig. + We also set the level of every logger we created to 0 (logging.NOTSET), meaning that + the level of the root logger will be used to tell if messages should be logged. + Additionally, since our loggers use the NullHandler, they won't log anything themselves, + but due to having propagation enabled they will pass messages to the root logger, + which in turn will log to the stream set in this function. + Since the root logger is the one doing the work, when using JSON we set its formatter + so that every message logged is JSON-compatible. + """ + if self._is_enabled: + raise ValueError('re-enabling logging. Not sure whether that is ok...') + + log_level = LOG_LEVELS[level] + logging.basicConfig(level=log_level, format=log_format, stream=stream) + self._is_enabled = True + + self._use_json = use_json + sys.excepthook = self._get_except_hook(sys.excepthook) + + # since there could be loggers already created we go through all of them + # and set their levels to 0 so they will use the root logger's level + for name in self._all_names: + logger = self.get_or_create_silent_logger(name) + self._set_logger_level(logger, logging.NOTSET) + + # add a JSON formatter to the root logger, which will be used by every logger + if self._use_json: + _root_logger_wrapper.set_formatter(JsonFormatter()) + print('[') + + def end_logging(self): + """ + Must be called at the end of the main function if the caller wants + json-compatible output + """ + if not self._is_enabled: + return + self._is_enabled = False + + # end logging + self._all_names = set() + logging.shutdown() + + # end json list + if self._use_json: + print(']') + self._use_json = False + + def _get_except_hook(self, old_hook): + """ + Global hook for exceptions so we can always end logging. + We wrap any hook currently set to avoid overwriting global hooks set by oletools. + Note that this is only called by enable_logging, which in turn is called by + the main() function in oletools' scripts. When scripts are being imported this + code won't execute and won't affect global hooks. + """ + def hook(exctype, value, traceback): + self.end_logging() + old_hook(exctype, value, traceback) + + return hook + + def _get_or_create_logger(self, name, level, handler=None): + """ + Get or create a new logger. This newly created logger will have the + handler and level that was passed, but if it already exists it's not changed. + We also wrap the logger in an adapter so we can easily extend its functionality. + """ + + # logging.getLogger creates a logger if it doesn't exist, + # so we need to check before calling it + if handler and not self._log_exists(name): + logger = logging.getLogger(name) + logger.addHandler(handler) + self._set_logger_level(logger, level) + else: + logger = logging.getLogger(name) + + # Keep track of every logger we created so we can easily change + # their levels whenever needed + self._all_names.add(name) + + adapted_logger = OletoolsLoggerAdapter(logger, None) + adapted_logger.set_json_enabled_function(lambda: self._use_json) + + return adapted_logger + + @staticmethod + def _set_logger_level(logger, level): + """ + If the logging is already initialized, we set the level of our logger + to 0, meaning that it will reuse the level of the root logger. + That means that if the root logger level changes, we will keep using + its level and not logging unnecessarily. + """ + + # if this log was wrapped, unwrap it to set the level + if isinstance(logger, OletoolsLoggerAdapter): + logger = logger.logger + + if _root_logger_wrapper.is_logging_initialized(): + logger.setLevel(logging.NOTSET) + else: + logger.setLevel(level) + + @staticmethod + def _log_exists(name): + """ + We check the log manager instead of our global _all_names variable + since the logger could have been created outside of the helper + """ + return name in logging.Logger.manager.loggerDict diff --git a/oletools/msodde.py b/oletools/msodde.py index 0b6ff4f..69eac6c 100644 --- a/oletools/msodde.py +++ b/oletools/msodde.py @@ -53,8 +53,6 @@ import argparse import os from os.path import abspath, dirname import sys -import json -import logging import re import csv @@ -63,6 +61,7 @@ import olefile from oletools import ooxml from oletools import xls_parser from oletools import rtfobj +from oletools.common.log_helper import log_helper # ----------------------------------------------------------------------------- # CHANGELOG: @@ -212,63 +211,12 @@ THIS IS WORK IN PROGRESS - Check updates regularly! Please report any issue at https://github.com/decalage2/oletools/issues """ % __version__ -BANNER_JSON = dict(type='meta', version=__version__, name='msodde', - link='http://decalage.info/python/oletools', - message='THIS IS WORK IN PROGRESS - Check updates regularly! ' - 'Please report any issue at ' - 'https://github.com/decalage2/oletools/issues') - # === LOGGING ================================================================= DEFAULT_LOG_LEVEL = "warning" # Default log level -LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL -} - - -class NullHandler(logging.Handler): - """ - Log Handler without output, to avoid printing messages if logging is not - configured by the main application. - Python 2.7 has logging.NullHandler, but this is necessary for 2.6: - see https://docs.python.org/2.6/library/logging.html#configuring-logging-for-a-library - """ - def emit(self, record): - pass - - -def get_logger(name, level=logging.CRITICAL+1): - """ - Create a suitable logger object for this module. - The goal is not to change settings of the root logger, to avoid getting - other modules' logs on the screen. - If a logger exists with same name, reuse it. (Else it would have duplicate - handlers and messages would be doubled.) - The level is set to CRITICAL+1 by default, to avoid any logging. - """ - # First, test if there is already a logger with the same name, else it - # will generate duplicate messages (due to duplicate handlers): - if name in logging.Logger.manager.loggerDict: - # NOTE: another less intrusive but more "hackish" solution would be to - # use getLogger then test if its effective level is not default. - logger = logging.getLogger(name) - # make sure level is OK: - logger.setLevel(level) - return logger - # get a new logger: - logger = logging.getLogger(name) - # only add a NullHandler for this logger, it is up to the application - # to configure its own logging: - logger.addHandler(NullHandler()) - logger.setLevel(level) - return logger # a global logger object used for debugging: -log = get_logger('msodde') +logger = log_helper.get_or_create_silent_logger('msodde') # === UNICODE IN PY2 ========================================================= @@ -312,7 +260,7 @@ def ensure_stdout_handles_unicode(): encoding = 'utf8' # logging is probably not initialized yet, but just in case - log.debug('wrapping sys.stdout with encoder using {0}'.format(encoding)) + logger.debug('wrapping sys.stdout with encoder using {0}'.format(encoding)) wrapper = codecs.getwriter(encoding) sys.stdout = wrapper(sys.stdout) @@ -396,7 +344,7 @@ def process_doc_field(data): """ check if field instructions start with DDE expects unicode input, returns unicode output (empty if not dde) """ - log.debug('processing field \'{0}\''.format(data)) + logger.debug('processing field \'{0}\''.format(data)) if data.lstrip().lower().startswith(u'dde'): return data @@ -434,7 +382,7 @@ def process_doc_stream(stream): if char == OLE_FIELD_START: if have_start and max_size_exceeded: - log.debug('big field was not a field after all') + logger.debug('big field was not a field after all') have_start = True have_sep = False max_size_exceeded = False @@ -446,7 +394,7 @@ def process_doc_stream(stream): # now we are after start char but not at end yet if char == OLE_FIELD_SEP: if have_sep: - log.debug('unexpected field: has multiple separators!') + logger.debug('unexpected field: has multiple separators!') have_sep = True elif char == OLE_FIELD_END: # have complete field now, process it @@ -464,7 +412,7 @@ def process_doc_stream(stream): if max_size_exceeded: pass elif len(field_contents) > OLE_FIELD_MAX_SIZE: - log.debug('field exceeds max size of {0}. Ignore rest' + logger.debug('field exceeds max size of {0}. Ignore rest' .format(OLE_FIELD_MAX_SIZE)) max_size_exceeded = True @@ -482,9 +430,9 @@ def process_doc_stream(stream): field_contents += u'?' if max_size_exceeded: - log.debug('big field was not a field after all') + logger.debug('big field was not a field after all') - log.debug('Checked {0} characters, found {1} fields' + logger.debug('Checked {0} characters, found {1} fields' .format(idx, len(result_parts))) return result_parts @@ -498,7 +446,7 @@ def process_doc(filepath): empty if none were found. dde-links will still begin with the dde[auto] key word (possibly after some whitespace) """ - log.debug('process_doc') + logger.debug('process_doc') ole = olefile.OleFileIO(filepath, path_encoding=None) links = [] @@ -508,7 +456,7 @@ def process_doc(filepath): # this direntry is not part of the tree --> unused or orphan direntry = ole._load_direntry(sid) is_stream = direntry.entry_type == olefile.STGTY_STREAM - log.debug('direntry {:2d} {}: {}' + logger.debug('direntry {:2d} {}: {}' .format(sid, '[orphan]' if is_orphan else direntry.name, 'is stream of size {}'.format(direntry.size) if is_stream else @@ -593,7 +541,7 @@ def process_docx(filepath, field_filter_mode=None): ddetext += unquote(elem.text) # apply field command filter - log.debug('filtering with mode "{0}"'.format(field_filter_mode)) + logger.debug('filtering with mode "{0}"'.format(field_filter_mode)) if field_filter_mode in (FIELD_FILTER_ALL, None): clean_fields = all_fields elif field_filter_mode == FIELD_FILTER_DDE: @@ -652,7 +600,7 @@ def field_is_blacklisted(contents): index = FIELD_BLACKLIST_CMDS.index(words[0].lower()) except ValueError: # first word is no blacklisted command return False - log.debug('trying to match "{0}" to blacklist command {1}' + logger.debug('trying to match "{0}" to blacklist command {1}' .format(contents, FIELD_BLACKLIST[index])) _, nargs_required, nargs_optional, sw_with_arg, sw_solo, sw_format \ = FIELD_BLACKLIST[index] @@ -664,11 +612,11 @@ def field_is_blacklisted(contents): break nargs += 1 if nargs < nargs_required: - log.debug('too few args: found {0}, but need at least {1} in "{2}"' + logger.debug('too few args: found {0}, but need at least {1} in "{2}"' .format(nargs, nargs_required, contents)) return False elif nargs > nargs_required + nargs_optional: - log.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"' + logger.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"' .format(nargs, nargs_required, nargs_optional, contents)) return False @@ -678,14 +626,14 @@ def field_is_blacklisted(contents): for word in words[1+nargs:]: if expect_arg: # this is an argument for the last switch if arg_choices and (word not in arg_choices): - log.debug('Found invalid switch argument "{0}" in "{1}"' + logger.debug('Found invalid switch argument "{0}" in "{1}"' .format(word, contents)) return False expect_arg = False arg_choices = [] # in general, do not enforce choices continue # "no further questions, your honor" elif not FIELD_SWITCH_REGEX.match(word): - log.debug('expected switch, found "{0}" in "{1}"' + logger.debug('expected switch, found "{0}" in "{1}"' .format(word, contents)) return False # we want a switch and we got a valid one @@ -707,7 +655,7 @@ def field_is_blacklisted(contents): if 'numeric' in sw_format: arg_choices = [] # too many choices to list them here else: - log.debug('unexpected switch {0} in "{1}"' + logger.debug('unexpected switch {0} in "{1}"' .format(switch, contents)) return False @@ -733,11 +681,11 @@ def process_xlsx(filepath): # binary parts, e.g. contained in .xlsb for subfile, content_type, handle in parser.iter_non_xml(): try: - logging.info('Parsing non-xml subfile {0} with content type {1}' + logger.info('Parsing non-xml subfile {0} with content type {1}' .format(subfile, content_type)) for record in xls_parser.parse_xlsb_part(handle, content_type, subfile): - logging.debug('{0}: {1}'.format(subfile, record)) + logger.debug('{0}: {1}'.format(subfile, record)) if isinstance(record, xls_parser.XlsbBeginSupBook) and \ record.link_type == \ xls_parser.XlsbBeginSupBook.LINK_TYPE_DDE: @@ -747,14 +695,14 @@ def process_xlsx(filepath): if content_type.startswith('application/vnd.ms-excel.') or \ content_type.startswith('application/vnd.ms-office.'): # pylint: disable=bad-indentation # should really be able to parse these either as xml or records - log_func = logging.warning + log_func = logger.warning elif content_type.startswith('image/') or content_type == \ 'application/vnd.openxmlformats-officedocument.' + \ 'spreadsheetml.printerSettings': # understandable that these are not record-base - log_func = logging.debug + log_func = logger.debug else: # default - log_func = logging.info + log_func = logger.info log_func('Failed to parse {0} of content type {1}' .format(subfile, content_type)) # in any case: continue with next @@ -774,15 +722,15 @@ class RtfFieldParser(rtfobj.RtfParser): def open_destination(self, destination): if destination.cword == b'fldinst': - log.debug('*** Start field data at index %Xh' % destination.start) + logger.debug('*** Start field data at index %Xh' % destination.start) def close_destination(self, destination): if destination.cword == b'fldinst': - log.debug('*** Close field data at index %Xh' % self.index) - log.debug('Field text: %r' % destination.data) + logger.debug('*** Close field data at index %Xh' % self.index) + logger.debug('Field text: %r' % destination.data) # remove extra spaces and newline chars: field_clean = destination.data.translate(None, b'\r\n').strip() - log.debug('Cleaned Field text: %r' % field_clean) + logger.debug('Cleaned Field text: %r' % field_clean) self.fields.append(field_clean) def control_symbol(self, matchobject): @@ -804,7 +752,7 @@ def process_rtf(file_handle, field_filter_mode=None): rtfparser.parse() all_fields = [field.decode('ascii') for field in rtfparser.fields] # apply field command filter - log.debug('found {1} fields, filtering with mode "{0}"' + logger.debug('found {1} fields, filtering with mode "{0}"' .format(field_filter_mode, len(all_fields))) if field_filter_mode in (FIELD_FILTER_ALL, None): clean_fields = all_fields @@ -853,7 +801,7 @@ def process_csv(filepath): if is_small and not results: # easy to mis-sniff small files. Try different delimiters - log.debug('small file, no results; try all delimiters') + logger.debug('small file, no results; try all delimiters') file_handle.seek(0) other_delim = CSV_DELIMITERS.replace(dialect.delimiter, '') for delim in other_delim: @@ -861,12 +809,12 @@ def process_csv(filepath): file_handle.seek(0) results, _ = process_csv_dialect(file_handle, delim) except csv.Error: # e.g. sniffing fails - log.debug('failed to csv-parse with delimiter {0!r}' + logger.debug('failed to csv-parse with delimiter {0!r}' .format(delim)) if is_small and not results: # try whole file as single cell, since sniffing fails in this case - log.debug('last attempt: take whole file as single unquoted cell') + logger.debug('last attempt: take whole file as single unquoted cell') file_handle.seek(0) match = CSV_DDE_FORMAT.match(file_handle.read(CSV_SMALL_THRESH)) if match: @@ -882,7 +830,7 @@ def process_csv_dialect(file_handle, delimiters): dialect = csv.Sniffer().sniff(file_handle.read(CSV_SMALL_THRESH), delimiters=delimiters) dialect.strict = False # microsoft is never strict - log.debug('sniffed csv dialect with delimiter {0!r} ' + logger.debug('sniffed csv dialect with delimiter {0!r} ' 'and quote char {1!r}' .format(dialect.delimiter, dialect.quotechar)) @@ -924,7 +872,7 @@ def process_excel_xml(filepath): break if formula is None: continue - log.debug('found cell with formula {0}'.format(formula)) + logger.debug('found cell with formula {0}'.format(formula)) match = re.match(XML_DDE_FORMAT, formula) if match: dde_links.append(u' '.join(match.groups()[:2])) @@ -934,40 +882,40 @@ def process_excel_xml(filepath): def process_file(filepath, field_filter_mode=None): """ decides which of the process_* functions to call """ if olefile.isOleFile(filepath): - log.debug('Is OLE. Checking streams to see whether this is xls') + logger.debug('Is OLE. Checking streams to see whether this is xls') if xls_parser.is_xls(filepath): - log.debug('Process file as excel 2003 (xls)') + logger.debug('Process file as excel 2003 (xls)') return process_xls(filepath) else: - log.debug('Process file as word 2003 (doc)') + logger.debug('Process file as word 2003 (doc)') return process_doc(filepath) with open(filepath, 'rb') as file_handle: if file_handle.read(4) == RTF_START: - log.debug('Process file as rtf') + logger.debug('Process file as rtf') return process_rtf(file_handle, field_filter_mode) try: doctype = ooxml.get_type(filepath) - log.debug('Detected file type: {0}'.format(doctype)) + logger.debug('Detected file type: {0}'.format(doctype)) except Exception as exc: - log.debug('Exception trying to xml-parse file: {0}'.format(exc)) + logger.debug('Exception trying to xml-parse file: {0}'.format(exc)) doctype = None if doctype == ooxml.DOCTYPE_EXCEL: - log.debug('Process file as excel 2007+ (xlsx)') + logger.debug('Process file as excel 2007+ (xlsx)') return process_xlsx(filepath) elif doctype in (ooxml.DOCTYPE_EXCEL_XML, ooxml.DOCTYPE_EXCEL_XML2003): - log.debug('Process file as xml from excel 2003/2007+') + logger.debug('Process file as xml from excel 2003/2007+') return process_excel_xml(filepath) elif doctype in (ooxml.DOCTYPE_WORD_XML, ooxml.DOCTYPE_WORD_XML2003): - log.debug('Process file as xml from word 2003/2007+') + logger.debug('Process file as xml from word 2003/2007+') return process_docx(filepath) elif doctype is None: - log.debug('Process file as csv') + logger.debug('Process file as csv') return process_csv(filepath) else: # could be docx; if not: this is the old default code path - log.debug('Process file as word 2007+ (docx)') + logger.debug('Process file as word 2007+ (docx)') return process_docx(filepath, field_filter_mode) @@ -985,27 +933,14 @@ def main(cmd_line_args=None): # Setup logging to the console: # here we use stdout instead of stderr by default, so that the output # can be redirected properly. - logging.basicConfig(level=LOG_LEVELS[args.loglevel], stream=sys.stdout, - format='%(levelname)-8s %(message)s') - # enable logging in the modules: - log.setLevel(logging.NOTSET) - - if args.json and args.loglevel.lower() == 'debug': - log.warning('Debug log output will not be json-compatible!') + log_helper.enable_logging(args.json, args.loglevel, stream=sys.stdout) if args.nounquote: global NO_QUOTES NO_QUOTES = True - if args.json: - jout = [] - jout.append(BANNER_JSON) - else: - # print banner with version - print(BANNER) - - if not args.json: - print('Opening file: %s' % args.filepath) + logger.print_str(BANNER) + logger.print_str('Opening file: %s' % args.filepath) text = '' return_code = 1 @@ -1013,22 +948,12 @@ def main(cmd_line_args=None): text = process_file(args.filepath, args.field_filter_mode) return_code = 0 except Exception as exc: - if args.json: - jout.append(dict(type='error', error=type(exc).__name__, - message=str(exc))) - else: - raise # re-raise last known exception, keeping trace intact - - if args.json: - for line in text.splitlines(): - if line.strip(): - jout.append(dict(type='dde-link', link=line.strip())) - json.dump(jout, sys.stdout, check_circular=False, indent=4) - print() # add a newline after closing "]" - return return_code # required if we catch an exception in json-mode - else: - print ('DDE Links:') - print(text) + logger.exception(exc.message) + + logger.print_str('DDE Links:') + logger.print_str(text) + + log_helper.end_logging() return return_code diff --git a/oletools/ooxml.py b/oletools/ooxml.py index 5250ae3..78ef489 100644 --- a/oletools/ooxml.py +++ b/oletools/ooxml.py @@ -14,7 +14,7 @@ TODO: may have to tell apart single xml types: office2003 looks much different """ import sys -import logging +from oletools.common.log_helper import log_helper from zipfile import ZipFile, BadZipfile, is_zipfile from os.path import splitext import io @@ -27,6 +27,7 @@ try: except ImportError: import xml.etree.cElementTree as ET +logger = log_helper.get_or_create_silent_logger('ooxml') #: subfiles that have to be part of every ooxml file FILE_CONTENT_TYPES = '[Content_Types].xml' @@ -142,7 +143,7 @@ def get_type(filename): is_xls = False is_ppt = False for _, elem, _ in parser.iter_xml(FILE_CONTENT_TYPES): - logging.debug(u' ' + debug_str(elem)) + logger.debug(u' ' + debug_str(elem)) try: content_type = elem.attrib['ContentType'] except KeyError: # ContentType not an attr @@ -160,7 +161,7 @@ def get_type(filename): if not is_doc and not is_xls and not is_ppt: return DOCTYPE_NONE else: - logging.warning('Encountered contradictory content types') + logger.warning('Encountered contradictory content types') return DOCTYPE_MIXED @@ -220,7 +221,7 @@ class ZipSubFile(object): self.name = filename if size is None: self.size = container.getinfo(filename).file_size - logging.debug('zip stream has size {0}'.format(self.size)) + logger.debug('zip stream has size {0}'.format(self.size)) else: self.size = size if 'w' in mode.lower(): @@ -484,10 +485,10 @@ class XmlParser(object): want_tags = [] elif isstr(tags): want_tags = [tags, ] - logging.debug('looking for tags: {0}'.format(tags)) + logger.debug('looking for tags: {0}'.format(tags)) else: want_tags = tags - logging.debug('looking for tags: {0}'.format(tags)) + logger.debug('looking for tags: {0}'.format(tags)) for subfile, handle in self.iter_files(subfiles): events = ('start', 'end') @@ -499,7 +500,7 @@ class XmlParser(object): continue if event == 'start': if elem.tag in want_tags: - logging.debug('remember start of tag {0} at {1}' + logger.debug('remember start of tag {0} at {1}' .format(elem.tag, depth)) inside_tags.append((elem.tag, depth)) depth += 1 @@ -515,18 +516,18 @@ class XmlParser(object): if inside_tags[-1] == curr_tag: inside_tags.pop() else: - logging.error('found end for wanted tag {0} ' + logger.error('found end for wanted tag {0} ' 'but last start tag {1} does not' ' match'.format(curr_tag, inside_tags[-1])) # try to recover: close all deeper tags while inside_tags and \ inside_tags[-1][1] >= depth: - logging.debug('recover: pop {0}' + logger.debug('recover: pop {0}' .format(inside_tags[-1])) inside_tags.pop() except IndexError: # no inside_tag[-1] - logging.error('found end of {0} at depth {1} but ' + logger.error('found end of {0} at depth {1} but ' 'no start event') # yield element if is_wanted or not want_tags: @@ -543,12 +544,12 @@ class XmlParser(object): if subfile is None: # this is no zip subfile but single xml raise BadOOXML(self.filename, 'is neither zip nor xml') elif subfile.endswith('.xml'): - logger = logging.warning + log = logger.warning else: - logger = logging.debug - logger(' xml-parsing for {0} failed ({1}). ' - .format(subfile, err) + - 'Run iter_non_xml to investigate.') + log = logger.debug + log(' xml-parsing for {0} failed ({1}). ' + .format(subfile, err) + + 'Run iter_non_xml to investigate.') assert(depth == 0) def get_content_types(self): @@ -571,14 +572,14 @@ class XmlParser(object): if extension.startswith('.'): extension = extension[1:] defaults.append((extension, elem.attrib['ContentType'])) - logging.debug('found content type for extension {0[0]}: {0[1]}' + logger.debug('found content type for extension {0[0]}: {0[1]}' .format(defaults[-1])) elif elem.tag.endswith('Override'): subfile = elem.attrib['PartName'] if subfile.startswith('/'): subfile = subfile[1:] files.append((subfile, elem.attrib['ContentType'])) - logging.debug('found content type for subfile {0[0]}: {0[1]}' + logger.debug('found content type for subfile {0[0]}: {0[1]}' .format(files[-1])) return dict(files), dict(defaults) @@ -595,7 +596,7 @@ class XmlParser(object): To handle binary parts of an xlsb file, use xls_parser.parse_xlsb_part """ if not self.did_iter_all: - logging.warning('Did not iterate through complete file. ' + logger.warning('Did not iterate through complete file. ' 'Should run iter_xml() without args, first.') if not self.subfiles_no_xml: return @@ -628,7 +629,7 @@ def test(): see module doc for more info """ - logging.basicConfig(level=logging.DEBUG) + log_helper.enable_logging(False, logger.DEBUG) if len(sys.argv) != 2: print(u'To test this code, give me a single file as arg') return 2 @@ -647,6 +648,9 @@ def test(): if index > 100: print(u'...') break + + log_helper.end_logging() + return 0 diff --git a/tests/json/__init__.py b/tests/common/__init__.py index e69de29..e69de29 100644 --- a/tests/json/__init__.py +++ b/tests/common/__init__.py diff --git a/tests/common/log_helper/__init__.py b/tests/common/log_helper/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/common/log_helper/__init__.py diff --git a/tests/common/log_helper/log_helper_test_imported.py b/tests/common/log_helper/log_helper_test_imported.py new file mode 100644 index 0000000..b3777af --- /dev/null +++ b/tests/common/log_helper/log_helper_test_imported.py @@ -0,0 +1,23 @@ +""" +Dummy file that logs messages, meant to be imported +by the main test file +""" + +from oletools.common.log_helper import log_helper +import logging + +DEBUG_MESSAGE = 'imported: debug log' +INFO_MESSAGE = 'imported: info log' +WARNING_MESSAGE = 'imported: warning log' +ERROR_MESSAGE = 'imported: error log' +CRITICAL_MESSAGE = 'imported: critical log' + +logger = log_helper.get_or_create_silent_logger('test_imported', logging.ERROR) + + +def log(): + logger.debug(DEBUG_MESSAGE) + logger.info(INFO_MESSAGE) + logger.warning(WARNING_MESSAGE) + logger.error(ERROR_MESSAGE) + logger.critical(CRITICAL_MESSAGE) diff --git a/tests/common/log_helper/log_helper_test_main.py b/tests/common/log_helper/log_helper_test_main.py new file mode 100644 index 0000000..0f6057a --- /dev/null +++ b/tests/common/log_helper/log_helper_test_main.py @@ -0,0 +1,57 @@ +""" Test log_helpers """ + +import sys +from tests.common.log_helper import log_helper_test_imported +from oletools.common.log_helper import log_helper + +DEBUG_MESSAGE = 'main: debug log' +INFO_MESSAGE = 'main: info log' +WARNING_MESSAGE = 'main: warning log' +ERROR_MESSAGE = 'main: error log' +CRITICAL_MESSAGE = 'main: critical log' + +logger = log_helper.get_or_create_silent_logger('test_main') + + +def init_logging_and_log(args): + """ + Try to cover possible logging scenarios. For each scenario covered, here's the expected args and outcome: + - Log without enabling: [''] + * logging when being imported - should never print + - Log as JSON without enabling: ['as-json', ''] + * logging as JSON when being imported - should never print + - Enable and log: ['enable', ''] + * logging when being run as script - should log messages + - Enable and log as JSON: ['as-json', 'enable', ''] + * logging as JSON when being run as script - should log messages as JSON + - Enable, log as JSON and throw: ['enable', 'as-json', 'throw', ''] + * should produce JSON-compatible output, even after an unhandled exception + """ + + # the level should always be the last argument passed + level = args[-1] + use_json = 'as-json' in args + throw = 'throw' in args + + if 'enable' in args: + log_helper.enable_logging(use_json, level, stream=sys.stdout) + + _log() + + if throw: + raise Exception('An exception occurred before ending the logging') + + log_helper.end_logging() + + +def _log(): + logger.debug(DEBUG_MESSAGE) + logger.info(INFO_MESSAGE) + logger.warning(WARNING_MESSAGE) + logger.error(ERROR_MESSAGE) + logger.critical(CRITICAL_MESSAGE) + log_helper_test_imported.log() + + +if __name__ == '__main__': + init_logging_and_log(sys.argv[1:]) diff --git a/tests/common/log_helper/test_log_helper.py b/tests/common/log_helper/test_log_helper.py new file mode 100644 index 0000000..03dee68 --- /dev/null +++ b/tests/common/log_helper/test_log_helper.py @@ -0,0 +1,112 @@ +""" Test the log helper + +This tests the generic log helper. +Check if it handles imported modules correctly +and that the default silent logger won't log when nothing is enabled +""" + +import unittest +import sys +import json +import subprocess +from tests.common.log_helper import log_helper_test_main +from tests.common.log_helper import log_helper_test_imported +from os.path import dirname, join, relpath, abspath + +# this is the common base of "tests" and "oletools" dirs +ROOT_DIRECTORY = abspath(join(__file__, '..', '..', '..', '..')) +TEST_FILE = relpath(join(dirname(__file__), 'log_helper_test_main.py'), ROOT_DIRECTORY) +PYTHON_EXECUTABLE = sys.executable + +MAIN_LOG_MESSAGES = [ + log_helper_test_main.DEBUG_MESSAGE, + log_helper_test_main.INFO_MESSAGE, + log_helper_test_main.WARNING_MESSAGE, + log_helper_test_main.ERROR_MESSAGE, + log_helper_test_main.CRITICAL_MESSAGE +] + + +class TestLogHelper(unittest.TestCase): + def test_it_doesnt_log_when_not_enabled(self): + output = self._run_test(['debug']) + self.assertTrue(len(output) == 0) + + def test_it_doesnt_log_json_when_not_enabled(self): + output = self._run_test(['as-json', 'debug']) + self.assertTrue(len(output) == 0) + + def test_logs_when_enabled(self): + output = self._run_test(['enable', 'warning']) + + expected_messages = [ + log_helper_test_main.WARNING_MESSAGE, + log_helper_test_main.ERROR_MESSAGE, + log_helper_test_main.CRITICAL_MESSAGE, + log_helper_test_imported.WARNING_MESSAGE, + log_helper_test_imported.ERROR_MESSAGE, + log_helper_test_imported.CRITICAL_MESSAGE + ] + + for msg in expected_messages: + self.assertIn(msg, output) + + def test_logs_json_when_enabled(self): + output = self._run_test(['enable', 'as-json', 'critical']) + + self._assert_json_messages(output, [ + log_helper_test_main.CRITICAL_MESSAGE, + log_helper_test_imported.CRITICAL_MESSAGE + ]) + + def test_json_correct_on_exceptions(self): + """ + Test that even on unhandled exceptions our JSON is always correct + """ + output = self._run_test(['enable', 'as-json', 'throw', 'critical'], False) + self._assert_json_messages(output, [ + log_helper_test_main.CRITICAL_MESSAGE, + log_helper_test_imported.CRITICAL_MESSAGE + ]) + + def _assert_json_messages(self, output, messages): + try: + json_data = json.loads(output) + self.assertEquals(len(json_data), len(messages)) + + for i in range(len(messages)): + self.assertEquals(messages[i], json_data[i]['msg']) + except ValueError: + self.fail('Invalid json:\n' + output) + + self.assertNotEqual(len(json_data), 0, msg='Output was empty') + + def _run_test(self, args, should_succeed=True): + """ + Use subprocess to better simulate the real scenario and avoid + logging conflicts when running multiple tests (since logging depends on singletons, + we might get errors or false positives between sequential tests runs) + """ + child = subprocess.Popen( + [PYTHON_EXECUTABLE, TEST_FILE] + args, + shell=False, + env={'PYTHONPATH': ROOT_DIRECTORY}, + universal_newlines=True, + cwd=ROOT_DIRECTORY, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + (output, output_err) = child.communicate() + + if not isinstance(output, str): + output = output.decode('utf-8') + + self.assertEquals(child.returncode == 0, should_succeed) + + return output.strip() + + +# just in case somebody calls this file as a script +if __name__ == '__main__': + unittest.main() diff --git a/tests/json/test_output.py b/tests/json/test_output.py deleted file mode 100644 index f1fd48b..0000000 --- a/tests/json/test_output.py +++ /dev/null @@ -1,99 +0,0 @@ -""" Test validity of json output - -Some scripts have a json output flag. Verify that at default log levels output -can be captured as-is and parsed by a json parser -- checking the return code -if desired. -""" - -import unittest -import sys -import json -import os -from os.path import join -from oletools import msodde -from tests.test_utils import OutputCapture, DATA_BASE_DIR - -if sys.version_info[0] <= 2: - from oletools import olevba -else: - from oletools import olevba3 as olevba - - -class TestValidJson(unittest.TestCase): - """ - Ensure that script output is valid json. - If check_return_code is True we also ignore the output - of runs that didn't succeed. - """ - - @staticmethod - def iter_test_files(): - """ Iterate over all test files in DATA_BASE_DIR """ - for dirpath, _, filenames in os.walk(DATA_BASE_DIR): - for filename in filenames: - yield join(dirpath, filename) - - def run_and_parse(self, program, args, print_output=False, check_return_code=True): - """ run single program with single file and parse output """ - with OutputCapture() as capturer: # capture stdout - try: - return_code = program(args) - except Exception: - return_code = 1 # would result in non-zero exit - except SystemExit as se: - return_code = se.code or 0 # se.code can be None - if check_return_code and return_code is not 0: - if print_output: - print('Command failed ({0}) -- not parsing output' - .format(return_code)) - return [] # no need to test - - self.assertNotEqual(return_code, None, - msg='self-test fail: return_code not set') - - # now test output - if print_output: - print(capturer.get_data()) - try: - json_data = json.loads(capturer.get_data()) - except ValueError: - self.fail('Invalid json:\n' + capturer.get_data()) - self.assertNotEqual(len(json_data), 0, msg='Output was empty') - return json_data - - def run_all_files(self, program, args_without_filename, print_output=False): - """ run test for a single program over all test files """ - n_files = 0 - for testfile in self.iter_test_files(): # loop over all input - args = args_without_filename + [testfile, ] - self.run_and_parse(program, args, print_output) - n_files += 1 - self.assertNotEqual(n_files, 0, - msg='self-test fail: No test files found') - - def test_msodde(self): - """ Test msodde.py """ - self.run_all_files(msodde.main, ['-j', ]) - - def test_olevba(self): - """ Test olevba.py with default args """ - self.run_all_files(olevba.main, ['-j', ]) - - def test_olevba_analysis(self): - """ Test olevba.py with -a """ - self.run_all_files(olevba.main, ['-j', '-a', ]) - - def test_olevba_recurse(self): - """ Test olevba.py with -r """ - json_data = self.run_and_parse(olevba.main, - ['-j', '-r', join(DATA_BASE_DIR, '*')], - check_return_code=False) - self.assertNotEqual(len(json_data), 0, - msg='olevba[3] returned non-zero or no output') - self.assertNotEqual(json_data[-1]['n_processed'], 0, - msg='self-test fail: No test files found!') - - -# just in case somebody calls this file as a script -if __name__ == '__main__': - unittest.main() diff --git a/tests/msodde/test_basic.py b/tests/msodde/test_basic.py index 1966a2f..ac3121c 100644 --- a/tests/msodde/test_basic.py +++ b/tests/msodde/test_basic.py @@ -10,15 +10,13 @@ from __future__ import print_function import unittest from oletools import msodde -from tests.test_utils import OutputCapture, DATA_BASE_DIR as BASE_DIR -import shlex +from tests.test_utils import DATA_BASE_DIR as BASE_DIR from os.path import join from traceback import print_exc class TestReturnCode(unittest.TestCase): """ check return codes and exception behaviour (not text output) """ - def test_valid_doc(self): """ check that a valid doc file leads to 0 exit status """ for filename in ( @@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase): def do_test_validity(self, args, expect_error=False): """ helper for test_valid_doc[x] """ - args = shlex.split(args) - return_code = -1 have_exception = False try: - return_code = msodde.main(args) + msodde.process_file(args, msodde.FIELD_FILTER_BLACKLIST) except Exception: have_exception = True print_exc() except SystemExit as exc: # sys.exit() was called - return_code = exc.code + have_exception = True if exc.code is None: - return_code = 0 + have_exception = False - self.assertEqual(expect_error, have_exception or (return_code != 0), - msg='Args={0}, expect={1}, exc={2}, return={3}' - .format(args, expect_error, have_exception, - return_code)) + self.assertEqual(expect_error, have_exception, + msg='Args={0}, expect={1}, exc={2}' + .format(args, expect_error, have_exception)) class TestDdeLinks(unittest.TestCase): """ capture output of msodde and check dde-links are found correctly """ - def get_dde_from_output(self, capturer): + @staticmethod + def get_dde_from_output(output): """ helper to read dde links from captured output - - duplicate in tests/msodde/test_csv """ - have_start_line = False - result = [] - for line in capturer: - if not line.strip(): - continue # skip empty lines - if have_start_line: - result.append(line) - elif line == 'DDE Links:': - have_start_line = True - - self.assertTrue(have_start_line) # ensure output was complete - return result + return [o for o in output.splitlines()] def test_with_dde(self): """ check that dde links appear on stdout """ filename = 'dde-test-from-office2003.doc' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertNotEqual(len(self.get_dde_from_output(output)), 0, msg='Found no dde links in output of ' + filename) def test_no_dde(self): """ check that no dde links appear on stdout """ filename = 'harmless-clean.doc' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertEqual(len(self.get_dde_from_output(output)), 0, msg='Found dde links in output of ' + filename) def test_with_dde_utf16le(self): """ check that dde links appear on stdout """ filename = 'dde-test-from-office2013-utf_16le-korean.doc' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertNotEqual(len(self.get_dde_from_output(output)), 0, msg='Found no dde links in output of ' + filename) def test_excel(self): """ check that dde links are found in excel 2007+ files """ expect = ['DDE-Link cmd /c calc.exe', ] for extn in 'xlsx', 'xlsm', 'xlsb': - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', 'dde-test.' + extn), ]) - self.assertEqual(expect, self.get_dde_from_output(capturer), + output = msodde.process_file( + join(BASE_DIR, 'msodde', 'dde-test.' + extn), msodde.FIELD_FILTER_BLACKLIST) + + self.assertEqual(expect, self.get_dde_from_output(output), msg='unexpected output for dde-test.{0}: {1}' - .format(extn, capturer.get_data())) + .format(extn, output)) def test_xml(self): """ check that dde in xml from word / excel is found """ for name_part in 'excel2003', 'word2003', 'word2007': filename = 'dde-in-' + name_part + '.xml' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename), ]) - links = self.get_dde_from_output(capturer) + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + links = self.get_dde_from_output(output) self.assertEqual(len(links), 1, 'found {0} dde-links in {1}' .format(len(links), filename)) self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}' @@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase): def test_clean_rtf_blacklist(self): """ find a lot of hyperlinks in rtf spec """ filename = 'RTF-Spec-1.7.rtf' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertEqual(len(self.get_dde_from_output(capturer)), 1413) + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertEqual(len(self.get_dde_from_output(output)), 1413) def test_clean_rtf_ddeonly(self): """ find no dde links in rtf spec """ filename = 'RTF-Spec-1.7.rtf' - with OutputCapture() as capturer: - msodde.main(['-d', join(BASE_DIR, 'msodde', filename)]) - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_DDE) + self.assertEqual(len(self.get_dde_from_output(output)), 0, msg='Found dde links in output of ' + filename) diff --git a/tests/msodde/test_csv.py b/tests/msodde/test_csv.py index 2c1e7f1..92131b4 100644 --- a/tests/msodde/test_csv.py +++ b/tests/msodde/test_csv.py @@ -9,7 +9,7 @@ import os from os.path import join from oletools import msodde -from tests.test_utils import OutputCapture, DATA_BASE_DIR +from tests.test_utils import DATA_BASE_DIR class TestCSV(unittest.TestCase): @@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase): def test_file(self): """ test simple small example file """ filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv') - with OutputCapture() as capturer: - capturer.reload_module(msodde) # re-create logger - ret_code = msodde.main([filename, ]) - self.assertEqual(ret_code, 0) - links = self.get_dde_from_output(capturer) + output = msodde.process_file(filename, msodde.FIELD_FILTER_BLACKLIST) + links = self.get_dde_from_output(output) self.assertEqual(len(links), 1) self.assertEqual(links[0], r"cmd '/k \..\..\..\Windows\System32\calc.exe'") @@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase): if self.DO_DEBUG: args += ['-l', 'debug'] - with OutputCapture() as capturer: - capturer.reload_module(msodde) # re-create logger - ret_code = msodde.main(args) - self.assertEqual(ret_code, 0, 'checking sample resulted in ' - 'error:\n' + sample_text) - return capturer + processed_args = msodde.process_args(args) + + return msodde.process_file( + processed_args.filepath, processed_args.field_filter_mode) except Exception: raise @@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase): os.remove(filename) filename = None # just in case - def get_dde_from_output(self, capturer): + @staticmethod + def get_dde_from_output(output): """ helper to read dde links from captured output - - duplicate in tests/msodde/test_basic """ - have_start_line = False - result = [] - for line in capturer: - if self.DO_DEBUG: - print('captured: ' + line) - if not line.strip(): - continue # skip empty lines - if have_start_line: - result.append(line) - elif line == 'DDE Links:': - have_start_line = True - - self.assertTrue(have_start_line) # ensure output was complete - return result + return [o for o in output.splitlines()] def test_regex(self): """ check that regex captures other ways to include dde commands diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index fca8642..c6671c7 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,3 @@ -from .output_capture import OutputCapture - from os.path import dirname, join # Directory with test data, independent of current working directory diff --git a/tests/test_utils/output_capture.py b/tests/test_utils/output_capture.py deleted file mode 100644 index 0a6c6a2..0000000 --- a/tests/test_utils/output_capture.py +++ /dev/null @@ -1,83 +0,0 @@ -""" class OutputCapture to test what scripts print to stdout """ - -from __future__ import print_function -import sys -import logging - - -# python 2/3 version conflict: -if sys.version_info.major <= 2: - from StringIO import StringIO - # reload is a builtin -else: - from io import StringIO - if sys.version_info.minor < 4: - from imp import reload - else: - from importlib import reload - - -class OutputCapture: - """ context manager that captures stdout - - use as follows:: - - with OutputCapture() as capturer: - run_my_script(some_args) - - # either test line-by-line ... - for line in capturer: - some_test(line) - # ...or test all output in one go - some_test(capturer.get_data()) - - In order to solve issues with old logger instances still remembering closed - StringIO instances as "their" stdout, logging is shutdown and restarted - upon entering this Context Manager. This means that you may have to reload - your module, as well. - """ - - def __init__(self): - self.buffer = StringIO() - self.orig_stdout = None - self.data = None - - def __enter__(self): - # Avoid problems with old logger instances that still remember an old - # closed StringIO as their sys.stdout - logging.shutdown() - reload(logging) - - # replace sys.stdout with own buffer. - self.orig_stdout = sys.stdout - sys.stdout = self.buffer - return self - - def __exit__(self, exc_type, exc_value, traceback): - sys.stdout = self.orig_stdout # re-set to original - self.data = self.buffer.getvalue() - self.buffer.close() # close buffer - self.buffer = None - - if exc_type: # there has been an error - print('Got error during output capture!') - print('Print captured output and re-raise:') - for line in self.data.splitlines(): - print(line.rstrip()) # print output before re-raising - - def get_data(self): - """ retrieve all the captured data """ - if self.buffer is not None: - return self.buffer.getvalue() - elif self.data is not None: - return self.data - else: # should not be possible - raise RuntimeError('programming error or someone messed with data!') - - def __iter__(self): - for line in self.get_data().splitlines(): - yield line - - def reload_module(self, mod): - """ Wrapper around reload function for different python versions """ - return reload(mod)