Obey Python naming and method structure conventions

* Rename internal methods and fields to start with an underscore.
* Rename global constants to uppercase.
* Change methods that don't use self to be class methods or static
  methods as appropriate.

No behavior change in this commit.
This commit is contained in:
Gilles Peskine 2019-05-27 18:29:47 +02:00
parent 9d4d750041
commit 42a0a0aeea
2 changed files with 106 additions and 98 deletions

View file

@ -10,7 +10,7 @@ import os
import re import re
import sys import sys
output_template = '''\ OUTPUT_TEMPLATE = '''\
/* Automatically generated by generate_psa_constant.py. DO NOT EDIT. */ /* Automatically generated by generate_psa_constant.py. DO NOT EDIT. */
static const char *psa_strerror(psa_status_t status) static const char *psa_strerror(psa_status_t status)
@ -154,19 +154,19 @@ static int psa_snprint_key_usage(char *buffer, size_t buffer_size,
/* End of automatically generated file. */ /* End of automatically generated file. */
''' '''
key_type_from_curve_template = '''if (%(tester)s(type)) { KEY_TYPE_FROM_CURVE_TEMPLATE = '''if (%(tester)s(type)) {
append_with_curve(&buffer, buffer_size, &required_size, append_with_curve(&buffer, buffer_size, &required_size,
"%(builder)s", %(builder_length)s, "%(builder)s", %(builder_length)s,
PSA_KEY_TYPE_GET_CURVE(type)); PSA_KEY_TYPE_GET_CURVE(type));
} else ''' } else '''
key_type_from_group_template = '''if (%(tester)s(type)) { KEY_TYPE_FROM_GROUP_TEMPLATE = '''if (%(tester)s(type)) {
append_with_group(&buffer, buffer_size, &required_size, append_with_group(&buffer, buffer_size, &required_size,
"%(builder)s", %(builder_length)s, "%(builder)s", %(builder_length)s,
PSA_KEY_TYPE_GET_GROUP(type)); PSA_KEY_TYPE_GET_GROUP(type));
} else ''' } else '''
algorithm_from_hash_template = '''if (%(tester)s(core_alg)) { ALGORITHM_FROM_HASH_TEMPLATE = '''if (%(tester)s(core_alg)) {
append(&buffer, buffer_size, &required_size, append(&buffer, buffer_size, &required_size,
"%(builder)s(", %(builder_length)s + 1); "%(builder)s(", %(builder_length)s + 1);
append_with_alg(&buffer, buffer_size, &required_size, append_with_alg(&buffer, buffer_size, &required_size,
@ -175,7 +175,7 @@ algorithm_from_hash_template = '''if (%(tester)s(core_alg)) {
append(&buffer, buffer_size, &required_size, ")", 1); append(&buffer, buffer_size, &required_size, ")", 1);
} else ''' } else '''
bit_test_template = '''\ BIT_TEST_TEMPLATE = '''\
if (%(var)s & %(flag)s) { if (%(var)s & %(flag)s) {
if (required_size != 0) { if (required_size != 0) {
append(&buffer, buffer_size, &required_size, " | ", 3); append(&buffer, buffer_size, &required_size, " | ", 3);
@ -274,102 +274,104 @@ class MacroCollector:
for line in header_file: for line in header_file:
self.read_line(line) self.read_line(line)
def make_return_case(self, name): @staticmethod
def _make_return_case(name):
return 'case %(name)s: return "%(name)s";' % {'name': name} return 'case %(name)s: return "%(name)s";' % {'name': name}
def make_append_case(self, name): @staticmethod
def _make_append_case(name):
template = ('case %(name)s: ' template = ('case %(name)s: '
'append(&buffer, buffer_size, &required_size, "%(name)s", %(length)d); ' 'append(&buffer, buffer_size, &required_size, "%(name)s", %(length)d); '
'break;') 'break;')
return template % {'name': name, 'length': len(name)} return template % {'name': name, 'length': len(name)}
def make_inner_append_case(self, name): @staticmethod
template = ('case %(name)s: ' def _make_bit_test(var, flag):
'append(buffer, buffer_size, required_size, "%(name)s", %(length)d); ' return BIT_TEST_TEMPLATE % {'var': var,
'break;')
return template % {'name': name, 'length': len(name)}
def make_bit_test(self, var, flag):
return bit_test_template % {'var': var,
'flag': flag, 'flag': flag,
'length': len(flag)} 'length': len(flag)}
def make_status_cases(self): def _make_status_cases(self):
return '\n '.join(map(self.make_return_case, return '\n '.join(map(self._make_return_case,
sorted(self.statuses))) sorted(self.statuses)))
def make_ecc_curve_cases(self): def _make_ecc_curve_cases(self):
return '\n '.join(map(self.make_return_case, return '\n '.join(map(self._make_return_case,
sorted(self.ecc_curves))) sorted(self.ecc_curves)))
def make_dh_group_cases(self): def _make_dh_group_cases(self):
return '\n '.join(map(self.make_return_case, return '\n '.join(map(self._make_return_case,
sorted(self.dh_groups))) sorted(self.dh_groups)))
def make_key_type_cases(self): def _make_key_type_cases(self):
return '\n '.join(map(self.make_append_case, return '\n '.join(map(self._make_append_case,
sorted(self.key_types))) sorted(self.key_types)))
def make_key_type_from_curve_code(self, builder, tester): @staticmethod
return key_type_from_curve_template % {'builder': builder, def _make_key_type_from_curve_code(builder, tester):
return KEY_TYPE_FROM_CURVE_TEMPLATE % {'builder': builder,
'builder_length': len(builder), 'builder_length': len(builder),
'tester': tester} 'tester': tester}
def make_key_type_from_group_code(self, builder, tester): @staticmethod
return key_type_from_group_template % {'builder': builder, def _make_key_type_from_group_code(builder, tester):
return KEY_TYPE_FROM_GROUP_TEMPLATE % {'builder': builder,
'builder_length': len(builder), 'builder_length': len(builder),
'tester': tester} 'tester': tester}
def make_ecc_key_type_code(self): def _make_ecc_key_type_code(self):
d = self.key_types_from_curve d = self.key_types_from_curve
make = self.make_key_type_from_curve_code make = self._make_key_type_from_curve_code
return ''.join([make(k, d[k]) for k in sorted(d.keys())]) return ''.join([make(k, d[k]) for k in sorted(d.keys())])
def make_dh_key_type_code(self): def _make_dh_key_type_code(self):
d = self.key_types_from_group d = self.key_types_from_group
make = self.make_key_type_from_group_code make = self._make_key_type_from_group_code
return ''.join([make(k, d[k]) for k in sorted(d.keys())]) return ''.join([make(k, d[k]) for k in sorted(d.keys())])
def make_hash_algorithm_cases(self): def _make_hash_algorithm_cases(self):
return '\n '.join(map(self.make_return_case, return '\n '.join(map(self._make_return_case,
sorted(self.hash_algorithms))) sorted(self.hash_algorithms)))
def make_ka_algorithm_cases(self): def _make_ka_algorithm_cases(self):
return '\n '.join(map(self.make_return_case, return '\n '.join(map(self._make_return_case,
sorted(self.ka_algorithms))) sorted(self.ka_algorithms)))
def make_algorithm_cases(self): def _make_algorithm_cases(self):
return '\n '.join(map(self.make_append_case, return '\n '.join(map(self._make_append_case,
sorted(self.algorithms))) sorted(self.algorithms)))
def make_algorithm_from_hash_code(self, builder, tester): @staticmethod
return algorithm_from_hash_template % {'builder': builder, def _make_algorithm_from_hash_code(builder, tester):
return ALGORITHM_FROM_HASH_TEMPLATE % {'builder': builder,
'builder_length': len(builder), 'builder_length': len(builder),
'tester': tester} 'tester': tester}
def make_algorithm_code(self): def _make_algorithm_code(self):
d = self.algorithms_from_hash d = self.algorithms_from_hash
make = self.make_algorithm_from_hash_code make = self._make_algorithm_from_hash_code
return ''.join([make(k, d[k]) for k in sorted(d.keys())]) return ''.join([make(k, d[k]) for k in sorted(d.keys())])
def make_key_usage_code(self): def _make_key_usage_code(self):
return '\n'.join([self.make_bit_test('usage', bit) return '\n'.join([self._make_bit_test('usage', bit)
for bit in sorted(self.key_usages)]) for bit in sorted(self.key_usages)])
def write_file(self, output_file): def write_file(self, output_file):
"""Generate the pretty-printer function code from the gathered
constant definitions."""
data = {} data = {}
data['status_cases'] = self.make_status_cases() data['status_cases'] = self._make_status_cases()
data['ecc_curve_cases'] = self.make_ecc_curve_cases() data['ecc_curve_cases'] = self._make_ecc_curve_cases()
data['dh_group_cases'] = self.make_dh_group_cases() data['dh_group_cases'] = self._make_dh_group_cases()
data['key_type_cases'] = self.make_key_type_cases() data['key_type_cases'] = self._make_key_type_cases()
data['key_type_code'] = (self.make_ecc_key_type_code() + data['key_type_code'] = (self._make_ecc_key_type_code() +
self.make_dh_key_type_code()) self._make_dh_key_type_code())
data['hash_algorithm_cases'] = self.make_hash_algorithm_cases() data['hash_algorithm_cases'] = self._make_hash_algorithm_cases()
data['ka_algorithm_cases'] = self.make_ka_algorithm_cases() data['ka_algorithm_cases'] = self._make_ka_algorithm_cases()
data['algorithm_cases'] = self.make_algorithm_cases() data['algorithm_cases'] = self._make_algorithm_cases()
data['algorithm_code'] = self.make_algorithm_code() data['algorithm_code'] = self._make_algorithm_code()
data['key_usage_code'] = self.make_key_usage_code() data['key_usage_code'] = self._make_key_usage_code()
output_file.write(output_template % data) output_file.write(OUTPUT_TEMPLATE % data)
def generate_psa_constants(header_file_names, output_file_name): def generate_psa_constants(header_file_names, output_file_name):
collector = MacroCollector() collector = MacroCollector()

View file

@ -44,10 +44,10 @@ snippet annotates the exception with the file name and line number.'''
self.line_number = line_number self.line_number = line_number
yield content yield content
self.line_number = 'exit' self.line_number = 'exit'
def __exit__(self, type, value, traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
if type is not None: if exc_type is not None:
raise ReadFileLineException(self.filename, self.line_number) \ raise ReadFileLineException(self.filename, self.line_number) \
from value from exc_value
class Inputs: class Inputs:
'''Accumulate information about macros to test. '''Accumulate information about macros to test.
@ -98,7 +98,8 @@ Call this after parsing all the inputs.'''
self.arguments_for['curve'] = sorted(self.ecc_curves) self.arguments_for['curve'] = sorted(self.ecc_curves)
self.arguments_for['group'] = sorted(self.dh_groups) self.arguments_for['group'] = sorted(self.dh_groups)
def format_arguments(self, name, arguments): @staticmethod
def _format_arguments(name, arguments):
'''Format a macro call with arguments..''' '''Format a macro call with arguments..'''
return name + '(' + ', '.join(arguments) + ')' return name + '(' + ', '.join(arguments) + ')'
@ -117,27 +118,32 @@ where each argument takes each possible value at least once.'''
return return
argument_lists = [self.arguments_for[arg] for arg in argspec] argument_lists = [self.arguments_for[arg] for arg in argspec]
arguments = [values[0] for values in argument_lists] arguments = [values[0] for values in argument_lists]
yield self.format_arguments(name, arguments) yield self._format_arguments(name, arguments)
for i in range(len(arguments)): for i in range(len(arguments)):
for value in argument_lists[i][1:]: for value in argument_lists[i][1:]:
arguments[i] = value arguments[i] = value
yield self.format_arguments(name, arguments) yield self._format_arguments(name, arguments)
arguments[i] = argument_lists[0][0] arguments[i] = argument_lists[0][0]
except BaseException as e: except BaseException as e:
raise Exception('distribute_arguments({})'.format(name)) from e raise Exception('distribute_arguments({})'.format(name)) from e
_argument_split_re = re.compile(r' *, *')
@classmethod
def _argument_split(cls, arguments):
return re.split(cls._argument_split_re, arguments)
# Regex for interesting header lines. # Regex for interesting header lines.
# Groups: 1=macro name, 2=type, 3=argument list (optional). # Groups: 1=macro name, 2=type, 3=argument list (optional).
header_line_re = \ _header_line_re = \
re.compile(r'#define +' + re.compile(r'#define +' +
r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' + r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' +
r'(?:\(([^\n()]*)\))?') r'(?:\(([^\n()]*)\))?')
# Regex of macro names to exclude. # Regex of macro names to exclude.
excluded_name_re = re.compile('_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
# Additional excluded macros. # Additional excluded macros.
# PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
# currently doesn't support them. Deprecated errors are also excluded. # currently doesn't support them. Deprecated errors are also excluded.
excluded_names = set(['PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH', _excluded_names = set(['PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH',
'PSA_ALG_FULL_LENGTH_MAC', 'PSA_ALG_FULL_LENGTH_MAC',
'PSA_ALG_ECDH', 'PSA_ALG_ECDH',
'PSA_ALG_FFDH', 'PSA_ALG_FFDH',
@ -146,22 +152,22 @@ where each argument takes each possible value at least once.'''
'PSA_ERROR_EMPTY_SLOT', 'PSA_ERROR_EMPTY_SLOT',
'PSA_ERROR_INSUFFICIENT_CAPACITY', 'PSA_ERROR_INSUFFICIENT_CAPACITY',
]) ])
argument_split_re = re.compile(r' *, *')
def parse_header_line(self, line): def parse_header_line(self, line):
'''Parse a C header line, looking for "#define PSA_xxx".''' '''Parse a C header line, looking for "#define PSA_xxx".'''
m = re.match(self.header_line_re, line) m = re.match(self._header_line_re, line)
if not m: if not m:
return return
name = m.group(1) name = m.group(1)
if re.search(self.excluded_name_re, name) or \ if re.search(self._excluded_name_re, name) or \
name in self.excluded_names: name in self._excluded_names:
return return
dest = self.table_by_prefix.get(m.group(2)) dest = self.table_by_prefix.get(m.group(2))
if dest is None: if dest is None:
return return
dest.add(name) dest.add(name)
if m.group(3): if m.group(3):
self.argspecs[name] = re.split(self.argument_split_re, m.group(3)) self.argspecs[name] = self._argument_split(m.group(3))
def parse_header(self, filename): def parse_header(self, filename):
'''Parse a C header file, looking for "#define PSA_xxx".''' '''Parse a C header file, looking for "#define PSA_xxx".'''
@ -193,12 +199,12 @@ where each argument takes each possible value at least once.'''
# Regex matching a *.data line containing a test function call and # Regex matching a *.data line containing a test function call and
# its arguments. The actual definition is partly positional, but this # its arguments. The actual definition is partly positional, but this
# regex is good enough in practice. # regex is good enough in practice.
test_case_line_re = re.compile('(?!depends_on:)(\w+):([^\n :][^:\n]*)') _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
def parse_test_cases(self, filename): def parse_test_cases(self, filename):
'''Parse a test case file (*.data), looking for algorithm metadata tests.''' '''Parse a test case file (*.data), looking for algorithm metadata tests.'''
with read_file_lines(filename) as lines: with read_file_lines(filename) as lines:
for line in lines: for line in lines:
m = re.match(self.test_case_line_re, line) m = re.match(self._test_case_line_re, line)
if m: if m:
self.add_test_case_line(m.group(1), m.group(2)) self.add_test_case_line(m.group(1), m.group(2))
@ -221,9 +227,9 @@ def remove_file_if_exists(filename):
except: except:
pass pass
def run_c(options, type, names): def run_c(options, type_word, names):
'''Generate and run a program to print out numerical values for names.''' '''Generate and run a program to print out numerical values for names.'''
if type == 'status': if type_word == 'status':
cast_to = 'long' cast_to = 'long'
printf_format = '%ld' printf_format = '%ld'
else: else:
@ -232,7 +238,7 @@ def run_c(options, type, names):
c_name = None c_name = None
exe_name = None exe_name = None
try: try:
c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type), c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type_word),
suffix='.c', suffix='.c',
dir='programs/psa') dir='programs/psa')
exe_suffix = '.exe' if platform.system() == 'Windows' else '' exe_suffix = '.exe' if platform.system() == 'Windows' else ''
@ -240,7 +246,7 @@ def run_c(options, type, names):
remove_file_if_exists(exe_name) remove_file_if_exists(exe_name)
c_file = os.fdopen(c_fd, 'w', encoding='ascii') c_file = os.fdopen(c_fd, 'w', encoding='ascii')
c_file.write('/* Generated by test_psa_constant_names.py for {} values */' c_file.write('/* Generated by test_psa_constant_names.py for {} values */'
.format(type)) .format(type_word))
c_file.write(''' c_file.write('''
#include <stdio.h> #include <stdio.h>
#include <psa/crypto.h> #include <psa/crypto.h>
@ -260,7 +266,7 @@ int main(void)
['-o', exe_name, c_name]) ['-o', exe_name, c_name])
if options.keep_c: if options.keep_c:
sys.stderr.write('List of {} tests kept at {}\n' sys.stderr.write('List of {} tests kept at {}\n'
.format(type, c_name)) .format(type_word, c_name))
else: else:
os.remove(c_name) os.remove(c_name)
output = subprocess.check_output([exe_name]) output = subprocess.check_output([exe_name])
@ -268,31 +274,31 @@ int main(void)
finally: finally:
remove_file_if_exists(exe_name) remove_file_if_exists(exe_name)
normalize_strip_re = re.compile(r'\s+') NORMALIZE_STRIP_RE = re.compile(r'\s+')
def normalize(expr): def normalize(expr):
'''Normalize the C expression so as not to care about trivial differences. '''Normalize the C expression so as not to care about trivial differences.
Currently "trivial differences" means whitespace.''' Currently "trivial differences" means whitespace.'''
expr = re.sub(normalize_strip_re, '', expr, len(expr)) expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr))
return expr.strip().split('\n') return expr.strip().split('\n')
def do_test(options, inputs, type, names): def do_test(options, inputs, type_word, names):
'''Test psa_constant_names for the specified type. '''Test psa_constant_names for the specified type.
Run program on names. Run program on names.
Use inputs to figure out what arguments to pass to macros that take arguments.''' Use inputs to figure out what arguments to pass to macros that take arguments.'''
names = sorted(itertools.chain(*map(inputs.distribute_arguments, names))) names = sorted(itertools.chain(*map(inputs.distribute_arguments, names)))
values = run_c(options, type, names) values = run_c(options, type_word, names)
output = subprocess.check_output([options.program, type] + values) output = subprocess.check_output([options.program, type_word] + values)
outputs = output.decode('ascii').strip().split('\n') outputs = output.decode('ascii').strip().split('\n')
errors = [(type, name, value, output) errors = [(type_word, name, value, output)
for (name, value, output) in zip(names, values, outputs) for (name, value, output) in zip(names, values, outputs)
if normalize(name) != normalize(output)] if normalize(name) != normalize(output)]
return len(names), errors return len(names), errors
def report_errors(errors): def report_errors(errors):
'''Describe each case where the output is not as expected.''' '''Describe each case where the output is not as expected.'''
for type, name, value, output in errors: for type_word, name, value, output in errors:
print('For {} "{}", got "{}" (value: {})' print('For {} "{}", got "{}" (value: {})'
.format(type, name, output, value)) .format(type_word, name, output, value))
def run_tests(options, inputs): def run_tests(options, inputs):
'''Run psa_constant_names on all the gathered inputs. '''Run psa_constant_names on all the gathered inputs.
@ -301,13 +307,13 @@ that were tested and errors is the list of cases where the output was
not as expected.''' not as expected.'''
count = 0 count = 0
errors = [] errors = []
for type, names in [('status', inputs.statuses), for type_word, names in [('status', inputs.statuses),
('algorithm', inputs.algorithms), ('algorithm', inputs.algorithms),
('ecc_curve', inputs.ecc_curves), ('ecc_curve', inputs.ecc_curves),
('dh_group', inputs.dh_groups), ('dh_group', inputs.dh_groups),
('key_type', inputs.key_types), ('key_type', inputs.key_types),
('key_usage', inputs.key_usage_flags)]: ('key_usage', inputs.key_usage_flags)]:
c, e = do_test(options, inputs, type, names) c, e = do_test(options, inputs, type_word, names)
count += c count += c
errors += e errors += e
return count, errors return count, errors