From 2d8154b981b558a93dcbb0168ed8dcbce707dd77 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 18 Apr 2026 13:01:02 -0400 Subject: [PATCH] skip --checkpoint file statements with --resume In --batch mode, when the batch input script is not STDIN, and when --checkpoint is also given, --resume causes mycli to replay the checkpoint file, looking for leading matching statements, and skip execution of batch statements already present in the checkpoint file. Motivation: resumption of interrupted batch scripts. The number of statements in the checkpoint file must be fewer than the number of statements in the batch script, and form a leading match, or mycli will exit without executing anything. Once execution is picked up again from the midpoint of the --batch script, we continue to append _new_ statements to the checkpoint file, after each statement is successfully executed. That behavior is unchanged. This allows the checkpoint file to be used again if the batch script is interrupted multiple times. The --progress bar and included ETA calculation account for the statements replayed from the checkpoint file, and show corrected views. Further work could include creating a [batch] section in myclirc and adding a default value for resumption, with a --no-resume option. Note: some SQL statements change server/session state or start transactions. But _any_ successful statement will be checkpointed and then not executed upon resumption in --resume mode. It is incumbent on the user to account for such state when resuming from a checkpoint. --- changelog.md | 1 + mycli/main.py | 15 +- mycli/main_modes/batch.py | 62 +++++++ test/pytests/test_main.py | 9 + test/pytests/test_main_modes_batch.py | 239 +++++++++++++++++++++++++- 5 files changed, 322 insertions(+), 4 deletions(-) diff --git a/changelog.md b/changelog.md index fecc97bf..ae91fce7 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ Upcoming (TBD) Features --------- * Remove undocumented `%mycli` Jupyter magic. +* Add `--resume` to replay `--checkpoint` files with `--batch`. Bug Fixes diff --git a/mycli/main.py b/mycli/main.py index ae6ca3c5..9a577d92 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1346,7 +1346,12 @@ class CliArgs: ) checkpoint: TextIOWrapper | None = clickdc.option( type=click.File(mode='a', encoding='utf-8'), - help='In batch or --execute mode, log successful queries to a file.', + help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', + ) + resume: bool = clickdc.option( + '--resume', + is_flag=True, + help='In batch mode, resume after replaying statements in the --checkpoint file.', ) defaults_group_suffix: str | None = clickdc.option( type=str, @@ -1514,6 +1519,14 @@ def get_password_from_file(password_file: str | None) -> str | None: if cli_args.password is None and os.environ.get("MYSQL_PWD") is not None: cli_args.password = os.environ.get("MYSQL_PWD") + if cli_args.resume and not cli_args.checkpoint: + click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') + sys.exit(1) + + if cli_args.resume and not cli_args.batch: + click.secho('Error: --resume requires a --batch file.', err=True, fg='red') + sys.exit(1) + mycli = MyCli( prompt=cli_args.prompt, toolbar_format=cli_args.toolbar, diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index ba23e839..80c0f7d8 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -1,5 +1,6 @@ from __future__ import annotations +from io import TextIOWrapper import os import sys import time @@ -19,6 +20,53 @@ from mycli.main import CliArgs, MyCli +class CheckpointReplayError(Exception): + pass + + +def replay_checkpoint_file( + batch_path: str, + checkpoint: TextIOWrapper | None, + resume: bool, +) -> int: + if not resume: + return 0 + + if checkpoint is None: + return 0 + + if batch_path == '-': + raise CheckpointReplayError('--resume is incompatible with reading from the standard input.') + + checkpoint_name = checkpoint.name + checkpoint.flush() + completed_count = 0 + try: + with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_name, mode='r', encoding='utf-8') as checkpoint_h: + try: + batch_gen = statements_from_filehandle(batch_h) + except ValueError as e: + raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None + for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h): + try: + batch_statement, _batch_counter = next(batch_gen) + except StopIteration: + raise CheckpointReplayError('Checkpoint script longer than batch script.') from None + except ValueError as e: + raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None + if checkpoint_statement != batch_statement: + raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.') + completed_count += 1 + except ValueError as e: + raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint.name}: {e}') from None + except FileNotFoundError as e: + raise CheckpointReplayError(f'FileNotFoundError: {e}') from None + except OSError as e: + raise CheckpointReplayError(f'OSError: {e}') from None + + return completed_count + + def dispatch_batch_statements( mycli: 'MyCli', cli_args: 'CliArgs', @@ -70,6 +118,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: click.secho('--progress is only compatible with a plain file.', err=True, fg='red') return 1 try: + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) batch_count_h = click.open_file(cli_args.batch) for _statement, _counter in statements_from_filehandle(batch_count_h): goal_statements += 1 @@ -82,6 +131,10 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: except ValueError as e: click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red') return 1 + except CheckpointReplayError as e: + name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') + return 1 try: if goal_statements: pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) @@ -98,6 +151,8 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: for _pb_counter in pb(range(goal_statements)): statement, statement_counter = next(batch_gen) + if statement_counter < completed_statement_count: + continue dispatch_batch_statements(mycli, cli_args, statement, statement_counter) except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: click.secho(str(e), err=True, fg='red') @@ -113,12 +168,19 @@ def main_batch_without_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: if not sys.stdin.isatty() and cli_args.batch != '-': click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') try: + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) batch_h = click.open_file(cli_args.batch) except (OSError, FileNotFoundError): click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') return 1 + except CheckpointReplayError as e: + name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') + return 1 try: for statement, counter in statements_from_filehandle(batch_h): + if counter < completed_statement_count: + continue dispatch_batch_statements(mycli, cli_args, statement, counter) except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: click.secho(str(e), err=True, fg='red') diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 84019590..226b047a 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -2126,6 +2126,15 @@ def test_execute_arg_warns_about_ignoring_stdin(monkeypatch): assert 'Ignoring STDIN' in result.output +def test_resume_requires_checkpoint() -> None: + runner = CliRunner() + + result = runner.invoke(click_entrypoint, args=['--resume']) + + assert result.exit_code == 1 + assert 'Error:' in result.output + + def test_execute_arg_supersedes_batch_file(monkeypatch): mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) runner = CliRunner() diff --git a/test/pytests/test_main_modes_batch.py b/test/pytests/test_main_modes_batch.py index 06ff1800..9d7fd9a2 100644 --- a/test/pytests/test_main_modes_batch.py +++ b/test/pytests/test_main_modes_batch.py @@ -1,7 +1,9 @@ from __future__ import annotations from dataclasses import dataclass +from io import TextIOWrapper import os +from pathlib import Path import sys from tempfile import NamedTemporaryFile from types import SimpleNamespace @@ -23,8 +25,9 @@ class DummyCliArgs: format: str = 'tsv' noninteractive: bool = True throttle: float = 0.0 - checkpoint: str | None = None + checkpoint: str | TextIOWrapper | None = None batch: str | None = None + resume: bool = False @dataclass @@ -47,9 +50,9 @@ def __init__(self, destructive_warning: bool = False, run_query_error: Exception self.destructive_keywords = ('drop',) self.logger = DummyLogger() self.run_query_error = run_query_error - self.ran_queries: list[tuple[str, str | None, bool]] = [] + self.ran_queries: list[tuple[str, str | TextIOWrapper | None, bool]] = [] - def run_query(self, query: str, checkpoint: str | None = None, new_line: bool = True) -> None: + def run_query(self, query: str, checkpoint: str | TextIOWrapper | None = None, new_line: bool = True) -> None: if self.run_query_error is not None: raise self.run_query_error self.ran_queries.append((query, checkpoint, new_line)) @@ -142,6 +145,98 @@ def invoke_click_batch( os.remove(batch_file.name) +def write_batch_file(tmp_path: Path, contents: str) -> str: + batch_path = tmp_path / 'batch.sql' + batch_path.write_text(contents, encoding='utf-8') + return str(batch_path) + + +def open_checkpoint_file(tmp_path: Path, contents: str) -> TextIOWrapper: + checkpoint_path = tmp_path / 'checkpoint.sql' + checkpoint_path.write_text(contents, encoding='utf-8') + return checkpoint_path.open('a', encoding='utf-8') + + +def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + assert batch_mode.replay_checkpoint_file(batch_path, None, resume=True) == 0 + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'): + batch_mode.replay_checkpoint_file('-', checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch'))) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def raise_on_next(): + raise ValueError('bad batch iterator') + yield + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return raise_on_next() + return iter([('select 1;', 0)]) + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return iter([('select 1;', 0)]) + return (_ for _ in ()).throw(ValueError('bad checkpoint')) + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint.name}: bad checkpoint'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: + batch_path = str(tmp_path / 'missing.sql') + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed'))) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + @pytest.mark.parametrize( ('format_name', 'batch_counter', 'expected'), ( @@ -401,6 +496,126 @@ def test_main_batch_without_progress_bar_processes_statements(monkeypatch) -> No assert batch_handle.closed is True +def test_main_batch_without_progress_bar_skips_checkpoint_prefix(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\nselect 3;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 3;', 2)] + + +def test_main_batch_without_progress_bar_skips_only_matching_duplicate_prefix(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 1;', 1), ('select 2;', 2)] + + +def test_main_batch_without_progress_bar_fails_on_mismatched_checkpoint(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert dispatch_calls == [] + + +def test_main_batch_without_progress_bar_succeeds_when_checkpoint_skips_all(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [] + + +def test_main_batch_with_progress_bar_skips_checkpoint_prefix_and_counts_all_statements(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\nselect 3;\n') + dispatch_calls: list[tuple[str, int]] = [] + + DummyProgressBar.calls.clear() + monkeypatch.setattr(batch_mode, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(batch_mode.prompt_toolkit.output, 'create_output', lambda **_kwargs: object()) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 2;', 1), ('select 3;', 2)] + assert DummyProgressBar.calls == [[0, 1, 2]] + + +def test_main_batch_with_progress_bar_returns_error_when_checkpoint_replay_fails(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + messages: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [(f'Error replaying --checkpoint file: {checkpoint.name}: Statement mismatch: select 9;.', True, 'red')] + + def test_main_batch_without_progress_bar_returns_error_when_iteration_fails(monkeypatch) -> None: messages: list[tuple[str, bool, str]] = [] batch_handle = DummyFile('run') @@ -473,6 +688,24 @@ def test_click_batch_file_modes(monkeypatch, contents: str, extra_args: list[str assert DummyProgressBar.calls == expected_progress +def test_click_batch_file_skips_checkpoint_prefix(monkeypatch, tmp_path: Path) -> None: + mycli_main, _mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + MockMyCli.ran_queries = [] + checkpoint_path = tmp_path / 'checkpoint.sql' + checkpoint_path.write_text('select 2;\n', encoding='utf-8') + + result, _batch_file_name = invoke_click_batch( + runner, + mycli_main, + 'select 2;\nselect 3;\n', + [f'--checkpoint={checkpoint_path}', '--resume'], + ) + + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 3;'] + + def test_batch_file_with_progress_requires_plain_file(monkeypatch, tmp_path) -> None: mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) runner = CliRunner()