Commit 0bc67280679a460b285db76ec1243115e07e74a0

Authored by Christian Herdtweck
1 parent b22b36c5

test: Use call_and_capture in olevba tests

Running main() within same interpreter capturing SystemExit and
sys.stdout/err always causes trouble (at least in unittest).
Fork another python process and capture from there, that is easier
and cleaner in my view.
Showing 1 changed file with 39 additions and 87 deletions
tests/olevba/test_basic.py
... ... @@ -3,44 +3,12 @@ Test basic functionality of olevba[3]
3 3 """
4 4  
5 5 import unittest
6   -import sys
7 6 import os
8 7 from os.path import join
9   -from contextlib import contextmanager
10   -try:
11   - from cStringIO import StringIO
12   -except ImportError: # py3:
13   - from io import StringIO
14   -if sys.version_info.major <= 2:
15   - from oletools import olevba
16   -else:
17   - from oletools import olevba3 as olevba
  8 +import re
18 9  
19 10 # Directory with test data, independent of current working directory
20   -from tests.test_utils import DATA_BASE_DIR
21   -
22   -
23   -@contextmanager
24   -def capture_output():
25   - """
26   - Temporarily replace stdout/stderr with buffers to capture output.
27   -
28   - Once we only support python>=3.4: this is already built into python as
29   - :py:func:`contextlib.redirect_stdout`.
30   -
31   - Not quite sure why, but seems to only work once per test function ...
32   - """
33   - orig_stdout = sys.stdout
34   - orig_stderr = sys.stderr
35   -
36   - try:
37   - sys.stdout = StringIO()
38   - sys.stderr = StringIO()
39   - yield sys.stdout, sys.stderr
40   -
41   - finally:
42   - sys.stdout = orig_stdout
43   - sys.stderr = orig_stderr
  11 +from tests.test_utils import DATA_BASE_DIR, call_and_capture
44 12  
45 13  
46 14 class TestOlevbaBasic(unittest.TestCase):
... ... @@ -57,62 +25,48 @@ class TestOlevbaBasic(unittest.TestCase):
57 25 def do_test_behaviour(self, filename):
58 26 """Helper for test_{text,empty}_behaviour."""
59 27 input_file = join(DATA_BASE_DIR, 'basic', filename)
60   - ret_code = -1
61   -
62   - # run olevba, capturing its output and return code
63   - with capture_output() as (stdout, stderr):
64   - with self.assertRaises(SystemExit) as raise_context:
65   - olevba.main([input_file, ])
66   - ret_code = raise_context.exception.code
67   -
68   - # check that return code is 0
69   - self.assertEqual(ret_code, 0)
70   -
71   - # check there are only warnings in stderr
72   - stderr = stderr.getvalue()
73   - skip_line = False
74   - for line in stderr.splitlines():
75   - if skip_line:
76   - skip_line = False
77   - continue
78   - self.assertTrue(line.startswith('WARNING ') or
79   - 'ResourceWarning' in line,
80   - msg='Line "{}" in stderr is unexpected for {}'\
81   - .format(line.rstrip(), filename))
82   - if 'ResourceWarning' in line:
83   - skip_line = True
84   - self.assertIn('not encrypted', stderr)
85   -
86   - # check stdout
87   - stdout = stdout.getvalue().lower()
88   - self.assertIn(input_file.lower(), stdout)
89   - self.assertIn('type: text', stdout)
90   - self.assertIn('no suspicious', stdout)
91   - self.assertNotIn('error', stdout)
92   - self.assertNotIn('warn', stdout)
  28 + output, _ = call_and_capture('olevba', args=(input_file, ))
  29 +
  30 + # check output
  31 + self.assertTrue(re.search(r'^Type:\s+Text\s*$', output, re.MULTILINE),
  32 + msg='"Type: Text" not found in output:\n' + output)
  33 + self.assertTrue(re.search(r'^No suspicious .+ found.$', output,
  34 + re.MULTILINE),
  35 + msg='"No suspicous...found" not found in output:\n' + \
  36 + output)
  37 + self.assertNotIn('error', output.lower())
  38 +
  39 + # check warnings
  40 + for line in output.splitlines():
  41 + if line.startswith('WARNING ') and 'encrypted' in line:
  42 + continue # encryption warnings are ok
  43 + elif 'warn' in line.lower():
  44 + raise self.fail('Found "warn" in output line: "{}"'
  45 + .format(line.rstrip()))
  46 + self.assertIn('not encrypted', output)
93 47  
94 48 def test_rtf_behaviour(self):
95 49 """Test behaviour of olevba when presented with an rtf file."""
96 50 input_file = join(DATA_BASE_DIR, 'msodde', 'RTF-Spec-1.7.rtf')
97   - ret_code = -1
98   -
99   - # run olevba, capturing its output and return code
100   - with capture_output() as (stdout, stderr):
101   - with self.assertRaises(SystemExit) as raise_context:
102   - olevba.main([input_file, ])
103   - ret_code = raise_context.exception.code
  51 + output, ret_code = call_and_capture('olevba', args=(input_file, ),
  52 + accept_nonzero_exit=True)
104 53  
105 54 # check that return code is olevba.RETURN_OPEN_ERROR
106 55 self.assertEqual(ret_code, 5)
107   - stdout = stdout.getvalue().lower()
108   - self.assertNotIn('error', stdout)
109   - self.assertNotIn('warn', stdout)
110 56  
111   - stderr = stderr.getvalue().lower()
112   - self.assertIn('fileopenerror', stderr)
113   - self.assertIn('is rtf', stderr)
114   - self.assertIn('rtfobj.py', stderr)
115   - self.assertIn('not encrypted', stderr)
  57 + # check output:
  58 + self.assertIn('FileOpenError', output)
  59 + self.assertIn('is RTF', output)
  60 + self.assertIn('rtfobj.py', output)
  61 + self.assertIn('not encrypted', output)
  62 +
  63 + # check warnings
  64 + for line in output.splitlines():
  65 + if line.startswith('WARNING ') and 'encrypted' in line:
  66 + continue # encryption warnings are ok
  67 + elif 'warn' in line.lower():
  68 + raise self.fail('Found "warn" in output line: "{}"'
  69 + .format(line.rstrip()))
116 70  
117 71 def test_crypt_return(self):
118 72 """
... ... @@ -136,11 +90,9 @@ class TestOlevbaBasic(unittest.TestCase):
136 90 continue
137 91 full_name = join(CRYPT_DIR, filename)
138 92 for args in ADD_ARGS:
139   - try:
140   - olevba.main(args + [full_name, ])
141   - self.fail('Olevba should have exited')
142   - except SystemExit as sys_exit:
143   - ret_code = sys_exit.code or 0 # sys_exit.code can be None
  93 + _, ret_code = call_and_capture('olevba',
  94 + args=[full_name, ] + args,
  95 + accept_nonzero_exit=True)
144 96 self.assertEqual(ret_code, CRYPT_RETURN_CODE,
145 97 msg='Wrong return code {} for args {}'\
146 98 .format(ret_code, args + [filename, ]))
... ...