diff --git a/changelog.md b/changelog.md index 7d030543..1b9d3bf9 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Features +--------- +* Add `--resume` to replay `--checkpoint` files with `--batch`. + + Bug Fixes --------- * Make LLM timings use the same format as other timings. diff --git a/mycli/main.py b/mycli/main.py index ae6ca3c5..6fda91e7 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1346,7 +1346,12 @@ class CliArgs: ) checkpoint: TextIOWrapper | None = clickdc.option( type=click.File(mode='a', encoding='utf-8'), - help='In batch or --execute mode, log successful queries to a file.', + help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', + ) + resume: bool = clickdc.option( + '--resume', + is_flag=True, + help='In batch mode, resume after replaying statements in the --checkpoint file.', ) defaults_group_suffix: str | None = clickdc.option( type=str, @@ -1514,6 +1519,10 @@ def get_password_from_file(password_file: str | None) -> str | None: if cli_args.password is None and os.environ.get("MYSQL_PWD") is not None: cli_args.password = os.environ.get("MYSQL_PWD") + if cli_args.resume and not cli_args.checkpoint: + click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') + sys.exit(1) + mycli = MyCli( prompt=cli_args.prompt, toolbar_format=cli_args.toolbar, diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index ba23e839..80c0f7d8 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -1,5 +1,6 @@ from __future__ import annotations +from io import TextIOWrapper import os import sys import time @@ -19,6 +20,53 @@ from mycli.main import CliArgs, MyCli +class CheckpointReplayError(Exception): + pass + + +def replay_checkpoint_file( + batch_path: str, + checkpoint: TextIOWrapper | None, + resume: bool, +) -> int: + if not resume: + return 0 + + if checkpoint is None: + return 0 + + if batch_path == '-': + raise CheckpointReplayError('--resume is incompatible with reading from the standard input.') + + checkpoint_name = checkpoint.name + checkpoint.flush() + completed_count = 0 + try: + with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_name, mode='r', encoding='utf-8') as checkpoint_h: + try: + batch_gen = statements_from_filehandle(batch_h) + except ValueError as e: + raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None + for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h): + try: + batch_statement, _batch_counter = next(batch_gen) + except StopIteration: + raise CheckpointReplayError('Checkpoint script longer than batch script.') from None + except ValueError as e: + raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None + if checkpoint_statement != batch_statement: + raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.') + completed_count += 1 + except ValueError as e: + raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint.name}: {e}') from None + except FileNotFoundError as e: + raise CheckpointReplayError(f'FileNotFoundError: {e}') from None + except OSError as e: + raise CheckpointReplayError(f'OSError: {e}') from None + + return completed_count + + def dispatch_batch_statements( mycli: 'MyCli', cli_args: 'CliArgs', @@ -70,6 +118,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: click.secho('--progress is only compatible with a plain file.', err=True, fg='red') return 1 try: + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) batch_count_h = click.open_file(cli_args.batch) for _statement, _counter in statements_from_filehandle(batch_count_h): goal_statements += 1 @@ -82,6 +131,10 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: except ValueError as e: click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red') return 1 + except CheckpointReplayError as e: + name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') + return 1 try: if goal_statements: pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) @@ -98,6 +151,8 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: for _pb_counter in pb(range(goal_statements)): statement, statement_counter = next(batch_gen) + if statement_counter < completed_statement_count: + continue dispatch_batch_statements(mycli, cli_args, statement, statement_counter) except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: click.secho(str(e), err=True, fg='red') @@ -113,12 +168,19 @@ def main_batch_without_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: if not sys.stdin.isatty() and cli_args.batch != '-': click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') try: + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) batch_h = click.open_file(cli_args.batch) except (OSError, FileNotFoundError): click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') return 1 + except CheckpointReplayError as e: + name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') + return 1 try: for statement, counter in statements_from_filehandle(batch_h): + if counter < completed_statement_count: + continue dispatch_batch_statements(mycli, cli_args, statement, counter) except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: click.secho(str(e), err=True, fg='red') diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 84019590..226b047a 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -2126,6 +2126,15 @@ def test_execute_arg_warns_about_ignoring_stdin(monkeypatch): assert 'Ignoring STDIN' in result.output +def test_resume_requires_checkpoint() -> None: + runner = CliRunner() + + result = runner.invoke(click_entrypoint, args=['--resume']) + + assert result.exit_code == 1 + assert 'Error:' in result.output + + def test_execute_arg_supersedes_batch_file(monkeypatch): mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) runner = CliRunner() diff --git a/test/pytests/test_main_modes_batch.py b/test/pytests/test_main_modes_batch.py index 06ff1800..9d7fd9a2 100644 --- a/test/pytests/test_main_modes_batch.py +++ b/test/pytests/test_main_modes_batch.py @@ -1,7 +1,9 @@ from __future__ import annotations from dataclasses import dataclass +from io import TextIOWrapper import os +from pathlib import Path import sys from tempfile import NamedTemporaryFile from types import SimpleNamespace @@ -23,8 +25,9 @@ class DummyCliArgs: format: str = 'tsv' noninteractive: bool = True throttle: float = 0.0 - checkpoint: str | None = None + checkpoint: str | TextIOWrapper | None = None batch: str | None = None + resume: bool = False @dataclass @@ -47,9 +50,9 @@ def __init__(self, destructive_warning: bool = False, run_query_error: Exception self.destructive_keywords = ('drop',) self.logger = DummyLogger() self.run_query_error = run_query_error - self.ran_queries: list[tuple[str, str | None, bool]] = [] + self.ran_queries: list[tuple[str, str | TextIOWrapper | None, bool]] = [] - def run_query(self, query: str, checkpoint: str | None = None, new_line: bool = True) -> None: + def run_query(self, query: str, checkpoint: str | TextIOWrapper | None = None, new_line: bool = True) -> None: if self.run_query_error is not None: raise self.run_query_error self.ran_queries.append((query, checkpoint, new_line)) @@ -142,6 +145,98 @@ def invoke_click_batch( os.remove(batch_file.name) +def write_batch_file(tmp_path: Path, contents: str) -> str: + batch_path = tmp_path / 'batch.sql' + batch_path.write_text(contents, encoding='utf-8') + return str(batch_path) + + +def open_checkpoint_file(tmp_path: Path, contents: str) -> TextIOWrapper: + checkpoint_path = tmp_path / 'checkpoint.sql' + checkpoint_path.write_text(contents, encoding='utf-8') + return checkpoint_path.open('a', encoding='utf-8') + + +def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + assert batch_mode.replay_checkpoint_file(batch_path, None, resume=True) == 0 + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'): + batch_mode.replay_checkpoint_file('-', checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch'))) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def raise_on_next(): + raise ValueError('bad batch iterator') + yield + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return raise_on_next() + return iter([('select 1;', 0)]) + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return iter([('select 1;', 0)]) + return (_ for _ in ()).throw(ValueError('bad checkpoint')) + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint.name}: bad checkpoint'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: + batch_path = str(tmp_path / 'missing.sql') + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed'))) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + @pytest.mark.parametrize( ('format_name', 'batch_counter', 'expected'), ( @@ -401,6 +496,126 @@ def test_main_batch_without_progress_bar_processes_statements(monkeypatch) -> No assert batch_handle.closed is True +def test_main_batch_without_progress_bar_skips_checkpoint_prefix(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\nselect 3;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 3;', 2)] + + +def test_main_batch_without_progress_bar_skips_only_matching_duplicate_prefix(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 1;', 1), ('select 2;', 2)] + + +def test_main_batch_without_progress_bar_fails_on_mismatched_checkpoint(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert dispatch_calls == [] + + +def test_main_batch_without_progress_bar_succeeds_when_checkpoint_skips_all(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [] + + +def test_main_batch_with_progress_bar_skips_checkpoint_prefix_and_counts_all_statements(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\nselect 3;\n') + dispatch_calls: list[tuple[str, int]] = [] + + DummyProgressBar.calls.clear() + monkeypatch.setattr(batch_mode, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(batch_mode.prompt_toolkit.output, 'create_output', lambda **_kwargs: object()) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 2;', 1), ('select 3;', 2)] + assert DummyProgressBar.calls == [[0, 1, 2]] + + +def test_main_batch_with_progress_bar_returns_error_when_checkpoint_replay_fails(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + messages: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [(f'Error replaying --checkpoint file: {checkpoint.name}: Statement mismatch: select 9;.', True, 'red')] + + def test_main_batch_without_progress_bar_returns_error_when_iteration_fails(monkeypatch) -> None: messages: list[tuple[str, bool, str]] = [] batch_handle = DummyFile('run') @@ -473,6 +688,24 @@ def test_click_batch_file_modes(monkeypatch, contents: str, extra_args: list[str assert DummyProgressBar.calls == expected_progress +def test_click_batch_file_skips_checkpoint_prefix(monkeypatch, tmp_path: Path) -> None: + mycli_main, _mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + MockMyCli.ran_queries = [] + checkpoint_path = tmp_path / 'checkpoint.sql' + checkpoint_path.write_text('select 2;\n', encoding='utf-8') + + result, _batch_file_name = invoke_click_batch( + runner, + mycli_main, + 'select 2;\nselect 3;\n', + [f'--checkpoint={checkpoint_path}', '--resume'], + ) + + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 3;'] + + def test_batch_file_with_progress_requires_plain_file(monkeypatch, tmp_path) -> None: mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) runner = CliRunner()