diff --git a/cloud_pipelines_backend/annotation/utils.py b/cloud_pipelines_backend/annotation/utils.py new file mode 100644 index 0000000..28a24e9 --- /dev/null +++ b/cloud_pipelines_backend/annotation/utils.py @@ -0,0 +1,107 @@ +import logging + +from sqlalchemy import orm + +from .. import backend_types_sql as bts +from .. import errors +from ..search import filter_query_sql + +_logger = logging.getLogger(__name__) + +_SYSTEM_KEY_RESERVED_MSG = ( + "Annotation keys starting with " + f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use." +) + + +def fail_if_changing_system_annotation(*, key: str) -> None: + if key.startswith(filter_query_sql.SYSTEM_KEY_PREFIX): + raise errors.ApiValidationError(_SYSTEM_KEY_RESERVED_MSG) + + +def _truncate_for_annotation( + *, + value: str, + field_name: str, + pipeline_run_id: bts.IdType, +) -> str: + """Truncate a string to fit the annotation VARCHAR column. + + Returns the value unchanged if it fits within _STR_MAX_LENGTH, + otherwise truncates and logs a warning with the run ID and + original length. + """ + max_len = bts._STR_MAX_LENGTH + if len(value) <= max_len: + return value + + _logger.warning( + f"Truncating {field_name} annotation for run {pipeline_run_id}: " + f"{len(value)} chars -> {max_len} chars" + ) + return value[:max_len] + + +def mirror_system_annotations( + *, + session: orm.Session, + pipeline_run_id: bts.IdType, + created_by: str | None, + pipeline_name: str | None, +) -> None: + """Mirror pipeline run fields as system annotations for filter_query search. + + Always creates an annotation for every run, even when the source value is + None or empty (stored as ""). This ensures data parity so every run has a + row for each system key. + """ + + # TODO: The original pipeline_run.created_by and the pipeline name stored in + # extra_data / task_spec are saved untruncated, while the annotation mirror + # is truncated to VARCHAR(255). This creates a data parity mismatch between + # the source columns and their annotation copies. Revisit this to either + # widen the annotation column or enforce the same limit at the source. + + created_by_value = created_by + if created_by_value is None: + created_by_value = "" + _logger.warning( + f"Pipeline run id {pipeline_run_id} `created_by` is None, " + 'setting it to empty string "" for data parity' + ) + + created_by_value = _truncate_for_annotation( + value=created_by_value, + field_name=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + pipeline_run_id=pipeline_run_id, + ) + + session.add( + bts.PipelineRunAnnotation( + pipeline_run_id=pipeline_run_id, + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value=created_by_value, + ) + ) + + pipeline_name_value = pipeline_name + if pipeline_name_value is None: + pipeline_name_value = "" + _logger.warning( + f"Pipeline run id {pipeline_run_id} `pipeline_name` is None, " + 'setting it to empty string "" for data parity' + ) + + pipeline_name_value = _truncate_for_annotation( + value=pipeline_name_value, + field_name=filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME, + pipeline_run_id=pipeline_run_id, + ) + + session.add( + bts.PipelineRunAnnotation( + pipeline_run_id=pipeline_run_id, + key=filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME, + value=pipeline_name_value, + ) + ) diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 413d806..cc8733b 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -2,7 +2,7 @@ import datetime import logging import typing -from typing import Any, Final, Optional +from typing import Any, Optional import sqlalchemy as sql from sqlalchemy import orm @@ -10,7 +10,12 @@ from . import backend_types_sql as bts from . import component_structures as structures from . import errors -from . import filter_query_sql +from .annotation import utils as annotation_utils +from .search import runs as search_runs +from .search.runs import ( + ListPipelineJobsResponse, + PipelineRunResponse, +) if typing.TYPE_CHECKING: from cloud_pipelines.orchestration.storage_providers import ( @@ -32,72 +37,7 @@ def _get_current_time() -> datetime.datetime: return datetime.datetime.now(tz=datetime.timezone.utc) -def _get_pipeline_name_from_task_spec( - *, - task_spec_dict: dict[str, Any], -) -> str | None: - """Extract pipeline name from a task_spec dict via component_ref.spec.name. - - Traversal path: - task_spec_dict -> TaskSpec -> component_ref -> spec -> name - - Returns None if any step in the chain is missing or parsing fails. - """ - try: - task_spec = structures.TaskSpec.from_json_dict(task_spec_dict) - except Exception: - return None - spec = task_spec.component_ref.spec - if spec is None: - return None - return spec.name or None - - -# ==== PipelineJobService -@dataclasses.dataclass(kw_only=True) -class PipelineRunResponse: - id: bts.IdType - root_execution_id: bts.IdType - annotations: dict[str, Any] | None = None - # status: "PipelineJobStatus" - created_by: str | None = None - created_at: datetime.datetime | None = None - pipeline_name: str | None = None - execution_status_stats: dict[str, int] | None = None - - @classmethod - def from_db(cls, pipeline_run: bts.PipelineRun) -> "PipelineRunResponse": - return PipelineRunResponse( - id=pipeline_run.id, - root_execution_id=pipeline_run.root_execution_id, - annotations=pipeline_run.annotations, - created_by=pipeline_run.created_by, - created_at=pipeline_run.created_at, - ) - - -class GetPipelineRunResponse(PipelineRunResponse): - pass - - -@dataclasses.dataclass(kw_only=True) -class ListPipelineJobsResponse: - pipeline_runs: list[PipelineRunResponse] - next_page_token: str | None = None - - class PipelineRunsApiService_Sql: - _PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name" - _DEFAULT_PAGE_SIZE: Final[int] = 10 - _SYSTEM_KEY_RESERVED_MSG = ( - "Annotation keys starting with " - f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use." - ) - - def _fail_if_changing_system_annotation(self, *, key: str) -> None: - if key.startswith(filter_query_sql.SYSTEM_KEY_PREFIX): - raise errors.ApiValidationError(self._SYSTEM_KEY_RESERVED_MSG) - def create( self, session: orm.Session, @@ -130,14 +70,14 @@ def create( annotations=annotations, created_by=created_by, extra_data={ - self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name, + search_runs._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name, }, ) session.add(pipeline_run) # Flush to populate pipeline_run.id (server-generated) before inserting annotation FKs. # TODO: Use ORM relationship instead of explicit flush + manual FK assignment. session.flush() - _mirror_system_annotations( + annotation_utils.mirror_system_annotations( session=session, pipeline_run_id=pipeline_run.id, created_by=created_by, @@ -207,100 +147,16 @@ def list( include_pipeline_names: bool = False, include_execution_stats: bool = False, ) -> ListPipelineJobsResponse: - where_clauses, offset, next_token = filter_query_sql.build_list_filters( - filter_value=filter, - filter_query_value=filter_query, - page_token_value=page_token, + return search_runs.get_pipeline_runs( + session=session, + page_token=page_token, + filter=filter, + filter_query=filter_query, current_user=current_user, - page_size=self._DEFAULT_PAGE_SIZE, + include_pipeline_names=include_pipeline_names, + include_execution_stats=include_execution_stats, ) - pipeline_runs = list( - session.scalars( - sql.select(bts.PipelineRun) - .where(*where_clauses) - .order_by(bts.PipelineRun.created_at.desc()) - .offset(offset) - .limit(self._DEFAULT_PAGE_SIZE) - ).all() - ) - - next_page_token = ( - next_token if len(pipeline_runs) >= self._DEFAULT_PAGE_SIZE else None - ) - - return ListPipelineJobsResponse( - pipeline_runs=[ - self._create_pipeline_run_response( - session=session, - pipeline_run=pipeline_run, - include_pipeline_names=include_pipeline_names, - include_execution_stats=include_execution_stats, - ) - for pipeline_run in pipeline_runs - ], - next_page_token=next_page_token, - ) - - def _create_pipeline_run_response( - self, - *, - session: orm.Session, - pipeline_run: bts.PipelineRun, - include_pipeline_names: bool, - include_execution_stats: bool, - ) -> PipelineRunResponse: - response = PipelineRunResponse.from_db(pipeline_run) - if include_pipeline_names: - pipeline_name = None - extra_data = pipeline_run.extra_data or {} - if self._PIPELINE_NAME_EXTRA_DATA_KEY in extra_data: - pipeline_name = extra_data[self._PIPELINE_NAME_EXTRA_DATA_KEY] - else: - execution_node = session.get( - bts.ExecutionNode, pipeline_run.root_execution_id - ) - if execution_node: - pipeline_name = _get_pipeline_name_from_task_spec( - task_spec_dict=execution_node.task_spec - ) - response.pipeline_name = pipeline_name - if include_execution_stats: - execution_status_stats = self._calculate_execution_status_stats( - session=session, root_execution_id=pipeline_run.root_execution_id - ) - response.execution_status_stats = { - status.value: count for status, count in execution_status_stats.items() - } - return response - - def _calculate_execution_status_stats( - self, session: orm.Session, root_execution_id: bts.IdType - ) -> dict[bts.ContainerExecutionStatus, int]: - query = ( - sql.select( - bts.ExecutionNode.container_execution_status, - sql.func.count().label("count"), - ) - .join( - bts.ExecutionToAncestorExecutionLink, - bts.ExecutionToAncestorExecutionLink.execution_id - == bts.ExecutionNode.id, - ) - .where( - bts.ExecutionToAncestorExecutionLink.ancestor_execution_id - == root_execution_id - ) - .where(bts.ExecutionNode.container_execution_status != None) - .group_by( - bts.ExecutionNode.container_execution_status, - ) - ) - execution_status_stat_rows = session.execute(query).tuples().all() - execution_status_stats = dict(execution_status_stat_rows) - - return execution_status_stats - def list_annotations( self, *, @@ -330,7 +186,7 @@ def set_annotation( user_name: str | None = None, skip_user_check: bool = False, ): - self._fail_if_changing_system_annotation(key=key) + annotation_utils.fail_if_changing_system_annotation(key=key) pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") @@ -353,7 +209,7 @@ def delete_annotation( user_name: str | None = None, skip_user_check: bool = False, ): - self._fail_if_changing_system_annotation(key=key) + annotation_utils.fail_if_changing_system_annotation(key=key) pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") @@ -469,22 +325,6 @@ class ArtifactNodeIdResponse: id: bts.IdType -@dataclasses.dataclass(kw_only=True) -class ExecutionStatusSummary: - total_executions: int = 0 - ended_executions: int = 0 - has_ended: bool = False - - def count_execution_status( - self, *, status: bts.ContainerExecutionStatus, count: int - ) -> None: - self.total_executions += count - if status in bts.CONTAINER_STATUSES_ENDED: - self.ended_executions += count - - self.has_ended = self.ended_executions == self.total_executions - - @dataclasses.dataclass class GetGraphExecutionStateResponse: child_execution_status_stats: dict[bts.IdType, dict[str, int]] @@ -1301,94 +1141,6 @@ def delete_settings( ] -def _truncate_for_annotation( - *, - value: str, - field_name: str, - pipeline_run_id: bts.IdType, -) -> str: - """Truncate a string to fit the annotation VARCHAR column. - - Returns the value unchanged if it fits within _STR_MAX_LENGTH, - otherwise truncates and logs a warning with the run ID and - original length. - """ - max_len = bts._STR_MAX_LENGTH - if len(value) <= max_len: - return value - - _logger.warning( - f"Truncating {field_name} annotation for run {pipeline_run_id}: " - f"{len(value)} chars -> {max_len} chars" - ) - return value[:max_len] - - -def _mirror_system_annotations( - *, - session: orm.Session, - pipeline_run_id: bts.IdType, - created_by: str | None, - pipeline_name: str | None, -) -> None: - """Mirror pipeline run fields as system annotations for filter_query search. - - Always creates an annotation for every run, even when the source value is - None or empty (stored as ""). This ensures data parity so every run has a - row for each system key. - """ - - # TODO: The original pipeline_run.created_by and the pipeline name stored in - # extra_data / task_spec are saved untruncated, while the annotation mirror - # is truncated to VARCHAR(255). This creates a data parity mismatch between - # the source columns and their annotation copies. Revisit this to either - # widen the annotation column or enforce the same limit at the source. - - created_by_value = created_by - if created_by_value is None: - created_by_value = "" - _logger.warning( - f"Pipeline run id {pipeline_run_id} `created_by` is None, " - 'setting it to empty string "" for data parity' - ) - - created_by_value = _truncate_for_annotation( - value=created_by_value, - field_name=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, - pipeline_run_id=pipeline_run_id, - ) - - session.add( - bts.PipelineRunAnnotation( - pipeline_run_id=pipeline_run_id, - key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, - value=created_by_value, - ) - ) - - pipeline_name_value = pipeline_name - if pipeline_name_value is None: - pipeline_name_value = "" - _logger.warning( - f"Pipeline run id {pipeline_run_id} `pipeline_name` is None, " - 'setting it to empty string "" for data parity' - ) - - pipeline_name_value = _truncate_for_annotation( - value=pipeline_name_value, - field_name=filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME, - pipeline_run_id=pipeline_run_id, - ) - - session.add( - bts.PipelineRunAnnotation( - pipeline_run_id=pipeline_run_id, - key=filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME, - value=pipeline_name_value, - ) - ) - - def _recursively_create_all_executions_and_artifacts_root( session: orm.Session, root_task_spec: structures.TaskSpec, diff --git a/cloud_pipelines_backend/database_migrations.py b/cloud_pipelines_backend/backfill/annotations.py similarity index 99% rename from cloud_pipelines_backend/database_migrations.py rename to cloud_pipelines_backend/backfill/annotations.py index 746ba6b..e56ea2c 100644 --- a/cloud_pipelines_backend/database_migrations.py +++ b/cloud_pipelines_backend/backfill/annotations.py @@ -6,8 +6,8 @@ import sqlalchemy from sqlalchemy import orm -from . import backend_types_sql as bts -from . import filter_query_sql +from .. import backend_types_sql as bts +from ..search import filter_query_sql _logger = logging.getLogger(__name__) diff --git a/cloud_pipelines_backend/database_ops.py b/cloud_pipelines_backend/database_ops.py index caf889e..fc08e61 100644 --- a/cloud_pipelines_backend/database_ops.py +++ b/cloud_pipelines_backend/database_ops.py @@ -4,7 +4,7 @@ from sqlalchemy import orm from . import backend_types_sql as bts -from . import database_migrations +from .backfill import annotations as backfill_annotations _logger = logging.getLogger(__name__) @@ -108,7 +108,7 @@ def migrate_db( _logger.info("Skipping annotation backfills") else: with orm.Session(db_engine) as session: - database_migrations.run_all_annotation_backfills( + backfill_annotations.run_all_annotation_backfills( session=session, ) diff --git a/cloud_pipelines_backend/execution/utils.py b/cloud_pipelines_backend/execution/utils.py new file mode 100644 index 0000000..97441ec --- /dev/null +++ b/cloud_pipelines_backend/execution/utils.py @@ -0,0 +1,19 @@ +import dataclasses + +from .. import backend_types_sql as bts + + +@dataclasses.dataclass(kw_only=True) +class ExecutionStatusSummary: + total_executions: int = 0 + ended_executions: int = 0 + has_ended: bool = False + + def count_execution_status( + self, *, status: bts.ContainerExecutionStatus, count: int + ) -> None: + self.total_executions += count + if status in bts.CONTAINER_STATUSES_ENDED: + self.ended_executions += count + + self.has_ended = self.ended_executions == self.total_executions diff --git a/cloud_pipelines_backend/filter_query_models.py b/cloud_pipelines_backend/search/filter_query_models.py similarity index 100% rename from cloud_pipelines_backend/filter_query_models.py rename to cloud_pipelines_backend/search/filter_query_models.py diff --git a/cloud_pipelines_backend/filter_query_sql.py b/cloud_pipelines_backend/search/filter_query_sql.py similarity index 99% rename from cloud_pipelines_backend/filter_query_sql.py rename to cloud_pipelines_backend/search/filter_query_sql.py index b306a57..cb10857 100644 --- a/cloud_pipelines_backend/filter_query_sql.py +++ b/cloud_pipelines_backend/search/filter_query_sql.py @@ -6,8 +6,8 @@ import sqlalchemy as sql -from . import backend_types_sql as bts -from . import errors +from .. import backend_types_sql as bts +from .. import errors from . import filter_query_models SYSTEM_KEY_PREFIX: Final[str] = "system/" diff --git a/cloud_pipelines_backend/search/runs.py b/cloud_pipelines_backend/search/runs.py new file mode 100644 index 0000000..7e3d915 --- /dev/null +++ b/cloud_pipelines_backend/search/runs.py @@ -0,0 +1,179 @@ +import dataclasses +import datetime +from typing import Any, Final + +import sqlalchemy as sql +from sqlalchemy import orm + +from .. import backend_types_sql as bts +from .. import component_structures as structures +from . import filter_query_sql + +_DEFAULT_PAGE_SIZE: Final[int] = 10 +_PIPELINE_NAME_EXTRA_DATA_KEY: Final[str] = "pipeline_name" + + +@dataclasses.dataclass(kw_only=True) +class PipelineRunResponse: + id: bts.IdType + root_execution_id: bts.IdType + annotations: dict[str, Any] | None = None + created_by: str | None = None + created_at: datetime.datetime | None = None + pipeline_name: str | None = None + execution_status_stats: dict[str, int] | None = None + + @classmethod + def from_db(cls, pipeline_run: bts.PipelineRun) -> "PipelineRunResponse": + return PipelineRunResponse( + id=pipeline_run.id, + root_execution_id=pipeline_run.root_execution_id, + annotations=pipeline_run.annotations, + created_by=pipeline_run.created_by, + created_at=pipeline_run.created_at, + ) + + +@dataclasses.dataclass(kw_only=True) +class ListPipelineJobsResponse: + pipeline_runs: list[PipelineRunResponse] + next_page_token: str | None = None + + +def get_pipeline_name_from_task_spec( + *, + task_spec_dict: dict[str, Any], +) -> str | None: + """Extract pipeline name from a task_spec dict via component_ref.spec.name. + + Traversal path: + task_spec_dict -> TaskSpec -> component_ref -> spec -> name + + Returns None if any step in the chain is missing or parsing fails. + """ + try: + task_spec = structures.TaskSpec.from_json_dict(task_spec_dict) + except Exception: + return None + spec = task_spec.component_ref.spec + if spec is None: + return None + return spec.name or None + + +def _query_pipeline_runs( + *, + session: orm.Session, + where_clauses: list[sql.ColumnElement], + offset: int, + page_size: int, +) -> list[bts.PipelineRun]: + return list( + session.scalars( + sql.select(bts.PipelineRun) + .where(*where_clauses) + .order_by(bts.PipelineRun.created_at.desc()) + .offset(offset) + .limit(page_size) + ).all() + ) + + +def _calculate_execution_status_stats( + *, + session: orm.Session, + root_execution_id: bts.IdType, +) -> dict[bts.ContainerExecutionStatus, int]: + query = ( + sql.select( + bts.ExecutionNode.container_execution_status, + sql.func.count().label("count"), + ) + .join( + bts.ExecutionToAncestorExecutionLink, + bts.ExecutionToAncestorExecutionLink.execution_id == bts.ExecutionNode.id, + ) + .where( + bts.ExecutionToAncestorExecutionLink.ancestor_execution_id + == root_execution_id + ) + .where(bts.ExecutionNode.container_execution_status != None) + .group_by( + bts.ExecutionNode.container_execution_status, + ) + ) + execution_status_stat_rows = session.execute(query).tuples().all() + return dict(execution_status_stat_rows) + + +def _create_pipeline_run_response( + *, + session: orm.Session, + pipeline_run: bts.PipelineRun, + include_pipeline_names: bool, + include_execution_stats: bool, +) -> PipelineRunResponse: + response = PipelineRunResponse.from_db(pipeline_run) + if include_pipeline_names: + pipeline_name = None + extra_data = pipeline_run.extra_data or {} + if _PIPELINE_NAME_EXTRA_DATA_KEY in extra_data: + pipeline_name = extra_data[_PIPELINE_NAME_EXTRA_DATA_KEY] + else: + execution_node = session.get( + bts.ExecutionNode, pipeline_run.root_execution_id + ) + if execution_node: + pipeline_name = get_pipeline_name_from_task_spec( + task_spec_dict=execution_node.task_spec + ) + response.pipeline_name = pipeline_name + if include_execution_stats: + execution_status_stats = _calculate_execution_status_stats( + session=session, root_execution_id=pipeline_run.root_execution_id + ) + response.execution_status_stats = { + status.value: count for status, count in execution_status_stats.items() + } + return response + + +def get_pipeline_runs( + *, + session: orm.Session, + page_token: str | None = None, + filter: str | None = None, + filter_query: str | None = None, + current_user: str | None = None, + include_pipeline_names: bool = False, + include_execution_stats: bool = False, +) -> ListPipelineJobsResponse: + where_clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value=filter, + filter_query_value=filter_query, + page_token_value=page_token, + current_user=current_user, + page_size=_DEFAULT_PAGE_SIZE, + ) + + pipeline_runs = _query_pipeline_runs( + session=session, + where_clauses=where_clauses, + offset=offset, + page_size=_DEFAULT_PAGE_SIZE, + ) + + next_page_token = next_token if len(pipeline_runs) >= _DEFAULT_PAGE_SIZE else None + + return ListPipelineJobsResponse( + pipeline_runs=[ + _create_pipeline_run_response( + session=session, + pipeline_run=pipeline_run, + include_pipeline_names=include_pipeline_names, + include_execution_stats=include_execution_stats, + ) + for pipeline_run in pipeline_runs + ], + next_page_token=next_page_token, + ) diff --git a/tests/annotation/test_utils.py b/tests/annotation/test_utils.py new file mode 100644 index 0000000..512cf5c --- /dev/null +++ b/tests/annotation/test_utils.py @@ -0,0 +1,442 @@ +import pytest +import sqlalchemy +from sqlalchemy import orm + +from cloud_pipelines_backend import api_server_sql +from cloud_pipelines_backend import backend_types_sql as bts +from cloud_pipelines_backend import component_structures as structures +from cloud_pipelines_backend import database_ops +from cloud_pipelines_backend import errors +from cloud_pipelines_backend.annotation import utils as annotation_utils +from cloud_pipelines_backend.search import filter_query_sql + + +def _make_task_spec(pipeline_name: str = "test-pipeline") -> structures.TaskSpec: + return structures.TaskSpec( + component_ref=structures.ComponentReference( + spec=structures.ComponentSpec( + name=pipeline_name, + implementation=structures.ContainerImplementation( + container=structures.ContainerSpec(image="test-image:latest"), + ), + ), + ), + ) + + +@pytest.fixture() +def session_factory(): + engine = database_ops.create_db_engine(database_uri="sqlite://") + bts._TableBase.metadata.create_all(engine) + return orm.sessionmaker(engine) + + +@pytest.fixture() +def service(): + return api_server_sql.PipelineRunsApiService_Sql() + + +def _create_run(session_factory, service, **kwargs): + """Create a pipeline run using a fresh session (mirrors production per-request sessions).""" + with session_factory() as session: + return service.create(session, **kwargs) + + +class TestPipelineRunServiceCreate: + def test_create_returns_pipeline_run(self, session_factory, service): + result = _create_run( + session_factory, service, root_task=_make_task_spec("my-pipeline") + ) + assert result.id is not None + assert result.root_execution_id is not None + assert result.created_at is not None + + def test_create_with_created_by(self, session_factory, service): + result = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1@example.com", + ) + assert result.created_by == "user1@example.com" + + def test_create_with_annotations(self, session_factory, service): + annotations = {"team": "ml-ops", "project": "search"} + result = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + annotations=annotations, + ) + assert result.annotations == annotations + + def test_create_without_created_by(self, session_factory, service): + result = _create_run(session_factory, service, root_task=_make_task_spec()) + assert result.created_by is None + + def test_create_mirrors_name_and_created_by(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec("my-pipeline"), + created_by="alice", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "my-pipeline" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "alice" + ) + + def test_create_mirrors_name_only(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec("solo-pipeline"), + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "solo-pipeline" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "" + ) + + def test_create_mirrors_created_by_only(self, session_factory, service): + task_spec = _make_task_spec("placeholder") + task_spec.component_ref.spec.name = None + run = _create_run( + session_factory, service, root_task=task_spec, created_by="alice" + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "alice" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "" + ) + + def test_create_mirrors_empty_values_as_empty_string( + self, session_factory, service + ): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(""), + created_by="", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "" + ) + + def test_create_mirrors_absent_values_as_empty_string( + self, session_factory, service + ): + task_spec = _make_task_spec("placeholder") + task_spec.component_ref.spec.name = None + run = _create_run(session_factory, service, root_task=task_spec) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "" + ) + + +class TestPipelineRunAnnotationCrud: + def test_system_annotations_coexist_with_user_annotations( + self, session_factory, service + ): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec("my-pipeline"), + created_by="alice", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="ml-ops", + user_name="alice", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations["team"] == "ml-ops" + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "my-pipeline" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "alice" + ) + + def test_set_annotation(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="ml-ops", + user_name="user1", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations["team"] == "ml-ops" + + def test_set_annotation_overwrites(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="old-value", + user_name="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="new-value", + user_name="user1", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations["team"] == "new-value" + + def test_delete_annotation(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="ml-ops", + user_name="user1", + ) + with session_factory() as session: + service.delete_annotation( + session=session, + id=run.id, + key="team", + user_name="user1", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert "team" not in annotations + + def test_list_annotations_only_system(self, session_factory, service): + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations == { + filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME: "test-pipeline", + filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY: "", + } + + def test_set_annotation_rejects_system_key(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + with pytest.raises( + errors.ApiValidationError, match="reserved for system use" + ): + service.set_annotation( + session=session, + id=run.id, + key="system/pipeline_run.created_by", + value="hacker", + user_name="user1", + ) + + def test_delete_annotation_rejects_system_key(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + with pytest.raises( + errors.ApiValidationError, match="reserved for system use" + ): + service.delete_annotation( + session=session, + id=run.id, + key="system/pipeline_run.created_by", + user_name="user1", + ) + + +class TestTruncateForAnnotation: + """Unit tests for truncate_for_annotation() helper.""" + + def test_exact_255_unchanged(self) -> None: + value = "a" * bts._STR_MAX_LENGTH + result = annotation_utils._truncate_for_annotation( + value=value, + field_name=filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME, + pipeline_run_id="run-1", + ) + assert result == value + + def test_256_truncated_and_logs_warning(self, caplog) -> None: + value = "b" * 256 + field = filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME + with caplog.at_level("WARNING"): + result = annotation_utils._truncate_for_annotation( + value=value, + field_name=field, + pipeline_run_id="run-xyz", + ) + assert result == "b" * bts._STR_MAX_LENGTH + assert len(caplog.records) == 1 + msg = caplog.records[0].message + assert "run-xyz" in msg + assert str(field) in msg + + +class TestAnnotationValueOverflow: + """Reproduction tests using mysql_varchar_limit_session_factory (SQLite TRIGGER + enforcement). These tests prove that >255 char values are rejected, + mimicking MySQL's DataError 1406. + + Covers all write paths into pipeline_run_annotation: + - set_annotation(): long key, long value + - create() via mirror_system_annotations(): long pipeline_name, long created_by + """ + + # TODO: set_annotation() currently has no truncation guard for the + # VARCHAR(255) limit on annotation key/value columns. These tests + # document the failure. Fix deferred to a separate PR to avoid + # convoluting the backfill + mirror_system_annotations fix. + + def test_set_annotation_long_value_raises_on_overflow( + self, + mysql_varchar_limit_session_factory: orm.sessionmaker, + service: api_server_sql.PipelineRunsApiService_Sql, + ) -> None: + """set_annotation() with a 300-char value overflows the + VARCHAR(255) column and triggers IntegrityError.""" + run = _create_run( + mysql_varchar_limit_session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with mysql_varchar_limit_session_factory() as session: + with pytest.raises( + sqlalchemy.exc.IntegrityError, match="Data too long.*value" + ): + service.set_annotation( + session=session, + id=run.id, + key="team", + value="v" * 300, + user_name="user1", + ) + + def test_set_annotation_long_key_raises_on_overflow( + self, + mysql_varchar_limit_session_factory: orm.sessionmaker, + service: api_server_sql.PipelineRunsApiService_Sql, + ) -> None: + """set_annotation() with a 300-char key overflows the + VARCHAR(255) key column and triggers IntegrityError.""" + run = _create_run( + mysql_varchar_limit_session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with mysql_varchar_limit_session_factory() as session: + with pytest.raises( + sqlalchemy.exc.IntegrityError, match="Data too long.*key" + ): + service.set_annotation( + session=session, + id=run.id, + key="k" * 300, + value="short", + user_name="user1", + ) + + def test_create_run_long_pipeline_name_truncated( + self, + mysql_varchar_limit_session_factory: orm.sessionmaker, + service: api_server_sql.PipelineRunsApiService_Sql, + ) -> None: + """create() with a 300-char pipeline name is truncated to 255 + in mirror_system_annotations().""" + run = _create_run( + mysql_varchar_limit_session_factory, + service, + root_task=_make_task_spec("p" * 300), + ) + key = filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME + with mysql_varchar_limit_session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations[key] == "p" * bts._STR_MAX_LENGTH + + def test_create_run_long_created_by_truncated( + self, + mysql_varchar_limit_session_factory: orm.sessionmaker, + service: api_server_sql.PipelineRunsApiService_Sql, + ) -> None: + """create() with a 300-char created_by is truncated to 255 + in mirror_system_annotations().""" + run = _create_run( + mysql_varchar_limit_session_factory, + service, + root_task=_make_task_spec(), + created_by="u" * 300, + ) + key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + with mysql_varchar_limit_session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations[key] == "u" * bts._STR_MAX_LENGTH diff --git a/tests/test_database_migrations.py b/tests/backfill/test_annotations.py similarity index 99% rename from tests/test_database_migrations.py rename to tests/backfill/test_annotations.py index 288168a..5a638ed 100644 --- a/tests/test_database_migrations.py +++ b/tests/backfill/test_annotations.py @@ -29,9 +29,9 @@ from cloud_pipelines_backend import api_server_sql from cloud_pipelines_backend import backend_types_sql as bts from cloud_pipelines_backend import component_structures as structures -from cloud_pipelines_backend import database_migrations +from cloud_pipelines_backend.backfill import annotations as database_migrations from cloud_pipelines_backend import database_ops -from cloud_pipelines_backend import filter_query_sql +from cloud_pipelines_backend.search import filter_query_sql def _make_task_spec( diff --git a/tests/execution/test_utils.py b/tests/execution/test_utils.py new file mode 100644 index 0000000..fc3f035 --- /dev/null +++ b/tests/execution/test_utils.py @@ -0,0 +1,57 @@ +from cloud_pipelines_backend import backend_types_sql as bts +from cloud_pipelines_backend.execution import utils as execution_utils + + +class TestExecutionStatusSummary: + def test_initial_state(self): + summary = execution_utils.ExecutionStatusSummary() + assert summary.total_executions == 0 + assert summary.ended_executions == 0 + assert summary.has_ended is False + + def test_accumulate_all_ended_statuses(self): + """Add each ended status with 2^i count for robust uniqueness.""" + summary = execution_utils.ExecutionStatusSummary() + ended_statuses = sorted(bts.CONTAINER_STATUSES_ENDED, key=lambda s: s.value) + expected_total = 0 + expected_ended = 0 + for i, status in enumerate(ended_statuses): + count = 2**i + summary.count_execution_status(status=status, count=count) + expected_total += count + expected_ended += count + assert summary.total_executions == expected_total + assert summary.ended_executions == expected_ended + assert summary.has_ended is True + + def test_accumulate_all_in_progress_statuses(self): + """Add each in-progress status with 2^i count for robust uniqueness.""" + summary = execution_utils.ExecutionStatusSummary() + in_progress_statuses = sorted( + set(bts.ContainerExecutionStatus) - bts.CONTAINER_STATUSES_ENDED, + key=lambda s: s.value, + ) + expected_total = 0 + for i, status in enumerate(in_progress_statuses): + count = 2**i + summary.count_execution_status(status=status, count=count) + expected_total += count + assert summary.total_executions == expected_total + assert summary.ended_executions == 0 + assert summary.has_ended is False + + def test_accumulate_all_statuses(self): + """Add every status with 2^i count. Summary math must be exact.""" + summary = execution_utils.ExecutionStatusSummary() + all_statuses = sorted(bts.ContainerExecutionStatus, key=lambda s: s.value) + expected_total = 0 + expected_ended = 0 + for i, status in enumerate(all_statuses): + count = 2**i + expected_total += count + if status in bts.CONTAINER_STATUSES_ENDED: + expected_ended += count + summary.count_execution_status(status=status, count=count) + assert summary.total_executions == expected_total + assert summary.ended_executions == expected_ended + assert summary.has_ended == (expected_ended == expected_total) diff --git a/tests/test_filter_query_models.py b/tests/search/test_filter_query_models.py similarity index 99% rename from tests/test_filter_query_models.py rename to tests/search/test_filter_query_models.py index c0ba51c..cdf9cb5 100644 --- a/tests/test_filter_query_models.py +++ b/tests/search/test_filter_query_models.py @@ -1,7 +1,7 @@ import pydantic import pytest -from cloud_pipelines_backend import filter_query_models +from cloud_pipelines_backend.search import filter_query_models class TestFilterQuery: diff --git a/tests/test_filter_query_sql.py b/tests/search/test_filter_query_sql.py similarity index 99% rename from tests/test_filter_query_sql.py rename to tests/search/test_filter_query_sql.py index b299e52..e8147f6 100644 --- a/tests/test_filter_query_sql.py +++ b/tests/search/test_filter_query_sql.py @@ -7,8 +7,8 @@ from cloud_pipelines_backend import backend_types_sql as bts from cloud_pipelines_backend import errors -from cloud_pipelines_backend import filter_query_models -from cloud_pipelines_backend import filter_query_sql +from cloud_pipelines_backend.search import filter_query_models +from cloud_pipelines_backend.search import filter_query_sql def _compile(clause: sql.ColumnElement) -> str: diff --git a/tests/test_api_server_sql.py b/tests/search/test_runs.py similarity index 68% rename from tests/test_api_server_sql.py rename to tests/search/test_runs.py index bf30c53..18f0008 100644 --- a/tests/test_api_server_sql.py +++ b/tests/search/test_runs.py @@ -10,62 +10,8 @@ from cloud_pipelines_backend import component_structures as structures from cloud_pipelines_backend import database_ops from cloud_pipelines_backend import errors -from cloud_pipelines_backend import filter_query_sql - - -class TestExecutionStatusSummary: - def test_initial_state(self): - summary = api_server_sql.ExecutionStatusSummary() - assert summary.total_executions == 0 - assert summary.ended_executions == 0 - assert summary.has_ended is False - - def test_accumulate_all_ended_statuses(self): - """Add each ended status with 2^i count for robust uniqueness.""" - summary = api_server_sql.ExecutionStatusSummary() - ended_statuses = sorted(bts.CONTAINER_STATUSES_ENDED, key=lambda s: s.value) - expected_total = 0 - expected_ended = 0 - for i, status in enumerate(ended_statuses): - count = 2**i - summary.count_execution_status(status=status, count=count) - expected_total += count - expected_ended += count - assert summary.total_executions == expected_total - assert summary.ended_executions == expected_ended - assert summary.has_ended is True - - def test_accumulate_all_in_progress_statuses(self): - """Add each in-progress status with 2^i count for robust uniqueness.""" - summary = api_server_sql.ExecutionStatusSummary() - in_progress_statuses = sorted( - set(bts.ContainerExecutionStatus) - bts.CONTAINER_STATUSES_ENDED, - key=lambda s: s.value, - ) - expected_total = 0 - for i, status in enumerate(in_progress_statuses): - count = 2**i - summary.count_execution_status(status=status, count=count) - expected_total += count - assert summary.total_executions == expected_total - assert summary.ended_executions == 0 - assert summary.has_ended is False - - def test_accumulate_all_statuses(self): - """Add every status with 2^i count. Summary math must be exact.""" - summary = api_server_sql.ExecutionStatusSummary() - all_statuses = sorted(bts.ContainerExecutionStatus, key=lambda s: s.value) - expected_total = 0 - expected_ended = 0 - for i, status in enumerate(all_statuses): - count = 2**i - expected_total += count - if status in bts.CONTAINER_STATUSES_ENDED: - expected_ended += count - summary.count_execution_status(status=status, count=count) - assert summary.total_executions == expected_total - assert summary.ended_executions == expected_ended - assert summary.has_ended == (expected_ended == expected_total) +from cloud_pipelines_backend.search import filter_query_sql +from cloud_pipelines_backend.search import runs as search_runs def _make_task_spec(pipeline_name: str = "test-pipeline") -> structures.TaskSpec: @@ -236,7 +182,7 @@ def test_base_response(self, session_factory, service): run = _create_run(session_factory, service, root_task=_make_task_spec()) with session_factory() as session: db_run = session.get(bts.PipelineRun, run.id) - response = service._create_pipeline_run_response( + response = search_runs._create_pipeline_run_response( session=session, pipeline_run=db_run, include_pipeline_names=False, @@ -254,7 +200,7 @@ def test_pipeline_name_from_task_spec(self, session_factory, service): ) with session_factory() as session: db_run = session.get(bts.PipelineRun, run.id) - response = service._create_pipeline_run_response( + response = search_runs._create_pipeline_run_response( session=session, pipeline_run=db_run, include_pipeline_names=True, @@ -274,7 +220,7 @@ def test_pipeline_name_from_extra_data(self, session_factory, service): session.commit() with session_factory() as session: db_run = session.get(bts.PipelineRun, run.id) - response = service._create_pipeline_run_response( + response = search_runs._create_pipeline_run_response( session=session, pipeline_run=db_run, include_pipeline_names=True, @@ -295,7 +241,7 @@ def test_pipeline_name_none_when_no_execution_node(self, session_factory, servic session.commit() with session_factory() as session: db_run = session.get(bts.PipelineRun, run.id) - response = service._create_pipeline_run_response( + response = search_runs._create_pipeline_run_response( session=session, pipeline_run=db_run, include_pipeline_names=True, @@ -307,7 +253,7 @@ def test_with_execution_stats(self, session_factory, service): run = _create_run(session_factory, service, root_task=_make_task_spec()) with session_factory() as session: db_run = session.get(bts.PipelineRun, run.id) - response = service._create_pipeline_run_response( + response = search_runs._create_pipeline_run_response( session=session, pipeline_run=db_run, include_pipeline_names=False, @@ -316,406 +262,6 @@ def test_with_execution_stats(self, session_factory, service): assert response.execution_status_stats is not None -class TestPipelineRunServiceCreate: - def test_create_returns_pipeline_run(self, session_factory, service): - result = _create_run( - session_factory, service, root_task=_make_task_spec("my-pipeline") - ) - assert result.id is not None - assert result.root_execution_id is not None - assert result.created_at is not None - - def test_create_with_created_by(self, session_factory, service): - result = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - created_by="user1@example.com", - ) - assert result.created_by == "user1@example.com" - - def test_create_with_annotations(self, session_factory, service): - annotations = {"team": "ml-ops", "project": "search"} - result = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - annotations=annotations, - ) - assert result.annotations == annotations - - def test_create_without_created_by(self, session_factory, service): - result = _create_run(session_factory, service, root_task=_make_task_spec()) - assert result.created_by is None - - def test_create_mirrors_name_and_created_by(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec("my-pipeline"), - created_by="alice", - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] - == "my-pipeline" - ) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] - == "alice" - ) - - def test_create_mirrors_name_only(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec("solo-pipeline"), - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] - == "solo-pipeline" - ) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] - == "" - ) - - def test_create_mirrors_created_by_only(self, session_factory, service): - task_spec = _make_task_spec("placeholder") - task_spec.component_ref.spec.name = None - run = _create_run( - session_factory, service, root_task=task_spec, created_by="alice" - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] - == "alice" - ) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] - == "" - ) - - def test_create_mirrors_empty_values_as_empty_string( - self, session_factory, service - ): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec(""), - created_by="", - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] - == "" - ) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] - == "" - ) - - def test_create_mirrors_absent_values_as_empty_string( - self, session_factory, service - ): - task_spec = _make_task_spec("placeholder") - task_spec.component_ref.spec.name = None - run = _create_run(session_factory, service, root_task=task_spec) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] - == "" - ) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] - == "" - ) - - -class TestPipelineRunAnnotationCrud: - def test_system_annotations_coexist_with_user_annotations( - self, session_factory, service - ): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec("my-pipeline"), - created_by="alice", - ) - with session_factory() as session: - service.set_annotation( - session=session, - id=run.id, - key="team", - value="ml-ops", - user_name="alice", - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert annotations["team"] == "ml-ops" - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] - == "my-pipeline" - ) - assert ( - annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] - == "alice" - ) - - def test_set_annotation(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with session_factory() as session: - service.set_annotation( - session=session, - id=run.id, - key="team", - value="ml-ops", - user_name="user1", - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert annotations["team"] == "ml-ops" - - def test_set_annotation_overwrites(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with session_factory() as session: - service.set_annotation( - session=session, - id=run.id, - key="team", - value="old-value", - user_name="user1", - ) - with session_factory() as session: - service.set_annotation( - session=session, - id=run.id, - key="team", - value="new-value", - user_name="user1", - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert annotations["team"] == "new-value" - - def test_delete_annotation(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with session_factory() as session: - service.set_annotation( - session=session, - id=run.id, - key="team", - value="ml-ops", - user_name="user1", - ) - with session_factory() as session: - service.delete_annotation( - session=session, - id=run.id, - key="team", - user_name="user1", - ) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert "team" not in annotations - - def test_list_annotations_only_system(self, session_factory, service): - run = _create_run(session_factory, service, root_task=_make_task_spec()) - with session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert annotations == { - filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME: "test-pipeline", - filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY: "", - } - - def test_set_annotation_rejects_system_key(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with session_factory() as session: - with pytest.raises( - errors.ApiValidationError, match="reserved for system use" - ): - service.set_annotation( - session=session, - id=run.id, - key="system/pipeline_run.created_by", - value="hacker", - user_name="user1", - ) - - def test_delete_annotation_rejects_system_key(self, session_factory, service): - run = _create_run( - session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with session_factory() as session: - with pytest.raises( - errors.ApiValidationError, match="reserved for system use" - ): - service.delete_annotation( - session=session, - id=run.id, - key="system/pipeline_run.created_by", - user_name="user1", - ) - - -class TestTruncateForAnnotation: - """Unit tests for _truncate_for_annotation() helper.""" - - def test_exact_255_unchanged(self) -> None: - value = "a" * bts._STR_MAX_LENGTH - result = api_server_sql._truncate_for_annotation( - value=value, - field_name=filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME, - pipeline_run_id="run-1", - ) - assert result == value - - def test_256_truncated_and_logs_warning(self, caplog) -> None: - value = "b" * 256 - field = filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME - with caplog.at_level("WARNING"): - result = api_server_sql._truncate_for_annotation( - value=value, - field_name=field, - pipeline_run_id="run-xyz", - ) - assert result == "b" * bts._STR_MAX_LENGTH - assert len(caplog.records) == 1 - msg = caplog.records[0].message - assert "run-xyz" in msg - assert str(field) in msg - - -class TestAnnotationValueOverflow: - """Reproduction tests using mysql_varchar_limit_session_factory (SQLite TRIGGER - enforcement). These tests prove that >255 char values are rejected, - mimicking MySQL's DataError 1406. - - Covers all write paths into pipeline_run_annotation: - - set_annotation(): long key, long value - - create() via _mirror_system_annotations(): long pipeline_name, long created_by - """ - - # TODO: set_annotation() currently has no truncation guard for the - # VARCHAR(255) limit on annotation key/value columns. These tests - # document the failure. Fix deferred to a separate PR to avoid - # convoluting the backfill + _mirror_system_annotations fix. - - def test_set_annotation_long_value_raises_on_overflow( - self, - mysql_varchar_limit_session_factory: orm.sessionmaker, - service: api_server_sql.PipelineRunsApiService_Sql, - ) -> None: - """set_annotation() with a 300-char value overflows the - VARCHAR(255) column and triggers IntegrityError.""" - run = _create_run( - mysql_varchar_limit_session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with mysql_varchar_limit_session_factory() as session: - with pytest.raises( - sqlalchemy.exc.IntegrityError, match="Data too long.*value" - ): - service.set_annotation( - session=session, - id=run.id, - key="team", - value="v" * 300, - user_name="user1", - ) - - def test_set_annotation_long_key_raises_on_overflow( - self, - mysql_varchar_limit_session_factory: orm.sessionmaker, - service: api_server_sql.PipelineRunsApiService_Sql, - ) -> None: - """set_annotation() with a 300-char key overflows the - VARCHAR(255) key column and triggers IntegrityError.""" - run = _create_run( - mysql_varchar_limit_session_factory, - service, - root_task=_make_task_spec(), - created_by="user1", - ) - with mysql_varchar_limit_session_factory() as session: - with pytest.raises( - sqlalchemy.exc.IntegrityError, match="Data too long.*key" - ): - service.set_annotation( - session=session, - id=run.id, - key="k" * 300, - value="short", - user_name="user1", - ) - - def test_create_run_long_pipeline_name_truncated( - self, - mysql_varchar_limit_session_factory: orm.sessionmaker, - service: api_server_sql.PipelineRunsApiService_Sql, - ) -> None: - """create() with a 300-char pipeline name is truncated to 255 - in _mirror_system_annotations().""" - run = _create_run( - mysql_varchar_limit_session_factory, - service, - root_task=_make_task_spec("p" * 300), - ) - key = filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME - with mysql_varchar_limit_session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert annotations[key] == "p" * bts._STR_MAX_LENGTH - - def test_create_run_long_created_by_truncated( - self, - mysql_varchar_limit_session_factory: orm.sessionmaker, - service: api_server_sql.PipelineRunsApiService_Sql, - ) -> None: - """create() with a 300-char created_by is truncated to 255 - in _mirror_system_annotations().""" - run = _create_run( - mysql_varchar_limit_session_factory, - service, - root_task=_make_task_spec(), - created_by="u" * 300, - ) - key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY - with mysql_varchar_limit_session_factory() as session: - annotations = service.list_annotations(session=session, id=run.id) - assert annotations[key] == "u" * bts._STR_MAX_LENGTH - - class TestFilterQueryApiWiring: def test_filter_query_validates_invalid_json(self, session_factory, service): from pydantic import ValidationError @@ -1579,26 +1125,32 @@ def test_filter_query_created_by_unsupported_predicate( class TestGetPipelineNameFromTaskSpec: - """Unit tests for _get_pipeline_name_from_task_spec.""" + """Unit tests for get_pipeline_name_from_task_spec.""" def test_returns_name(self): """Happy path: task_spec_dict -> TaskSpec -> component_ref -> spec -> name""" task = _make_task_spec(pipeline_name="my-pipe") - result = api_server_sql._get_pipeline_name_from_task_spec( + from cloud_pipelines_backend.search import runs + + result = runs.get_pipeline_name_from_task_spec( task_spec_dict=task.to_json_dict() ) assert result == "my-pipe" def test_returns_none_when_spec_is_none(self): """task_spec_dict -> TaskSpec -> component_ref -> [spec=None]""" - result = api_server_sql._get_pipeline_name_from_task_spec( + from cloud_pipelines_backend.search import runs + + result = runs.get_pipeline_name_from_task_spec( task_spec_dict={"component_ref": {}}, ) assert result is None def test_returns_none_when_name_is_none(self): """task_spec_dict -> ... -> spec -> [name=None]""" - result = api_server_sql._get_pipeline_name_from_task_spec( + from cloud_pipelines_backend.search import runs + + result = runs.get_pipeline_name_from_task_spec( task_spec_dict={ "component_ref": { "spec": { @@ -1613,7 +1165,9 @@ def test_returns_none_when_name_is_none(self): def test_returns_none_when_name_is_empty(self): """task_spec_dict -> ... -> spec -> [name=""]""" - result = api_server_sql._get_pipeline_name_from_task_spec( + from cloud_pipelines_backend.search import runs + + result = runs.get_pipeline_name_from_task_spec( task_spec_dict={ "component_ref": { "spec": { @@ -1629,7 +1183,115 @@ def test_returns_none_when_name_is_empty(self): def test_returns_none_on_malformed_dict(self): """[task_spec_dict=malformed] -> from_json_dict() raises""" - result = api_server_sql._get_pipeline_name_from_task_spec( - task_spec_dict={"bad": "data"} - ) + from cloud_pipelines_backend.search import runs + + result = runs.get_pipeline_name_from_task_spec(task_spec_dict={"bad": "data"}) assert result is None + + +class TestCalculateExecutionStatusStats: + """Unit tests for search_runs._calculate_execution_status_stats.""" + + _MINIMAL_TASK_SPEC: dict = { + "componentRef": { + "spec": { + "name": "stub", + "implementation": { + "container": {"image": "stub:latest"}, + }, + } + } + } + + def _make_execution( + self, + session: orm.Session, + *, + root: bts.ExecutionNode, + status: bts.ContainerExecutionStatus | None = None, + ) -> bts.ExecutionNode: + node = bts.ExecutionNode(task_spec=self._MINIMAL_TASK_SPEC) + node.container_execution_status = status + session.add(node) + session.flush() + link = bts.ExecutionToAncestorExecutionLink( + ancestor_execution=root, + execution=node, + ) + session.add(link) + session.flush() + return node + + def test_empty(self, session_factory): + root = bts.ExecutionNode(task_spec=self._MINIMAL_TASK_SPEC) + with session_factory() as session: + session.add(root) + session.flush() + result = search_runs._calculate_execution_status_stats( + session=session, root_execution_id=root.id + ) + assert result == {} + + def test_single_status_group(self, session_factory): + root = bts.ExecutionNode(task_spec=self._MINIMAL_TASK_SPEC) + with session_factory() as session: + session.add(root) + session.flush() + self._make_execution( + session, + root=root, + status=bts.ContainerExecutionStatus.SUCCEEDED, + ) + self._make_execution( + session, + root=root, + status=bts.ContainerExecutionStatus.SUCCEEDED, + ) + result = search_runs._calculate_execution_status_stats( + session=session, root_execution_id=root.id + ) + assert result == {bts.ContainerExecutionStatus.SUCCEEDED: 2} + + def test_multiple_status_groups(self, session_factory): + root = bts.ExecutionNode(task_spec=self._MINIMAL_TASK_SPEC) + with session_factory() as session: + session.add(root) + session.flush() + self._make_execution( + session, + root=root, + status=bts.ContainerExecutionStatus.SUCCEEDED, + ) + self._make_execution( + session, + root=root, + status=bts.ContainerExecutionStatus.SUCCEEDED, + ) + self._make_execution( + session, + root=root, + status=bts.ContainerExecutionStatus.FAILED, + ) + result = search_runs._calculate_execution_status_stats( + session=session, root_execution_id=root.id + ) + assert result == { + bts.ContainerExecutionStatus.SUCCEEDED: 2, + bts.ContainerExecutionStatus.FAILED: 1, + } + + def test_null_statuses_excluded(self, session_factory): + root = bts.ExecutionNode(task_spec=self._MINIMAL_TASK_SPEC) + with session_factory() as session: + session.add(root) + session.flush() + self._make_execution( + session, + root=root, + status=bts.ContainerExecutionStatus.RUNNING, + ) + self._make_execution(session, root=root, status=None) + result = search_runs._calculate_execution_status_stats( + session=session, root_execution_id=root.id + ) + assert result == {bts.ContainerExecutionStatus.RUNNING: 1}