diff --git a/tests/msodde/test_basic.py b/tests/msodde/test_basic.py index 1966a2f..ac3121c 100644 --- a/tests/msodde/test_basic.py +++ b/tests/msodde/test_basic.py @@ -10,15 +10,13 @@ from __future__ import print_function import unittest from oletools import msodde -from tests.test_utils import OutputCapture, DATA_BASE_DIR as BASE_DIR -import shlex +from tests.test_utils import DATA_BASE_DIR as BASE_DIR from os.path import join from traceback import print_exc class TestReturnCode(unittest.TestCase): """ check return codes and exception behaviour (not text output) """ - def test_valid_doc(self): """ check that a valid doc file leads to 0 exit status """ for filename in ( @@ -59,87 +57,73 @@ class TestReturnCode(unittest.TestCase): def do_test_validity(self, args, expect_error=False): """ helper for test_valid_doc[x] """ - args = shlex.split(args) - return_code = -1 have_exception = False try: - return_code = msodde.main(args) + msodde.process_file(args, msodde.FIELD_FILTER_BLACKLIST) except Exception: have_exception = True print_exc() except SystemExit as exc: # sys.exit() was called - return_code = exc.code + have_exception = True if exc.code is None: - return_code = 0 + have_exception = False - self.assertEqual(expect_error, have_exception or (return_code != 0), - msg='Args={0}, expect={1}, exc={2}, return={3}' - .format(args, expect_error, have_exception, - return_code)) + self.assertEqual(expect_error, have_exception, + msg='Args={0}, expect={1}, exc={2}' + .format(args, expect_error, have_exception)) class TestDdeLinks(unittest.TestCase): """ capture output of msodde and check dde-links are found correctly """ - def get_dde_from_output(self, capturer): + @staticmethod + def get_dde_from_output(output): """ helper to read dde links from captured output - - duplicate in tests/msodde/test_csv """ - have_start_line = False - result = [] - for line in capturer: - if not line.strip(): - continue # skip empty lines - if have_start_line: - result.append(line) - elif line == 'DDE Links:': - have_start_line = True - - self.assertTrue(have_start_line) # ensure output was complete - return result + return [o for o in output.splitlines()] def test_with_dde(self): """ check that dde links appear on stdout """ filename = 'dde-test-from-office2003.doc' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertNotEqual(len(self.get_dde_from_output(output)), 0, msg='Found no dde links in output of ' + filename) def test_no_dde(self): """ check that no dde links appear on stdout """ filename = 'harmless-clean.doc' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertEqual(len(self.get_dde_from_output(output)), 0, msg='Found dde links in output of ' + filename) def test_with_dde_utf16le(self): """ check that dde links appear on stdout """ filename = 'dde-test-from-office2013-utf_16le-korean.doc' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertNotEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertNotEqual(len(self.get_dde_from_output(output)), 0, msg='Found no dde links in output of ' + filename) def test_excel(self): """ check that dde links are found in excel 2007+ files """ expect = ['DDE-Link cmd /c calc.exe', ] for extn in 'xlsx', 'xlsm', 'xlsb': - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', 'dde-test.' + extn), ]) - self.assertEqual(expect, self.get_dde_from_output(capturer), + output = msodde.process_file( + join(BASE_DIR, 'msodde', 'dde-test.' + extn), msodde.FIELD_FILTER_BLACKLIST) + + self.assertEqual(expect, self.get_dde_from_output(output), msg='unexpected output for dde-test.{0}: {1}' - .format(extn, capturer.get_data())) + .format(extn, output)) def test_xml(self): """ check that dde in xml from word / excel is found """ for name_part in 'excel2003', 'word2003', 'word2007': filename = 'dde-in-' + name_part + '.xml' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename), ]) - links = self.get_dde_from_output(capturer) + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + links = self.get_dde_from_output(output) self.assertEqual(len(links), 1, 'found {0} dde-links in {1}' .format(len(links), filename)) self.assertTrue('cmd' in links[0], 'no "cmd" in dde-link for {0}' @@ -150,16 +134,16 @@ class TestDdeLinks(unittest.TestCase): def test_clean_rtf_blacklist(self): """ find a lot of hyperlinks in rtf spec """ filename = 'RTF-Spec-1.7.rtf' - with OutputCapture() as capturer: - msodde.main([join(BASE_DIR, 'msodde', filename)]) - self.assertEqual(len(self.get_dde_from_output(capturer)), 1413) + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_BLACKLIST) + self.assertEqual(len(self.get_dde_from_output(output)), 1413) def test_clean_rtf_ddeonly(self): """ find no dde links in rtf spec """ filename = 'RTF-Spec-1.7.rtf' - with OutputCapture() as capturer: - msodde.main(['-d', join(BASE_DIR, 'msodde', filename)]) - self.assertEqual(len(self.get_dde_from_output(capturer)), 0, + output = msodde.process_file( + join(BASE_DIR, 'msodde', filename), msodde.FIELD_FILTER_DDE) + self.assertEqual(len(self.get_dde_from_output(output)), 0, msg='Found dde links in output of ' + filename) diff --git a/tests/msodde/test_csv.py b/tests/msodde/test_csv.py index 2c1e7f1..92131b4 100644 --- a/tests/msodde/test_csv.py +++ b/tests/msodde/test_csv.py @@ -9,7 +9,7 @@ import os from os.path import join from oletools import msodde -from tests.test_utils import OutputCapture, DATA_BASE_DIR +from tests.test_utils import DATA_BASE_DIR class TestCSV(unittest.TestCase): @@ -69,11 +69,8 @@ class TestCSV(unittest.TestCase): def test_file(self): """ test simple small example file """ filename = join(DATA_BASE_DIR, 'msodde', 'dde-in-csv.csv') - with OutputCapture() as capturer: - capturer.reload_module(msodde) # re-create logger - ret_code = msodde.main([filename, ]) - self.assertEqual(ret_code, 0) - links = self.get_dde_from_output(capturer) + output = msodde.process_file(filename, msodde.FIELD_FILTER_BLACKLIST) + links = self.get_dde_from_output(output) self.assertEqual(len(links), 1) self.assertEqual(links[0], r"cmd '/k \..\..\..\Windows\System32\calc.exe'") @@ -91,12 +88,10 @@ class TestCSV(unittest.TestCase): if self.DO_DEBUG: args += ['-l', 'debug'] - with OutputCapture() as capturer: - capturer.reload_module(msodde) # re-create logger - ret_code = msodde.main(args) - self.assertEqual(ret_code, 0, 'checking sample resulted in ' - 'error:\n' + sample_text) - return capturer + processed_args = msodde.process_args(args) + + return msodde.process_file( + processed_args.filepath, processed_args.field_filter_mode) except Exception: raise @@ -111,25 +106,11 @@ class TestCSV(unittest.TestCase): os.remove(filename) filename = None # just in case - def get_dde_from_output(self, capturer): + @staticmethod + def get_dde_from_output(output): """ helper to read dde links from captured output - - duplicate in tests/msodde/test_basic """ - have_start_line = False - result = [] - for line in capturer: - if self.DO_DEBUG: - print('captured: ' + line) - if not line.strip(): - continue # skip empty lines - if have_start_line: - result.append(line) - elif line == 'DDE Links:': - have_start_line = True - - self.assertTrue(have_start_line) # ensure output was complete - return result + return [o for o in output.splitlines()] def test_regex(self): """ check that regex captures other ways to include dde commands