portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20089
portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20089vacu9708 wants to merge 2 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20089
Note: Links to docs will display an error until the docs builds have been completed.
|
This PR needs a
|
…tmax Problem: Softmax and log_softmax accumulated exp(x - max) in the tensor dtype. For BFloat16, the running sum saturates around 256 — adding 1.0 stops changing the total — so a uniform softmax over N=512 elements outputs ~1/256 instead of 1/512. Changes: Accumulate the exp-sum in float for Half/BFloat16 by threading an ACC type through the map-reduce calls. Loads and stores remain in the tensor dtype. Continues the fp32-accumulation work in pytorch#19117. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Problem: The fast-path and generic reduction loops in mean.out and sum.IntList_out accumulated the running sum in the tensor dtype. For BFloat16, the sum saturates around 256, so a mean over N=512 all-ones elements gives 0.5 instead of 1.0, and summing 512 all-ones elements gives 256 instead of 512. Changes: Accumulate in float for Half/BFloat16 by promoting the loop accumulator to ACC in both the fast path and the generic path. The final result is cast back to the tensor dtype on store. Continues the fp32-accumulation work in pytorch#19117. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
7777b92 to
98d2f39
Compare
|
I am opening the PR again with a more appropriate branch name |
Motivation
softmax, log_softmax, mean, and sum all accumulate their reduction in the input dtype. For BFloat16, that sum saturates around 256. Once it gets there, adding 1.0 rounds away and the total gets stuck. A uniform softmax over 512 elements in BFloat16 gives
~1/256per output instead of1/512.Why FP32 accumulation is needed
BFloat16 has the same exponent width as Float32, so it has a similar range. However, it has far fewer fraction bits, which makes its representable spacing much coarser as values grow.
BFloat16Float32, but coarse spacingFloat32For BFloat16, the gap between consecutive representable values (i.e, the smallest step size) increases at each power-of-two range:
[128, 256)1128, 129, 130, ..., 255[256, 512)2256, 258, 260, ..., 510As a result, once a BFloat16 running sum reaches
256, adding1.0no longer changes the value:256 + 1257256257is not representable and rounds back to256(according to IEEE 754; round-to-nearest-even)This directly affects all four ops for large inputs. For a softmax over 512 zeros, each
exp(0)contributes1.0, so the denominator should be512. If the BFloat16 accumulation gets stuck at256, the output becomes approximately1/256instead of the correct1/512.5125121/512512~256~1/256ATen accumulates reductions in float for Half/BFloat16 (via
acc_type). This PR does the same, following the pattern already established inop_grid_sampler_2d(#19117).Tests