diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py index d0d1f3f49..ff07ecd45 100755 --- a/scripts/generate_psa_constants.py +++ b/scripts/generate_psa_constants.py @@ -304,7 +304,7 @@ class CaseBuilder(macro_collector.PSAMacroCollector): def _make_key_usage_code(self): return '\n'.join([self._make_bit_test('usage', bit) - for bit in sorted(self.key_usages)]) + for bit in sorted(self.key_usage_flags)]) def write_file(self, output_file): """Generate the pretty-printer function code from the gathered diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py index c9e6ec337..a2192baf4 100644 --- a/scripts/mbedtls_dev/macro_collector.py +++ b/scripts/mbedtls_dev/macro_collector.py @@ -126,7 +126,7 @@ class PSAMacroEnumerator: return itertools.chain(*map(self.distribute_arguments, names)) -class PSAMacroCollector: +class PSAMacroCollector(PSAMacroEnumerator): """Collect PSA crypto macro definitions from C header files. """ @@ -138,18 +138,11 @@ class PSAMacroCollector: * include_intermediate: if true, include intermediate macros such as PSA_XXX_BASE that do not designate semantic values. """ + super().__init__() self.include_intermediate = include_intermediate - self.statuses = set() #type: Set[str] - self.key_types = set() #type: Set[str] self.key_types_from_curve = {} #type: Dict[str, str] self.key_types_from_group = {} #type: Dict[str, str] - self.ecc_curves = set() #type: Set[str] - self.dh_groups = set() #type: Set[str] - self.algorithms = set() #type: Set[str] - self.hash_algorithms = set() #type: Set[str] - self.ka_algorithms = set() #type: Set[str] self.algorithms_from_hash = {} #type: Dict[str, str] - self.key_usages = set() #type: Set[str] def is_internal_name(self, name: str) -> bool: """Whether this is an internal macro. Internal macros will be skipped.""" @@ -160,6 +153,30 @@ class PSAMacroCollector: return True return name.endswith('_FLAG') or name.endswith('_MASK') + def record_algorithm_subtype(self, name: str, expansion: str) -> None: + """Record the subtype of an algorithm constructor. + + Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm + is of a subtype that is tracked in its own set, add it to the relevant + set. + """ + # This code is very ad hoc and fragile. It should be replaced by + # something more robust. + if re.match(r'MAC(?:_|\Z)', name): + self.mac_algorithms.add(name) + elif re.match(r'KDF(?:_|\Z)', name): + self.kdf_algorithms.add(name) + elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion): + self.hash_algorithms.add(name) + elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion): + self.mac_algorithms.add(name) + elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion): + self.aead_algorithms.add(name) + elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion): + self.ka_algorithms.add(name) + elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion): + self.kdf_algorithms.add(name) + # "#define" followed by a macro name with either no parameters # or a single parameter and a non-empty expansion. # Grab the macro name in group 1, the parameter name if any in group 2 @@ -180,6 +197,8 @@ class PSAMacroCollector: return name, parameter, expansion = m.groups() expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion) + if parameter: + self.argspecs[name] = [parameter] if re.match(self._deprecated_definition_re, expansion): # Skip deprecated values, which are assumed to be # backward compatibility aliases that share @@ -207,12 +226,7 @@ class PSAMacroCollector: # Ad hoc skipping of duplicate names for some numerical values return self.algorithms.add(name) - # Ad hoc detection of hash algorithms - if re.search(r'0x020000[0-9A-Fa-f]{2}', expansion): - self.hash_algorithms.add(name) - # Ad hoc detection of key agreement algorithms - if re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion): - self.ka_algorithms.add(name) + self.record_algorithm_subtype(name, expansion) elif name.startswith('PSA_ALG_') and parameter == 'hash_alg': if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']: # A naming irregularity @@ -221,7 +235,7 @@ class PSAMacroCollector: tester = name[:8] + 'IS_' + name[8:] self.algorithms_from_hash[name] = tester elif name.startswith('PSA_KEY_USAGE_') and not parameter: - self.key_usages.add(name) + self.key_usage_flags.add(name) else: # Other macro without parameter return