diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index e64040802..e261b4f56 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -338,48 +338,52 @@ def collect_values(inputs, type_word, include_path=None, keep_c=False): include_path=include_path, keep_c=keep_c) return expressions, values -def do_test(options, inputs, type_word): - """Test psa_constant_names for the specified type. +class Tests: + """An object representing tests and their results.""" - Run program on names. - Use inputs to figure out what arguments to pass to macros that - take arguments. + def __init__(self, options): + self.options = options + self.count = 0 + self.errors = [] - Return ``(count, errors)`` where ``count`` is the number of expressions - that have been tested and ``errors`` is the list of errors that were - encountered. - """ - expressions, values = collect_values(inputs, type_word, - include_path=options.include, - keep_c=options.keep_c) - output = subprocess.check_output([options.program, type_word] + values) - outputs = output.decode('ascii').strip().split('\n') - errors = [(type_word, expr, value, output) - for (expr, value, output) in zip(expressions, values, outputs) - if normalize(expr) != normalize(output)] - return len(expressions), errors + def run_one(self, inputs, type_word): + """Test psa_constant_names for the specified type. -def report_errors(errors): - """Describe each case where the output is not as expected.""" - for type_word, name, value, output in errors: - print('For {} "{}", got "{}" (value: {})' - .format(type_word, name, output, value)) + Run the program on the names for this type. + Use the inputs to figure out what arguments to pass to macros that + take arguments. + """ + expressions, values = collect_values(inputs, type_word, + include_path=self.options.include, + keep_c=self.options.keep_c) + output = subprocess.check_output([self.options.program, type_word] + + values) + outputs = output.decode('ascii').strip().split('\n') + self.count += len(expressions) + for expr, value, output in zip(expressions, values, outputs): + if normalize(expr) != normalize(output): + self.errors.append((type_word, expr, value, output)) -def run_tests(options, inputs): - """Run psa_constant_names on all the gathered inputs. + def run_all(self, inputs): + """Run psa_constant_names on all the gathered inputs.""" + for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', + 'key_type', 'key_usage']: + self.run_one(inputs, type_word) - Return a tuple (count, errors) where count is the total number of inputs - that were tested and errors is the list of cases where the output was - not as expected. - """ - count = 0 - errors = [] - for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', - 'key_type', 'key_usage']: - c, e = do_test(options, inputs, type_word) - count += c - errors += e - return count, errors + def report(self, out): + """Describe each case where the output is not as expected. + + Write the errors to ``out``. + Also write a total. + """ + for type_word, name, value, output in self.errors: + out.write('For {} "{}", got "{}" (value: {})\n' + .format(type_word, name, output, value)) + out.write('{} test cases'.format(self.count)) + if self.errors: + out.write(', {} FAIL\n'.format(len(self.errors))) + else: + out.write(' PASS\n') HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] @@ -401,12 +405,10 @@ def main(): options = parser.parse_args() headers = [os.path.join(options.include[0], h) for h in HEADERS] inputs = gather_inputs(headers, TEST_SUITES) - count, errors = run_tests(options, inputs) - report_errors(errors) - if errors == []: - print('{} test cases PASS'.format(count)) - else: - print('{} test cases, {} FAIL'.format(count, len(errors))) + tests = Tests(options) + tests.run_all(inputs) + tests.report(sys.stdout) + if tests.errors: exit(1) if __name__ == '__main__':