diff --git a/tests/olevba/test_crypto.py b/tests/olevba/test_crypto.py index b2dc84d..787a1ad 100644 --- a/tests/olevba/test_crypto.py +++ b/tests/olevba/test_crypto.py @@ -2,13 +2,11 @@ import sys import unittest -import os from os.path import join as pjoin -from subprocess import check_output, CalledProcessError import json from collections import OrderedDict -from tests.test_utils import DATA_BASE_DIR, SOURCE_BASE_DIR +from tests.test_utils import DATA_BASE_DIR, call_and_capture from oletools import crypto @@ -34,25 +32,11 @@ class OlevbaCryptoWriteProtectTest(unittest.TestCase): """ def test_autostart(self): """Check that autostart macro is found in xls[mb] sample file.""" - # create a PYTHONPATH environment var to prefer our olevba - env = os.environ - try: - env['PYTHONPATH'] = SOURCE_BASE_DIR + os.pathsep + \ - os.environ['PYTHONPATH'] - except KeyError: - env['PYTHONPATH'] = SOURCE_BASE_DIR - for suffix in 'xlsm', 'xlsb': example_file = pjoin( DATA_BASE_DIR, 'encrypted', 'autostart-encrypt-standardpassword.' + suffix) - try: - output = check_output([sys.executable, '-m', 'olevba', '-j', - example_file], - universal_newlines=True, env=env) - except CalledProcessError as err: - print(err.output) - raise + output, _ = call_and_capture('olevba', args=('-j', example_file)) data = json.loads(output, object_pairs_hook=OrderedDict) # debug: json.dump(data, sys.stdout, indent=4) self.assertEqual(len(data), 4) diff --git a/tests/test_utils/utils.py b/tests/test_utils/utils.py index b3b9005..8a35936 100644 --- a/tests/test_utils/utils.py +++ b/tests/test_utils/utils.py @@ -2,8 +2,10 @@ """Utils generally useful for unittests.""" +import sys import os from os.path import dirname, join, abspath +from subprocess import check_output, STDOUT, CalledProcessError # Base dir of project, contains subdirs "tests" and "oletools" and README.md @@ -14,3 +16,48 @@ DATA_BASE_DIR = join(PROJECT_ROOT, 'tests', 'test-data') # Directory with source code SOURCE_BASE_DIR = join(PROJECT_ROOT, 'oletools') + + +def call_and_capture(module, args=None, accept_nonzero_exit=False): + """ + Run module as script, capturing and returning output and return code. + + This is the best way to capture a module's stdout and stderr; trying to + modify sys.stdout/sys.stderr to StringIO-Buffers frequently causes trouble. + + Only drawback sofar: stdout and stderr are merged into one (which is + what users see on their shell as well). + + :param str module: name of module to test, e.g. `olevba` + :param args: arguments for module's main function + :param bool fail_nonzero: Raise error if command returns non-0 return code + :returns: ret_code, output + :rtype: int, str + """ + # create a PYTHONPATH environment var to prefer our current code + env = os.environ.copy() + try: + env['PYTHONPATH'] = SOURCE_BASE_DIR + os.pathsep + \ + os.environ['PYTHONPATH'] + except KeyError: + env['PYTHONPATH'] = SOURCE_BASE_DIR + + # ensure args is a tuple + my_args = tuple(args) if args else () + + ret_code = -1 + try: + output = check_output((sys.executable, '-m', module) + my_args, + universal_newlines=True, env=env, + stderr=STDOUT) + ret_code = 0 + + except CalledProcessError as err: + if accept_nonzero_exit: + ret_code = err.returncode + output = err.output + else: + print(err.output) + raise + + return output, ret_code