Add docs_per_step for dynamic microbatch accumulation#520
Open
jlamypoirier wants to merge 1 commit into
Open
Conversation
A schedule config field that replaces the static microbatch count with a
runtime document-count target. Matches DeepSpeed's
gradient_accumulation_passes semantics for RL: each microbatch holds one
rollout and the step boundary is set by total rollouts rather than a
fixed microbatch count.
- ScheduleConfig.docs_per_step — when >0, Trainer._prefetch_to_doc_target
fetches microbatches one at a time, all-reduces the per-microbatch doc
count, and stops once the global total reaches the target. The final
step total is broadcast to every microbatch so the loss normalization
stays consistent.
- Trainer._get_or_build_schedule(N) builds and caches a per-N Schedule
with _depth_first_override = N // breadth_first_micro_batches, reusing
the schedule machinery without touching the runner.
- Schedule._eff_{depth_first,sequential_micro_batches,num_inputs} expose
the effective values under an override.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
debc3dd to
5afa8c7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
A schedule config field that replaces the static microbatch count with a runtime document-count target. Each step accumulates microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global cumulative total reaches the target. Matches DeepSpeed's
gradient_accumulation_passessemantics for RL: each microbatch holds one rollout and the step boundary is set by total rollouts rather than a fixed microbatch count.ScheduleConfig.docs_per_step— when > 0,Trainer._prefetch_to_doc_targetdrives the dynamic accumulation. The final step total is broadcast back to every microbatch so the loss normalization denominator stays consistent.Trainer._get_or_build_schedule(N)builds and caches a per-NSchedulewith_depth_first_override = N // breadth_first_micro_batches, reusing existing schedule machinery without touching the runner.Schedule._eff_{depth_first,sequential_micro_batches,num_inputs}expose the effective values under an override.Off by default (
docs_per_step=0) — the disabled path takes the original static-schedule branch.Test plan
pytest tests/layers/test_docs_per_step.py— passesOriginally part of #502.