diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 00c1e05d0d0..17eaf0f8754 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -53,6 +53,16 @@ const ( minEnvironmentVersion = 4 ) +// acceleratorProvisioningNotice maps a GPU accelerator type to the upfront notice +// shown while its serverless compute is provisioned. Latencies vary widely by type +// (a single A10 is acquired in minutes; an 8xH100 node is ~10 min at P50 and can +// exceed 30 min at P90), so the wording is tuned per type to set expectations +// accurately. Types absent from this map fall back to a generic message. +var acceleratorProvisioningNotice = map[string]string{ + "GPU_1xA10": "Provisioning GPU_1xA10 compute. This usually takes a few minutes and may take longer when capacity is constrained.", + "GPU_8xH100": "Provisioning GPU_8xH100 compute. This typically takes around 10 minutes and can exceed 30 minutes when capacity is constrained.", +} + type ClientOptions struct { // Id of the cluster to connect to (for dedicated clusters) ClusterID string @@ -578,7 +588,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId)) // Return the run ID even on error so callers can fetch the run's failure details. - return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) + return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts) } func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { @@ -642,7 +652,7 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime()) defer sp.Close() if autoStart { - sp.Update("Ensuring the cluster is running...") + sp.Update("Waiting for compute to start...") err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID) if err != nil { return fmt.Errorf("failed to ensure that the cluster is running: %w", err) @@ -662,13 +672,27 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, // waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates. // Returns an error if the task fails to start or if polling times out. -func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error { +func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, opts ClientOptions) error { + waitingMessage := "Waiting for compute to start..." + if opts.Accelerator != "" { + // GPU capacity is acquired on demand and the wait varies a lot by accelerator + // type; without this notice users assume a long PENDING wait means the service + // is down. Latencies differ enough between types that a single message would be + // misleading, so phrase the heads-up per accelerator with a generic fallback. + notice, ok := acceleratorProvisioningNotice[opts.Accelerator] + if !ok { + notice = fmt.Sprintf("Provisioning %s compute. This can take several minutes and may take longer when capacity is constrained.", opts.Accelerator) + } + cmdio.LogString(ctx, notice) + waitingMessage = fmt.Sprintf("Provisioning %s compute...", opts.Accelerator) + } + sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime()) defer sp.Close() - sp.Update("Starting SSH server...") + sp.Update(waitingMessage) var prevState jobs.RunLifecycleStateV2State - _, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { + _, err := retries.Poll(ctx, opts.TaskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{ RunId: runID, }) @@ -697,7 +721,7 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, // Update spinner if state changed if currentState != prevState { - sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState)) + sp.Update(fmt.Sprintf("%s (task: %s)", waitingMessage, currentState)) prevState = currentState } diff --git a/experimental/ssh/internal/client/client_internal_test.go b/experimental/ssh/internal/client/client_internal_test.go index 35740d8cee7..a7347b08c74 100644 --- a/experimental/ssh/internal/client/client_internal_test.go +++ b/experimental/ssh/internal/client/client_internal_test.go @@ -110,7 +110,7 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) { api.EXPECT().GetRunOutput(mock.Anything, jobs.GetRunOutputRequest{RunId: 99}).Return( &jobs.RunOutput{}, nil) - err := waitForJobToStart(ctx, m.WorkspaceClient, 1, 30*time.Second) + err := waitForJobToStart(ctx, m.WorkspaceClient, 1, ClientOptions{TaskStartupTimeout: 30 * time.Second}) require.Error(t, err) assert.Contains(t, err.Error(), "ssh server bootstrap job failed") assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")