Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ const (
minEnvironmentVersion = 4
)

// acceleratorProvisioningNotice maps a GPU accelerator type to the upfront notice
// shown while its serverless compute is provisioned. Latencies vary widely by type
// (a single A10 is acquired in minutes; an 8xH100 node is ~10 min at P50 and can
// exceed 30 min at P90), so the wording is tuned per type to set expectations
// accurately. Types absent from this map fall back to a generic message.
var acceleratorProvisioningNotice = map[string]string{
"GPU_1xA10": "Provisioning GPU_1xA10 compute. This usually takes a few minutes and may take longer when capacity is constrained.",
"GPU_8xH100": "Provisioning GPU_8xH100 compute. This typically takes around 10 minutes and can exceed 30 minutes when capacity is constrained.",
}

type ClientOptions struct {
// Id of the cluster to connect to (for dedicated clusters)
ClusterID string
Expand Down Expand Up @@ -578,7 +588,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId))

// Return the run ID even on error so callers can fetch the run's failure details.
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout)
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts)
}

func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error {
Expand Down Expand Up @@ -642,7 +652,7 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,
sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
defer sp.Close()
if autoStart {
sp.Update("Ensuring the cluster is running...")
sp.Update("Waiting for compute to start...")
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
if err != nil {
return fmt.Errorf("failed to ensure that the cluster is running: %w", err)
Expand All @@ -662,13 +672,27 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,

// waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates.
// Returns an error if the task fails to start or if polling times out.
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error {
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, opts ClientOptions) error {
waitingMessage := "Waiting for compute to start..."
if opts.Accelerator != "" {
// GPU capacity is acquired on demand and the wait varies a lot by accelerator
// type; without this notice users assume a long PENDING wait means the service
// is down. Latencies differ enough between types that a single message would be
// misleading, so phrase the heads-up per accelerator with a generic fallback.
notice, ok := acceleratorProvisioningNotice[opts.Accelerator]
if !ok {
notice = fmt.Sprintf("Provisioning %s compute. This can take several minutes and may take longer when capacity is constrained.", opts.Accelerator)
}
cmdio.LogString(ctx, notice)
waitingMessage = fmt.Sprintf("Provisioning %s compute...", opts.Accelerator)
}

sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
defer sp.Close()
sp.Update("Starting SSH server...")
sp.Update(waitingMessage)
var prevState jobs.RunLifecycleStateV2State

_, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
_, err := retries.Poll(ctx, opts.TaskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{
RunId: runID,
})
Expand Down Expand Up @@ -697,7 +721,7 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient,

// Update spinner if state changed
if currentState != prevState {
sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState))
sp.Update(fmt.Sprintf("%s (task: %s)", waitingMessage, currentState))
prevState = currentState
}

Expand Down
2 changes: 1 addition & 1 deletion experimental/ssh/internal/client/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) {
api.EXPECT().GetRunOutput(mock.Anything, jobs.GetRunOutputRequest{RunId: 99}).Return(
&jobs.RunOutput{}, nil)

err := waitForJobToStart(ctx, m.WorkspaceClient, 1, 30*time.Second)
err := waitForJobToStart(ctx, m.WorkspaceClient, 1, ClientOptions{TaskStartupTimeout: 30 * time.Second})
require.Error(t, err)
assert.Contains(t, err.Error(), "ssh server bootstrap job failed")
assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")
Expand Down
Loading