Add type annotations

Prepare to move InputsForTest to macro_collector.py.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine 2021-03-29 20:37:40 +02:00
parent 4d59f01bcd
commit e30ad46f91

View file

@ -28,13 +28,15 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
import scripts_path # pylint: disable=unused-import import scripts_path # pylint: disable=unused-import
from mbedtls_dev import c_build_helper from mbedtls_dev import c_build_helper
from mbedtls_dev import macro_collector from mbedtls_dev.macro_collector import PSAMacroEnumerator
from mbedtls_dev import typing_util
class ReadFileLineException(Exception): class ReadFileLineException(Exception):
def __init__(self, filename, line_number): def __init__(self, filename: str, line_number: Union[int, str]) -> None:
message = 'in {} at {}'.format(filename, line_number) message = 'in {} at {}'.format(filename, line_number)
super(ReadFileLineException, self).__init__(message) super(ReadFileLineException, self).__init__(message)
self.filename = filename self.filename = filename
@ -59,36 +61,37 @@ class read_file_lines:
except that if process(line) raises an exception, then the read_file_lines except that if process(line) raises an exception, then the read_file_lines
snippet annotates the exception with the file name and line number. snippet annotates the exception with the file name and line number.
""" """
def __init__(self, filename, binary=False): def __init__(self, filename: str, binary: bool = False) -> None:
self.filename = filename self.filename = filename
self.line_number = 'entry' self.line_number = 'entry' #type: Union[int, str]
self.generator = None self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
self.binary = binary self.binary = binary
def __enter__(self): def __enter__(self) -> 'read_file_lines':
self.generator = enumerate(open(self.filename, self.generator = enumerate(open(self.filename,
'rb' if self.binary else 'r')) 'rb' if self.binary else 'r'))
return self return self
def __iter__(self): def __iter__(self) -> Iterator[str]:
assert self.generator is not None
for line_number, content in self.generator: for line_number, content in self.generator:
self.line_number = line_number self.line_number = line_number
yield content yield content
self.line_number = 'exit' self.line_number = 'exit'
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
if exc_type is not None: if exc_type is not None:
raise ReadFileLineException(self.filename, self.line_number) \ raise ReadFileLineException(self.filename, self.line_number) \
from exc_value from exc_value
class InputsForTest(macro_collector.PSAMacroEnumerator): class InputsForTest(PSAMacroEnumerator):
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
"""Accumulate information about macros to test. """Accumulate information about macros to test.
enumerate
This includes macro names as well as information about their arguments This includes macro names as well as information about their arguments
when applicable. when applicable.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.all_declared = set() self.all_declared = set() #type: Set[str]
# Sets of names per type # Sets of names per type
self.statuses.add('PSA_SUCCESS') self.statuses.add('PSA_SUCCESS')
self.algorithms.add('0xffffffff') self.algorithms.add('0xffffffff')
@ -118,7 +121,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
'DH_GROUP': self.dh_groups, 'DH_GROUP': self.dh_groups,
'KEY_TYPE': self.key_types, 'KEY_TYPE': self.key_types,
'KEY_USAGE': self.key_usage_flags, 'KEY_USAGE': self.key_usage_flags,
} } #type: Dict[str, Set[str]]
# Test functions # Test functions
self.table_by_test_function = { self.table_by_test_function = {
# Any function ending in _algorithm also gets added to # Any function ending in _algorithm also gets added to
@ -141,13 +144,13 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
'asymmetric_signature_wildcard': [self.algorithms], 'asymmetric_signature_wildcard': [self.algorithms],
'asymmetric_encryption_algorithm': [], 'asymmetric_encryption_algorithm': [],
'other_algorithm': [], 'other_algorithm': [],
} } #type: Dict[str, List[Set[str]]]
self.arguments_for['mac_length'] += ['1', '63'] self.arguments_for['mac_length'] += ['1', '63']
self.arguments_for['min_mac_length'] += ['1', '63'] self.arguments_for['min_mac_length'] += ['1', '63']
self.arguments_for['tag_length'] += ['1', '63'] self.arguments_for['tag_length'] += ['1', '63']
self.arguments_for['min_tag_length'] += ['1', '63'] self.arguments_for['min_tag_length'] += ['1', '63']
def get_names(self, type_word): def get_names(self, type_word: str) -> Set[str]:
"""Return the set of known names of values of the given type.""" """Return the set of known names of values of the given type."""
return { return {
'status': self.statuses, 'status': self.statuses,
@ -176,7 +179,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
# auxiliary macros. # auxiliary macros.
'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
]) ])
def parse_header_line(self, line): def parse_header_line(self, line: str) -> None:
"""Parse a C header line, looking for "#define PSA_xxx".""" """Parse a C header line, looking for "#define PSA_xxx"."""
m = re.match(self._header_line_re, line) m = re.match(self._header_line_re, line)
if not m: if not m:
@ -193,8 +196,8 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
if m.group(3): if m.group(3):
self.argspecs[name] = self._argument_split(m.group(3)) self.argspecs[name] = self._argument_split(m.group(3))
_nonascii_re = re.compile(rb'[^\x00-\x7f]+') _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
def parse_header(self, filename): def parse_header(self, filename: str) -> None:
"""Parse a C header file, looking for "#define PSA_xxx".""" """Parse a C header file, looking for "#define PSA_xxx"."""
with read_file_lines(filename, binary=True) as lines: with read_file_lines(filename, binary=True) as lines:
for line in lines: for line in lines:
@ -202,19 +205,19 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
self.parse_header_line(line) self.parse_header_line(line)
_macro_identifier_re = re.compile(r'[A-Z]\w+') _macro_identifier_re = re.compile(r'[A-Z]\w+')
def generate_undeclared_names(self, expr): def generate_undeclared_names(self, expr: str) -> Iterable[str]:
for name in re.findall(self._macro_identifier_re, expr): for name in re.findall(self._macro_identifier_re, expr):
if name not in self.all_declared: if name not in self.all_declared:
yield name yield name
def accept_test_case_line(self, function, argument): def accept_test_case_line(self, function: str, argument: str) -> bool:
#pylint: disable=unused-argument #pylint: disable=unused-argument
undeclared = list(self.generate_undeclared_names(argument)) undeclared = list(self.generate_undeclared_names(argument))
if undeclared: if undeclared:
raise Exception('Undeclared names in test case', undeclared) raise Exception('Undeclared names in test case', undeclared)
return True return True
def add_test_case_line(self, function, argument): def add_test_case_line(self, function: str, argument: str) -> None:
"""Parse a test case data line, looking for algorithm metadata tests.""" """Parse a test case data line, looking for algorithm metadata tests."""
sets = [] sets = []
if function.endswith('_algorithm'): if function.endswith('_algorithm'):
@ -234,7 +237,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
# its arguments. The actual definition is partly positional, but this # its arguments. The actual definition is partly positional, but this
# regex is good enough in practice. # regex is good enough in practice.
_test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
def parse_test_cases(self, filename): def parse_test_cases(self, filename: str) -> None:
"""Parse a test case file (*.data), looking for algorithm metadata tests.""" """Parse a test case file (*.data), looking for algorithm metadata tests."""
with read_file_lines(filename) as lines: with read_file_lines(filename) as lines:
for line in lines: for line in lines:
@ -242,7 +245,9 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
if m: if m:
self.add_test_case_line(m.group(1), m.group(2)) self.add_test_case_line(m.group(1), m.group(2))
def gather_inputs(headers, test_suites, inputs_class=InputsForTest): def gather_inputs(headers: Iterable[str],
test_suites: Iterable[str],
inputs_class=InputsForTest) -> PSAMacroEnumerator:
"""Read the list of inputs to test psa_constant_names with.""" """Read the list of inputs to test psa_constant_names with."""
inputs = inputs_class() inputs = inputs_class()
for header in headers: for header in headers:
@ -252,7 +257,10 @@ def gather_inputs(headers, test_suites, inputs_class=InputsForTest):
inputs.gather_arguments() inputs.gather_arguments()
return inputs return inputs
def run_c(type_word, expressions, include_path=None, keep_c=False): def run_c(type_word: str,
expressions: Iterable[str],
include_path: Optional[str] = None,
keep_c: bool = False) -> List[str]:
"""Generate and run a program to print out numerical values of C expressions.""" """Generate and run a program to print out numerical values of C expressions."""
if type_word == 'status': if type_word == 'status':
cast_to = 'long' cast_to = 'long'
@ -271,14 +279,17 @@ def run_c(type_word, expressions, include_path=None, keep_c=False):
) )
NORMALIZE_STRIP_RE = re.compile(r'\s+') NORMALIZE_STRIP_RE = re.compile(r'\s+')
def normalize(expr): def normalize(expr: str) -> str:
"""Normalize the C expression so as not to care about trivial differences. """Normalize the C expression so as not to care about trivial differences.
Currently "trivial differences" means whitespace. Currently "trivial differences" means whitespace.
""" """
return re.sub(NORMALIZE_STRIP_RE, '', expr) return re.sub(NORMALIZE_STRIP_RE, '', expr)
def collect_values(inputs, type_word, include_path=None, keep_c=False): def collect_values(inputs: InputsForTest,
type_word: str,
include_path: Optional[str] = None,
keep_c: bool = False) -> Tuple[List[str], List[str]]:
"""Generate expressions using known macro names and calculate their values. """Generate expressions using known macro names and calculate their values.
Return a list of pairs of (expr, value) where expr is an expression and Return a list of pairs of (expr, value) where expr is an expression and
@ -296,12 +307,12 @@ class Tests:
Error = namedtuple('Error', Error = namedtuple('Error',
['type', 'expression', 'value', 'output']) ['type', 'expression', 'value', 'output'])
def __init__(self, options): def __init__(self, options) -> None:
self.options = options self.options = options
self.count = 0 self.count = 0
self.errors = [] self.errors = [] #type: List[Tests.Error]
def run_one(self, inputs, type_word): def run_one(self, inputs: InputsForTest, type_word: str) -> None:
"""Test psa_constant_names for the specified type. """Test psa_constant_names for the specified type.
Run the program on the names for this type. Run the program on the names for this type.
@ -311,9 +322,10 @@ class Tests:
expressions, values = collect_values(inputs, type_word, expressions, values = collect_values(inputs, type_word,
include_path=self.options.include, include_path=self.options.include,
keep_c=self.options.keep_c) keep_c=self.options.keep_c)
output = subprocess.check_output([self.options.program, type_word] + output_bytes = subprocess.check_output([self.options.program,
values) type_word] + values)
outputs = output.decode('ascii').strip().split('\n') output = output_bytes.decode('ascii')
outputs = output.strip().split('\n')
self.count += len(expressions) self.count += len(expressions)
for expr, value, output in zip(expressions, values, outputs): for expr, value, output in zip(expressions, values, outputs):
if self.options.show: if self.options.show:
@ -324,13 +336,13 @@ class Tests:
value=value, value=value,
output=output)) output=output))
def run_all(self, inputs): def run_all(self, inputs: InputsForTest) -> None:
"""Run psa_constant_names on all the gathered inputs.""" """Run psa_constant_names on all the gathered inputs."""
for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
'key_type', 'key_usage']: 'key_type', 'key_usage']:
self.run_one(inputs, type_word) self.run_one(inputs, type_word)
def report(self, out): def report(self, out: typing_util.Writable) -> None:
"""Describe each case where the output is not as expected. """Describe each case where the output is not as expected.
Write the errors to ``out``. Write the errors to ``out``.