Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,12 @@ class ReplicaGroup(CoreModel):
CommandsList,
Field(description="The shell commands to run for replicas in this group"),
] = []
router: Annotated[
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add excludes in core/compatibility, for client compatibility with older servers

Optional[AnyServiceRouterConfig],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AnyServiceRouterConfig has the policy and pd_disaggregation properties that are not applicable here. I think we might need a separate class for ReplicaGroup.router, without these options.

(see this thread)

Field(
description="When set, replicas in this group run the in-service HTTP router (e.g. SGLang).",
),
] = None

@validator("name")
def validate_name(cls, v: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -1032,6 +1038,20 @@ def validate_replica_groups_have_commands_or_image(cls, values):

return values

@root_validator()
def validate_at_most_one_router_replica_group(cls, values):
replicas = values.get("replicas")
if not isinstance(replicas, list):
return values
router_groups = [g for g in replicas if g.router is not None]
if len(router_groups) > 1:
raise ValueError("At most one replica group may specify `router`.")
if router_groups:
router_group = router_groups[0]
if router_group.count.min != 1 or router_group.count.max != 1:
raise ValueError("For now replica group with `router` must have `count: 1`.")
return values


class ServiceConfigurationConfig(
ProfileParamsConfig,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def register_replica(
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
internal_ip=body.internal_ip,
is_router_replica=body.is_router_replica,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class RegisterReplicaRequest(BaseModel):
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]
internal_ip: Optional[str] = None
is_router_replica: bool = False


class RegisterEntrypointRequest(BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything to prevent exposing the router replica's /workers API on Nginx?

See the second comment in this thread

Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def register_replica(
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
internal_ip: Optional[str] = None,
is_router_replica: bool = False,
) -> None:
replica = models.Replica(
id=replica_id,
Expand All @@ -152,6 +153,7 @@ async def register_replica(
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
internal_ip=internal_ip,
is_router_replica=is_router_replica,
)

async with lock:
Expand Down Expand Up @@ -291,6 +293,13 @@ async def apply_service(
)
for replica, conn in replica_conns.items()
]
router_replicas = [r for r in service.replicas if r.is_router_replica]
if router_replicas:
replica_configs_for_nginx = [c for c in replica_configs if c.id == router_replicas[0].id]
Comment on lines +296 to +298
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the router replica is not yet registered, or temporarily unregistered (e.g., if it failed and is being restarted)? It seems that the gateway will then assume that the service doesn't have a router replica, and Nginx will direct incoming requests directly to worker replicas, which is not expected.

Also, do we actually need to register worker replicas on the gateway, considering the gateway should only communicate with the router replica? My initial proposal was not to register them, and I think that would fix the problem above, and also optimize and simplify a few things (no need for extra network communication, no need to distinguish between router and worker replicas on the gateway, etc).

service_config = await get_nginx_service_config(service, replica_configs_for_nginx)
await nginx.register(service_config, (await repo.get_config()).acme_settings)
return replica_failures

service_config = await get_nginx_service_config(service, replica_configs)
await nginx.register(service_config, (await repo.get_config()).acme_settings)
return replica_failures
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Replica(ImmutableModel):
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
internal_ip: Optional[str] = None
is_router_replica: bool = False


class IPAddressPartitioningKey(ImmutableModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ async def get_service_replica_client(
timeout=HTTP_TIMEOUT,
)
# Nginx not available, forward directly to the tunnel
replica = random.choice(service.replicas)
router_replicas = [r for r in service.replicas if r.is_router_replica]
replicas_to_use = router_replicas if router_replicas else service.replicas
replica = random.choice(replicas_to_use)
connection = await service_conn_pool.get(replica.id)
if connection is None:
project = await repo.get_project(service.project_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
PlacementGroupPipeline,
)
from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline
from dstack._internal.server.background.pipeline_tasks.service_router_worker_sync import (
ServiceRouterWorkerSyncPipeline,
)
from dstack._internal.server.background.pipeline_tasks.volumes import VolumePipeline
from dstack._internal.utils.logging import get_logger

Expand All @@ -36,6 +39,7 @@ def __init__(self) -> None:
InstancePipeline(pipeline_hinter=self._hinter),
PlacementGroupPipeline(pipeline_hinter=self._hinter),
RunPipeline(pipeline_hinter=self._hinter),
ServiceRouterWorkerSyncPipeline(pipeline_hinter=self._hinter),
VolumePipeline(pipeline_hinter=self._hinter),
]:
self.register_pipeline(builtin_pipeline)
Expand Down
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have some time, consider covering the pipeline by unit tests. I think all (or at least most of) our existing pipelines and background tasks currently have good coverage

Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import asyncio
import uuid
from dataclasses import dataclass
from datetime import timedelta
from typing import Sequence

from sqlalchemy import delete, or_, select, true, update
from sqlalchemy.orm import joinedload, load_only, selectinload

from dstack._internal.core.models.runs import JobStatus, RunStatus
from dstack._internal.server.background.pipeline_tasks.base import (
Fetcher,
Heartbeater,
ItemUpdateMap,
Pipeline,
PipelineItem,
Worker,
log_lock_token_changed_after_processing,
log_lock_token_mismatch,
resolve_now_placeholders,
set_processed_update_map_fields,
set_unlock_update_map_fields,
)
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
InstanceModel,
JobModel,
ProjectModel,
RunModel,
ServiceRouterWorkerSyncModel,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
from dstack._internal.server.services.router_worker_sync import (
run_model_has_router_replica_group,
sync_router_workers_for_run_model,
)
from dstack._internal.server.utils import sentry_utils
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


@dataclass
class ServiceRouterWorkerSyncPipelineItem(PipelineItem):
run_id: uuid.UUID


class ServiceRouterWorkerSyncPipeline(Pipeline[ServiceRouterWorkerSyncPipelineItem]):
def __init__(
self,
workers_num: int = 8,
queue_lower_limit_factor: float = 0.5,
queue_upper_limit_factor: float = 2.0,
min_processing_interval: timedelta = timedelta(seconds=5),
lock_timeout: timedelta = timedelta(seconds=25),
heartbeat_trigger: timedelta = timedelta(seconds=10),
*,
pipeline_hinter: PipelineHinterProtocol,
) -> None:
super().__init__(
workers_num=workers_num,
queue_lower_limit_factor=queue_lower_limit_factor,
queue_upper_limit_factor=queue_upper_limit_factor,
min_processing_interval=min_processing_interval,
lock_timeout=lock_timeout,
heartbeat_trigger=heartbeat_trigger,
)
self.__heartbeater = Heartbeater[ServiceRouterWorkerSyncPipelineItem](
model_type=ServiceRouterWorkerSyncModel,
lock_timeout=self._lock_timeout,
heartbeat_trigger=self._heartbeat_trigger,
)
self.__fetcher = ServiceRouterWorkerSyncFetcher(
queue=self._queue,
queue_desired_minsize=self._queue_desired_minsize,
min_processing_interval=self._min_processing_interval,
lock_timeout=self._lock_timeout,
heartbeater=self.__heartbeater,
)
self.__workers = [
ServiceRouterWorkerSyncWorker(
queue=self._queue,
heartbeater=self.__heartbeater,
pipeline_hinter=pipeline_hinter,
)
for _ in range(self._workers_num)
]

@property
def hint_fetch_model_name(self) -> str:
return ServiceRouterWorkerSyncModel.__name__

@property
def _heartbeater(self) -> Heartbeater[ServiceRouterWorkerSyncPipelineItem]:
return self.__heartbeater

@property
def _fetcher(self) -> Fetcher[ServiceRouterWorkerSyncPipelineItem]:
return self.__fetcher

@property
def _workers(self) -> Sequence["ServiceRouterWorkerSyncWorker"]:
return self.__workers


class ServiceRouterWorkerSyncFetcher(Fetcher[ServiceRouterWorkerSyncPipelineItem]):
@sentry_utils.instrument_pipeline_task("ServiceRouterWorkerSyncFetcher.fetch")
async def fetch(self, limit: int) -> list[ServiceRouterWorkerSyncPipelineItem]:
sync_lock, _ = get_locker(get_db().dialect_name).get_lockset(
ServiceRouterWorkerSyncModel.__tablename__
)
async with sync_lock:
async with get_session_ctx() as session:
now = get_current_datetime()
res = await session.execute(
select(ServiceRouterWorkerSyncModel)
.join(RunModel, RunModel.id == ServiceRouterWorkerSyncModel.run_id)
.where(
RunModel.status == RunStatus.RUNNING,
or_(
ServiceRouterWorkerSyncModel.last_processed_at
<= now - self._min_processing_interval,
ServiceRouterWorkerSyncModel.last_processed_at
== ServiceRouterWorkerSyncModel.created_at,
),
or_(
ServiceRouterWorkerSyncModel.lock_expires_at.is_(None),
ServiceRouterWorkerSyncModel.lock_expires_at < now,
),
)
.order_by(ServiceRouterWorkerSyncModel.last_processed_at.asc())
.limit(limit)
.with_for_update(
skip_locked=True, key_share=True, of=ServiceRouterWorkerSyncModel
)
.options(
load_only(
ServiceRouterWorkerSyncModel.id,
ServiceRouterWorkerSyncModel.run_id,
ServiceRouterWorkerSyncModel.lock_token,
ServiceRouterWorkerSyncModel.lock_expires_at,
)
)
)
rows = list(res.scalars().all())
lock_expires_at = get_current_datetime() + self._lock_timeout
lock_token = uuid.uuid4()
items: list[ServiceRouterWorkerSyncPipelineItem] = []
for row in rows:
prev_lock_expired = row.lock_expires_at is not None
row.lock_expires_at = lock_expires_at
row.lock_token = lock_token
row.lock_owner = ServiceRouterWorkerSyncPipeline.__name__
items.append(
ServiceRouterWorkerSyncPipelineItem(
__tablename__=ServiceRouterWorkerSyncModel.__tablename__,
id=row.id,
lock_expires_at=lock_expires_at,
lock_token=lock_token,
prev_lock_expired=prev_lock_expired,
run_id=row.run_id,
)
)
await session.commit()
return items


class _SyncRowUpdateMap(ItemUpdateMap, total=False):
pass


class ServiceRouterWorkerSyncWorker(Worker[ServiceRouterWorkerSyncPipelineItem]):
def __init__(
self,
queue: asyncio.Queue[ServiceRouterWorkerSyncPipelineItem],
heartbeater: Heartbeater[ServiceRouterWorkerSyncPipelineItem],
pipeline_hinter: PipelineHinterProtocol,
) -> None:
super().__init__(
queue=queue,
heartbeater=heartbeater,
pipeline_hinter=pipeline_hinter,
)

@sentry_utils.instrument_pipeline_task("ServiceRouterWorkerSyncWorker.process")
async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
async with get_session_ctx() as session:
res = await session.execute(
select(ServiceRouterWorkerSyncModel)
.where(
ServiceRouterWorkerSyncModel.id == item.id,
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
)
.options(selectinload(ServiceRouterWorkerSyncModel.run))
)
sync_row = res.unique().scalar_one_or_none()
if sync_row is None:
log_lock_token_mismatch(logger, item)
return
run_model = sync_row.run
if (
run_model.deleted
or run_model.status.is_finished()
or run_model.status != RunStatus.RUNNING
or not run_model_has_router_replica_group(run_model)
):
await session.delete(sync_row)
await session.commit()
return

async with get_session_ctx() as session:
res = await session.execute(
select(RunModel)
.where(RunModel.id == item.run_id)
.options(
load_only(RunModel.id, RunModel.run_spec),
selectinload(
RunModel.jobs.and_(
JobModel.status == JobStatus.RUNNING,
JobModel.registered == true(),
)
)
.load_only(
JobModel.id,
JobModel.status,
JobModel.registered,
JobModel.job_spec_data,
JobModel.job_provisioning_data,
JobModel.job_runtime_data,
)
.options(
joinedload(JobModel.project).load_only(
ProjectModel.id, ProjectModel.ssh_private_key
),
joinedload(JobModel.instance)
.load_only(InstanceModel.id, InstanceModel.remote_connection_info)
.joinedload(InstanceModel.project)
.load_only(ProjectModel.id, ProjectModel.ssh_private_key),
),
)
)
Comment on lines +217 to +243
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is potentially a very inefficient select – a run can have thousands of job submissions. Select only the jobs that the processing needs, i.e. only the router replica job. Also every selectinload will be a separate query here – not sure if it's justified. joinedload may be a better suited for a one-to-one rel. Also, try to avoid loading all models's columns and use load_only to select only the necessary.

Copy link
Copy Markdown
Collaborator Author

@Bihan Bihan Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check if below proposed query addresses the concerns

  1. Avoid loading thousands of job submissions: no longer load RunModel.jobs unconditionally. The selectinload(RunModel.jobs.and_(...)) restricts the loaded jobs to only RUNNING + registered replicas, which are the only ones sync_router_workers_for_run_model() can use (router job selection and worker list building both ignore non‑running / unregistered jobs).

  2. selectinload is intentional: RunModel.jobs is a one‑to‑many collection; using joinedload would duplicate the RunModel row per job.

  3. joinedload for one‑to‑one/many‑to‑one: RunModel.project, JobModel.project, JobModel.instance, InstanceModel.project are loaded with joinedload because these are scalar relationships from from run,job and instance.

  4. Use load_only: This limits columns required by sync_router_workers_for_run_model(run_for_sync) and _get_service_replica_client(job_model)

res = await session.execute(
    select(RunModel)
    .where(RunModel.id == item.run_id)
    .options(
        load_only(RunModel.id, RunModel.run_spec),
        selectinload(
            RunModel.jobs.and_(
                JobModel.status == JobStatus.RUNNING,
                JobModel.registered == true(),
            )
        )
        .load_only(
            JobModel.id,
            JobModel.status,
            JobModel.registered,
            JobModel.job_spec_data,
            JobModel.job_provisioning_data,
            JobModel.job_runtime_data,
        )
        .options(
            joinedload(JobModel.project).load_only(ProjectModel.id, ProjectModel.ssh_private_key),
            joinedload(JobModel.instance)
            .load_only(InstanceModel.id, InstanceModel.remote_connection_info)
            .joinedload(InstanceModel.project)
            .load_only(ProjectModel.id, ProjectModel.ssh_private_key),
        ),
    )
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, at least at a glance

run_for_sync = res.unique().scalar_one_or_none()

if run_for_sync is None:
async with get_session_ctx() as session:
await session.execute(
delete(ServiceRouterWorkerSyncModel).where(
ServiceRouterWorkerSyncModel.id == item.id,
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
)
)
await session.commit()
return

await sync_router_workers_for_run_model(run_for_sync)

update_map: _SyncRowUpdateMap = {}
set_processed_update_map_fields(update_map)
set_unlock_update_map_fields(update_map)
async with get_session_ctx() as session:
now = get_current_datetime()
resolve_now_placeholders(update_map, now=now)
res2 = await session.execute(
update(ServiceRouterWorkerSyncModel)
.where(
ServiceRouterWorkerSyncModel.id == item.id,
ServiceRouterWorkerSyncModel.lock_token == item.lock_token,
)
.values(**update_map)
.returning(ServiceRouterWorkerSyncModel.id)
)
if not list(res2.scalars().all()):
log_lock_token_changed_after_processing(logger, item)
Loading
Loading