-
Notifications
You must be signed in to change notification settings - Fork 221
[Draft PR] Support router as replica with pipelines #3721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -801,6 +801,12 @@ class ReplicaGroup(CoreModel): | |
| CommandsList, | ||
| Field(description="The shell commands to run for replicas in this group"), | ||
| ] = [] | ||
| router: Annotated[ | ||
| Optional[AnyServiceRouterConfig], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(see this thread) |
||
| Field( | ||
| description="When set, replicas in this group run the in-service HTTP router (e.g. SGLang).", | ||
| ), | ||
| ] = None | ||
|
|
||
| @validator("name") | ||
| def validate_name(cls, v: Optional[str]) -> Optional[str]: | ||
|
|
@@ -1032,6 +1038,20 @@ def validate_replica_groups_have_commands_or_image(cls, values): | |
|
|
||
| return values | ||
|
|
||
| @root_validator() | ||
| def validate_at_most_one_router_replica_group(cls, values): | ||
| replicas = values.get("replicas") | ||
| if not isinstance(replicas, list): | ||
| return values | ||
| router_groups = [g for g in replicas if g.router is not None] | ||
| if len(router_groups) > 1: | ||
| raise ValueError("At most one replica group may specify `router`.") | ||
| if router_groups: | ||
| router_group = router_groups[0] | ||
| if router_group.count.min != 1 or router_group.count.max != 1: | ||
| raise ValueError("For now replica group with `router` must have `count: 1`.") | ||
| return values | ||
|
|
||
|
|
||
| class ServiceConfigurationConfig( | ||
| ProfileParamsConfig, | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there anything to prevent exposing the router replica's See the second comment in this thread |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,6 +141,7 @@ async def register_replica( | |
| nginx: Nginx, | ||
| service_conn_pool: ServiceConnectionPool, | ||
| internal_ip: Optional[str] = None, | ||
| is_router_replica: bool = False, | ||
| ) -> None: | ||
| replica = models.Replica( | ||
| id=replica_id, | ||
|
|
@@ -152,6 +153,7 @@ async def register_replica( | |
| ssh_head_proxy=ssh_head_proxy, | ||
| ssh_head_proxy_private_key=ssh_head_proxy_private_key, | ||
| internal_ip=internal_ip, | ||
| is_router_replica=is_router_replica, | ||
| ) | ||
|
|
||
| async with lock: | ||
|
|
@@ -291,6 +293,13 @@ async def apply_service( | |
| ) | ||
| for replica, conn in replica_conns.items() | ||
| ] | ||
| router_replicas = [r for r in service.replicas if r.is_router_replica] | ||
| if router_replicas: | ||
| replica_configs_for_nginx = [c for c in replica_configs if c.id == router_replicas[0].id] | ||
|
Comment on lines
+296
to
+298
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the router replica is not yet registered, or temporarily unregistered (e.g., if it failed and is being restarted)? It seems that the gateway will then assume that the service doesn't have a router replica, and Nginx will direct incoming requests directly to worker replicas, which is not expected. Also, do we actually need to register worker replicas on the gateway, considering the gateway should only communicate with the router replica? My initial proposal was not to register them, and I think that would fix the problem above, and also optimize and simplify a few things (no need for extra network communication, no need to distinguish between router and worker replicas on the gateway, etc). |
||
| service_config = await get_nginx_service_config(service, replica_configs_for_nginx) | ||
| await nginx.register(service_config, (await repo.get_config()).acme_settings) | ||
| return replica_failures | ||
|
|
||
| service_config = await get_nginx_service_config(service, replica_configs) | ||
| await nginx.register(service_config, (await repo.get_config()).acme_settings) | ||
| return replica_failures | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have some time, consider covering the pipeline by unit tests. I think all (or at least most of) our existing pipelines and background tasks currently have good coverage |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,275 @@ | ||
| import asyncio | ||
| import uuid | ||
| from dataclasses import dataclass | ||
| from datetime import timedelta | ||
| from typing import Sequence | ||
|
|
||
| from sqlalchemy import delete, or_, select, true, update | ||
| from sqlalchemy.orm import joinedload, load_only, selectinload | ||
|
|
||
| from dstack._internal.core.models.runs import JobStatus, RunStatus | ||
| from dstack._internal.server.background.pipeline_tasks.base import ( | ||
| Fetcher, | ||
| Heartbeater, | ||
| ItemUpdateMap, | ||
| Pipeline, | ||
| PipelineItem, | ||
| Worker, | ||
| log_lock_token_changed_after_processing, | ||
| log_lock_token_mismatch, | ||
| resolve_now_placeholders, | ||
| set_processed_update_map_fields, | ||
| set_unlock_update_map_fields, | ||
| ) | ||
| from dstack._internal.server.db import get_db, get_session_ctx | ||
| from dstack._internal.server.models import ( | ||
| InstanceModel, | ||
| JobModel, | ||
| ProjectModel, | ||
| RunModel, | ||
| ServiceRouterWorkerSyncModel, | ||
| ) | ||
| from dstack._internal.server.services.locking import get_locker | ||
| from dstack._internal.server.services.pipelines import PipelineHinterProtocol | ||
| from dstack._internal.server.services.router_worker_sync import ( | ||
| run_model_has_router_replica_group, | ||
| sync_router_workers_for_run_model, | ||
| ) | ||
| from dstack._internal.server.utils import sentry_utils | ||
| from dstack._internal.utils.common import get_current_datetime | ||
| from dstack._internal.utils.logging import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ServiceRouterWorkerSyncPipelineItem(PipelineItem): | ||
| run_id: uuid.UUID | ||
|
|
||
|
|
||
| class ServiceRouterWorkerSyncPipeline(Pipeline[ServiceRouterWorkerSyncPipelineItem]): | ||
| def __init__( | ||
| self, | ||
| workers_num: int = 8, | ||
| queue_lower_limit_factor: float = 0.5, | ||
| queue_upper_limit_factor: float = 2.0, | ||
| min_processing_interval: timedelta = timedelta(seconds=5), | ||
| lock_timeout: timedelta = timedelta(seconds=25), | ||
| heartbeat_trigger: timedelta = timedelta(seconds=10), | ||
| *, | ||
| pipeline_hinter: PipelineHinterProtocol, | ||
| ) -> None: | ||
| super().__init__( | ||
| workers_num=workers_num, | ||
| queue_lower_limit_factor=queue_lower_limit_factor, | ||
| queue_upper_limit_factor=queue_upper_limit_factor, | ||
| min_processing_interval=min_processing_interval, | ||
| lock_timeout=lock_timeout, | ||
| heartbeat_trigger=heartbeat_trigger, | ||
| ) | ||
| self.__heartbeater = Heartbeater[ServiceRouterWorkerSyncPipelineItem]( | ||
| model_type=ServiceRouterWorkerSyncModel, | ||
| lock_timeout=self._lock_timeout, | ||
| heartbeat_trigger=self._heartbeat_trigger, | ||
| ) | ||
| self.__fetcher = ServiceRouterWorkerSyncFetcher( | ||
| queue=self._queue, | ||
| queue_desired_minsize=self._queue_desired_minsize, | ||
| min_processing_interval=self._min_processing_interval, | ||
| lock_timeout=self._lock_timeout, | ||
| heartbeater=self.__heartbeater, | ||
| ) | ||
| self.__workers = [ | ||
| ServiceRouterWorkerSyncWorker( | ||
| queue=self._queue, | ||
| heartbeater=self.__heartbeater, | ||
| pipeline_hinter=pipeline_hinter, | ||
| ) | ||
| for _ in range(self._workers_num) | ||
| ] | ||
|
|
||
| @property | ||
| def hint_fetch_model_name(self) -> str: | ||
| return ServiceRouterWorkerSyncModel.__name__ | ||
|
|
||
| @property | ||
| def _heartbeater(self) -> Heartbeater[ServiceRouterWorkerSyncPipelineItem]: | ||
| return self.__heartbeater | ||
|
|
||
| @property | ||
| def _fetcher(self) -> Fetcher[ServiceRouterWorkerSyncPipelineItem]: | ||
| return self.__fetcher | ||
|
|
||
| @property | ||
| def _workers(self) -> Sequence["ServiceRouterWorkerSyncWorker"]: | ||
| return self.__workers | ||
|
|
||
|
|
||
| class ServiceRouterWorkerSyncFetcher(Fetcher[ServiceRouterWorkerSyncPipelineItem]): | ||
| @sentry_utils.instrument_pipeline_task("ServiceRouterWorkerSyncFetcher.fetch") | ||
| async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]: | ||
| sync_lock, _ = get_locker(get_db().dialect_name).get_lockset( | ||
| ServiceRouterWorkerSyncModel.__tablename__ | ||
| ) | ||
| async with sync_lock: | ||
| async with get_session_ctx() as session: | ||
| now = get_current_datetime() | ||
| res = await session.execute( | ||
| select(ServiceRouterWorkerSyncModel) | ||
| .join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id) | ||
| .where( | ||
| RunModel.status == RunStatus.RUNNING, | ||
| or_( | ||
| ServiceRouterWorkerSyncModel.last_processed_at | ||
| <= now - self._min_processing_interval, | ||
| ServiceRouterWorkerSyncModel.last_processed_at | ||
| == ServiceRouterWorkerSyncModel.created_at, | ||
| ), | ||
| or_( | ||
| ServiceRouterWorkerSyncModel.lock_expires_at.is_(None), | ||
| ServiceRouterWorkerSyncModel.lock_expires_at < now, | ||
| ), | ||
| ) | ||
| .order_by(ServiceRouterWorkerSyncModel.last_processed_at.asc()) | ||
| .limit(limit) | ||
| .with_for_update( | ||
| skip_locked=True, key_share=True, of=ServiceRouterWorkerSyncModel | ||
| ) | ||
| .options( | ||
| load_only( | ||
| ServiceRouterWorkerSyncModel.id, | ||
| ServiceRouterWorkerSyncModel.run_id, | ||
| ServiceRouterWorkerSyncModel.lock_token, | ||
| ServiceRouterWorkerSyncModel.lock_expires_at, | ||
| ) | ||
| ) | ||
| ) | ||
| rows = list(res.scalars().all()) | ||
| lock_expires_at = get_current_datetime() + self._lock_timeout | ||
| lock_token = uuid.uuid4() | ||
| items: list[ServiceRouterWorkerSyncPipelineItem] = [] | ||
| for row in rows: | ||
| prev_lock_expired = row.lock_expires_at is not None | ||
| row.lock_expires_at = lock_expires_at | ||
| row.lock_token = lock_token | ||
| row.lock_owner = ServiceRouterWorkerSyncPipeline.__name__ | ||
| items.append( | ||
| ServiceRouterWorkerSyncPipelineItem( | ||
| __tablename__=ServiceRouterWorkerSyncModel.__tablename__, | ||
| id=row.id, | ||
| lock_expires_at=lock_expires_at, | ||
| lock_token=lock_token, | ||
| prev_lock_expired=prev_lock_expired, | ||
| run_id=row.run_id, | ||
| ) | ||
| ) | ||
| await session.commit() | ||
| return items | ||
|
|
||
|
|
||
| class _SyncRowUpdateMap(ItemUpdateMap, total=False): | ||
| pass | ||
|
|
||
|
|
||
| class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]): | ||
| def __init__( | ||
| self, | ||
| queue: asyncio.Queue[ServiceRouterWorkerSyncPipelineItem], | ||
| heartbeater: Heartbeater[ServiceRouterWorkerSyncPipelineItem], | ||
| pipeline_hinter: PipelineHinterProtocol, | ||
| ) -> None: | ||
| super().__init__( | ||
| queue=queue, | ||
| heartbeater=heartbeater, | ||
| pipeline_hinter=pipeline_hinter, | ||
| ) | ||
|
|
||
| @sentry_utils.instrument_pipeline_task("ServiceRouterWorkerSyncWorker.process") | ||
| async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None: | ||
| async with get_session_ctx() as session: | ||
| res = await session.execute( | ||
| select(ServiceRouterWorkerSyncModel) | ||
| .where( | ||
| ServiceRouterWorkerSyncModel.id == item.id, | ||
| ServiceRouterWorkerSyncModel.lock_token == item.lock_token, | ||
| ) | ||
| .options(selectinload(ServiceRouterWorkerSyncModel.run)) | ||
| ) | ||
| sync_row = res.unique().scalar_one_or_none() | ||
| if sync_row is None: | ||
| log_lock_token_mismatch(logger, item) | ||
| return | ||
| run_model = sync_row.run | ||
| if ( | ||
| run_model.deleted | ||
| or run_model.status.is_finished() | ||
| or run_model.status != RunStatus.RUNNING | ||
| or not run_model_has_router_replica_group(run_model) | ||
| ): | ||
| await session.delete(sync_row) | ||
| await session.commit() | ||
| return | ||
|
|
||
| async with get_session_ctx() as session: | ||
| res = await session.execute( | ||
| select(RunModel) | ||
| .where(RunModel.id == item.run_id) | ||
| .options( | ||
| load_only(RunModel.id, RunModel.run_spec), | ||
| selectinload( | ||
| RunModel.jobs.and_( | ||
| JobModel.status == JobStatus.RUNNING, | ||
| JobModel.registered == true(), | ||
| ) | ||
| ) | ||
| .load_only( | ||
| JobModel.id, | ||
| JobModel.status, | ||
| JobModel.registered, | ||
| JobModel.job_spec_data, | ||
| JobModel.job_provisioning_data, | ||
| JobModel.job_runtime_data, | ||
| ) | ||
| .options( | ||
| joinedload(JobModel.project).load_only( | ||
| ProjectModel.id, ProjectModel.ssh_private_key | ||
| ), | ||
| joinedload(JobModel.instance) | ||
| .load_only(InstanceModel.id, InstanceModel.remote_connection_info) | ||
| .joinedload(InstanceModel.project) | ||
| .load_only(ProjectModel.id, ProjectModel.ssh_private_key), | ||
| ), | ||
| ) | ||
| ) | ||
|
Comment on lines
+217
to
+243
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is potentially a very inefficient select – a run can have thousands of job submissions. Select only the jobs that the processing needs, i.e. only the router replica job. Also every selectinload will be a separate query here – not sure if it's justified. joinedload may be a better suited for a one-to-one rel. Also, try to avoid loading all models's columns and use load_only to select only the necessary.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check if below proposed query addresses the concerns
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks good, at least at a glance |
||
| run_for_sync = res.unique().scalar_one_or_none() | ||
|
|
||
| if run_for_sync is None: | ||
| async with get_session_ctx() as session: | ||
| await session.execute( | ||
| delete(ServiceRouterWorkerSyncModel).where( | ||
| ServiceRouterWorkerSyncModel.id == item.id, | ||
| ServiceRouterWorkerSyncModel.lock_token == item.lock_token, | ||
| ) | ||
| ) | ||
| await session.commit() | ||
| return | ||
|
|
||
| await sync_router_workers_for_run_model(run_for_sync) | ||
|
|
||
| update_map: _SyncRowUpdateMap = {} | ||
| set_processed_update_map_fields(update_map) | ||
| set_unlock_update_map_fields(update_map) | ||
| async with get_session_ctx() as session: | ||
| now = get_current_datetime() | ||
| resolve_now_placeholders(update_map, now=now) | ||
| res2 = await session.execute( | ||
| update(ServiceRouterWorkerSyncModel) | ||
| .where( | ||
| ServiceRouterWorkerSyncModel.id == item.id, | ||
| ServiceRouterWorkerSyncModel.lock_token == item.lock_token, | ||
| ) | ||
| .values(**update_map) | ||
| .returning(ServiceRouterWorkerSyncModel.id) | ||
| ) | ||
| if not list(res2.scalars().all()): | ||
| log_lock_token_changed_after_processing(logger, item) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add excludes in
core/compatibility, for client compatibility with older servers