Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ bazel-src-cli
.DS_Store
samples
.amp
bin/
1 change: 1 addition & 0 deletions cmd/src/batch_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ func convertWorkspace(w batcheslib.WorkspacesExecutionInput) *executor.Task {
BatchChangeAttributes: &w.BatchChangeAttributes,
CachedStepResultFound: w.CachedStepResultFound,
CachedStepResult: w.CachedStepResult,
ModelProviderURL: w.ModelProviderURL,
}

return task
Expand Down
114 changes: 110 additions & 4 deletions internal/batches/executor/run_steps.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"time"

batcheslib "github.com/sourcegraph/sourcegraph/lib/batches"
"github.com/sourcegraph/sourcegraph/lib/batches/codingagent"
codingagenttypes "github.com/sourcegraph/sourcegraph/lib/batches/codingagent/types"
"github.com/sourcegraph/sourcegraph/lib/batches/execution"
"github.com/sourcegraph/sourcegraph/lib/batches/git"
"github.com/sourcegraph/sourcegraph/lib/batches/template"
Expand Down Expand Up @@ -272,12 +274,32 @@ func executeSingleStep(
return bytes.Buffer{}, bytes.Buffer{}, err
}

runScriptFile, runScript, cleanup, err := createRunScriptFile(ctx, opts.TempDir, step.Run, stepContext)
var (
runScriptFile string
runScript string
runScriptCleanup func()
)
if step.CodingAgent != nil {
if opts.Task.ModelProviderURL == "" {
err = errors.New("codingAgent step requires a model-provider URL")
opts.UI.StepPreparingFailed(stepIdx+1, err)
return bytes.Buffer{}, bytes.Buffer{}, err
}
runScript, err = codingagent.RenderRunCommand(step.CodingAgent, opts.Task.ModelProviderURL, stepContext)
if err != nil {
err = errors.Wrap(err, "rendering codingAgent step")
opts.UI.StepPreparingFailed(stepIdx+1, err)
return bytes.Buffer{}, bytes.Buffer{}, err
}
runScriptFile, runScriptCleanup, err = writeRunScriptFile(opts.TempDir, runScript)
} else {
runScriptFile, runScript, runScriptCleanup, err = createRunScriptFile(ctx, opts.TempDir, step.Run, stepContext)
}
if err != nil {
opts.UI.StepPreparingFailed(stepIdx+1, err)
return bytes.Buffer{}, bytes.Buffer{}, err
}
defer cleanup()
defer runScriptCleanup()

// Parse and render the step.Files.
filesToMount, cleanup, err := createFilesToMount(opts.TempDir, step, stepContext)
Expand All @@ -303,12 +325,16 @@ func executeSingleStep(
return bytes.Buffer{}, bytes.Buffer{}, err
}

if step.CodingAgent != nil {
forwardCodingAgentEnv(opts.GlobalEnv, env)
}

opts.UI.StepPreparingSuccess(stepIdx + 1)

// ----------
// EXECUTION
// ----------
opts.UI.StepStarted(stepIdx+1, runScript, env)
opts.UI.StepStarted(stepIdx+1, runScript, redactSensitiveEnv(env))

workspaceOpts, err := workspace.DockerRunOpts(ctx, workDir)
if err != nil {
Expand Down Expand Up @@ -394,7 +420,7 @@ func executeSingleStep(
}

opts.Logger.Logf("[Step %d] run: %q, container: %q", stepIdx+1, step.Run, step.Container)
opts.Logger.Logf("[Step %d] full command: %q", stepIdx+1, strings.Join(cmd.Args, " "))
opts.Logger.Logf("[Step %d] full command: %q", stepIdx+1, strings.Join(redactSensitiveArgs(cmd.Args), " "))

// Start the command.
t0 := time.Now()
Expand Down Expand Up @@ -573,6 +599,86 @@ func createRunScriptFile(ctx context.Context, tempDir string, stepRun string, st
return runScriptFile.Name(), runScript.String(), cleanup, nil
}

// forwardCodingAgentEnv copies the model-provider auth env vars
// (SRC_BATCHES_MODEL_PROVIDER_TOKEN, SRC_BATCHES_JOB_ID) from globalEnv
// into stepEnv so they reach the user container.
func forwardCodingAgentEnv(globalEnv []string, stepEnv map[string]string) {
for _, key := range []string{codingagenttypes.ModelProviderTokenEnvVar, codingagenttypes.JobIDEnvVar} {
for _, e := range globalEnv {
if v, ok := strings.CutPrefix(e, key+"="); ok {
stepEnv[key] = v
break
}
}
}
}

// sensitiveEnvKeys names env vars that get passed verbatim into the user
// container but must be scrubbed from UI sinks and log lines.
var sensitiveEnvKeys = map[string]struct{}{
codingagenttypes.ModelProviderTokenEnvVar: {},
}

const redactedPlaceholder = "REDACTED"

func redactSensitiveEnv(env map[string]string) map[string]string {
out := make(map[string]string, len(env))
for k, v := range env {
if _, sensitive := sensitiveEnvKeys[k]; sensitive && v != "" {
out[k] = redactedPlaceholder
} else {
out[k] = v
}
}
return out
}

// redactSensitiveArgs scrubs the value side of `-e KEY=VALUE` pairs whose
// KEY is sensitive, returning a copy of args suitable for logging.
func redactSensitiveArgs(args []string) []string {
out := make([]string, len(args))
copy(out, args)
for i := 0; i+1 < len(out); i++ {
if out[i] != "-e" {
continue
}
key, _, ok := strings.Cut(out[i+1], "=")
if !ok {
continue
}
if _, sensitive := sensitiveEnvKeys[key]; sensitive {
out[i+1] = key + "=" + redactedPlaceholder
}
}
return out
}

// writeRunScriptFile writes a pre-rendered run script verbatim to a temp
// file. Unlike createRunScriptFile it does NOT pass the content through
// template.RenderStepTemplate, so embedded `{{` sequences in a
// shell-quoted prompt are not re-parsed as templates.
func writeRunScriptFile(tempDir, script string) (string, func(), error) {
runScriptFile, err := os.CreateTemp(tempDir, "")
if err != nil {
return "", nil, errors.Wrap(err, "creating temporary file")
}
cleanup := func() { os.Remove(runScriptFile.Name()) }

if _, err := runScriptFile.WriteString(script); err != nil {
cleanup()
return "", nil, errors.Wrap(err, "writing to temporary file")
}
if err := runScriptFile.Close(); err != nil {
cleanup()
return "", nil, errors.Wrap(err, "closing temporary file")
}
if err := os.Chmod(runScriptFile.Name(), 0644); err != nil {
cleanup()
return "", nil, errors.Wrap(err, "setting permissions on the temporary file")
}
return runScriptFile.Name(), cleanup, nil
}

// createCidFile creates a temporary file that will contain the container ID
// when executing steps.
// It returns the location of the file and a function that cleans up the
Expand Down
112 changes: 112 additions & 0 deletions internal/batches/executor/run_steps_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package executor

import (
"slices"
"testing"

codingagenttypes "github.com/sourcegraph/sourcegraph/lib/batches/codingagent/types"
)

func TestRedactSensitiveEnv(t *testing.T) {
in := map[string]string{
codingagenttypes.ModelProviderTokenEnvVar: "tok-abc",
"PATH": "/bin",
}
out := redactSensitiveEnv(in)
if got := out[codingagenttypes.ModelProviderTokenEnvVar]; got != redactedPlaceholder {
t.Errorf("token: got %q want %q", got, redactedPlaceholder)
}
if got := out["PATH"]; got != "/bin" {
t.Errorf("PATH should not be redacted: got %q", got)
}
if in[codingagenttypes.ModelProviderTokenEnvVar] != "tok-abc" {
t.Errorf("input must not be mutated")
}
}

func TestRedactSensitiveArgs(t *testing.T) {
in := []string{
"docker", "run",
"-e", codingagenttypes.ModelProviderTokenEnvVar + "=tok-abc",
"-e", codingagenttypes.JobIDEnvVar + "=job-123",
"-e", "PATH=/bin",
"--", "image:tag", "/script",
}
out := redactSensitiveArgs(in)
if slices.Contains(out, codingagenttypes.ModelProviderTokenEnvVar+"=tok-abc") {
t.Errorf("token value still present in args: %v", out)
}
if !slices.Contains(out, codingagenttypes.ModelProviderTokenEnvVar+"="+redactedPlaceholder) {
t.Errorf("token not redacted in args: %v", out)
}
if !slices.Contains(out, codingagenttypes.JobIDEnvVar+"=job-123") {
t.Errorf("job id should pass through: %v", out)
}
}

func TestForwardCodingAgentEnv(t *testing.T) {
cases := []struct {
name string
globalEnv []string
stepEnv map[string]string
want map[string]string
}{
{
name: "forwards both vars",
globalEnv: []string{
"PATH=/bin",
codingagenttypes.ModelProviderTokenEnvVar + "=tok-abc",
codingagenttypes.JobIDEnvVar + "=job-123",
},
stepEnv: map[string]string{},
want: map[string]string{
codingagenttypes.ModelProviderTokenEnvVar: "tok-abc",
codingagenttypes.JobIDEnvVar: "job-123",
},
},
{
name: "forwards only what is set",
globalEnv: []string{
codingagenttypes.JobIDEnvVar + "=job-456",
},
stepEnv: map[string]string{},
want: map[string]string{
codingagenttypes.JobIDEnvVar: "job-456",
},
},
{
name: "preserves preexisting step env and overwrites on match",
globalEnv: []string{
codingagenttypes.ModelProviderTokenEnvVar + "=from-global",
},
stepEnv: map[string]string{
"OTHER": "x",
codingagenttypes.ModelProviderTokenEnvVar: "from-step",
},
want: map[string]string{
"OTHER": "x",
codingagenttypes.ModelProviderTokenEnvVar: "from-global",
},
},
{
name: "no-op when env not present",
globalEnv: []string{"PATH=/bin"},
stepEnv: map[string]string{},
want: map[string]string{},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
forwardCodingAgentEnv(tc.globalEnv, tc.stepEnv)
if len(tc.stepEnv) != len(tc.want) {
t.Fatalf("len mismatch: got %d want %d (got=%v want=%v)", len(tc.stepEnv), len(tc.want), tc.stepEnv, tc.want)
}
for k, v := range tc.want {
if got := tc.stepEnv[k]; got != v {
t.Errorf("env[%q]: got %q want %q", k, got, v)
}
}
})
}
}
3 changes: 3 additions & 0 deletions internal/batches/executor/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type Task struct {
// When this field is true, CachedStepResult is also populated.
CachedStepResultFound bool
CachedStepResult execution.AfterStepResult
// ModelProviderURL is the resolved proxy base URL for coding-agent
// steps; empty unless the spec contains at least one codingAgent step.
ModelProviderURL string
}

func (t *Task) ArchivePathToFetch() string {
Expand Down
33 changes: 26 additions & 7 deletions lib/batches/batch_spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,26 @@ func (oqor *OnQueryOrRepository) GetBranches() ([]string, error) {
}

type Step struct {
Run string `json:"run,omitempty" yaml:"run"`
Container string `json:"container,omitempty" yaml:"container"`
Env env.Environment `json:"env" yaml:"env"`
Files map[string]string `json:"files,omitempty" yaml:"files,omitempty"`
Outputs Outputs `json:"outputs,omitempty" yaml:"outputs,omitempty"`
Mount []Mount `json:"mount,omitempty" yaml:"mount,omitempty"`
If any `json:"if,omitempty" yaml:"if,omitempty"`
Run string `json:"run,omitempty" yaml:"run"`
CodingAgent *CodingAgentStep `json:"codingAgent,omitempty" yaml:"codingAgent,omitempty"`
Container string `json:"container,omitempty" yaml:"container"`
Image string `json:"image,omitempty" yaml:"image"`
Env env.Environment `json:"env" yaml:"env"`
Files map[string]string `json:"files,omitempty" yaml:"files,omitempty"`
Outputs Outputs `json:"outputs,omitempty" yaml:"outputs,omitempty"`
Mount []Mount `json:"mount,omitempty" yaml:"mount,omitempty"`
If any `json:"if,omitempty" yaml:"if,omitempty"`
}

type CodingAgentType string

const (
CodingAgentTypeCodex CodingAgentType = "codex"
)

type CodingAgentStep struct {
Type CodingAgentType `json:"type,omitempty" yaml:"type"`
Prompt string `json:"prompt,omitempty" yaml:"prompt"`
}

func (s *Step) IfCondition() string {
Expand Down Expand Up @@ -181,6 +194,12 @@ func parseBatchSpec(schema string, data []byte) (*BatchSpec, error) {
errs = errors.Append(errs, NewValidationError(errors.Newf("step %d files target path contains invalid characters", i+1)))
}
}
if step.CodingAgent != nil && step.Run != "" {
errs = errors.Append(errs, NewValidationError(errors.Newf("step %d: codingAgent and run cannot be combined in the same step", i+1)))
}
if step.CodingAgent != nil && step.Container == "" && step.Image == "" {
errs = errors.Append(errs, NewValidationError(errors.Newf("step %d: codingAgent step requires an image", i+1)))
}
}

return &spec, errs
Expand Down
Loading
Loading