diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index f789869b3e7..ce00d1437e3 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -10,4 +10,6 @@ ### Bundles + * Jobs that use cluster policy default values for their cluster configuration now correctly update those defaults on every deployment ([#3255](https://github.com/databricks/cli/pull/3255)). + ### API Changes diff --git a/bundle/deploy/terraform/tfdyn/convert_job.go b/bundle/deploy/terraform/tfdyn/convert_job.go index bb2f8cd0fed..e38ea36d6b2 100644 --- a/bundle/deploy/terraform/tfdyn/convert_job.go +++ b/bundle/deploy/terraform/tfdyn/convert_job.go @@ -3,7 +3,9 @@ package tfdyn import ( "context" "fmt" + "slices" "sort" + "strings" "github.com/databricks/cli/bundle/internal/tf/schema" "github.com/databricks/cli/libs/dyn" @@ -12,6 +14,58 @@ import ( "github.com/databricks/databricks-sdk-go/service/jobs" ) +func patchApplyPolicyDefaultValues(_ dyn.Path, v dyn.Value) (dyn.Value, error) { + // If the field "apply_policy_default_values" is not set, do nothing. + if b, ok := v.Get("apply_policy_default_values").AsBool(); !ok || !b { + return v, nil + } + + // If the field "policy_id" is not set, do nothing. + if _, ok := v.Get("policy_id").AsString(); !ok { + return v, nil + } + + // The field "apply_policy_default_values" is set. + // We need to collect the list of fields that are set explicitly + // and pass it to Terraform. This enables Terraform to clear + // server-side defaults from the update request, which in turn + // allows the backend to re-apply the policy defaults. + // + // For more details, see: https://github.com/databricks/terraform-provider-databricks/pull/4834 + // + paths := dyn.CollectLeafPaths(v) + + // If any of the map or sequence fields are set, always include them entirely instead of traversing the them. + for _, field := range []string{ + "custom_tags", + "init_scripts", + "spark_conf", + "spark_env_vars", + "ssh_public_keys", + } { + if vv := v.Get(field); vv.IsValid() { + // Remove all paths that start with the field. + paths = slices.DeleteFunc(paths, func(p string) bool { + return strings.HasPrefix(p, field+".") || strings.HasPrefix(p, field+"[") + }) + // Add the field to the paths. + paths = append(paths, field) + } + } + + sort.Strings(paths) + valList := make([]dyn.Value, len(paths)) + for i, s := range paths { + valList[i] = dyn.V(s) + } + v, err := dyn.Set(v, "__apply_policy_default_values_allow_list", dyn.V(valList)) + if err != nil { + return dyn.InvalidValue, err + } + + return v, nil +} + func convertJobResource(ctx context.Context, vin dyn.Value) (dyn.Value, error) { // Normalize the input value to the underlying job schema. // This removes superfluous keys and adapts the input to the expected schema. @@ -101,6 +155,22 @@ func convertJobResource(ctx context.Context, vin dyn.Value) (dyn.Value, error) { log.Debugf(ctx, "job normalization diagnostic: %s", diag.Summary) } + // Apply __apply_policy_default_values_allow_list for tasks + vout, err = dyn.Map(vout, "task", dyn.Foreach(func(_ dyn.Path, v dyn.Value) (dyn.Value, error) { + return dyn.Map(v, "new_cluster", patchApplyPolicyDefaultValues) + })) + if err != nil { + return dyn.InvalidValue, err + } + + // Apply __apply_policy_default_values_allow_list for job clusters + vout, err = dyn.Map(vout, "job_cluster", dyn.Foreach(func(_ dyn.Path, v dyn.Value) (dyn.Value, error) { + return dyn.Map(v, "new_cluster", patchApplyPolicyDefaultValues) + })) + if err != nil { + return dyn.InvalidValue, err + } + return vout, err } diff --git a/bundle/deploy/terraform/tfdyn/convert_job_test.go b/bundle/deploy/terraform/tfdyn/convert_job_test.go index 8f4cfc2fa83..a7c506d592b 100644 --- a/bundle/deploy/terraform/tfdyn/convert_job_test.go +++ b/bundle/deploy/terraform/tfdyn/convert_job_test.go @@ -149,3 +149,140 @@ func TestConvertJob(t *testing.T) { }, }, out.Permissions["job_my_job"]) } + +func TestConvertJobApplyPolicyDefaultValues(t *testing.T) { + src := resources.Job{ + JobSettings: jobs.JobSettings{ + Name: "my job", + JobClusters: []jobs.JobCluster{ + { + JobClusterKey: "key", + NewCluster: compute.ClusterSpec{ + ApplyPolicyDefaultValues: true, + PolicyId: "policy_id", + GcpAttributes: &compute.GcpAttributes{ + Availability: "SPOT", + LocalSsdCount: 2, + }, + }, + }, + { + JobClusterKey: "key2", + NewCluster: compute.ClusterSpec{ + ApplyPolicyDefaultValues: true, + PolicyId: "policy_id2", + CustomTags: map[string]string{ + "key": "value", + }, + InitScripts: []compute.InitScriptInfo{ + { + Workspace: &compute.WorkspaceStorageInfo{ + Destination: "/Workspace/path/to/init_script1", + }, + }, + { + Workspace: &compute.WorkspaceStorageInfo{ + Destination: "/Workspace/path/to/init_script2", + }, + }, + }, + SparkConf: map[string]string{ + "key": "value", + }, + SparkEnvVars: map[string]string{ + "key": "value", + }, + SshPublicKeys: []string{ + "ssh-rsa 1234", + }, + }, + }, + { + JobClusterKey: "key3", + NewCluster: compute.ClusterSpec{ + ApplyPolicyDefaultValues: true, + SparkVersion: "16.4.x-scala2.12", + }, + }, + }, + }, + } + + vin, err := convert.FromTyped(src, dyn.NilValue) + require.NoError(t, err) + + ctx := context.Background() + out := schema.NewResources() + err = jobConverter{}.Convert(ctx, "my_job", vin, out) + require.NoError(t, err) + + assert.Equal(t, map[string]any{ + "name": "my job", + "job_cluster": []any{ + map[string]any{ + "job_cluster_key": "key", + "new_cluster": map[string]any{ + "__apply_policy_default_values_allow_list": []any{ + "apply_policy_default_values", + "gcp_attributes.availability", + "gcp_attributes.local_ssd_count", + "policy_id", + }, + "apply_policy_default_values": true, + "policy_id": "policy_id", + "gcp_attributes": map[string]any{ + "availability": "SPOT", + "local_ssd_count": int64(2), + }, + }, + }, + map[string]any{ + "job_cluster_key": "key2", + "new_cluster": map[string]any{ + "__apply_policy_default_values_allow_list": []any{ + "apply_policy_default_values", + "custom_tags", + "init_scripts", + "policy_id", + "spark_conf", + "spark_env_vars", + "ssh_public_keys", + }, + "apply_policy_default_values": true, + "policy_id": "policy_id2", + "custom_tags": map[string]any{ + "key": "value", + }, + "init_scripts": []any{ + map[string]any{ + "workspace": map[string]any{ + "destination": "/Workspace/path/to/init_script1", + }, + }, + map[string]any{ + "workspace": map[string]any{ + "destination": "/Workspace/path/to/init_script2", + }, + }, + }, + "spark_conf": map[string]any{ + "key": "value", + }, + "spark_env_vars": map[string]any{ + "key": "value", + }, + "ssh_public_keys": []any{ + "ssh-rsa 1234", + }, + }, + }, + map[string]any{ + "job_cluster_key": "key3", + "new_cluster": map[string]any{ + "apply_policy_default_values": true, + "spark_version": "16.4.x-scala2.12", + }, + }, + }, + }, out.Job["my_job"]) +} diff --git a/libs/dyn/walk.go b/libs/dyn/walk.go index 9d0a99356da..3f705caf960 100644 --- a/libs/dyn/walk.go +++ b/libs/dyn/walk.go @@ -66,3 +66,26 @@ func walk(v Value, p Path, fn func(p Path, v Value) (Value, error)) (Value, erro return v, nil } + +// CollectLeafPaths traverses the value and returns all paths (as dot notation strings) to leaf nodes (non-map, non-sequence). +// The return value is not ordered. +func CollectLeafPaths(v Value) []string { + var paths []string + + Walk(v, func(p Path, v Value) (Value, error) { //nolint:errcheck + if len(p) == 0 { + return v, nil + } + + switch v.Kind() { + case KindMap, KindSequence: + // Ignore internal nodes. + default: + paths = append(paths, p.String()) + } + + return v, nil + }) + + return paths +} diff --git a/libs/dyn/walk_test.go b/libs/dyn/walk_test.go index f7222b0a5b7..2ccd8695cbe 100644 --- a/libs/dyn/walk_test.go +++ b/libs/dyn/walk_test.go @@ -252,3 +252,18 @@ func TestWalkSequenceError(t *testing.T) { assert.Equal(t, MustPathFromString(".[1]"), tracker.calls[2].path) assert.Equal(t, V("bar"), tracker.calls[2].value) } + +func TestCollectLeafPaths(t *testing.T) { + v := V(map[string]Value{ + "a": V(1), + "b": V(map[string]Value{ + "c": V(2), + "d": V(map[string]Value{ + "e": V(3), + }), + }), + "f": V([]Value{V(4), V(5)}), + }) + paths := CollectLeafPaths(v) + assert.ElementsMatch(t, []string{"a", "b.c", "b.d.e", "f[0]", "f[1]"}, paths) +}