Merge pull request #5605 from gilles-peskine-arm/storage-format-check-mononicity-2.28

Backport 2.28: Check storage format tests for regressions
This commit is contained in:
Gilles Peskine 2022-03-07 17:02:29 +01:00 committed by GitHub
commit ecd119d241
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 264 additions and 88 deletions

View file

@ -34,7 +34,7 @@ Use a similar approach for files other than keys where possible and relevant.
Test cases should normally not be removed from the code base: if something has worked before, it should keep working in future versions, so we should keep testing it. Test cases should normally not be removed from the code base: if something has worked before, it should keep working in future versions, so we should keep testing it.
This cannot be enforced solely by looking at a single version of Mbed TLS, since there would be no indication that more test cases used to exist. It can only be enforced through review of library changes. The review may be assisted by a tool that compares the old and the new version, in the same way that `abi-check.py` compares the library's API and ABI. This cannot be enforced solely by looking at a single version of Mbed TLS, since there would be no indication that more test cases used to exist. It can only be enforced through review of library changes. The review is be assisted by a tool that compares the old and the new version, which is implemented in `scripts/abi_check.py`. This tool fails the CI if load-and-check test case disappears (changed test cases are raised as false positives).
If the way certain keys are stored changes, and we don't deliberately decide to stop supporting old keys (which should only be done by retiring a version of the storage format), then we should keep the corresponding test cases in load-only mode: create a file with the expected content, load it and check the data that it contains. If the way certain keys are stored changes, and we don't deliberately decide to stop supporting old keys (which should only be done by retiring a version of the storage format), then we should keep the corresponding test cases in load-only mode: create a file with the expected content, load it and check the data that it contains.

View file

@ -1,14 +1,26 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Purpose This script compares the interfaces of two versions of Mbed TLS, looking
for backward incompatibilities between two different Git revisions within
an Mbed TLS repository. It must be run from the root of a Git working tree.
This script is a small wrapper around the abi-compliance-checker and For the source (API) and runtime (ABI) interface compatibility, this script
abi-dumper tools, applying them to compare the ABI and API of the library is a small wrapper around the abi-compliance-checker and abi-dumper tools,
files from two different Git revisions within an Mbed TLS repository. applying them to compare the header and library files.
The results of the comparison are either formatted as HTML and stored at
a configurable location, or are given as a brief list of problems. For the storage format, this script compares the automatically generated
Returns 0 on success, 1 on ABI/API non-compliance, and 2 if there is an error storage tests and the manual read tests, and complains if there is a
while running the script. Note: must be run from Mbed TLS root. reduction in coverage. A change in test data will be signaled as a
coverage reduction since the old test data is no longer present. A change in
how test data is presented will be signaled as well; this would be a false
positive.
The results of the API/ABI comparison are either formatted as HTML and stored
at a configurable location, or are given as a brief list of problems.
Returns 0 on success, 1 on non-compliance, and 2 if there is an error
while running the script.
You must run this test from an Mbed TLS root.
""" """
# Copyright The Mbed TLS Contributors # Copyright The Mbed TLS Contributors
@ -26,7 +38,9 @@ while running the script. Note: must be run from Mbed TLS root.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import os import os
import re
import sys import sys
import traceback import traceback
import shutil import shutil
@ -51,6 +65,9 @@ class AbiChecker:
configuration.report_dir: directory for output files configuration.report_dir: directory for output files
configuration.keep_all_reports: if false, delete old reports configuration.keep_all_reports: if false, delete old reports
configuration.brief: if true, output shorter report to stdout configuration.brief: if true, output shorter report to stdout
configuration.check_abi: if true, compare ABIs
configuration.check_api: if true, compare APIs
configuration.check_storage: if true, compare storage format tests
configuration.skip_file: path to file containing symbols and types to skip configuration.skip_file: path to file containing symbols and types to skip
""" """
self.repo_path = "." self.repo_path = "."
@ -64,6 +81,11 @@ class AbiChecker:
self.old_version = old_version self.old_version = old_version
self.new_version = new_version self.new_version = new_version
self.skip_file = configuration.skip_file self.skip_file = configuration.skip_file
self.check_abi = configuration.check_abi
self.check_api = configuration.check_api
if self.check_abi != self.check_api:
raise Exception('Checking API without ABI or vice versa is not supported')
self.check_storage_tests = configuration.check_storage
self.brief = configuration.brief self.brief = configuration.brief
self.git_command = "git" self.git_command = "git"
self.make_command = "make" self.make_command = "make"
@ -208,6 +230,93 @@ class AbiChecker:
self.log.debug(abi_dump_output.decode("utf-8")) self.log.debug(abi_dump_output.decode("utf-8"))
version.abi_dumps[mbed_module] = output_path version.abi_dumps[mbed_module] = output_path
@staticmethod
def _normalize_storage_test_case_data(line):
"""Eliminate cosmetic or irrelevant details in storage format test cases."""
line = re.sub(r'\s+', r'', line)
return line
def _read_storage_tests(self,
directory,
filename,
is_generated,
storage_tests):
"""Record storage tests from the given file.
Populate the storage_tests dictionary with test cases read from
filename under directory.
"""
at_paragraph_start = True
description = None
full_path = os.path.join(directory, filename)
with open(full_path) as fd:
for line_number, line in enumerate(fd, 1):
line = line.strip()
if not line:
at_paragraph_start = True
continue
if line.startswith('#'):
continue
if at_paragraph_start:
description = line.strip()
at_paragraph_start = False
continue
if line.startswith('depends_on:'):
continue
# We've reached a test case data line
test_case_data = self._normalize_storage_test_case_data(line)
if not is_generated:
# In manual test data, only look at read tests.
function_name = test_case_data.split(':', 1)[0]
if 'read' not in function_name.split('_'):
continue
metadata = SimpleNamespace(
filename=filename,
line_number=line_number,
description=description
)
storage_tests[test_case_data] = metadata
@staticmethod
def _list_generated_test_data_files(git_worktree_path):
"""List the generated test data files."""
output = subprocess.check_output(
['tests/scripts/generate_psa_tests.py', '--list'],
cwd=git_worktree_path,
).decode('ascii')
return [line for line in output.split('\n') if line]
def _get_storage_format_tests(self, version, git_worktree_path):
"""Record the storage format tests for the specified git version.
The storage format tests are the test suite data files whose name
contains "storage_format".
The version must be checked out at git_worktree_path.
This function creates or updates the generated data files.
"""
# Existing test data files. This may be missing some automatically
# generated files if they haven't been generated yet.
storage_data_files = set(glob.glob(
'tests/suites/test_suite_*storage_format*.data'
))
# Discover and (re)generate automatically generated data files.
to_be_generated = set()
for filename in self._list_generated_test_data_files(git_worktree_path):
if 'storage_format' in filename:
storage_data_files.add(filename)
to_be_generated.add(filename)
subprocess.check_call(
['tests/scripts/generate_psa_tests.py'] + sorted(to_be_generated),
cwd=git_worktree_path,
)
for test_file in sorted(storage_data_files):
self._read_storage_tests(git_worktree_path,
test_file,
test_file in to_be_generated,
version.storage_tests)
def _cleanup_worktree(self, git_worktree_path): def _cleanup_worktree(self, git_worktree_path):
"""Remove the specified git worktree.""" """Remove the specified git worktree."""
shutil.rmtree(git_worktree_path) shutil.rmtree(git_worktree_path)
@ -219,11 +328,14 @@ class AbiChecker:
self.log.debug(worktree_output.decode("utf-8")) self.log.debug(worktree_output.decode("utf-8"))
def _get_abi_dump_for_ref(self, version): def _get_abi_dump_for_ref(self, version):
"""Generate the ABI dumps for the specified git revision.""" """Generate the interface information for the specified git revision."""
git_worktree_path = self._get_clean_worktree_for_git_revision(version) git_worktree_path = self._get_clean_worktree_for_git_revision(version)
self._update_git_submodules(git_worktree_path, version) self._update_git_submodules(git_worktree_path, version)
self._build_shared_libraries(git_worktree_path, version) if self.check_abi:
self._get_abi_dumps_from_shared_libraries(version) self._build_shared_libraries(git_worktree_path, version)
self._get_abi_dumps_from_shared_libraries(version)
if self.check_storage_tests:
self._get_storage_format_tests(version, git_worktree_path)
self._cleanup_worktree(git_worktree_path) self._cleanup_worktree(git_worktree_path)
def _remove_children_with_tag(self, parent, tag): def _remove_children_with_tag(self, parent, tag):
@ -301,6 +413,37 @@ class AbiChecker:
os.remove(output_path) os.remove(output_path)
return True return True
@staticmethod
def _is_storage_format_compatible(old_tests, new_tests,
compatibility_report):
"""Check whether all tests present in old_tests are also in new_tests.
Append a message regarding compatibility to compatibility_report.
"""
missing = frozenset(old_tests.keys()).difference(new_tests.keys())
for test_data in sorted(missing):
metadata = old_tests[test_data]
compatibility_report.append(
'Test case from {} line {} "{}" has disappeared: {}'.format(
metadata.filename, metadata.line_number,
metadata.description, test_data
)
)
compatibility_report.append(
'FAIL: {}/{} storage format test cases have changed or disappeared.'.format(
len(missing), len(old_tests)
) if missing else
'PASS: All {} storage format test cases are preserved.'.format(
len(old_tests)
)
)
compatibility_report.append(
'Info: number of storage format tests cases: {} -> {}.'.format(
len(old_tests), len(new_tests)
)
)
return not missing
def get_abi_compatibility_report(self): def get_abi_compatibility_report(self):
"""Generate a report of the differences between the reference ABI """Generate a report of the differences between the reference ABI
and the new ABI. ABI dumps from self.old_version and self.new_version and the new ABI. ABI dumps from self.old_version and self.new_version
@ -310,12 +453,22 @@ class AbiChecker:
self._pretty_revision(self.new_version) self._pretty_revision(self.new_version)
)] )]
compliance_return_code = 0 compliance_return_code = 0
shared_modules = list(set(self.old_version.modules.keys()) &
set(self.new_version.modules.keys())) if self.check_abi:
for mbed_module in shared_modules: shared_modules = list(set(self.old_version.modules.keys()) &
if not self._is_library_compatible(mbed_module, set(self.new_version.modules.keys()))
compatibility_report): for mbed_module in shared_modules:
if not self._is_library_compatible(mbed_module,
compatibility_report):
compliance_return_code = 1
if self.check_storage_tests:
if not self._is_storage_format_compatible(
self.old_version.storage_tests,
self.new_version.storage_tests,
compatibility_report):
compliance_return_code = 1 compliance_return_code = 1
for version in [self.old_version, self.new_version]: for version in [self.old_version, self.new_version]:
for mbed_module, mbed_module_dump in version.abi_dumps.items(): for mbed_module, mbed_module_dump in version.abi_dumps.items():
os.remove(mbed_module_dump) os.remove(mbed_module_dump)
@ -328,7 +481,8 @@ class AbiChecker:
"""Generate a report of ABI differences """Generate a report of ABI differences
between self.old_rev and self.new_rev.""" between self.old_rev and self.new_rev."""
self.check_repo_path() self.check_repo_path()
self.check_abi_tools_are_installed() if self.check_api or self.check_abi:
self.check_abi_tools_are_installed()
self._get_abi_dump_for_ref(self.old_version) self._get_abi_dump_for_ref(self.old_version)
self._get_abi_dump_for_ref(self.new_version) self._get_abi_dump_for_ref(self.new_version)
return self.get_abi_compatibility_report() return self.get_abi_compatibility_report()
@ -337,17 +491,7 @@ class AbiChecker:
def run_main(): def run_main():
try: try:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=( description=__doc__
"""This script is a small wrapper around the
abi-compliance-checker and abi-dumper tools, applying them
to compare the ABI and API of the library files from two
different Git revisions within an Mbed TLS repository.
The results of the comparison are either formatted as HTML and
stored at a configurable location, or are given as a brief list
of problems. Returns 0 on success, 1 on ABI/API non-compliance,
and 2 if there is an error while running the script.
Note: must be run from Mbed TLS root."""
)
) )
parser.add_argument( parser.add_argument(
"-v", "--verbose", action="store_true", "-v", "--verbose", action="store_true",
@ -397,6 +541,24 @@ def run_main():
"(typically \"-s identifiers\" after running " "(typically \"-s identifiers\" after running "
"\"tests/scripts/list-identifiers.sh --internal\")") "\"tests/scripts/list-identifiers.sh --internal\")")
) )
parser.add_argument(
"--check-abi",
action='store_true', default=True,
help="Perform ABI comparison (default: yes)"
)
parser.add_argument("--no-check-abi", action='store_false', dest='check_abi')
parser.add_argument(
"--check-api",
action='store_true', default=True,
help="Perform API comparison (default: yes)"
)
parser.add_argument("--no-check-api", action='store_false', dest='check_api')
parser.add_argument(
"--check-storage",
action='store_true', default=True,
help="Perform storage tests comparison (default: yes)"
)
parser.add_argument("--no-check-storage", action='store_false', dest='check_storage')
parser.add_argument( parser.add_argument(
"-b", "--brief", action="store_true", "-b", "--brief", action="store_true",
help="output only the list of issues to stdout, instead of a full report", help="output only the list of issues to stdout, instead of a full report",
@ -413,6 +575,7 @@ def run_main():
crypto_repository=abi_args.old_crypto_repo, crypto_repository=abi_args.old_crypto_repo,
crypto_revision=abi_args.old_crypto_rev, crypto_revision=abi_args.old_crypto_rev,
abi_dumps={}, abi_dumps={},
storage_tests={},
modules={} modules={}
) )
new_version = SimpleNamespace( new_version = SimpleNamespace(
@ -423,6 +586,7 @@ def run_main():
crypto_repository=abi_args.new_crypto_repo, crypto_repository=abi_args.new_crypto_repo,
crypto_revision=abi_args.new_crypto_rev, crypto_revision=abi_args.new_crypto_rev,
abi_dumps={}, abi_dumps={},
storage_tests={},
modules={} modules={}
) )
configuration = SimpleNamespace( configuration = SimpleNamespace(
@ -430,6 +594,9 @@ def run_main():
report_dir=abi_args.report_dir, report_dir=abi_args.report_dir,
keep_all_reports=abi_args.keep_all_reports, keep_all_reports=abi_args.keep_all_reports,
brief=abi_args.brief, brief=abi_args.brief,
check_abi=abi_args.check_abi,
check_api=abi_args.check_api,
check_storage=abi_args.check_storage,
skip_file=abi_args.skip_file skip_file=abi_args.skip_file
) )
abi_check = AbiChecker(old_version, new_version, configuration) abi_check = AbiChecker(old_version, new_version, configuration)

View file

@ -63,15 +63,15 @@ class LostContent(Exception):
# The category names we use in the changelog. # The category names we use in the changelog.
# If you edit this, update ChangeLog.d/README.md. # If you edit this, update ChangeLog.d/README.md.
STANDARD_CATEGORIES = ( STANDARD_CATEGORIES = (
b'API changes', 'API changes',
b'Default behavior changes', 'Default behavior changes',
b'Requirement changes', 'Requirement changes',
b'New deprecations', 'New deprecations',
b'Removals', 'Removals',
b'Features', 'Features',
b'Security', 'Security',
b'Bugfix', 'Bugfix',
b'Changes', 'Changes',
) )
# The maximum line length for an entry # The maximum line length for an entry
@ -122,13 +122,13 @@ class ChangelogFormat:
class TextChangelogFormat(ChangelogFormat): class TextChangelogFormat(ChangelogFormat):
"""The traditional Mbed TLS changelog format.""" """The traditional Mbed TLS changelog format."""
_unreleased_version_text = b'= mbed TLS x.x.x branch released xxxx-xx-xx' _unreleased_version_text = '= mbed TLS x.x.x branch released xxxx-xx-xx'
@classmethod @classmethod
def is_released_version(cls, title): def is_released_version(cls, title):
# Look for an incomplete release date # Look for an incomplete release date
return not re.search(br'[0-9x]{4}-[0-9x]{2}-[0-9x]?x', title) return not re.search(r'[0-9x]{4}-[0-9x]{2}-[0-9x]?x', title)
_top_version_re = re.compile(br'(?:\A|\n)(=[^\n]*\n+)(.*?\n)(?:=|$)', _top_version_re = re.compile(r'(?:\A|\n)(=[^\n]*\n+)(.*?\n)(?:=|$)',
re.DOTALL) re.DOTALL)
@classmethod @classmethod
def extract_top_version(cls, changelog_file_content): def extract_top_version(cls, changelog_file_content):
@ -140,17 +140,17 @@ class TextChangelogFormat(ChangelogFormat):
top_version_body = m.group(2) top_version_body = m.group(2)
if cls.is_released_version(top_version_title): if cls.is_released_version(top_version_title):
top_version_end = top_version_start top_version_end = top_version_start
top_version_title = cls._unreleased_version_text + b'\n\n' top_version_title = cls._unreleased_version_text + '\n\n'
top_version_body = b'' top_version_body = ''
return (changelog_file_content[:top_version_start], return (changelog_file_content[:top_version_start],
top_version_title, top_version_body, top_version_title, top_version_body,
changelog_file_content[top_version_end:]) changelog_file_content[top_version_end:])
@classmethod @classmethod
def version_title_text(cls, version_title): def version_title_text(cls, version_title):
return re.sub(br'\n.*', version_title, re.DOTALL) return re.sub(r'\n.*', version_title, re.DOTALL)
_category_title_re = re.compile(br'(^\w.*)\n+', re.MULTILINE) _category_title_re = re.compile(r'(^\w.*)\n+', re.MULTILINE)
@classmethod @classmethod
def split_categories(cls, version_body): def split_categories(cls, version_body):
"""A category title is a line with the title in column 0.""" """A category title is a line with the title in column 0."""
@ -163,10 +163,10 @@ class TextChangelogFormat(ChangelogFormat):
title_starts = [m.start(1) for m in title_matches] title_starts = [m.start(1) for m in title_matches]
body_starts = [m.end(0) for m in title_matches] body_starts = [m.end(0) for m in title_matches]
body_ends = title_starts[1:] + [len(version_body)] body_ends = title_starts[1:] + [len(version_body)]
bodies = [version_body[body_start:body_end].rstrip(b'\n') + b'\n' bodies = [version_body[body_start:body_end].rstrip('\n') + '\n'
for (body_start, body_end) in zip(body_starts, body_ends)] for (body_start, body_end) in zip(body_starts, body_ends)]
title_lines = [version_body[:pos].count(b'\n') for pos in title_starts] title_lines = [version_body[:pos].count('\n') for pos in title_starts]
body_lines = [version_body[:pos].count(b'\n') for pos in body_starts] body_lines = [version_body[:pos].count('\n') for pos in body_starts]
return [CategoryContent(title_match.group(1), title_line, return [CategoryContent(title_match.group(1), title_line,
body, body_line) body, body_line)
for title_match, title_line, body, body_line for title_match, title_line, body, body_line
@ -176,9 +176,9 @@ class TextChangelogFormat(ChangelogFormat):
def format_category(cls, title, body): def format_category(cls, title, body):
# `split_categories` ensures that each body ends with a newline. # `split_categories` ensures that each body ends with a newline.
# Make sure that there is additionally a blank line between categories. # Make sure that there is additionally a blank line between categories.
if not body.endswith(b'\n\n'): if not body.endswith('\n\n'):
body += b'\n' body += '\n'
return title + b'\n' + body return title + '\n' + body
class ChangeLog: class ChangeLog:
"""An Mbed TLS changelog. """An Mbed TLS changelog.
@ -199,10 +199,10 @@ class ChangeLog:
# Only accept dotted version numbers (e.g. "3.1", not "3"). # Only accept dotted version numbers (e.g. "3.1", not "3").
# Refuse ".x" in a version number where x is a letter: this indicates # Refuse ".x" in a version number where x is a letter: this indicates
# a version that is not yet released. Something like "3.1a" is accepted. # a version that is not yet released. Something like "3.1a" is accepted.
_version_number_re = re.compile(br'[0-9]+\.[0-9A-Za-z.]+') _version_number_re = re.compile(r'[0-9]+\.[0-9A-Za-z.]+')
_incomplete_version_number_re = re.compile(br'.*\.[A-Za-z]') _incomplete_version_number_re = re.compile(r'.*\.[A-Za-z]')
_only_url_re = re.compile(br'^\s*\w+://\S+\s*$') _only_url_re = re.compile(r'^\s*\w+://\S+\s*$')
_has_url_re = re.compile(br'.*://.*') _has_url_re = re.compile(r'.*://.*')
def add_categories_from_text(self, filename, line_offset, def add_categories_from_text(self, filename, line_offset,
text, allow_unknown_category): text, allow_unknown_category):
@ -218,7 +218,7 @@ class ChangeLog:
raise InputFormatError(filename, raise InputFormatError(filename,
line_offset + category.title_line, line_offset + category.title_line,
'Unknown category: "{}"', 'Unknown category: "{}"',
category.name.decode('utf8')) category.name)
body_split = category.body.splitlines() body_split = category.body.splitlines()
@ -250,8 +250,8 @@ class ChangeLog:
# Split the top version section into categories. # Split the top version section into categories.
self.categories = OrderedDict() self.categories = OrderedDict()
for category in STANDARD_CATEGORIES: for category in STANDARD_CATEGORIES:
self.categories[category] = b'' self.categories[category] = ''
offset = (self.header + self.top_version_title).count(b'\n') + 1 offset = (self.header + self.top_version_title).count('\n') + 1
self.add_categories_from_text(input_stream.name, offset, self.add_categories_from_text(input_stream.name, offset,
top_version_body, True) top_version_body, True)
@ -264,7 +264,7 @@ class ChangeLog:
def write(self, filename): def write(self, filename):
"""Write the changelog to the specified file. """Write the changelog to the specified file.
""" """
with open(filename, 'wb') as out: with open(filename, 'w', encoding='utf-8') as out:
out.write(self.header) out.write(self.header)
out.write(self.top_version_title) out.write(self.top_version_title)
for title, body in self.categories.items(): for title, body in self.categories.items():
@ -303,7 +303,7 @@ class EntryFileSortKey:
hashes = subprocess.check_output(['git', 'log', '--format=%H', hashes = subprocess.check_output(['git', 'log', '--format=%H',
'--follow', '--follow',
'--', filename]) '--', filename])
m = re.search(b'(.+)$', hashes) m = re.search('(.+)$', hashes.decode('ascii'))
if not m: if not m:
# The git output is empty. This means that the file was # The git output is empty. This means that the file was
# never checked in. # never checked in.
@ -320,8 +320,8 @@ class EntryFileSortKey:
""" """
text = subprocess.check_output(['git', 'rev-list', text = subprocess.check_output(['git', 'rev-list',
'--merges', *options, '--merges', *options,
b'..'.join([some_hash, target])]) '..'.join([some_hash, target])])
return text.rstrip(b'\n').split(b'\n') return text.decode('ascii').rstrip('\n').split('\n')
@classmethod @classmethod
def merge_hash(cls, some_hash): def merge_hash(cls, some_hash):
@ -329,7 +329,7 @@ class EntryFileSortKey:
Return None if the given commit was never merged. Return None if the given commit was never merged.
""" """
target = b'HEAD' target = 'HEAD'
# List the merges from some_hash to the target in two ways. # List the merges from some_hash to the target in two ways.
# The ancestry list is the ones that are both descendants of # The ancestry list is the ones that are both descendants of
# some_hash and ancestors of the target. # some_hash and ancestors of the target.
@ -407,14 +407,17 @@ def check_output(generated_output_file, main_input_file, merged_files):
is also present in an output file. This is not perfect but good enough is also present in an output file. This is not perfect but good enough
for now. for now.
""" """
generated_output = set(open(generated_output_file, 'rb')) with open(generated_output_file, 'r', encoding='utf-8') as out_fd:
for line in open(main_input_file, 'rb'): generated_output = set(out_fd)
if line not in generated_output: with open(main_input_file, 'r', encoding='utf-8') as in_fd:
raise LostContent('original file', line) for line in in_fd:
for merged_file in merged_files: if line not in generated_output:
for line in open(merged_file, 'rb'): raise LostContent('original file', line)
if line not in generated_output: for merged_file in merged_files:
raise LostContent(merged_file, line) with open(merged_file, 'r', encoding='utf-8') as in_fd:
for line in in_fd:
if line not in generated_output:
raise LostContent(merged_file, line)
def finish_output(changelog, output_file, input_file, merged_files): def finish_output(changelog, output_file, input_file, merged_files):
"""Write the changelog to the output file. """Write the changelog to the output file.
@ -455,14 +458,14 @@ def merge_entries(options):
Write the new changelog to options.output. Write the new changelog to options.output.
Remove the merged entries if options.keep_entries is false. Remove the merged entries if options.keep_entries is false.
""" """
with open(options.input, 'rb') as input_file: with open(options.input, 'r', encoding='utf-8') as input_file:
changelog = ChangeLog(input_file, TextChangelogFormat) changelog = ChangeLog(input_file, TextChangelogFormat)
files_to_merge = list_files_to_merge(options) files_to_merge = list_files_to_merge(options)
if not files_to_merge: if not files_to_merge:
sys.stderr.write('There are no pending changelog entries.\n') sys.stderr.write('There are no pending changelog entries.\n')
return return
for filename in files_to_merge: for filename in files_to_merge:
with open(filename, 'rb') as input_file: with open(filename, 'r', encoding='utf-8') as input_file:
changelog.add_file(input_file) changelog.add_file(input_file)
finish_output(changelog, options.output, options.input, files_to_merge) finish_output(changelog, options.output, options.input, files_to_merge)
if not options.keep_entries: if not options.keep_entries:

View file

@ -18,7 +18,7 @@
import itertools import itertools
import re import re
from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union from typing import Dict, IO, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
class ReadFileLineException(Exception): class ReadFileLineException(Exception):
@ -50,12 +50,13 @@ class read_file_lines:
""" """
def __init__(self, filename: str, binary: bool = False) -> None: def __init__(self, filename: str, binary: bool = False) -> None:
self.filename = filename self.filename = filename
self.file = None #type: Optional[IO[str]]
self.line_number = 'entry' #type: Union[int, str] self.line_number = 'entry' #type: Union[int, str]
self.generator = None #type: Optional[Iterable[Tuple[int, str]]] self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
self.binary = binary self.binary = binary
def __enter__(self) -> 'read_file_lines': def __enter__(self) -> 'read_file_lines':
self.generator = enumerate(open(self.filename, self.file = open(self.filename, 'rb' if self.binary else 'r')
'rb' if self.binary else 'r')) self.generator = enumerate(self.file)
return self return self
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
assert self.generator is not None assert self.generator is not None
@ -64,6 +65,8 @@ class read_file_lines:
yield content yield content
self.line_number = 'exit' self.line_number = 'exit'
def __exit__(self, exc_type, exc_value, exc_traceback) -> None: def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
if self.file is not None:
self.file.close()
if exc_type is not None: if exc_type is not None:
raise ReadFileLineException(self.filename, self.line_number) \ raise ReadFileLineException(self.filename, self.line_number) \
from exc_value from exc_value

View file

@ -56,18 +56,19 @@ class Requirements:
* Comments (``#`` at the beginning of the line or after whitespace). * Comments (``#`` at the beginning of the line or after whitespace).
* ``-r FILENAME`` to include another file. * ``-r FILENAME`` to include another file.
""" """
for line in open(filename): with open(filename) as fd:
line = line.strip() for line in fd:
line = re.sub(r'(\A|\s+)#.*', r'', line) line = line.strip()
if not line: line = re.sub(r'(\A|\s+)#.*', r'', line)
continue if not line:
m = re.match(r'-r\s+', line) continue
if m: m = re.match(r'-r\s+', line)
nested_file = os.path.join(os.path.dirname(filename), if m:
line[m.end(0):]) nested_file = os.path.join(os.path.dirname(filename),
self.add_file(nested_file) line[m.end(0):])
continue self.add_file(nested_file)
self.requirements.append(self.adjust_requirement(line)) continue
self.requirements.append(self.adjust_requirement(line))
def write(self, out: typing_util.Writable) -> None: def write(self, out: typing_util.Writable) -> None:
"""List the gathered requirements.""" """List the gathered requirements."""

View file

@ -723,6 +723,8 @@ class TestGenerator:
filename = self.filename_for(basename) filename = self.filename_for(basename)
test_case.write_data_file(filename, test_cases) test_case.write_data_file(filename, test_cases)
# Note that targets whose name containns 'test_format' have their content
# validated by `abi_check.py`.
TARGETS = { TARGETS = {
'test_suite_psa_crypto_generate_key.generated': 'test_suite_psa_crypto_generate_key.generated':
lambda info: KeyGenerate(info).test_cases_for_key_generation(), lambda info: KeyGenerate(info).test_cases_for_key_generation(),