Add type annotations

Prepare to move InputsForTest to macro_collector.py.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine 2021-03-29 20:37:40 +02:00
parent 4d59f01bcd
commit e30ad46f91

View file

@ -28,13 +28,15 @@ import os
import re
import subprocess
import sys
from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
import scripts_path # pylint: disable=unused-import
from mbedtls_dev import c_build_helper
from mbedtls_dev import macro_collector
from mbedtls_dev.macro_collector import PSAMacroEnumerator
from mbedtls_dev import typing_util
class ReadFileLineException(Exception):
def __init__(self, filename, line_number):
def __init__(self, filename: str, line_number: Union[int, str]) -> None:
message = 'in {} at {}'.format(filename, line_number)
super(ReadFileLineException, self).__init__(message)
self.filename = filename
@ -59,36 +61,37 @@ class read_file_lines:
except that if process(line) raises an exception, then the read_file_lines
snippet annotates the exception with the file name and line number.
"""
def __init__(self, filename, binary=False):
def __init__(self, filename: str, binary: bool = False) -> None:
self.filename = filename
self.line_number = 'entry'
self.generator = None
self.line_number = 'entry' #type: Union[int, str]
self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
self.binary = binary
def __enter__(self):
def __enter__(self) -> 'read_file_lines':
self.generator = enumerate(open(self.filename,
'rb' if self.binary else 'r'))
return self
def __iter__(self):
def __iter__(self) -> Iterator[str]:
assert self.generator is not None
for line_number, content in self.generator:
self.line_number = line_number
yield content
self.line_number = 'exit'
def __exit__(self, exc_type, exc_value, exc_traceback):
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
if exc_type is not None:
raise ReadFileLineException(self.filename, self.line_number) \
from exc_value
class InputsForTest(macro_collector.PSAMacroEnumerator):
class InputsForTest(PSAMacroEnumerator):
# pylint: disable=too-many-instance-attributes
"""Accumulate information about macros to test.
enumerate
This includes macro names as well as information about their arguments
when applicable.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.all_declared = set()
self.all_declared = set() #type: Set[str]
# Sets of names per type
self.statuses.add('PSA_SUCCESS')
self.algorithms.add('0xffffffff')
@ -118,7 +121,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
'DH_GROUP': self.dh_groups,
'KEY_TYPE': self.key_types,
'KEY_USAGE': self.key_usage_flags,
}
} #type: Dict[str, Set[str]]
# Test functions
self.table_by_test_function = {
# Any function ending in _algorithm also gets added to
@ -141,13 +144,13 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
'asymmetric_signature_wildcard': [self.algorithms],
'asymmetric_encryption_algorithm': [],
'other_algorithm': [],
}
} #type: Dict[str, List[Set[str]]]
self.arguments_for['mac_length'] += ['1', '63']
self.arguments_for['min_mac_length'] += ['1', '63']
self.arguments_for['tag_length'] += ['1', '63']
self.arguments_for['min_tag_length'] += ['1', '63']
def get_names(self, type_word):
def get_names(self, type_word: str) -> Set[str]:
"""Return the set of known names of values of the given type."""
return {
'status': self.statuses,
@ -176,7 +179,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
# auxiliary macros.
'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
])
def parse_header_line(self, line):
def parse_header_line(self, line: str) -> None:
"""Parse a C header line, looking for "#define PSA_xxx"."""
m = re.match(self._header_line_re, line)
if not m:
@ -193,8 +196,8 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
if m.group(3):
self.argspecs[name] = self._argument_split(m.group(3))
_nonascii_re = re.compile(rb'[^\x00-\x7f]+')
def parse_header(self, filename):
_nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
def parse_header(self, filename: str) -> None:
"""Parse a C header file, looking for "#define PSA_xxx"."""
with read_file_lines(filename, binary=True) as lines:
for line in lines:
@ -202,19 +205,19 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
self.parse_header_line(line)
_macro_identifier_re = re.compile(r'[A-Z]\w+')
def generate_undeclared_names(self, expr):
def generate_undeclared_names(self, expr: str) -> Iterable[str]:
for name in re.findall(self._macro_identifier_re, expr):
if name not in self.all_declared:
yield name
def accept_test_case_line(self, function, argument):
def accept_test_case_line(self, function: str, argument: str) -> bool:
#pylint: disable=unused-argument
undeclared = list(self.generate_undeclared_names(argument))
if undeclared:
raise Exception('Undeclared names in test case', undeclared)
return True
def add_test_case_line(self, function, argument):
def add_test_case_line(self, function: str, argument: str) -> None:
"""Parse a test case data line, looking for algorithm metadata tests."""
sets = []
if function.endswith('_algorithm'):
@ -234,7 +237,7 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
# its arguments. The actual definition is partly positional, but this
# regex is good enough in practice.
_test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
def parse_test_cases(self, filename):
def parse_test_cases(self, filename: str) -> None:
"""Parse a test case file (*.data), looking for algorithm metadata tests."""
with read_file_lines(filename) as lines:
for line in lines:
@ -242,7 +245,9 @@ class InputsForTest(macro_collector.PSAMacroEnumerator):
if m:
self.add_test_case_line(m.group(1), m.group(2))
def gather_inputs(headers, test_suites, inputs_class=InputsForTest):
def gather_inputs(headers: Iterable[str],
test_suites: Iterable[str],
inputs_class=InputsForTest) -> PSAMacroEnumerator:
"""Read the list of inputs to test psa_constant_names with."""
inputs = inputs_class()
for header in headers:
@ -252,7 +257,10 @@ def gather_inputs(headers, test_suites, inputs_class=InputsForTest):
inputs.gather_arguments()
return inputs
def run_c(type_word, expressions, include_path=None, keep_c=False):
def run_c(type_word: str,
expressions: Iterable[str],
include_path: Optional[str] = None,
keep_c: bool = False) -> List[str]:
"""Generate and run a program to print out numerical values of C expressions."""
if type_word == 'status':
cast_to = 'long'
@ -271,14 +279,17 @@ def run_c(type_word, expressions, include_path=None, keep_c=False):
)
NORMALIZE_STRIP_RE = re.compile(r'\s+')
def normalize(expr):
def normalize(expr: str) -> str:
"""Normalize the C expression so as not to care about trivial differences.
Currently "trivial differences" means whitespace.
"""
return re.sub(NORMALIZE_STRIP_RE, '', expr)
def collect_values(inputs, type_word, include_path=None, keep_c=False):
def collect_values(inputs: InputsForTest,
type_word: str,
include_path: Optional[str] = None,
keep_c: bool = False) -> Tuple[List[str], List[str]]:
"""Generate expressions using known macro names and calculate their values.
Return a list of pairs of (expr, value) where expr is an expression and
@ -296,12 +307,12 @@ class Tests:
Error = namedtuple('Error',
['type', 'expression', 'value', 'output'])
def __init__(self, options):
def __init__(self, options) -> None:
self.options = options
self.count = 0
self.errors = []
self.errors = [] #type: List[Tests.Error]
def run_one(self, inputs, type_word):
def run_one(self, inputs: InputsForTest, type_word: str) -> None:
"""Test psa_constant_names for the specified type.
Run the program on the names for this type.
@ -311,9 +322,10 @@ class Tests:
expressions, values = collect_values(inputs, type_word,
include_path=self.options.include,
keep_c=self.options.keep_c)
output = subprocess.check_output([self.options.program, type_word] +
values)
outputs = output.decode('ascii').strip().split('\n')
output_bytes = subprocess.check_output([self.options.program,
type_word] + values)
output = output_bytes.decode('ascii')
outputs = output.strip().split('\n')
self.count += len(expressions)
for expr, value, output in zip(expressions, values, outputs):
if self.options.show:
@ -324,13 +336,13 @@ class Tests:
value=value,
output=output))
def run_all(self, inputs):
def run_all(self, inputs: InputsForTest) -> None:
"""Run psa_constant_names on all the gathered inputs."""
for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
'key_type', 'key_usage']:
self.run_one(inputs, type_word)
def report(self, out):
def report(self, out: typing_util.Writable) -> None:
"""Describe each case where the output is not as expected.
Write the errors to ``out``.