From e30ad46f91d3d44cba0aee6f97fee819c9d1ca85 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Mon, 29 Mar 2021 20:37:40 +0200 Subject: [PATCH] Add type annotations Prepare to move InputsForTest to macro_collector.py. Signed-off-by: Gilles Peskine --- tests/scripts/test_psa_constant_names.py | 80 ++++++++++++++---------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index 15b83d8e8..237a344eb 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -28,13 +28,15 @@ import os import re import subprocess import sys +from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union import scripts_path # pylint: disable=unused-import from mbedtls_dev import c_build_helper -from mbedtls_dev import macro_collector +from mbedtls_dev.macro_collector import PSAMacroEnumerator +from mbedtls_dev import typing_util class ReadFileLineException(Exception): - def __init__(self, filename, line_number): + def __init__(self, filename: str, line_number: Union[int, str]) -> None: message = 'in {} at {}'.format(filename, line_number) super(ReadFileLineException, self).__init__(message) self.filename = filename @@ -59,36 +61,37 @@ class read_file_lines: except that if process(line) raises an exception, then the read_file_lines snippet annotates the exception with the file name and line number. """ - def __init__(self, filename, binary=False): + def __init__(self, filename: str, binary: bool = False) -> None: self.filename = filename - self.line_number = 'entry' - self.generator = None + self.line_number = 'entry' #type: Union[int, str] + self.generator = None #type: Optional[Iterable[Tuple[int, str]]] self.binary = binary - def __enter__(self): + def __enter__(self) -> 'read_file_lines': self.generator = enumerate(open(self.filename, 'rb' if self.binary else 'r')) return self - def __iter__(self): + def __iter__(self) -> Iterator[str]: + assert self.generator is not None for line_number, content in self.generator: self.line_number = line_number yield content self.line_number = 'exit' - def __exit__(self, exc_type, exc_value, exc_traceback): + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: if exc_type is not None: raise ReadFileLineException(self.filename, self.line_number) \ from exc_value -class InputsForTest(macro_collector.PSAMacroEnumerator): +class InputsForTest(PSAMacroEnumerator): # pylint: disable=too-many-instance-attributes """Accumulate information about macros to test. - +enumerate This includes macro names as well as information about their arguments when applicable. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self.all_declared = set() + self.all_declared = set() #type: Set[str] # Sets of names per type self.statuses.add('PSA_SUCCESS') self.algorithms.add('0xffffffff') @@ -118,7 +121,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): 'DH_GROUP': self.dh_groups, 'KEY_TYPE': self.key_types, 'KEY_USAGE': self.key_usage_flags, - } + } #type: Dict[str, Set[str]] # Test functions self.table_by_test_function = { # Any function ending in _algorithm also gets added to @@ -141,13 +144,13 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): 'asymmetric_signature_wildcard': [self.algorithms], 'asymmetric_encryption_algorithm': [], 'other_algorithm': [], - } + } #type: Dict[str, List[Set[str]]] self.arguments_for['mac_length'] += ['1', '63'] self.arguments_for['min_mac_length'] += ['1', '63'] self.arguments_for['tag_length'] += ['1', '63'] self.arguments_for['min_tag_length'] += ['1', '63'] - def get_names(self, type_word): + def get_names(self, type_word: str) -> Set[str]: """Return the set of known names of values of the given type.""" return { 'status': self.statuses, @@ -176,7 +179,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): # auxiliary macros. 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', ]) - def parse_header_line(self, line): + def parse_header_line(self, line: str) -> None: """Parse a C header line, looking for "#define PSA_xxx".""" m = re.match(self._header_line_re, line) if not m: @@ -193,8 +196,8 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): if m.group(3): self.argspecs[name] = self._argument_split(m.group(3)) - _nonascii_re = re.compile(rb'[^\x00-\x7f]+') - def parse_header(self, filename): + _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern + def parse_header(self, filename: str) -> None: """Parse a C header file, looking for "#define PSA_xxx".""" with read_file_lines(filename, binary=True) as lines: for line in lines: @@ -202,19 +205,19 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): self.parse_header_line(line) _macro_identifier_re = re.compile(r'[A-Z]\w+') - def generate_undeclared_names(self, expr): + def generate_undeclared_names(self, expr: str) -> Iterable[str]: for name in re.findall(self._macro_identifier_re, expr): if name not in self.all_declared: yield name - def accept_test_case_line(self, function, argument): + def accept_test_case_line(self, function: str, argument: str) -> bool: #pylint: disable=unused-argument undeclared = list(self.generate_undeclared_names(argument)) if undeclared: raise Exception('Undeclared names in test case', undeclared) return True - def add_test_case_line(self, function, argument): + def add_test_case_line(self, function: str, argument: str) -> None: """Parse a test case data line, looking for algorithm metadata tests.""" sets = [] if function.endswith('_algorithm'): @@ -234,7 +237,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): # its arguments. The actual definition is partly positional, but this # regex is good enough in practice. _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') - def parse_test_cases(self, filename): + def parse_test_cases(self, filename: str) -> None: """Parse a test case file (*.data), looking for algorithm metadata tests.""" with read_file_lines(filename) as lines: for line in lines: @@ -242,7 +245,9 @@ class InputsForTest(macro_collector.PSAMacroEnumerator): if m: self.add_test_case_line(m.group(1), m.group(2)) -def gather_inputs(headers, test_suites, inputs_class=InputsForTest): +def gather_inputs(headers: Iterable[str], + test_suites: Iterable[str], + inputs_class=InputsForTest) -> PSAMacroEnumerator: """Read the list of inputs to test psa_constant_names with.""" inputs = inputs_class() for header in headers: @@ -252,7 +257,10 @@ def gather_inputs(headers, test_suites, inputs_class=InputsForTest): inputs.gather_arguments() return inputs -def run_c(type_word, expressions, include_path=None, keep_c=False): +def run_c(type_word: str, + expressions: Iterable[str], + include_path: Optional[str] = None, + keep_c: bool = False) -> List[str]: """Generate and run a program to print out numerical values of C expressions.""" if type_word == 'status': cast_to = 'long' @@ -271,14 +279,17 @@ def run_c(type_word, expressions, include_path=None, keep_c=False): ) NORMALIZE_STRIP_RE = re.compile(r'\s+') -def normalize(expr): +def normalize(expr: str) -> str: """Normalize the C expression so as not to care about trivial differences. Currently "trivial differences" means whitespace. """ return re.sub(NORMALIZE_STRIP_RE, '', expr) -def collect_values(inputs, type_word, include_path=None, keep_c=False): +def collect_values(inputs: InputsForTest, + type_word: str, + include_path: Optional[str] = None, + keep_c: bool = False) -> Tuple[List[str], List[str]]: """Generate expressions using known macro names and calculate their values. Return a list of pairs of (expr, value) where expr is an expression and @@ -296,12 +307,12 @@ class Tests: Error = namedtuple('Error', ['type', 'expression', 'value', 'output']) - def __init__(self, options): + def __init__(self, options) -> None: self.options = options self.count = 0 - self.errors = [] + self.errors = [] #type: List[Tests.Error] - def run_one(self, inputs, type_word): + def run_one(self, inputs: InputsForTest, type_word: str) -> None: """Test psa_constant_names for the specified type. Run the program on the names for this type. @@ -311,9 +322,10 @@ class Tests: expressions, values = collect_values(inputs, type_word, include_path=self.options.include, keep_c=self.options.keep_c) - output = subprocess.check_output([self.options.program, type_word] + - values) - outputs = output.decode('ascii').strip().split('\n') + output_bytes = subprocess.check_output([self.options.program, + type_word] + values) + output = output_bytes.decode('ascii') + outputs = output.strip().split('\n') self.count += len(expressions) for expr, value, output in zip(expressions, values, outputs): if self.options.show: @@ -324,13 +336,13 @@ class Tests: value=value, output=output)) - def run_all(self, inputs): + def run_all(self, inputs: InputsForTest) -> None: """Run psa_constant_names on all the gathered inputs.""" for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 'key_type', 'key_usage']: self.run_one(inputs, type_word) - def report(self, out): + def report(self, out: typing_util.Writable) -> None: """Describe each case where the output is not as expected. Write the errors to ``out``.