diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py index aae92d659..6baf53e10 100755 --- a/tests/scripts/generate_psa_tests.py +++ b/tests/scripts/generate_psa_tests.py @@ -69,41 +69,13 @@ def hack_dependencies_not_implemented(dependencies: List[str]) -> None: for dep in dependencies): dependencies.append('DEPENDENCY_NOT_IMPLEMENTED_YET') -def test_case_for_key_type_not_supported( - verb: str, key_type: str, bits: int, - dependencies: List[str], - *args: str, - param_descr: str = '' -) -> test_case.TestCase: - """Return one test case exercising a key creation method - for an unsupported key type or size. - """ - hack_dependencies_not_implemented(dependencies) - tc = test_case.TestCase() - short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type) - adverb = 'not' if dependencies else 'never' - if param_descr: - adverb = param_descr + ' ' + adverb - tc.set_description('PSA {} {} {}-bit {} supported' - .format(verb, short_key_type, bits, adverb)) - tc.set_dependencies(dependencies) - tc.set_function(verb + '_not_supported') - tc.set_arguments([key_type] + list(args)) - return tc -class TestGenerator: - """Gather information and generate test data.""" +class Information: + """Gather information about PSA constructors.""" - def __init__(self, options): - self.test_suite_directory = self.get_option(options, 'directory', - 'tests/suites') + def __init__(self) -> None: self.constructors = self.read_psa_interface() - @staticmethod - def get_option(options, name: str, default: T) -> T: - value = getattr(options, name, None) - return default if value is None else value - @staticmethod def remove_unwanted_macros( constructors: macro_collector.PSAMacroCollector @@ -126,14 +98,34 @@ class TestGenerator: self.remove_unwanted_macros(constructors) return constructors - def write_test_data_file(self, basename: str, - test_cases: Iterable[test_case.TestCase]) -> None: - """Write the test cases to a .data file. - The output file is ``basename + '.data'`` in the test suite directory. - """ - filename = os.path.join(self.test_suite_directory, basename + '.data') - test_case.write_data_file(filename, test_cases) +def test_case_for_key_type_not_supported( + verb: str, key_type: str, bits: int, + dependencies: List[str], + *args: str, + param_descr: str = '' +) -> test_case.TestCase: + """Return one test case exercising a key creation method + for an unsupported key type or size. + """ + hack_dependencies_not_implemented(dependencies) + tc = test_case.TestCase() + short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type) + adverb = 'not' if dependencies else 'never' + if param_descr: + adverb = param_descr + ' ' + adverb + tc.set_description('PSA {} {} {}-bit {} supported' + .format(verb, short_key_type, bits, adverb)) + tc.set_dependencies(dependencies) + tc.set_function(verb + '_not_supported') + tc.set_arguments([key_type] + list(args)) + return tc + +class NotSupported: + """Generate test cases for when something is not supported.""" + + def __init__(self, info: Information) -> None: + self.constructors = info.constructors ALWAYS_SUPPORTED = frozenset([ 'PSA_KEY_TYPE_DERIVE', @@ -187,7 +179,7 @@ class TestGenerator: # To be added: derive return test_cases - def generate_not_supported(self) -> None: + def generate_not_supported(self) -> List[test_case.TestCase]: """Generate test cases that exercise the creation of keys of unsupported types.""" test_cases = [] for key_type in sorted(self.constructors.key_types): @@ -202,13 +194,37 @@ class TestGenerator: kt, param_descr='type') test_cases += self.test_cases_for_key_type_not_supported( kt, 0, param_descr='curve') + return test_cases + + +class TestGenerator: + """Generate test data.""" + + def __init__(self, options) -> None: + self.test_suite_directory = self.get_option(options, 'directory', + 'tests/suites') + self.info = Information() + + @staticmethod + def get_option(options, name: str, default: T) -> T: + value = getattr(options, name, None) + return default if value is None else value + + def write_test_data_file(self, basename: str, + test_cases: Iterable[test_case.TestCase]) -> None: + """Write the test cases to a .data file. + + The output file is ``basename + '.data'`` in the test suite directory. + """ + filename = os.path.join(self.test_suite_directory, basename + '.data') + test_case.write_data_file(filename, test_cases) + + def generate_all(self) -> None: + test_cases = NotSupported(self.info).generate_not_supported() self.write_test_data_file( 'test_suite_psa_crypto_not_supported.generated', test_cases) - def generate_all(self): - self.generate_not_supported() - def main(args): """Command line entry point.""" parser = argparse.ArgumentParser(description=__doc__)