Commit 1542df504ef321511df84f15c16a15aed252c58c
Committed by
GitHub
Merge pull request #308 from christian-intra2net/central-logger-json
Unified logging with json option
Showing
17 changed files
with
588 additions
and
409 deletions
oletools/common/log_helper/__init__.py
0 → 100644
oletools/common/log_helper/_json_formatter.py
0 → 100644
| 1 | +import logging | ||
| 2 | +import json | ||
| 3 | + | ||
| 4 | + | ||
| 5 | +class JsonFormatter(logging.Formatter): | ||
| 6 | + """ | ||
| 7 | + Format every message to be logged as a JSON object | ||
| 8 | + """ | ||
| 9 | + _is_first_line = True | ||
| 10 | + | ||
| 11 | + def format(self, record): | ||
| 12 | + """ | ||
| 13 | + Since we don't buffer messages, we always prepend messages with a comma to make | ||
| 14 | + the output JSON-compatible. The only exception is when printing the first line, | ||
| 15 | + so we need to keep track of it. | ||
| 16 | + """ | ||
| 17 | + json_dict = dict(msg=record.msg, level=record.levelname) | ||
| 18 | + formatted_message = ' ' + json.dumps(json_dict) | ||
| 19 | + | ||
| 20 | + if self._is_first_line: | ||
| 21 | + self._is_first_line = False | ||
| 22 | + return formatted_message | ||
| 23 | + | ||
| 24 | + return ', ' + formatted_message |
oletools/common/log_helper/_logger_adapter.py
0 → 100644
| 1 | +import logging | ||
| 2 | +from . import _root_logger_wrapper | ||
| 3 | + | ||
| 4 | + | ||
| 5 | +class OletoolsLoggerAdapter(logging.LoggerAdapter): | ||
| 6 | + """ | ||
| 7 | + Adapter class for all loggers returned by the logging module. | ||
| 8 | + """ | ||
| 9 | + _json_enabled = None | ||
| 10 | + | ||
| 11 | + def print_str(self, message): | ||
| 12 | + """ | ||
| 13 | + This function replaces normal print() calls so we can format them as JSON | ||
| 14 | + when needed or just print them right away otherwise. | ||
| 15 | + """ | ||
| 16 | + if self._json_enabled and self._json_enabled(): | ||
| 17 | + # Messages from this function should always be printed, | ||
| 18 | + # so when using JSON we log using the same level that set | ||
| 19 | + self.log(_root_logger_wrapper.level(), message) | ||
| 20 | + else: | ||
| 21 | + print(message) | ||
| 22 | + | ||
| 23 | + def set_json_enabled_function(self, json_enabled): | ||
| 24 | + """ | ||
| 25 | + Set a function to be called to check whether JSON output is enabled. | ||
| 26 | + """ | ||
| 27 | + self._json_enabled = json_enabled | ||
| 28 | + | ||
| 29 | + def level(self): | ||
| 30 | + return self.logger.level |
oletools/common/log_helper/_root_logger_wrapper.py
0 → 100644
| 1 | +import logging | ||
| 2 | + | ||
| 3 | + | ||
| 4 | +def is_logging_initialized(): | ||
| 5 | + """ | ||
| 6 | + We use the same strategy as the logging module when checking if | ||
| 7 | + the logging was initialized - look for handlers in the root logger | ||
| 8 | + """ | ||
| 9 | + return len(logging.root.handlers) > 0 | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +def set_formatter(fmt): | ||
| 13 | + """ | ||
| 14 | + Set the formatter to be used by every handler of the root logger. | ||
| 15 | + """ | ||
| 16 | + if not is_logging_initialized(): | ||
| 17 | + return | ||
| 18 | + | ||
| 19 | + for handler in logging.root.handlers: | ||
| 20 | + handler.setFormatter(fmt) | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +def level(): | ||
| 24 | + return logging.root.level |
oletools/common/log_helper/log_helper.py
0 → 100644
| 1 | +""" | ||
| 2 | +log_helper.py | ||
| 3 | + | ||
| 4 | +General logging helpers | ||
| 5 | + | ||
| 6 | +.. codeauthor:: Intra2net AG <info@intra2net> | ||
| 7 | +""" | ||
| 8 | + | ||
| 9 | +# === LICENSE ================================================================= | ||
| 10 | +# | ||
| 11 | +# Redistribution and use in source and binary forms, with or without | ||
| 12 | +# modification, are permitted provided that the following conditions are met: | ||
| 13 | +# | ||
| 14 | +# * Redistributions of source code must retain the above copyright notice, | ||
| 15 | +# this list of conditions and the following disclaimer. | ||
| 16 | +# * Redistributions in binary form must reproduce the above copyright notice, | ||
| 17 | +# this list of conditions and the following disclaimer in the documentation | ||
| 18 | +# and/or other materials provided with the distribution. | ||
| 19 | +# | ||
| 20 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
| 21 | +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| 22 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
| 23 | +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
| 24 | +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
| 25 | +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
| 26 | +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
| 27 | +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
| 28 | +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
| 29 | +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
| 30 | +# POSSIBILITY OF SUCH DAMAGE. | ||
| 31 | + | ||
| 32 | +# ----------------------------------------------------------------------------- | ||
| 33 | +# CHANGELOG: | ||
| 34 | +# 2017-12-07 v0.01 CH: - first version | ||
| 35 | +# 2018-02-05 v0.02 SA: - fixed log level selection and reformatted code | ||
| 36 | +# 2018-02-06 v0.03 SA: - refactored code to deal with NullHandlers | ||
| 37 | +# 2018-02-07 v0.04 SA: - fixed control of handlers propagation | ||
| 38 | +# 2018-04-23 v0.05 SA: - refactored the whole logger to use an OOP approach | ||
| 39 | + | ||
| 40 | +# ----------------------------------------------------------------------------- | ||
| 41 | +# TODO: | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +from ._json_formatter import JsonFormatter | ||
| 45 | +from ._logger_adapter import OletoolsLoggerAdapter | ||
| 46 | +from . import _root_logger_wrapper | ||
| 47 | +import logging | ||
| 48 | +import sys | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +LOG_LEVELS = { | ||
| 52 | + 'debug': logging.DEBUG, | ||
| 53 | + 'info': logging.INFO, | ||
| 54 | + 'warning': logging.WARNING, | ||
| 55 | + 'error': logging.ERROR, | ||
| 56 | + 'critical': logging.CRITICAL | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +DEFAULT_LOGGER_NAME = 'oletools' | ||
| 60 | +DEFAULT_MESSAGE_FORMAT = '%(levelname)-8s %(message)s' | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +class LogHelper: | ||
| 64 | + def __init__(self): | ||
| 65 | + self._all_names = set() # set so we do not have duplicates | ||
| 66 | + self._use_json = False | ||
| 67 | + self._is_enabled = False | ||
| 68 | + | ||
| 69 | + def get_or_create_silent_logger(self, name=DEFAULT_LOGGER_NAME, level=logging.CRITICAL + 1): | ||
| 70 | + """ | ||
| 71 | + Get a logger or create one if it doesn't exist, setting a NullHandler | ||
| 72 | + as the handler (to avoid printing to the console). | ||
| 73 | + By default we also use a higher logging level so every message will | ||
| 74 | + be ignored. | ||
| 75 | + This will prevent oletools from logging unnecessarily when being imported | ||
| 76 | + from external tools. | ||
| 77 | + """ | ||
| 78 | + return self._get_or_create_logger(name, level, logging.NullHandler()) | ||
| 79 | + | ||
| 80 | + def enable_logging(self, use_json, level, log_format=DEFAULT_MESSAGE_FORMAT, stream=None): | ||
| 81 | + """ | ||
| 82 | + This function initializes the root logger and enables logging. | ||
| 83 | + We set the level of the root logger to the one passed by calling logging.basicConfig. | ||
| 84 | + We also set the level of every logger we created to 0 (logging.NOTSET), meaning that | ||
| 85 | + the level of the root logger will be used to tell if messages should be logged. | ||
| 86 | + Additionally, since our loggers use the NullHandler, they won't log anything themselves, | ||
| 87 | + but due to having propagation enabled they will pass messages to the root logger, | ||
| 88 | + which in turn will log to the stream set in this function. | ||
| 89 | + Since the root logger is the one doing the work, when using JSON we set its formatter | ||
| 90 | + so that every message logged is JSON-compatible. | ||
| 91 | + """ | ||
| 92 | + if self._is_enabled: | ||
| 93 | + raise ValueError('re-enabling logging. Not sure whether that is ok...') | ||
| 94 | + | ||
| 95 | + log_level = LOG_LEVELS[level] | ||
| 96 | + logging.basicConfig(level=log_level, format=log_format, stream=stream) | ||
| 97 | + self._is_enabled = True | ||
| 98 | + | ||
| 99 | + self._use_json = use_json | ||
| 100 | + sys.excepthook = self._get_except_hook(sys.excepthook) | ||
| 101 | + | ||
| 102 | + # since there could be loggers already created we go through all of them | ||
| 103 | + # and set their levels to 0 so they will use the root logger's level | ||
| 104 | + for name in self._all_names: | ||
| 105 | + logger = self.get_or_create_silent_logger(name) | ||
| 106 | + self._set_logger_level(logger, logging.NOTSET) | ||
| 107 | + | ||
| 108 | + # add a JSON formatter to the root logger, which will be used by every logger | ||
| 109 | + if self._use_json: | ||
| 110 | + _root_logger_wrapper.set_formatter(JsonFormatter()) | ||
| 111 | + print('[') | ||
| 112 | + | ||
| 113 | + def end_logging(self): | ||
| 114 | + """ | ||
| 115 | + Must be called at the end of the main function if the caller wants | ||
| 116 | + json-compatible output | ||
| 117 | + """ | ||
| 118 | + if not self._is_enabled: | ||
| 119 | + return | ||
| 120 | + self._is_enabled = False | ||
| 121 | + | ||
| 122 | + # end logging | ||
| 123 | + self._all_names = set() | ||
| 124 | + logging.shutdown() | ||
| 125 | + | ||
| 126 | + # end json list | ||
| 127 | + if self._use_json: | ||
| 128 | + print(']') | ||
| 129 | + self._use_json = False | ||
| 130 | + | ||
| 131 | + def _get_except_hook(self, old_hook): | ||
| 132 | + """ | ||
| 133 | + Global hook for exceptions so we can always end logging. | ||
| 134 | + We wrap any hook currently set to avoid overwriting global hooks set by oletools. | ||
| 135 | + Note that this is only called by enable_logging, which in turn is called by | ||
| 136 | + the main() function in oletools' scripts. When scripts are being imported this | ||
| 137 | + code won't execute and won't affect global hooks. | ||
| 138 | + """ | ||
| 139 | + def hook(exctype, value, traceback): | ||
| 140 | + self.end_logging() | ||
| 141 | + old_hook(exctype, value, traceback) | ||
| 142 | + | ||
| 143 | + return hook | ||
| 144 | + | ||
| 145 | + def _get_or_create_logger(self, name, level, handler=None): | ||
| 146 | + """ | ||
| 147 | + Get or create a new logger. This newly created logger will have the | ||
| 148 | + handler and level that was passed, but if it already exists it's not changed. | ||
| 149 | + We also wrap the logger in an adapter so we can easily extend its functionality. | ||
| 150 | + """ | ||
| 151 | + | ||
| 152 | + # logging.getLogger creates a logger if it doesn't exist, | ||
| 153 | + # so we need to check before calling it | ||
| 154 | + if handler and not self._log_exists(name): | ||
| 155 | + logger = logging.getLogger(name) | ||
| 156 | + logger.addHandler(handler) | ||
| 157 | + self._set_logger_level(logger, level) | ||
| 158 | + else: | ||
| 159 | + logger = logging.getLogger(name) | ||
| 160 | + | ||
| 161 | + # Keep track of every logger we created so we can easily change | ||
| 162 | + # their levels whenever needed | ||
| 163 | + self._all_names.add(name) | ||
| 164 | + | ||
| 165 | + adapted_logger = OletoolsLoggerAdapter(logger, None) | ||
| 166 | + adapted_logger.set_json_enabled_function(lambda: self._use_json) | ||
| 167 | + | ||
| 168 | + return adapted_logger | ||
| 169 | + | ||
| 170 | + @staticmethod | ||
| 171 | + def _set_logger_level(logger, level): | ||
| 172 | + """ | ||
| 173 | + If the logging is already initialized, we set the level of our logger | ||
| 174 | + to 0, meaning that it will reuse the level of the root logger. | ||
| 175 | + That means that if the root logger level changes, we will keep using | ||
| 176 | + its level and not logging unnecessarily. | ||
| 177 | + """ | ||
| 178 | + | ||
| 179 | + # if this log was wrapped, unwrap it to set the level | ||
| 180 | + if isinstance(logger, OletoolsLoggerAdapter): | ||
| 181 | + logger = logger.logger | ||
| 182 | + | ||
| 183 | + if _root_logger_wrapper.is_logging_initialized(): | ||
| 184 | + logger.setLevel(logging.NOTSET) | ||
| 185 | + else: | ||
| 186 | + logger.setLevel(level) | ||
| 187 | + | ||
| 188 | + @staticmethod | ||
| 189 | + def _log_exists(name): | ||
| 190 | + """ | ||
| 191 | + We check the log manager instead of our global _all_names variable | ||
| 192 | + since the logger could have been created outside of the helper | ||
| 193 | + """ | ||
| 194 | + return name in logging.Logger.manager.loggerDict |
oletools/msodde.py
| @@ -53,8 +53,6 @@ import argparse | @@ -53,8 +53,6 @@ import argparse | ||
| 53 | import os | 53 | import os |
| 54 | from os.path import abspath, dirname | 54 | from os.path import abspath, dirname |
| 55 | import sys | 55 | import sys |
| 56 | -import json | ||
| 57 | -import logging | ||
| 58 | import re | 56 | import re |
| 59 | import csv | 57 | import csv |
| 60 | 58 | ||
| @@ -63,6 +61,7 @@ import olefile | @@ -63,6 +61,7 @@ import olefile | ||
| 63 | from oletools import ooxml | 61 | from oletools import ooxml |
| 64 | from oletools import xls_parser | 62 | from oletools import xls_parser |
| 65 | from oletools import rtfobj | 63 | from oletools import rtfobj |
| 64 | +from oletools.common.log_helper import log_helper | ||
| 66 | 65 | ||
| 67 | # ----------------------------------------------------------------------------- | 66 | # ----------------------------------------------------------------------------- |
| 68 | # CHANGELOG: | 67 | # CHANGELOG: |
| @@ -212,63 +211,12 @@ THIS IS WORK IN PROGRESS - Check updates regularly! | @@ -212,63 +211,12 @@ THIS IS WORK IN PROGRESS - Check updates regularly! | ||
| 212 | Please report any issue at https://github.com/decalage2/oletools/issues | 211 | Please report any issue at https://github.com/decalage2/oletools/issues |
| 213 | """ % __version__ | 212 | """ % __version__ |
| 214 | 213 | ||
| 215 | -BANNER_JSON = dict(type='meta', version=__version__, name='msodde', | ||
| 216 | - link='http://decalage.info/python/oletools', | ||
| 217 | - message='THIS IS WORK IN PROGRESS - Check updates regularly! ' | ||
| 218 | - 'Please report any issue at ' | ||
| 219 | - 'https://github.com/decalage2/oletools/issues') | ||
| 220 | - | ||
| 221 | # === LOGGING ================================================================= | 214 | # === LOGGING ================================================================= |
| 222 | 215 | ||
| 223 | DEFAULT_LOG_LEVEL = "warning" # Default log level | 216 | DEFAULT_LOG_LEVEL = "warning" # Default log level |
| 224 | -LOG_LEVELS = { | ||
| 225 | - 'debug': logging.DEBUG, | ||
| 226 | - 'info': logging.INFO, | ||
| 227 | - 'warning': logging.WARNING, | ||
| 228 | - 'error': logging.ERROR, | ||
| 229 | - 'critical': logging.CRITICAL | ||
| 230 | -} | ||
| 231 | - | ||
| 232 | - | ||
| 233 | -class NullHandler(logging.Handler): | ||
| 234 | - """ | ||
| 235 | - Log Handler without output, to avoid printing messages if logging is not | ||
| 236 | - configured by the main application. | ||
| 237 | - Python 2.7 has logging.NullHandler, but this is necessary for 2.6: | ||
| 238 | - see https://docs.python.org/2.6/library/logging.html#configuring-logging-for-a-library | ||
| 239 | - """ | ||
| 240 | - def emit(self, record): | ||
| 241 | - pass | ||
| 242 | - | ||
| 243 | - | ||
| 244 | -def get_logger(name, level=logging.CRITICAL+1): | ||
| 245 | - """ | ||
| 246 | - Create a suitable logger object for this module. | ||
| 247 | - The goal is not to change settings of the root logger, to avoid getting | ||
| 248 | - other modules' logs on the screen. | ||
| 249 | - If a logger exists with same name, reuse it. (Else it would have duplicate | ||
| 250 | - handlers and messages would be doubled.) | ||
| 251 | - The level is set to CRITICAL+1 by default, to avoid any logging. | ||
| 252 | - """ | ||
| 253 | - # First, test if there is already a logger with the same name, else it | ||
| 254 | - # will generate duplicate messages (due to duplicate handlers): | ||
| 255 | - if name in logging.Logger.manager.loggerDict: | ||
| 256 | - # NOTE: another less intrusive but more "hackish" solution would be to | ||
| 257 | - # use getLogger then test if its effective level is not default. | ||
| 258 | - logger = logging.getLogger(name) | ||
| 259 | - # make sure level is OK: | ||
| 260 | - logger.setLevel(level) | ||
| 261 | - return logger | ||
| 262 | - # get a new logger: | ||
| 263 | - logger = logging.getLogger(name) | ||
| 264 | - # only add a NullHandler for this logger, it is up to the application | ||
| 265 | - # to configure its own logging: | ||
| 266 | - logger.addHandler(NullHandler()) | ||
| 267 | - logger.setLevel(level) | ||
| 268 | - return logger | ||
| 269 | 217 | ||
| 270 | # a global logger object used for debugging: | 218 | # a global logger object used for debugging: |
| 271 | -log = get_logger('msodde') | 219 | +logger = log_helper.get_or_create_silent_logger('msodde') |
| 272 | 220 | ||
| 273 | 221 | ||
| 274 | # === UNICODE IN PY2 ========================================================= | 222 | # === UNICODE IN PY2 ========================================================= |
| @@ -312,7 +260,7 @@ def ensure_stdout_handles_unicode(): | @@ -312,7 +260,7 @@ def ensure_stdout_handles_unicode(): | ||
| 312 | encoding = 'utf8' | 260 | encoding = 'utf8' |
| 313 | 261 | ||
| 314 | # logging is probably not initialized yet, but just in case | 262 | # logging is probably not initialized yet, but just in case |
| 315 | - log.debug('wrapping sys.stdout with encoder using {0}'.format(encoding)) | 263 | + logger.debug('wrapping sys.stdout with encoder using {0}'.format(encoding)) |
| 316 | 264 | ||
| 317 | wrapper = codecs.getwriter(encoding) | 265 | wrapper = codecs.getwriter(encoding) |
| 318 | sys.stdout = wrapper(sys.stdout) | 266 | sys.stdout = wrapper(sys.stdout) |
| @@ -396,7 +344,7 @@ def process_doc_field(data): | @@ -396,7 +344,7 @@ def process_doc_field(data): | ||
| 396 | """ check if field instructions start with DDE | 344 | """ check if field instructions start with DDE |
| 397 | 345 | ||
| 398 | expects unicode input, returns unicode output (empty if not dde) """ | 346 | expects unicode input, returns unicode output (empty if not dde) """ |
| 399 | - log.debug('processing field {0}'.format(data)) | 347 | + logger.debug('processing field {0}'.format(data)) |
| 400 | 348 | ||
| 401 | if data.lstrip().lower().startswith(u'dde'): | 349 | if data.lstrip().lower().startswith(u'dde'): |
| 402 | return data | 350 | return data |
| @@ -434,7 +382,7 @@ def process_doc_stream(stream): | @@ -434,7 +382,7 @@ def process_doc_stream(stream): | ||
| 434 | 382 | ||
| 435 | if char == OLE_FIELD_START: | 383 | if char == OLE_FIELD_START: |
| 436 | if have_start and max_size_exceeded: | 384 | if have_start and max_size_exceeded: |
| 437 | - log.debug('big field was not a field after all') | 385 | + logger.debug('big field was not a field after all') |
| 438 | have_start = True | 386 | have_start = True |
| 439 | have_sep = False | 387 | have_sep = False |
| 440 | max_size_exceeded = False | 388 | max_size_exceeded = False |
| @@ -446,7 +394,7 @@ def process_doc_stream(stream): | @@ -446,7 +394,7 @@ def process_doc_stream(stream): | ||
| 446 | # now we are after start char but not at end yet | 394 | # now we are after start char but not at end yet |
| 447 | if char == OLE_FIELD_SEP: | 395 | if char == OLE_FIELD_SEP: |
| 448 | if have_sep: | 396 | if have_sep: |
| 449 | - log.debug('unexpected field: has multiple separators!') | 397 | + logger.debug('unexpected field: has multiple separators!') |
| 450 | have_sep = True | 398 | have_sep = True |
| 451 | elif char == OLE_FIELD_END: | 399 | elif char == OLE_FIELD_END: |
| 452 | # have complete field now, process it | 400 | # have complete field now, process it |
| @@ -464,7 +412,7 @@ def process_doc_stream(stream): | @@ -464,7 +412,7 @@ def process_doc_stream(stream): | ||
| 464 | if max_size_exceeded: | 412 | if max_size_exceeded: |
| 465 | pass | 413 | pass |
| 466 | elif len(field_contents) > OLE_FIELD_MAX_SIZE: | 414 | elif len(field_contents) > OLE_FIELD_MAX_SIZE: |
| 467 | - log.debug('field exceeds max size of {0}. Ignore rest' | 415 | + logger.debug('field exceeds max size of {0}. Ignore rest' |
| 468 | .format(OLE_FIELD_MAX_SIZE)) | 416 | .format(OLE_FIELD_MAX_SIZE)) |
| 469 | max_size_exceeded = True | 417 | max_size_exceeded = True |
| 470 | 418 | ||
| @@ -482,9 +430,9 @@ def process_doc_stream(stream): | @@ -482,9 +430,9 @@ def process_doc_stream(stream): | ||
| 482 | field_contents += u'?' | 430 | field_contents += u'?' |
| 483 | 431 | ||
| 484 | if max_size_exceeded: | 432 | if max_size_exceeded: |
| 485 | - log.debug('big field was not a field after all') | 433 | + logger.debug('big field was not a field after all') |
| 486 | 434 | ||
| 487 | - log.debug('Checked {0} characters, found {1} fields' | 435 | + logger.debug('Checked {0} characters, found {1} fields' |
| 488 | .format(idx, len(result_parts))) | 436 | .format(idx, len(result_parts))) |
| 489 | 437 | ||
| 490 | return result_parts | 438 | return result_parts |
| @@ -498,7 +446,7 @@ def process_doc(filepath): | @@ -498,7 +446,7 @@ def process_doc(filepath): | ||
| 498 | empty if none were found. dde-links will still begin with the dde[auto] key | 446 | empty if none were found. dde-links will still begin with the dde[auto] key |
| 499 | word (possibly after some whitespace) | 447 | word (possibly after some whitespace) |
| 500 | """ | 448 | """ |
| 501 | - log.debug('process_doc') | 449 | + logger.debug('process_doc') |
| 502 | ole = olefile.OleFileIO(filepath, path_encoding=None) | 450 | ole = olefile.OleFileIO(filepath, path_encoding=None) |
| 503 | 451 | ||
| 504 | links = [] | 452 | links = [] |
| @@ -508,7 +456,7 @@ def process_doc(filepath): | @@ -508,7 +456,7 @@ def process_doc(filepath): | ||
| 508 | # this direntry is not part of the tree --> unused or orphan | 456 | # this direntry is not part of the tree --> unused or orphan |
| 509 | direntry = ole._load_direntry(sid) | 457 | direntry = ole._load_direntry(sid) |
| 510 | is_stream = direntry.entry_type == olefile.STGTY_STREAM | 458 | is_stream = direntry.entry_type == olefile.STGTY_STREAM |
| 511 | - log.debug('direntry {:2d} {}: {}' | 459 | + logger.debug('direntry {:2d} {}: {}' |
| 512 | .format(sid, '[orphan]' if is_orphan else direntry.name, | 460 | .format(sid, '[orphan]' if is_orphan else direntry.name, |
| 513 | 'is stream of size {}'.format(direntry.size) | 461 | 'is stream of size {}'.format(direntry.size) |
| 514 | if is_stream else | 462 | if is_stream else |
| @@ -593,7 +541,7 @@ def process_docx(filepath, field_filter_mode=None): | @@ -593,7 +541,7 @@ def process_docx(filepath, field_filter_mode=None): | ||
| 593 | ddetext += unquote(elem.text) | 541 | ddetext += unquote(elem.text) |
| 594 | 542 | ||
| 595 | # apply field command filter | 543 | # apply field command filter |
| 596 | - log.debug('filtering with mode "{0}"'.format(field_filter_mode)) | 544 | + logger.debug('filtering with mode "{0}"'.format(field_filter_mode)) |
| 597 | if field_filter_mode in (FIELD_FILTER_ALL, None): | 545 | if field_filter_mode in (FIELD_FILTER_ALL, None): |
| 598 | clean_fields = all_fields | 546 | clean_fields = all_fields |
| 599 | elif field_filter_mode == FIELD_FILTER_DDE: | 547 | elif field_filter_mode == FIELD_FILTER_DDE: |
| @@ -652,7 +600,7 @@ def field_is_blacklisted(contents): | @@ -652,7 +600,7 @@ def field_is_blacklisted(contents): | ||
| 652 | index = FIELD_BLACKLIST_CMDS.index(words[0].lower()) | 600 | index = FIELD_BLACKLIST_CMDS.index(words[0].lower()) |
| 653 | except ValueError: # first word is no blacklisted command | 601 | except ValueError: # first word is no blacklisted command |
| 654 | return False | 602 | return False |
| 655 | - log.debug('trying to match "{0}" to blacklist command {1}' | 603 | + logger.debug('trying to match "{0}" to blacklist command {1}' |
| 656 | .format(contents, FIELD_BLACKLIST[index])) | 604 | .format(contents, FIELD_BLACKLIST[index])) |
| 657 | _, nargs_required, nargs_optional, sw_with_arg, sw_solo, sw_format \ | 605 | _, nargs_required, nargs_optional, sw_with_arg, sw_solo, sw_format \ |
| 658 | = FIELD_BLACKLIST[index] | 606 | = FIELD_BLACKLIST[index] |
| @@ -664,11 +612,11 @@ def field_is_blacklisted(contents): | @@ -664,11 +612,11 @@ def field_is_blacklisted(contents): | ||
| 664 | break | 612 | break |
| 665 | nargs += 1 | 613 | nargs += 1 |
| 666 | if nargs < nargs_required: | 614 | if nargs < nargs_required: |
| 667 | - log.debug('too few args: found {0}, but need at least {1} in "{2}"' | 615 | + logger.debug('too few args: found {0}, but need at least {1} in "{2}"' |
| 668 | .format(nargs, nargs_required, contents)) | 616 | .format(nargs, nargs_required, contents)) |
| 669 | return False | 617 | return False |
| 670 | elif nargs > nargs_required + nargs_optional: | 618 | elif nargs > nargs_required + nargs_optional: |
| 671 | - log.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"' | 619 | + logger.debug('too many args: found {0}, but need at most {1}+{2} in "{3}"' |
| 672 | .format(nargs, nargs_required, nargs_optional, contents)) | 620 | .format(nargs, nargs_required, nargs_optional, contents)) |
| 673 | return False | 621 | return False |
| 674 | 622 | ||
| @@ -678,14 +626,14 @@ def field_is_blacklisted(contents): | @@ -678,14 +626,14 @@ def field_is_blacklisted(contents): | ||
| 678 | for word in words[1+nargs:]: | 626 | for word in words[1+nargs:]: |
| 679 | if expect_arg: # this is an argument for the last switch | 627 | if expect_arg: # this is an argument for the last switch |
| 680 | if arg_choices and (word not in arg_choices): | 628 | if arg_choices and (word not in arg_choices): |
| 681 | - log.debug('Found invalid switch argument "{0}" in "{1}"' | 629 | + logger.debug('Found invalid switch argument "{0}" in "{1}"' |
| 682 | .format(word, contents)) | 630 | .format(word, contents)) |
| 683 | return False | 631 | return False |
| 684 | expect_arg = False | 632 | expect_arg = False |
| 685 | arg_choices = [] # in general, do not enforce choices | 633 | arg_choices = [] # in general, do not enforce choices |
| 686 | continue # "no further questions, your honor" | 634 | continue # "no further questions, your honor" |
| 687 | elif not FIELD_SWITCH_REGEX.match(word): | 635 | elif not FIELD_SWITCH_REGEX.match(word): |
| 688 | - log.debug('expected switch, found "{0}" in "{1}"' | 636 | + logger.debug('expected switch, found "{0}" in "{1}"' |
| 689 | .format(word, contents)) | 637 | .format(word, contents)) |
| 690 | return False | 638 | return False |
| 691 | # we want a switch and we got a valid one | 639 | # we want a switch and we got a valid one |
| @@ -707,7 +655,7 @@ def field_is_blacklisted(contents): | @@ -707,7 +655,7 @@ def field_is_blacklisted(contents): | ||
| 707 | if 'numeric' in sw_format: | 655 | if 'numeric' in sw_format: |
| 708 | arg_choices = [] # too many choices to list them here | 656 | arg_choices = [] # too many choices to list them here |
| 709 | else: | 657 | else: |
| 710 | - log.debug('unexpected switch {0} in "{1}"' | 658 | + logger.debug('unexpected switch {0} in "{1}"' |
| 711 | .format(switch, contents)) | 659 | .format(switch, contents)) |
| 712 | return False | 660 | return False |
| 713 | 661 | ||
| @@ -733,11 +681,11 @@ def process_xlsx(filepath): | @@ -733,11 +681,11 @@ def process_xlsx(filepath): | ||
| 733 | # binary parts, e.g. contained in .xlsb | 681 | # binary parts, e.g. contained in .xlsb |
| 734 | for subfile, content_type, handle in parser.iter_non_xml(): | 682 | for subfile, content_type, handle in parser.iter_non_xml(): |
| 735 | try: | 683 | try: |
| 736 | - logging.info('Parsing non-xml subfile {0} with content type {1}' | 684 | + logger.info('Parsing non-xml subfile {0} with content type {1}' |
| 737 | .format(subfile, content_type)) | 685 | .format(subfile, content_type)) |
| 738 | for record in xls_parser.parse_xlsb_part(handle, content_type, | 686 | for record in xls_parser.parse_xlsb_part(handle, content_type, |
| 739 | subfile): | 687 | subfile): |
| 740 | - logging.debug('{0}: {1}'.format(subfile, record)) | 688 | + logger.debug('{0}: {1}'.format(subfile, record)) |
| 741 | if isinstance(record, xls_parser.XlsbBeginSupBook) and \ | 689 | if isinstance(record, xls_parser.XlsbBeginSupBook) and \ |
| 742 | record.link_type == \ | 690 | record.link_type == \ |
| 743 | xls_parser.XlsbBeginSupBook.LINK_TYPE_DDE: | 691 | xls_parser.XlsbBeginSupBook.LINK_TYPE_DDE: |
| @@ -747,14 +695,14 @@ def process_xlsx(filepath): | @@ -747,14 +695,14 @@ def process_xlsx(filepath): | ||
| 747 | if content_type.startswith('application/vnd.ms-excel.') or \ | 695 | if content_type.startswith('application/vnd.ms-excel.') or \ |
| 748 | content_type.startswith('application/vnd.ms-office.'): # pylint: disable=bad-indentation | 696 | content_type.startswith('application/vnd.ms-office.'): # pylint: disable=bad-indentation |
| 749 | # should really be able to parse these either as xml or records | 697 | # should really be able to parse these either as xml or records |
| 750 | - log_func = logging.warning | 698 | + log_func = logger.warning |
| 751 | elif content_type.startswith('image/') or content_type == \ | 699 | elif content_type.startswith('image/') or content_type == \ |
| 752 | 'application/vnd.openxmlformats-officedocument.' + \ | 700 | 'application/vnd.openxmlformats-officedocument.' + \ |
| 753 | 'spreadsheetml.printerSettings': | 701 | 'spreadsheetml.printerSettings': |
| 754 | # understandable that these are not record-base | 702 | # understandable that these are not record-base |
| 755 | - log_func = logging.debug | 703 | + log_func = logger.debug |
| 756 | else: # default | 704 | else: # default |
| 757 | - log_func = logging.info | 705 | + log_func = logger.info |
| 758 | log_func('Failed to parse {0} of content type {1}' | 706 | log_func('Failed to parse {0} of content type {1}' |
| 759 | .format(subfile, content_type)) | 707 | .format(subfile, content_type)) |
| 760 | # in any case: continue with next | 708 | # in any case: continue with next |
| @@ -774,15 +722,15 @@ class RtfFieldParser(rtfobj.RtfParser): | @@ -774,15 +722,15 @@ class RtfFieldParser(rtfobj.RtfParser): | ||
| 774 | 722 | ||
| 775 | def open_destination(self, destination): | 723 | def open_destination(self, destination): |
| 776 | if destination.cword == b'fldinst': | 724 | if destination.cword == b'fldinst': |
| 777 | - log.debug('*** Start field data at index %Xh' % destination.start) | 725 | + logger.debug('*** Start field data at index %Xh' % destination.start) |
| 778 | 726 | ||
| 779 | def close_destination(self, destination): | 727 | def close_destination(self, destination): |
| 780 | if destination.cword == b'fldinst': | 728 | if destination.cword == b'fldinst': |
| 781 | - log.debug('*** Close field data at index %Xh' % self.index) | ||
| 782 | - log.debug('Field text: %r' % destination.data) | 729 | + logger.debug('*** Close field data at index %Xh' % self.index) |
| 730 | + logger.debug('Field text: %r' % destination.data) | ||
| 783 | # remove extra spaces and newline chars: | 731 | # remove extra spaces and newline chars: |
| 784 | field_clean = destination.data.translate(None, b'\r\n').strip() | 732 | field_clean = destination.data.translate(None, b'\r\n').strip() |
| 785 | - log.debug('Cleaned Field text: %r' % field_clean) | 733 | + logger.debug('Cleaned Field text: %r' % field_clean) |
| 786 | self.fields.append(field_clean) | 734 | self.fields.append(field_clean) |
| 787 | 735 | ||
| 788 | def control_symbol(self, matchobject): | 736 | def control_symbol(self, matchobject): |
| @@ -804,7 +752,7 @@ def process_rtf(file_handle, field_filter_mode=None): | @@ -804,7 +752,7 @@ def process_rtf(file_handle, field_filter_mode=None): | ||
| 804 | rtfparser.parse() | 752 | rtfparser.parse() |
| 805 | all_fields = [field.decode('ascii') for field in rtfparser.fields] | 753 | all_fields = [field.decode('ascii') for field in rtfparser.fields] |
| 806 | # apply field command filter | 754 | # apply field command filter |
| 807 | - log.debug('found {1} fields, filtering with mode "{0}"' | 755 | + logger.debug('found {1} fields, filtering with mode "{0}"' |
| 808 | .format(field_filter_mode, len(all_fields))) | 756 | .format(field_filter_mode, len(all_fields))) |
| 809 | if field_filter_mode in (FIELD_FILTER_ALL, None): | 757 | if field_filter_mode in (FIELD_FILTER_ALL, None): |
| 810 | clean_fields = all_fields | 758 | clean_fields = all_fields |
| @@ -853,7 +801,7 @@ def process_csv(filepath): | @@ -853,7 +801,7 @@ def process_csv(filepath): | ||
| 853 | 801 | ||
| 854 | if is_small and not results: | 802 | if is_small and not results: |
| 855 | # easy to mis-sniff small files. Try different delimiters | 803 | # easy to mis-sniff small files. Try different delimiters |
| 856 | - log.debug('small file, no results; try all delimiters') | 804 | + logger.debug('small file, no results; try all delimiters') |
| 857 | file_handle.seek(0) | 805 | file_handle.seek(0) |
| 858 | other_delim = CSV_DELIMITERS.replace(dialect.delimiter, '') | 806 | other_delim = CSV_DELIMITERS.replace(dialect.delimiter, '') |
| 859 | for delim in other_delim: | 807 | for delim in other_delim: |
| @@ -861,12 +809,12 @@ def process_csv(filepath): | @@ -861,12 +809,12 @@ def process_csv(filepath): | ||
| 861 | file_handle.seek(0) | 809 | file_handle.seek(0) |
| 862 | results, _ = process_csv_dialect(file_handle, delim) | 810 | results, _ = process_csv_dialect(file_handle, delim) |
| 863 | except csv.Error: # e.g. sniffing fails | 811 | except csv.Error: # e.g. sniffing fails |
| 864 | - log.debug('failed to csv-parse with delimiter {0!r}' | 812 | + logger.debug('failed to csv-parse with delimiter {0!r}' |
| 865 | .format(delim)) | 813 | .format(delim)) |
| 866 | 814 | ||
| 867 | if is_small and not results: | 815 | if is_small and not results: |
| 868 | # try whole file as single cell, since sniffing fails in this case | 816 | # try whole file as single cell, since sniffing fails in this case |
| 869 | - log.debug('last attempt: take whole file as single unquoted cell') | 817 | + logger.debug('last attempt: take whole file as single unquoted cell') |
| 870 | file_handle.seek(0) | 818 | file_handle.seek(0) |
| 871 | match = CSV_DDE_FORMAT.match(file_handle.read(CSV_SMALL_THRESH)) | 819 | match = CSV_DDE_FORMAT.match(file_handle.read(CSV_SMALL_THRESH)) |
| 872 | if match: | 820 | if match: |
| @@ -882,7 +830,7 @@ def process_csv_dialect(file_handle, delimiters): | @@ -882,7 +830,7 @@ def process_csv_dialect(file_handle, delimiters): | ||
| 882 | dialect = csv.Sniffer().sniff(file_handle.read(CSV_SMALL_THRESH), | 830 | dialect = csv.Sniffer().sniff(file_handle.read(CSV_SMALL_THRESH), |
| 883 | delimiters=delimiters) | 831 | delimiters=delimiters) |
| 884 | dialect.strict = False # microsoft is never strict | 832 | dialect.strict = False # microsoft is never strict |
| 885 | - log.debug('sniffed csv dialect with delimiter {0!r} ' | 833 | + logger.debug('sniffed csv dialect with delimiter {0!r} ' |
| 886 | 'and quote char {1!r}' | 834 | 'and quote char {1!r}' |
| 887 | .format(dialect.delimiter, dialect.quotechar)) | 835 | .format(dialect.delimiter, dialect.quotechar)) |
| 888 | 836 | ||
| @@ -924,7 +872,7 @@ def process_excel_xml(filepath): | @@ -924,7 +872,7 @@ def process_excel_xml(filepath): | ||
| 924 | break | 872 | break |
| 925 | if formula is None: | 873 | if formula is None: |
| 926 | continue | 874 | continue |
| 927 | - log.debug('found cell with formula {0}'.format(formula)) | 875 | + logger.debug('found cell with formula {0}'.format(formula)) |
| 928 | match = re.match(XML_DDE_FORMAT, formula) | 876 | match = re.match(XML_DDE_FORMAT, formula) |
| 929 | if match: | 877 | if match: |
| 930 | dde_links.append(u' '.join(match.groups()[:2])) | 878 | dde_links.append(u' '.join(match.groups()[:2])) |
| @@ -934,40 +882,40 @@ def process_excel_xml(filepath): | @@ -934,40 +882,40 @@ def process_excel_xml(filepath): | ||
| 934 | def process_file(filepath, field_filter_mode=None): | 882 | def process_file(filepath, field_filter_mode=None): |
| 935 | """ decides which of the process_* functions to call """ | 883 | """ decides which of the process_* functions to call """ |
| 936 | if olefile.isOleFile(filepath): | 884 | if olefile.isOleFile(filepath): |
| 937 | - log.debug('Is OLE. Checking streams to see whether this is xls') | 885 | + logger.debug('Is OLE. Checking streams to see whether this is xls') |
| 938 | if xls_parser.is_xls(filepath): | 886 | if xls_parser.is_xls(filepath): |
| 939 | - log.debug('Process file as excel 2003 (xls)') | 887 | + logger.debug('Process file as excel 2003 (xls)') |
| 940 | return process_xls(filepath) | 888 | return process_xls(filepath) |
| 941 | else: | 889 | else: |
| 942 | - log.debug('Process file as word 2003 (doc)') | 890 | + logger.debug('Process file as word 2003 (doc)') |
| 943 | return process_doc(filepath) | 891 | return process_doc(filepath) |
| 944 | 892 | ||
| 945 | with open(filepath, 'rb') as file_handle: | 893 | with open(filepath, 'rb') as file_handle: |
| 946 | if file_handle.read(4) == RTF_START: | 894 | if file_handle.read(4) == RTF_START: |
| 947 | - log.debug('Process file as rtf') | 895 | + logger.debug('Process file as rtf') |
| 948 | return process_rtf(file_handle, field_filter_mode) | 896 | return process_rtf(file_handle, field_filter_mode) |
| 949 | 897 | ||
| 950 | try: | 898 | try: |
| 951 | doctype = ooxml.get_type(filepath) | 899 | doctype = ooxml.get_type(filepath) |
| 952 | - log.debug('Detected file type: {0}'.format(doctype)) | 900 | + logger.debug('Detected file type: {0}'.format(doctype)) |
| 953 | except Exception as exc: | 901 | except Exception as exc: |
| 954 | - log.debug('Exception trying to xml-parse file: {0}'.format(exc)) | 902 | + logger.debug('Exception trying to xml-parse file: {0}'.format(exc)) |
| 955 | doctype = None | 903 | doctype = None |
| 956 | 904 | ||
| 957 | if doctype == ooxml.DOCTYPE_EXCEL: | 905 | if doctype == ooxml.DOCTYPE_EXCEL: |
| 958 | - log.debug('Process file as excel 2007+ (xlsx)') | 906 | + logger.debug('Process file as excel 2007+ (xlsx)') |
| 959 | return process_xlsx(filepath) | 907 | return process_xlsx(filepath) |
| 960 | elif doctype in (ooxml.DOCTYPE_EXCEL_XML, ooxml.DOCTYPE_EXCEL_XML2003): | 908 | elif doctype in (ooxml.DOCTYPE_EXCEL_XML, ooxml.DOCTYPE_EXCEL_XML2003): |
| 961 | - log.debug('Process file as xml from excel 2003/2007+') | 909 | + logger.debug('Process file as xml from excel 2003/2007+') |
| 962 | return process_excel_xml(filepath) | 910 | return process_excel_xml(filepath) |
| 963 | elif doctype in (ooxml.DOCTYPE_WORD_XML, ooxml.DOCTYPE_WORD_XML2003): | 911 | elif doctype in (ooxml.DOCTYPE_WORD_XML, ooxml.DOCTYPE_WORD_XML2003): |
| 964 | - log.debug('Process file as xml from word 2003/2007+') | 912 | + logger.debug('Process file as xml from word 2003/2007+') |
| 965 | return process_docx(filepath) | 913 | return process_docx(filepath) |
| 966 | elif doctype is None: | 914 | elif doctype is None: |
| 967 | - log.debug('Process file as csv') | 915 | + logger.debug('Process file as csv') |
| 968 | return process_csv(filepath) | 916 | return process_csv(filepath) |
| 969 | else: # could be docx; if not: this is the old default code path | 917 | else: # could be docx; if not: this is the old default code path |
| 970 | - log.debug('Process file as word 2007+ (docx)') | 918 | + logger.debug('Process file as word 2007+ (docx)') |
| 971 | return process_docx(filepath, field_filter_mode) | 919 | return process_docx(filepath, field_filter_mode) |
| 972 | 920 | ||
| 973 | 921 | ||
| @@ -985,27 +933,14 @@ def main(cmd_line_args=None): | @@ -985,27 +933,14 @@ def main(cmd_line_args=None): | ||
| 985 | # Setup logging to the console: | 933 | # Setup logging to the console: |
| 986 | # here we use stdout instead of stderr by default, so that the output | 934 | # here we use stdout instead of stderr by default, so that the output |
| 987 | # can be redirected properly. | 935 | # can be redirected properly. |
| 988 | - logging.basicConfig(level=LOG_LEVELS[args.loglevel], stream=sys.stdout, | ||
| 989 | - format='%(levelname)-8s %(message)s') | ||
| 990 | - # enable logging in the modules: | ||
| 991 | - log.setLevel(logging.NOTSET) | ||
| 992 | - | ||
| 993 | - if args.json and args.loglevel.lower() == 'debug': | ||
| 994 | - log.warning('Debug log output will not be json-compatible!') | 936 | + log_helper.enable_logging(args.json, args.loglevel, stream=sys.stdout) |
| 995 | 937 | ||
| 996 | if args.nounquote: | 938 | if args.nounquote: |
| 997 | global NO_QUOTES | 939 | global NO_QUOTES |
| 998 | NO_QUOTES = True | 940 | NO_QUOTES = True |
| 999 | 941 | ||
| 1000 | - if args.json: | ||
| 1001 | - jout = [] | ||
| 1002 | - jout.append(BANNER_JSON) | ||
| 1003 | - else: | ||
| 1004 | - # print banner with version | ||
| 1005 | - print(BANNER) | ||
| 1006 | - | ||
| 1007 | - if not args.json: | ||
| 1008 | - print('Opening file: %s' % args.filepath) | 942 | + logger.print_str(BANNER) |
| 943 | + logger.print_str('Opening file: %s' % args.filepath) | ||
| 1009 | 944 | ||
| 1010 | text = '' | 945 | text = '' |
| 1011 | return_code = 1 | 946 | return_code = 1 |
| @@ -1013,22 +948,12 @@ def main(cmd_line_args=None): | @@ -1013,22 +948,12 @@ def main(cmd_line_args=None): | ||
| 1013 | text = process_file(args.filepath, args.field_filter_mode) | 948 | text = process_file(args.filepath, args.field_filter_mode) |
| 1014 | return_code = 0 | 949 | return_code = 0 |
| 1015 | except Exception as exc: | 950 | except Exception as exc: |
| 1016 | - if args.json: | ||
| 1017 | - jout.append(dict(type='error', error=type(exc).__name__, | ||
| 1018 | - message=str(exc))) | ||
| 1019 | - else: | ||
| 1020 | - raise # re-raise last known exception, keeping trace intact | ||
| 1021 | - | ||
| 1022 | - if args.json: | ||
| 1023 | - for line in text.splitlines(): | ||
| 1024 | - if line.strip(): | ||
| 1025 | - jout.append(dict(type='dde-link', link=line.strip())) | ||
| 1026 | - json.dump(jout, sys.stdout, check_circular=False, indent=4) | ||
| 1027 | - print() # add a newline after closing "]" | ||
| 1028 | - return return_code # required if we catch an exception in json-mode | ||
| 1029 | - else: | ||
| 1030 | - print ('DDE Links:') | ||
| 1031 | - print(text) | 951 | + logger.exception(exc.message) |
| 952 | + | ||
| 953 | + logger.print_str('DDE Links:') | ||
| 954 | + logger.print_str(text) | ||
| 955 | + | ||
| 956 | + log_helper.end_logging() | ||
| 1032 | 957 | ||
| 1033 | return return_code | 958 | return return_code |
| 1034 | 959 |
oletools/ooxml.py
| @@ -14,7 +14,7 @@ TODO: may have to tell apart single xml types: office2003 looks much different | @@ -14,7 +14,7 @@ TODO: may have to tell apart single xml types: office2003 looks much different | ||
| 14 | """ | 14 | """ |
| 15 | 15 | ||
| 16 | import sys | 16 | import sys |
| 17 | -import logging | 17 | +from oletools.common.log_helper import log_helper |
| 18 | from zipfile import ZipFile, BadZipfile, is_zipfile | 18 | from zipfile import ZipFile, BadZipfile, is_zipfile |
| 19 | from os.path import splitext | 19 | from os.path import splitext |
| 20 | import io | 20 | import io |
| @@ -27,6 +27,7 @@ try: | @@ -27,6 +27,7 @@ try: | ||
| 27 | except ImportError: | 27 | except ImportError: |
| 28 | import xml.etree.cElementTree as ET | 28 | import xml.etree.cElementTree as ET |
| 29 | 29 | ||
| 30 | +logger = log_helper.get_or_create_silent_logger('ooxml') | ||
| 30 | 31 | ||
| 31 | #: subfiles that have to be part of every ooxml file | 32 | #: subfiles that have to be part of every ooxml file |
| 32 | FILE_CONTENT_TYPES = '[Content_Types].xml' | 33 | FILE_CONTENT_TYPES = '[Content_Types].xml' |
| @@ -142,7 +143,7 @@ def get_type(filename): | @@ -142,7 +143,7 @@ def get_type(filename): | ||
| 142 | is_xls = False | 143 | is_xls = False |
| 143 | is_ppt = False | 144 | is_ppt = False |
| 144 | for _, elem, _ in parser.iter_xml(FILE_CONTENT_TYPES): | 145 | for _, elem, _ in parser.iter_xml(FILE_CONTENT_TYPES): |
| 145 | - logging.debug(u' ' + debug_str(elem)) | 146 | + logger.debug(u' ' + debug_str(elem)) |
| 146 | try: | 147 | try: |
| 147 | content_type = elem.attrib['ContentType'] | 148 | content_type = elem.attrib['ContentType'] |
| 148 | except KeyError: # ContentType not an attr | 149 | except KeyError: # ContentType not an attr |
| @@ -160,7 +161,7 @@ def get_type(filename): | @@ -160,7 +161,7 @@ def get_type(filename): | ||
| 160 | if not is_doc and not is_xls and not is_ppt: | 161 | if not is_doc and not is_xls and not is_ppt: |
| 161 | return DOCTYPE_NONE | 162 | return DOCTYPE_NONE |
| 162 | else: | 163 | else: |
| 163 | - logging.warning('Encountered contradictory content types') | 164 | + logger.warning('Encountered contradictory content types') |
| 164 | return DOCTYPE_MIXED | 165 | return DOCTYPE_MIXED |
| 165 | 166 | ||
| 166 | 167 | ||
| @@ -220,7 +221,7 @@ class ZipSubFile(object): | @@ -220,7 +221,7 @@ class ZipSubFile(object): | ||
| 220 | self.name = filename | 221 | self.name = filename |
| 221 | if size is None: | 222 | if size is None: |
| 222 | self.size = container.getinfo(filename).file_size | 223 | self.size = container.getinfo(filename).file_size |
| 223 | - logging.debug('zip stream has size {0}'.format(self.size)) | 224 | + logger.debug('zip stream has size {0}'.format(self.size)) |
| 224 | else: | 225 | else: |
| 225 | self.size = size | 226 | self.size = size |
| 226 | if 'w' in mode.lower(): | 227 | if 'w' in mode.lower(): |
| @@ -484,10 +485,10 @@ class XmlParser(object): | @@ -484,10 +485,10 @@ class XmlParser(object): | ||
| 484 | want_tags = [] | 485 | want_tags = [] |
| 485 | elif isstr(tags): | 486 | elif isstr(tags): |
| 486 | want_tags = [tags, ] | 487 | want_tags = [tags, ] |
| 487 | - logging.debug('looking for tags: {0}'.format(tags)) | 488 | + logger.debug('looking for tags: {0}'.format(tags)) |
| 488 | else: | 489 | else: |
| 489 | want_tags = tags | 490 | want_tags = tags |
| 490 | - logging.debug('looking for tags: {0}'.format(tags)) | 491 | + logger.debug('looking for tags: {0}'.format(tags)) |
| 491 | 492 | ||
| 492 | for subfile, handle in self.iter_files(subfiles): | 493 | for subfile, handle in self.iter_files(subfiles): |
| 493 | events = ('start', 'end') | 494 | events = ('start', 'end') |
| @@ -499,7 +500,7 @@ class XmlParser(object): | @@ -499,7 +500,7 @@ class XmlParser(object): | ||
| 499 | continue | 500 | continue |
| 500 | if event == 'start': | 501 | if event == 'start': |
| 501 | if elem.tag in want_tags: | 502 | if elem.tag in want_tags: |
| 502 | - logging.debug('remember start of tag {0} at {1}' | 503 | + logger.debug('remember start of tag {0} at {1}' |
| 503 | .format(elem.tag, depth)) | 504 | .format(elem.tag, depth)) |
| 504 | inside_tags.append((elem.tag, depth)) | 505 | inside_tags.append((elem.tag, depth)) |
| 505 | depth += 1 | 506 | depth += 1 |
| @@ -515,18 +516,18 @@ class XmlParser(object): | @@ -515,18 +516,18 @@ class XmlParser(object): | ||
| 515 | if inside_tags[-1] == curr_tag: | 516 | if inside_tags[-1] == curr_tag: |
| 516 | inside_tags.pop() | 517 | inside_tags.pop() |
| 517 | else: | 518 | else: |
| 518 | - logging.error('found end for wanted tag {0} ' | 519 | + logger.error('found end for wanted tag {0} ' |
| 519 | 'but last start tag {1} does not' | 520 | 'but last start tag {1} does not' |
| 520 | ' match'.format(curr_tag, | 521 | ' match'.format(curr_tag, |
| 521 | inside_tags[-1])) | 522 | inside_tags[-1])) |
| 522 | # try to recover: close all deeper tags | 523 | # try to recover: close all deeper tags |
| 523 | while inside_tags and \ | 524 | while inside_tags and \ |
| 524 | inside_tags[-1][1] >= depth: | 525 | inside_tags[-1][1] >= depth: |
| 525 | - logging.debug('recover: pop {0}' | 526 | + logger.debug('recover: pop {0}' |
| 526 | .format(inside_tags[-1])) | 527 | .format(inside_tags[-1])) |
| 527 | inside_tags.pop() | 528 | inside_tags.pop() |
| 528 | except IndexError: # no inside_tag[-1] | 529 | except IndexError: # no inside_tag[-1] |
| 529 | - logging.error('found end of {0} at depth {1} but ' | 530 | + logger.error('found end of {0} at depth {1} but ' |
| 530 | 'no start event') | 531 | 'no start event') |
| 531 | # yield element | 532 | # yield element |
| 532 | if is_wanted or not want_tags: | 533 | if is_wanted or not want_tags: |
| @@ -543,12 +544,12 @@ class XmlParser(object): | @@ -543,12 +544,12 @@ class XmlParser(object): | ||
| 543 | if subfile is None: # this is no zip subfile but single xml | 544 | if subfile is None: # this is no zip subfile but single xml |
| 544 | raise BadOOXML(self.filename, 'is neither zip nor xml') | 545 | raise BadOOXML(self.filename, 'is neither zip nor xml') |
| 545 | elif subfile.endswith('.xml'): | 546 | elif subfile.endswith('.xml'): |
| 546 | - logger = logging.warning | 547 | + log = logger.warning |
| 547 | else: | 548 | else: |
| 548 | - logger = logging.debug | ||
| 549 | - logger(' xml-parsing for {0} failed ({1}). ' | ||
| 550 | - .format(subfile, err) + | ||
| 551 | - 'Run iter_non_xml to investigate.') | 549 | + log = logger.debug |
| 550 | + log(' xml-parsing for {0} failed ({1}). ' | ||
| 551 | + .format(subfile, err) + | ||
| 552 | + 'Run iter_non_xml to investigate.') | ||
| 552 | assert(depth == 0) | 553 | assert(depth == 0) |
| 553 | 554 | ||
| 554 | def get_content_types(self): | 555 | def get_content_types(self): |
| @@ -571,14 +572,14 @@ class XmlParser(object): | @@ -571,14 +572,14 @@ class XmlParser(object): | ||
| 571 | if extension.startswith('.'): | 572 | if extension.startswith('.'): |
| 572 | extension = extension[1:] | 573 | extension = extension[1:] |
| 573 | defaults.append((extension, elem.attrib['ContentType'])) | 574 | defaults.append((extension, elem.attrib['ContentType'])) |
| 574 | - logging.debug('found content type for extension {0[0]}: {0[1]}' | 575 | + logger.debug('found content type for extension {0[0]}: {0[1]}' |
| 575 | .format(defaults[-1])) | 576 | .format(defaults[-1])) |
| 576 | elif elem.tag.endswith('Override'): | 577 | elif elem.tag.endswith('Override'): |
| 577 | subfile = elem.attrib['PartName'] | 578 | subfile = elem.attrib['PartName'] |
| 578 | if subfile.startswith('/'): | 579 | if subfile.startswith('/'): |
| 579 | subfile = subfile[1:] | 580 | subfile = subfile[1:] |
| 580 | files.append((subfile, elem.attrib['ContentType'])) | 581 | files.append((subfile, elem.attrib['ContentType'])) |
| 581 | - logging.debug('found content type for subfile {0[0]}: {0[1]}' | 582 | + logger.debug('found content type for subfile {0[0]}: {0[1]}' |
| 582 | .format(files[-1])) | 583 | .format(files[-1])) |
| 583 | return dict(files), dict(defaults) | 584 | return dict(files), dict(defaults) |
| 584 | 585 | ||
| @@ -595,7 +596,7 @@ class XmlParser(object): | @@ -595,7 +596,7 @@ class XmlParser(object): | ||
| 595 | To handle binary parts of an xlsb file, use xls_parser.parse_xlsb_part | 596 | To handle binary parts of an xlsb file, use xls_parser.parse_xlsb_part |
| 596 | """ | 597 | """ |
| 597 | if not self.did_iter_all: | 598 | if not self.did_iter_all: |
| 598 | - logging.warning('Did not iterate through complete file. ' | 599 | + logger.warning('Did not iterate through complete file. ' |
| 599 | 'Should run iter_xml() without args, first.') | 600 | 'Should run iter_xml() without args, first.') |
| 600 | if not self.subfiles_no_xml: | 601 | if not self.subfiles_no_xml: |
| 601 | return | 602 | return |
| @@ -628,7 +629,7 @@ def test(): | @@ -628,7 +629,7 @@ def test(): | ||
| 628 | 629 | ||
| 629 | see module doc for more info | 630 | see module doc for more info |
| 630 | """ | 631 | """ |
| 631 | - logging.basicConfig(level=logging.DEBUG) | 632 | + log_helper.enable_logging(False, logger.DEBUG) |
| 632 | if len(sys.argv) != 2: | 633 | if len(sys.argv) != 2: |
| 633 | print(u'To test this code, give me a single file as arg') | 634 | print(u'To test this code, give me a single file as arg') |
| 634 | return 2 | 635 | return 2 |
| @@ -647,6 +648,9 @@ def test(): | @@ -647,6 +648,9 @@ def test(): | ||
| 647 | if index > 100: | 648 | if index > 100: |
| 648 | print(u'...') | 649 | print(u'...') |
| 649 | break | 650 | break |
| 651 | + | ||
| 652 | + log_helper.end_logging() | ||
| 653 | + | ||
| 650 | return 0 | 654 | return 0 |
| 651 | 655 | ||
| 652 | 656 |
tests/json/__init__.py renamed to tests/common/__init__.py
tests/common/log_helper/__init__.py
0 → 100644
tests/common/log_helper/log_helper_test_imported.py
0 → 100644
| 1 | +""" | ||
| 2 | +Dummy file that logs messages, meant to be imported | ||
| 3 | +by the main test file | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +from oletools.common.log_helper import log_helper | ||
| 7 | +import logging | ||
| 8 | + | ||
| 9 | +DEBUG_MESSAGE = 'imported: debug log' | ||
| 10 | +INFO_MESSAGE = 'imported: info log' | ||
| 11 | +WARNING_MESSAGE = 'imported: warning log' | ||
| 12 | +ERROR_MESSAGE = 'imported: error log' | ||
| 13 | +CRITICAL_MESSAGE = 'imported: critical log' | ||
| 14 | + | ||
| 15 | +logger = log_helper.get_or_create_silent_logger('test_imported', logging.ERROR) | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def log(): | ||
| 19 | + logger.debug(DEBUG_MESSAGE) | ||
| 20 | + logger.info(INFO_MESSAGE) | ||
| 21 | + logger.warning(WARNING_MESSAGE) | ||
| 22 | + logger.error(ERROR_MESSAGE) | ||
| 23 | + logger.critical(CRITICAL_MESSAGE) |
tests/common/log_helper/log_helper_test_main.py
0 → 100644
| 1 | +""" Test log_helpers """ | ||
| 2 | + | ||
| 3 | +import sys | ||
| 4 | +from tests.common.log_helper import log_helper_test_imported | ||
| 5 | +from oletools.common.log_helper import log_helper | ||
| 6 | + | ||
| 7 | +DEBUG_MESSAGE = 'main: debug log' | ||
| 8 | +INFO_MESSAGE = 'main: info log' | ||
| 9 | +WARNING_MESSAGE = 'main: warning log' | ||
| 10 | +ERROR_MESSAGE = 'main: error log' | ||
| 11 | +CRITICAL_MESSAGE = 'main: critical log' | ||
| 12 | + | ||
| 13 | +logger = log_helper.get_or_create_silent_logger('test_main') | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +def init_logging_and_log(args): | ||
| 17 | + """ | ||
| 18 | + Try to cover possible logging scenarios. For each scenario covered, here's the expected args and outcome: | ||
| 19 | + - Log without enabling: ['<level>'] | ||
| 20 | + * logging when being imported - should never print | ||
| 21 | + - Log as JSON without enabling: ['as-json', '<level>'] | ||
| 22 | + * logging as JSON when being imported - should never print | ||
| 23 | + - Enable and log: ['enable', '<level>'] | ||
| 24 | + * logging when being run as script - should log messages | ||
| 25 | + - Enable and log as JSON: ['as-json', 'enable', '<level>'] | ||
| 26 | + * logging as JSON when being run as script - should log messages as JSON | ||
| 27 | + - Enable, log as JSON and throw: ['enable', 'as-json', 'throw', '<level>'] | ||
| 28 | + * should produce JSON-compatible output, even after an unhandled exception | ||
| 29 | + """ | ||
| 30 | + | ||
| 31 | + # the level should always be the last argument passed | ||
| 32 | + level = args[-1] | ||
| 33 | + use_json = 'as-json' in args | ||
| 34 | + throw = 'throw' in args | ||
| 35 | + | ||
| 36 | + if 'enable' in args: | ||
| 37 | + log_helper.enable_logging(use_json, level, stream=sys.stdout) | ||
| 38 | + | ||
| 39 | + _log() | ||
| 40 | + | ||
| 41 | + if throw: | ||
| 42 | + raise Exception('An exception occurred before ending the logging') | ||
| 43 | + | ||
| 44 | + log_helper.end_logging() | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def _log(): | ||
| 48 | + logger.debug(DEBUG_MESSAGE) | ||
| 49 | + logger.info(INFO_MESSAGE) | ||
| 50 | + logger.warning(WARNING_MESSAGE) | ||
| 51 | + logger.error(ERROR_MESSAGE) | ||
| 52 | + logger.critical(CRITICAL_MESSAGE) | ||
| 53 | + log_helper_test_imported.log() | ||
| 54 | + | ||
| 55 | + | ||
| 56 | +if __name__ == '__main__': | ||
| 57 | + init_logging_and_log(sys.argv[1:]) |
tests/common/log_helper/test_log_helper.py
0 → 100644
| 1 | +""" Test the log helper | ||
| 2 | + | ||
| 3 | +This tests the generic log helper. | ||
| 4 | +Check if it handles imported modules correctly | ||
| 5 | +and that the default silent logger won't log when nothing is enabled | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import unittest | ||
| 9 | +import sys | ||
| 10 | +import json | ||
| 11 | +import subprocess | ||
| 12 | +from tests.common.log_helper import log_helper_test_main | ||
| 13 | +from tests.common.log_helper import log_helper_test_imported | ||
| 14 | +from os.path import dirname, join, relpath, abspath | ||
| 15 | + | ||
| 16 | +# this is the common base of "tests" and "oletools" dirs | ||
| 17 | +ROOT_DIRECTORY = abspath(join(__file__, '..', '..', '..', '..')) | ||
| 18 | +TEST_FILE = relpath(join(dirname(__file__), 'log_helper_test_main.py'), ROOT_DIRECTORY) | ||
| 19 | +PYTHON_EXECUTABLE = sys.executable | ||
| 20 | + | ||
| 21 | +MAIN_LOG_MESSAGES = [ | ||
| 22 | + log_helper_test_main.DEBUG_MESSAGE, | ||
| 23 | + log_helper_test_main.INFO_MESSAGE, | ||
| 24 | + log_helper_test_main.WARNING_MESSAGE, | ||
| 25 | + log_helper_test_main.ERROR_MESSAGE, | ||
| 26 | + log_helper_test_main.CRITICAL_MESSAGE | ||
| 27 | +] | ||
| 28 | + | ||
| 29 | + | ||
| 30 | +class TestLogHelper(unittest.TestCase): | ||
| 31 | + def test_it_doesnt_log_when_not_enabled(self): | ||
| 32 | + output = self._run_test(['debug']) | ||
| 33 | + self.assertTrue(len(output) == 0) | ||
| 34 | + | ||
| 35 | + def test_it_doesnt_log_json_when_not_enabled(self): | ||
| 36 | + output = self._run_test(['as-json', 'debug']) | ||
| 37 | + self.assertTrue(len(output) == 0) | ||
| 38 | + | ||
| 39 | + def test_logs_when_enabled(self): | ||
| 40 | + output = self._run_test(['enable', 'warning']) | ||
| 41 | + | ||
| 42 | + expected_messages = [ | ||
| 43 | + log_helper_test_main.WARNING_MESSAGE, | ||
| 44 | + log_helper_test_main.ERROR_MESSAGE, | ||
| 45 | + log_helper_test_main.CRITICAL_MESSAGE, | ||
| 46 | + log_helper_test_imported.WARNING_MESSAGE, | ||
| 47 | + log_helper_test_imported.ERROR_MESSAGE, | ||
| 48 | + log_helper_test_imported.CRITICAL_MESSAGE | ||
| 49 | + ] | ||
| 50 | + | ||
| 51 | + for msg in expected_messages: | ||
| 52 | + self.assertIn(msg, output) | ||
| 53 | + | ||
| 54 | + def test_logs_json_when_enabled(self): | ||
| 55 | + output = self._run_test(['enable', 'as-json', 'critical']) | ||
| 56 | + | ||
| 57 | + self._assert_json_messages(output, [ | ||
| 58 | + log_helper_test_main.CRITICAL_MESSAGE, | ||
| 59 | + log_helper_test_imported.CRITICAL_MESSAGE | ||
| 60 | + ]) | ||
| 61 | + | ||
| 62 | + def test_json_correct_on_exceptions(self): | ||
| 63 | + """ | ||
| 64 | + Test that even on unhandled exceptions our JSON is always correct | ||
| 65 | + """ | ||
| 66 | + output = self._run_test(['enable', 'as-json', 'throw', 'critical'], False) | ||
| 67 | + self._assert_json_messages(output, [ | ||
| 68 | + log_helper_test_main.CRITICAL_MESSAGE, | ||
| 69 | + log_helper_test_imported.CRITICAL_MESSAGE | ||
| 70 | + ]) | ||
| 71 | + | ||
| 72 | + def _assert_json_messages(self, output, messages): | ||
| 73 | + try: | ||
| 74 | + json_data = json.loads(output) | ||
| 75 | + self.assertEquals(len(json_data), len(messages)) | ||
| 76 | + | ||
| 77 | + for i in range(len(messages)): | ||
| 78 | + self.assertEquals(messages[i], json_data[i]['msg']) | ||
| 79 | + except ValueError: | ||
| 80 | + self.fail('Invalid json:\n' + output) | ||
| 81 | + | ||
| 82 | + self.assertNotEqual(len(json_data), 0, msg='Output was empty') | ||
| 83 | + | ||
| 84 | + def _run_test(self, args, should_succeed=True): | ||
| 85 | + """ | ||
| 86 | + Use subprocess to better simulate the real scenario and avoid | ||
| 87 | + logging conflicts when running multiple tests (since logging depends on singletons, | ||
| 88 | + we might get errors or false positives between sequential tests runs) | ||
| 89 | + """ | ||
| 90 | + child = subprocess.Popen( | ||
| 91 | + [PYTHON_EXECUTABLE, TEST_FILE] + args, | ||
| 92 | + shell=False, | ||
| 93 | + env={'PYTHONPATH': ROOT_DIRECTORY}, | ||
| 94 | + universal_newlines=True, | ||
| 95 | + cwd=ROOT_DIRECTORY, | ||
| 96 | + stdin=None, | ||
| 97 | + stdout=subprocess.PIPE, | ||
| 98 | + stderr=subprocess.PIPE | ||
| 99 | + ) | ||
| 100 | + (output, output_err) = child.communicate() | ||
| 101 | + | ||
| 102 | + if not isinstance(output, str): | ||
| 103 | + output = output.decode('utf-8') | ||
| 104 | + | ||
| 105 | + self.assertEquals(child.returncode == 0, should_succeed) | ||
| 106 | + | ||
| 107 | + return output.strip() | ||
| 108 | + | ||
| 109 | + | ||
| 110 | +# just in case somebody calls this file as a script | ||
| 111 | +if __name__ == '__main__': | ||
| 112 | + unittest.main() |
tests/json/test_output.py deleted
| 1 | -""" Test validity of json output | ||
| 2 | - | ||
| 3 | -Some scripts have a json output flag. Verify that at default log levels output | ||
| 4 | -can be captured as-is and parsed by a json parser -- checking the return code | ||
| 5 | -if desired. | ||
| 6 | -""" | ||
| 7 | - | ||
| 8 | -import unittest | ||
| 9 | -import sys | ||
| 10 | -import json | ||
| 11 | -import os | ||
| 12 | -from os.path import join | ||
| 13 | -from oletools import msodde | ||
| 14 | -from tests.test_utils import OutputCapture, DATA_BASE_DIR | ||
| 15 | - | ||
| 16 | -if sys.version_info[0] <= 2: | ||
| 17 | - from oletools import olevba | ||
| 18 | -else: | ||
| 19 | - from oletools import olevba3 as olevba | ||
| 20 | - | ||
| 21 | - | ||
| 22 | -class TestValidJson(unittest.TestCase): | ||
| 23 | - """ | ||
| 24 | - Ensure that script output is valid json. | ||
| 25 | - If check_return_code is True we also ignore the output | ||
| 26 | - of runs that didn't succeed. | ||
| 27 | - """ | ||
| 28 | - | ||
| 29 | - @staticmethod | ||
| 30 | - def iter_test_files(): | ||
| 31 | - """ Iterate over all test files in DATA_BASE_DIR """ | ||
| 32 | - for dirpath, _, filenames in os.walk(DATA_BASE_DIR): | ||
| 33 | - for filename in filenames: | ||
| 34 | - yield join(dirpath, filename) | ||
| 35 | - | ||
| 36 | - def run_and_parse(self, program, args, print_output=False, check_return_code=True): | ||
| 37 | - """ run single program with single file and parse output """ | ||
| 38 | - with OutputCapture() as capturer: # capture stdout | ||
| 39 | - try: | ||
| 40 | - return_code = program(args) | ||
| 41 | - except Exception: | ||
| 42 | - return_code = 1 # would result in non-zero exit | ||
| 43 | - except SystemExit as se: | ||
| 44 | - return_code = se.code or 0 # se.code can be None | ||
| 45 | - if check_return_code and return_code is not 0: | ||
| 46 | - if print_output: | ||
| 47 | - print('Command failed ({0}) -- not parsing output' | ||
| 48 | - .format(return_code)) | ||
| 49 | - return [] # no need to test | ||
| 50 | - | ||
| 51 | - self.assertNotEqual(return_code, None, | ||
| 52 | - msg='self-test fail: return_code not set') | ||
| 53 | - | ||
| 54 | - # now test output | ||
| 55 | - if print_output: | ||
| 56 | - print(capturer.get_data()) | ||
| 57 | - try: | ||
| 58 | - json_data = json.loads(capturer.get_data()) | ||
| 59 | - except ValueError: | ||
| 60 | - self.fail('Invalid json:\n' + capturer.get_data()) | ||
| 61 | - self.assertNotEqual(len(json_data), 0, msg='Output was empty') | ||
| 62 | - return json_data | ||
| 63 | - | ||
| 64 | - def run_all_files(self, program, args_without_filename, print_output=False): | ||
| 65 | - """ run test for a single program over all test files """ | ||
| 66 | - n_files = 0 | ||
| 67 | - for testfile in self.iter_test_files(): # loop over all input | ||
| 68 | - args = args_without_filename + [testfile, ] | ||
| 69 | - self.run_and_parse(program, args, print_output) | ||
| 70 | - n_files += 1 | ||
| 71 | - self.assertNotEqual(n_files, 0, | ||
| 72 | - msg='self-test fail: No test files found') | ||
| 73 | - | ||
| 74 | - def test_msodde(self): | ||
| 75 | - """ Test msodde.py """ | ||
| 76 | - self.run_all_files(msodde.main, ['-j', ]) | ||
| 77 | - | ||
| 78 | - def test_olevba(self): | ||
| 79 | - """ Test olevba.py with default args """ | ||
| 80 | - self.run_all_files(olevba.main, ['-j', ]) | ||
| 81 | - | ||
| 82 | - def test_olevba_analysis(self): | ||
| 83 | - """ Test olevba.py with -a """ | ||
| 84 | - self.run_all_files(olevba.main, ['-j', '-a', ]) | ||
| 85 | - | ||
| 86 | - def test_olevba_recurse(self): | ||
| 87 | - """ Test olevba.py with -r """ | ||
| 88 | - json_data = self.run_and_parse(olevba.main, | ||
| 89 | - ['-j', '-r', join(DATA_BASE_DIR, '*')], | ||
| 90 | - check_return_code=False) | ||
| 91 | - self.assertNotEqual(len(json_data), 0, | ||
| 92 | - msg='olevba[3] returned non-zero or no output') | ||
| 93 | - self.assertNotEqual(json_data[-1]['n_processed'], 0, | ||
| 94 | - msg='self-test fail: No test files found!') | ||
| 95 | - | ||
| 96 | - | ||
| 97 | -# just in case somebody calls this file as a script | ||
| 98 | -if __name__ == '__main__': | ||
| 99 | - unittest.main() |
tests/msodde/test_basic.py
| @@ -10,15 +10,13 @@ from __future__ import print_function | @@ -10,15 +10,13 @@ from __future__ import print_function | ||
| 10 | 10 | ||
| 11 | import unittest | 11 | import unittest |
| 12 | from oletools import msodde | 12 | from oletools import msodde |
| 13 | -from tests.test_utils import OutputCapture, DATA_BASE_DIR as BASE_DIR | ||
| 14 | -import shlex | 13 | +from tests.test_utils import DATA_BASE_DIR as BASE_DIR |
| 15 | from os.path import join | 14 | from os.path import join |
| 16 | from traceback import print_exc | 15 | from traceback import print_exc |
| 17 | 16 | ||
| 18 | 17 | ||
| 19 | class TestReturnCode(unittest.TestCase): | 18 | class TestReturnCode(unittest.TestCase): |
| 20 | """ check return codes and exception behaviour (not text output) """ | 19 | """ check return codes and exception behaviour (not text output) """ |
| 21 | - | ||
| 22 | def test_valid_doc(self): | 20 | def test_valid_doc(self): |
| 23 | """ check that a valid doc file leads to 0 exit status """ | 21 | """ check that a valid doc file leads to 0 exit status """ |
| 24 | for filename in ( | 22 | for filename in ( |
| @@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase): | @@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase): | ||
| 59 | 57 | ||
| 60 | def do_test_validity(self, args, expect_error=False): | 58 | def do_test_validity(self, args, expect_error=False): |
| 61 | """ helper for test_valid_doc[x] """ | 59 | """ helper for test_valid_doc[x] """ |
| 62 | - args = shlex.split(args) | ||
| 63 | - return_code = -1 | ||
| 64 | have_exception = False | 60 | have_exception = False |
| 65 | try: | 61 | try: |
| 66 | - return_code = msodde.main(args) | 62 | + msodde.process_file(args, msodde.FIELD_FILTER_BLACKLIST) |
| 67 | except Exception: | 63 | except Exception: |
| 68 | have_exception = True | 64 | have_exception = True |
| 69 | print_exc() | 65 | print_exc() |
| 70 | except SystemExit as exc: # sys.exit() was called | 66 | except SystemExit as exc: # sys.exit() was called |
| 71 | - return_code = exc.code | 67 | + have_exception = True |
| 72 | if exc.code is None: | 68 | if exc.code is None: |
| 73 | - return_code = 0 | 69 | + have_exception = False |
| 74 | 70 | ||
| 75 | - self.assertEqual(expect_error, have_exception or (return_code != 0), | ||
| 76 | - msg='Args={0}, expect={1}, exc={2}, return={3}' | ||
| 77 | - .format(args, expect_error, have_exception, | ||
| 78 | - return_code)) | 71 | + self.assertEqual(expect_error, have_exception, |
| 72 | + msg='Args={0}, expect={1}, exc={2}' | ||
| 73 | + .format(args, expect_error, have_exception)) | ||
| 79 | 74 | ||
| 80 | 75 | ||
| 81 | class TestDdeLinks(unittest.TestCase): | 76 | class TestDdeLinks(unittest.TestCase): |
| 82 | """ capture output of msodde and check dde-links are found correctly """ | 77 | """ capture output of msodde and check dde-links are found correctly """ |
| 83 | 78 | ||
| 84 | - def get_dde_from_output(self, capturer): | 79 | + @staticmethod |
| 80 | + def get_dde_from_output(output): | ||
| 85 | """ helper to read dde links from captured output | 81 | """ helper to read dde links from captured output |
| 86 | - | ||
| 87 | - duplicate in tests/msodde/test_csv | ||
| 88 | """ | 82 | """ |
| 89 | - have_start_line = False | ||
| 90 | - result = [] | ||
| 91 | - for line in capturer: | ||
| 92 | - if not line.strip(): | ||
| 93 | - continue # skip empty lines | ||
| 94 | - if have_start_line: | ||
| 95 | - result.append(line) | ||
| 96 | - elif line == 'DDE Links:': | ||
| 97 | - have_start_line = True | ||
| 98 | - | ||
| 99 | - self.assertTrue(have_start_line) # ensure output was complete | ||
| 100 | - return result | 83 | + return [o for o in output.splitlines()] |
| 101 | 84 | ||
| 102 | def test_with_dde(self): | 85 | def test_with_dde(self): |
| 103 | """ check that dde links appear on stdout """ | 86 | """ check that dde links appear on stdout """ |
| 104 | filename = 'dde-test-from-office2003.doc' | 87 | filename = 'dde-test-from-office2003.doc' |
| 105 | - with OutputCapture() as capturer: | ||
| 106 | - msodde.main([join(BASE_DIR, 'msodde', filename)]) | ||
| 107 | - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, | 88 | + output = msodde.process_file( |
| 89 | + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) | ||
| 90 | + self.assertNotEqual(len(self.get_dde_from_output(output)), 0, | ||
| 108 | msg='Found no dde links in output of ' + filename) | 91 | msg='Found no dde links in output of ' + filename) |
| 109 | 92 | ||
| 110 | def test_no_dde(self): | 93 | def test_no_dde(self): |
| 111 | """ check that no dde links appear on stdout """ | 94 | """ check that no dde links appear on stdout """ |
| 112 | filename = 'harmless-clean.doc' | 95 | filename = 'harmless-clean.doc' |
| 113 | - with OutputCapture() as capturer: | ||
| 114 | - msodde.main([join(BASE_DIR, 'msodde', filename)]) | ||
| 115 | - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, | 96 | + output = msodde.process_file( |
| 97 | + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) | ||
| 98 | + self.assertEqual(len(self.get_dde_from_output(output)), 0, | ||
| 116 | msg='Found dde links in output of ' + filename) | 99 | msg='Found dde links in output of ' + filename) |
| 117 | 100 | ||
| 118 | def test_with_dde_utf16le(self): | 101 | def test_with_dde_utf16le(self): |
| 119 | """ check that dde links appear on stdout """ | 102 | """ check that dde links appear on stdout """ |
| 120 | filename = 'dde-test-from-office2013-utf_16le-korean.doc' | 103 | filename = 'dde-test-from-office2013-utf_16le-korean.doc' |
| 121 | - with OutputCapture() as capturer: | ||
| 122 | - msodde.main([join(BASE_DIR, 'msodde', filename)]) | ||
| 123 | - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, | 104 | + output = msodde.process_file( |
| 105 | + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) | ||
| 106 | + self.assertNotEqual(len(self.get_dde_from_output(output)), 0, | ||
| 124 | msg='Found no dde links in output of ' + filename) | 107 | msg='Found no dde links in output of ' + filename) |
| 125 | 108 | ||
| 126 | def test_excel(self): | 109 | def test_excel(self): |
| 127 | """ check that dde links are found in excel 2007+ files """ | 110 | """ check that dde links are found in excel 2007+ files """ |
| 128 | expect = ['DDE-Link cmd /c calc.exe', ] | 111 | expect = ['DDE-Link cmd /c calc.exe', ] |
| 129 | for extn in 'xlsx', 'xlsm', 'xlsb': | 112 | for extn in 'xlsx', 'xlsm', 'xlsb': |
| 130 | - with OutputCapture() as capturer: | ||
| 131 | - msodde.main([join(BASE_DIR, 'msodde', 'dde-test.' + extn), ]) | ||
| 132 | - self.assertEqual(expect, self.get_dde_from_output(capturer), | 113 | + output = msodde.process_file( |
| 114 | + join(BASE_DIR, 'msodde', 'dde-test.' + extn), msodde.FIELD_FILTER_BLACKLIST) | ||
| 115 | + | ||
| 116 | + self.assertEqual(expect, self.get_dde_from_output(output), | ||
| 133 | msg='unexpected output for dde-test.{0}: {1}' | 117 | msg='unexpected output for dde-test.{0}: {1}' |
| 134 | - .format(extn, capturer.get_data())) | 118 | + .format(extn, output)) |
| 135 | 119 | ||
| 136 | def test_xml(self): | 120 | def test_xml(self): |
| 137 | """ check that dde in xml from word / excel is found """ | 121 | """ check that dde in xml from word / excel is found """ |
| 138 | for name_part in 'excel2003', 'word2003', 'word2007': | 122 | for name_part in 'excel2003', 'word2003', 'word2007': |
| 139 | filename = 'dde-in-' + name_part + '.xml' | 123 | filename = 'dde-in-' + name_part + '.xml' |
| 140 | - with OutputCapture() as capturer: | ||
| 141 | - msodde.main([join(BASE_DIR, 'msodde', filename), ]) | ||
| 142 | - links = self.get_dde_from_output(capturer) | 124 | + output = msodde.process_file( |
| 125 | + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) | ||
| 126 | + links = self.get_dde_from_output(output) | ||
| 143 | self.assertEqual(len(links), 1, 'found {0} dde-links in {1}' | 127 | self.assertEqual(len(links), 1, 'found {0} dde-links in {1}' |
| 144 | .format(len(links), filename)) | 128 | .format(len(links), filename)) |
| 145 | self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}' | 129 | self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}' |
| @@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase): | @@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase): | ||
| 150 | def test_clean_rtf_blacklist(self): | 134 | def test_clean_rtf_blacklist(self): |
| 151 | """ find a lot of hyperlinks in rtf spec """ | 135 | """ find a lot of hyperlinks in rtf spec """ |
| 152 | filename = 'RTF-Spec-1.7.rtf' | 136 | filename = 'RTF-Spec-1.7.rtf' |
| 153 | - with OutputCapture() as capturer: | ||
| 154 | - msodde.main([join(BASE_DIR, 'msodde', filename)]) | ||
| 155 | - self.assertEqual(len(self.get_dde_from_output(capturer)), 1413) | 137 | + output = msodde.process_file( |
| 138 | + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) | ||
| 139 | + self.assertEqual(len(self.get_dde_from_output(output)), 1413) | ||
| 156 | 140 | ||
| 157 | def test_clean_rtf_ddeonly(self): | 141 | def test_clean_rtf_ddeonly(self): |
| 158 | """ find no dde links in rtf spec """ | 142 | """ find no dde links in rtf spec """ |
| 159 | filename = 'RTF-Spec-1.7.rtf' | 143 | filename = 'RTF-Spec-1.7.rtf' |
| 160 | - with OutputCapture() as capturer: | ||
| 161 | - msodde.main(['-d', join(BASE_DIR, 'msodde', filename)]) | ||
| 162 | - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, | 144 | + output = msodde.process_file( |
| 145 | + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_DDE) | ||
| 146 | + self.assertEqual(len(self.get_dde_from_output(output)), 0, | ||
| 163 | msg='Found dde links in output of ' + filename) | 147 | msg='Found dde links in output of ' + filename) |
| 164 | 148 | ||
| 165 | 149 |
tests/msodde/test_csv.py
| @@ -9,7 +9,7 @@ import os | @@ -9,7 +9,7 @@ import os | ||
| 9 | from os.path import join | 9 | from os.path import join |
| 10 | 10 | ||
| 11 | from oletools import msodde | 11 | from oletools import msodde |
| 12 | -from tests.test_utils import OutputCapture, DATA_BASE_DIR | 12 | +from tests.test_utils import DATA_BASE_DIR |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | class TestCSV(unittest.TestCase): | 15 | class TestCSV(unittest.TestCase): |
| @@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase): | @@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase): | ||
| 69 | def test_file(self): | 69 | def test_file(self): |
| 70 | """ test simple small example file """ | 70 | """ test simple small example file """ |
| 71 | filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv') | 71 | filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv') |
| 72 | - with OutputCapture() as capturer: | ||
| 73 | - capturer.reload_module(msodde) # re-create logger | ||
| 74 | - ret_code = msodde.main([filename, ]) | ||
| 75 | - self.assertEqual(ret_code, 0) | ||
| 76 | - links = self.get_dde_from_output(capturer) | 72 | + output = msodde.process_file(filename, msodde.FIELD_FILTER_BLACKLIST) |
| 73 | + links = self.get_dde_from_output(output) | ||
| 77 | self.assertEqual(len(links), 1) | 74 | self.assertEqual(len(links), 1) |
| 78 | self.assertEqual(links[0], | 75 | self.assertEqual(links[0], |
| 79 | r"cmd '/k \..\..\..\Windows\System32\calc.exe'") | 76 | r"cmd '/k \..\..\..\Windows\System32\calc.exe'") |
| @@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase): | @@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase): | ||
| 91 | if self.DO_DEBUG: | 88 | if self.DO_DEBUG: |
| 92 | args += ['-l', 'debug'] | 89 | args += ['-l', 'debug'] |
| 93 | 90 | ||
| 94 | - with OutputCapture() as capturer: | ||
| 95 | - capturer.reload_module(msodde) # re-create logger | ||
| 96 | - ret_code = msodde.main(args) | ||
| 97 | - self.assertEqual(ret_code, 0, 'checking sample resulted in ' | ||
| 98 | - 'error:\n' + sample_text) | ||
| 99 | - return capturer | 91 | + processed_args = msodde.process_args(args) |
| 92 | + | ||
| 93 | + return msodde.process_file( | ||
| 94 | + processed_args.filepath, processed_args.field_filter_mode) | ||
| 100 | 95 | ||
| 101 | except Exception: | 96 | except Exception: |
| 102 | raise | 97 | raise |
| @@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase): | @@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase): | ||
| 111 | os.remove(filename) | 106 | os.remove(filename) |
| 112 | filename = None # just in case | 107 | filename = None # just in case |
| 113 | 108 | ||
| 114 | - def get_dde_from_output(self, capturer): | 109 | + @staticmethod |
| 110 | + def get_dde_from_output(output): | ||
| 115 | """ helper to read dde links from captured output | 111 | """ helper to read dde links from captured output |
| 116 | - | ||
| 117 | - duplicate in tests/msodde/test_basic | ||
| 118 | """ | 112 | """ |
| 119 | - have_start_line = False | ||
| 120 | - result = [] | ||
| 121 | - for line in capturer: | ||
| 122 | - if self.DO_DEBUG: | ||
| 123 | - print('captured: ' + line) | ||
| 124 | - if not line.strip(): | ||
| 125 | - continue # skip empty lines | ||
| 126 | - if have_start_line: | ||
| 127 | - result.append(line) | ||
| 128 | - elif line == 'DDE Links:': | ||
| 129 | - have_start_line = True | ||
| 130 | - | ||
| 131 | - self.assertTrue(have_start_line) # ensure output was complete | ||
| 132 | - return result | 113 | + return [o for o in output.splitlines()] |
| 133 | 114 | ||
| 134 | def test_regex(self): | 115 | def test_regex(self): |
| 135 | """ check that regex captures other ways to include dde commands | 116 | """ check that regex captures other ways to include dde commands |
tests/test_utils/__init__.py
tests/test_utils/output_capture.py deleted
| 1 | -""" class OutputCapture to test what scripts print to stdout """ | ||
| 2 | - | ||
| 3 | -from __future__ import print_function | ||
| 4 | -import sys | ||
| 5 | -import logging | ||
| 6 | - | ||
| 7 | - | ||
| 8 | -# python 2/3 version conflict: | ||
| 9 | -if sys.version_info.major <= 2: | ||
| 10 | - from StringIO import StringIO | ||
| 11 | - # reload is a builtin | ||
| 12 | -else: | ||
| 13 | - from io import StringIO | ||
| 14 | - if sys.version_info.minor < 4: | ||
| 15 | - from imp import reload | ||
| 16 | - else: | ||
| 17 | - from importlib import reload | ||
| 18 | - | ||
| 19 | - | ||
| 20 | -class OutputCapture: | ||
| 21 | - """ context manager that captures stdout | ||
| 22 | - | ||
| 23 | - use as follows:: | ||
| 24 | - | ||
| 25 | - with OutputCapture() as capturer: | ||
| 26 | - run_my_script(some_args) | ||
| 27 | - | ||
| 28 | - # either test line-by-line ... | ||
| 29 | - for line in capturer: | ||
| 30 | - some_test(line) | ||
| 31 | - # ...or test all output in one go | ||
| 32 | - some_test(capturer.get_data()) | ||
| 33 | - | ||
| 34 | - In order to solve issues with old logger instances still remembering closed | ||
| 35 | - StringIO instances as "their" stdout, logging is shutdown and restarted | ||
| 36 | - upon entering this Context Manager. This means that you may have to reload | ||
| 37 | - your module, as well. | ||
| 38 | - """ | ||
| 39 | - | ||
| 40 | - def __init__(self): | ||
| 41 | - self.buffer = StringIO() | ||
| 42 | - self.orig_stdout = None | ||
| 43 | - self.data = None | ||
| 44 | - | ||
| 45 | - def __enter__(self): | ||
| 46 | - # Avoid problems with old logger instances that still remember an old | ||
| 47 | - # closed StringIO as their sys.stdout | ||
| 48 | - logging.shutdown() | ||
| 49 | - reload(logging) | ||
| 50 | - | ||
| 51 | - # replace sys.stdout with own buffer. | ||
| 52 | - self.orig_stdout = sys.stdout | ||
| 53 | - sys.stdout = self.buffer | ||
| 54 | - return self | ||
| 55 | - | ||
| 56 | - def __exit__(self, exc_type, exc_value, traceback): | ||
| 57 | - sys.stdout = self.orig_stdout # re-set to original | ||
| 58 | - self.data = self.buffer.getvalue() | ||
| 59 | - self.buffer.close() # close buffer | ||
| 60 | - self.buffer = None | ||
| 61 | - | ||
| 62 | - if exc_type: # there has been an error | ||
| 63 | - print('Got error during output capture!') | ||
| 64 | - print('Print captured output and re-raise:') | ||
| 65 | - for line in self.data.splitlines(): | ||
| 66 | - print(line.rstrip()) # print output before re-raising | ||
| 67 | - | ||
| 68 | - def get_data(self): | ||
| 69 | - """ retrieve all the captured data """ | ||
| 70 | - if self.buffer is not None: | ||
| 71 | - return self.buffer.getvalue() | ||
| 72 | - elif self.data is not None: | ||
| 73 | - return self.data | ||
| 74 | - else: # should not be possible | ||
| 75 | - raise RuntimeError('programming error or someone messed with data!') | ||
| 76 | - | ||
| 77 | - def __iter__(self): | ||
| 78 | - for line in self.get_data().splitlines(): | ||
| 79 | - yield line | ||
| 80 | - | ||
| 81 | - def reload_module(self, mod): | ||
| 82 | - """ Wrapper around reload function for different python versions """ | ||
| 83 | - return reload(mod) |