From 7b4d1a6bc93a19438fb306450b21ad672109304d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 13:06:16 +0800 Subject: [PATCH 01/12] feat(aggregate): enhance avg functionality with shared helper and Spark integration - Added shared helper for average calculations in `avg.rs` with conversion to average state. - Exported the aggregate module in `groups_accumulator.rs`. - Updated Spark's average function to maintain state order and count type. - Added tests for common helper null/filter semantics and Spark null filter cases. --- .../src/aggregate/groups_accumulator.rs | 1 + .../src/aggregate/groups_accumulator/avg.rs | 110 ++++++++++++++++++ .../spark/src/function/aggregate/avg.rs | 30 +++-- 3 files changed, 133 insertions(+), 8 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index ad2a21bb4733c..936316e333b6b 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -19,6 +19,7 @@ //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] pub mod accumulate; +pub mod avg; pub mod bool_op; pub mod nulls; pub mod prim_op; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs new file mode 100644 index 0000000000000..47b7eea3cb23c --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared helpers for average group accumulator state handling. + +use arrow::array::{ArrowNumericType, BooleanArray, PrimitiveArray}; + +use super::nulls::{filtered_null_mask, set_nulls}; + +/// Converts an AVG input value array into nullable per-row state arrays. +/// +/// The returned arrays are `(sum_state, count_state)`. Callers keep control of +/// their aggregate-specific state field order when wrapping these arrays in the +/// final state vector. +/// +/// Rows with NULL input values, `false` filters, or NULL filters are marked NULL +/// in both output arrays so later merge steps can ignore them consistently. +pub fn convert_to_avg_state( + sums: PrimitiveArray, + count_value: CountType::Native, + opt_filter: Option<&BooleanArray>, +) -> (PrimitiveArray, PrimitiveArray) +where + SumType: ArrowNumericType + Send, + CountType: ArrowNumericType + Send, +{ + let counts = PrimitiveArray::::from_value(count_value, sums.len()); + let nulls = filtered_null_mask(opt_filter, &sums); + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + (sums, counts) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, BooleanArray, Float64Array}; + use arrow::datatypes::Float64Type; + + use super::convert_to_avg_state; + + #[test] + fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() { + let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); + + let (sums, counts) = convert_to_avg_state::< + Float64Type, + arrow::datatypes::Int64Type, + >(sums, 1, None); + + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(!counts.is_null(2)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + } + + #[test] + fn convert_to_avg_state_applies_filter_nulls_to_sum_and_count() { + let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); + let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); + + let (sums, counts) = convert_to_avg_state::< + Float64Type, + arrow::datatypes::Int64Type, + >(sums, 1, Some(&filter)); + + assert_eq!(sums.null_count(), 2); + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(sums.is_null(2)); + assert!(!sums.is_null(3)); + + assert_eq!(counts.null_count(), 2); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(counts.is_null(2)); + assert!(!counts.is_null(3)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + } + + #[test] + fn convert_to_avg_state_preserves_sum_data_type() { + let sums = Float64Array::from(vec![1.0, 2.0]) + .with_data_type(arrow::datatypes::DataType::Float64); + + let (sums, _counts) = convert_to_avg_state::< + Float64Type, + arrow::datatypes::Int64Type, + >(sums, 1, None); + + assert_eq!(sums.data_type(), &arrow::datatypes::DataType::Float64); + } +} diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 5f4d2c253a2dc..0fc541e1dbaef 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -32,9 +32,7 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, TypeSignatureClass, Volatility, }; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, -}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; use std::sync::Arc; /// AVG aggregate expression @@ -356,11 +354,7 @@ where .as_primitive::() .clone() .with_data_type(self.return_data_type.clone()); - let counts = Int64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); + let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); // [sum, count] - must match state() and merge_batch() Ok(vec![ @@ -453,6 +447,26 @@ mod tests { assert_eq!(counts.value(2), 1); } + #[test] + fn convert_to_state_with_null_filter() { + let acc = make_acc(); + let values: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let sums = state[0].as_primitive::(); + let counts = state[1].as_primitive::(); + + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + + assert_eq!(counts.value(0), 1); + assert!(counts.is_null(1)); + assert_eq!(counts.value(2), 1); + } + #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc(); From 484fe010a077bd45903b51fe87c2d8d71627ca54 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 13:10:22 +0800 Subject: [PATCH 02/12] feat(aggregate): refactor Avg function to use shared state and preserve data type integrity - Updated built-in Avg to utilize shared `convert_to_avg_state`. - Ensured the order of state is preserved as [count, sum]. - Maintained count type as UInt64. - Ensured sum data type consistency for Decimal and Duration. Added tests for: - Float64: validating count/sum order and null filter semantics. - Decimal128: checking sum type and input null semantics. - DurationNanosecond: verifying sum type and filter semantics. --- datafusion/functions-aggregate/src/average.rs | 123 ++++++++++++++++-- 1 file changed, 113 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddeb9b0870a16..08ed27f6fb330 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -44,9 +44,7 @@ use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, -}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; @@ -955,13 +953,7 @@ where .as_primitive::() .clone() .with_data_type(self.sum_data_type.clone()); - let counts = UInt64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - - // set nulls on the arrays - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); + let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) } @@ -974,3 +966,114 @@ where self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::{Decimal128Type, DurationNanosecondType}; + + fn float64_acc() -> AvgGroupsAccumulator Result> + { + AvgGroupsAccumulator::::new( + &DataType::Float64, + &DataType::Float64, + |sum, count| Ok(sum / count as f64), + ) + } + + fn decimal128_acc() + -> AvgGroupsAccumulator Result> { + AvgGroupsAccumulator::::new( + &DataType::Decimal128(10, 2), + &DataType::Decimal128(14, 6), + |sum, _count| Ok(sum), + ) + } + + fn duration_acc() + -> AvgGroupsAccumulator Result> + { + AvgGroupsAccumulator::::new( + &DataType::Duration(TimeUnit::Nanosecond), + &DataType::Duration(TimeUnit::Nanosecond), + |sum, count| Ok(sum / count as i64), + ) + } + + #[test] + fn float64_convert_to_state_uses_count_sum_order_and_null_filter() { + let acc = float64_acc(); + let values: Vec = vec![Arc::new(Float64Array::from(vec![ + Some(10.0), + Some(20.0), + None, + Some(40.0), + ]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let counts = state[0].as_primitive::(); + let sums = state[1].as_primitive::(); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(counts.is_null(2)); + assert!(counts.is_null(3)); + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(sums.is_null(2)); + assert!(sums.is_null(3)); + } + + #[test] + fn decimal_convert_to_state_preserves_sum_type_and_nulls() { + let acc = decimal128_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(100_i128), + None, + Some(300_i128), + ]) + .with_data_type(DataType::Decimal128(10, 2)), + )]; + + let state = acc.convert_to_state(&values, None).unwrap(); + + let counts = state[0].as_primitive::(); + let sums = state[1].as_primitive::(); + assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(!counts.is_null(2)); + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + } + + #[test] + fn duration_convert_to_state_preserves_sum_type_and_applies_filter() { + let acc = duration_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(10_i64), + Some(20_i64), + ]) + .with_data_type(DataType::Duration(TimeUnit::Nanosecond)), + )]; + let filter = BooleanArray::from(vec![Some(false), Some(true)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let counts = state[0].as_primitive::(); + let sums = state[1].as_primitive::(); + assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(counts.values().as_ref(), &[1, 1]); + assert!(counts.is_null(0)); + assert!(!counts.is_null(1)); + assert!(sums.is_null(0)); + assert!(!sums.is_null(1)); + } +} From 13edd95754d8936932ed19b4d936bd1dc56968c9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 13:18:17 +0800 Subject: [PATCH 03/12] feat(tests): enhance test helpers and reduce redundancy - Added assert_validity test helpers for improved validation in tests - Reduced repeated null assertions to streamline code - Shortened common helper tests using local imports and type aliasing - Introduced built-in Avg avg_state test helper - Added comment for decimal test closure to clarify the unused avg_fn by convert_to_state --- .../src/aggregate/groups_accumulator/avg.rs | 54 ++++++++----------- datafusion/functions-aggregate/src/average.rs | 51 +++++++++--------- .../spark/src/function/aggregate/avg.rs | 25 ++++----- 3 files changed, 63 insertions(+), 67 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs index 47b7eea3cb23c..cc6f889082d1d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs @@ -49,25 +49,28 @@ where #[cfg(test)] mod tests { use arrow::array::{Array, BooleanArray, Float64Array}; - use arrow::datatypes::Float64Type; + use arrow::datatypes::{DataType, Float64Type, Int64Type}; use super::convert_to_avg_state; + type CountType = Int64Type; + + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + #[test] fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() { let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); - let (sums, counts) = convert_to_avg_state::< - Float64Type, - arrow::datatypes::Int64Type, - >(sums, 1, None); - - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(!counts.is_null(2)); + let (sums, counts) = + convert_to_avg_state::(sums, 1, None); + + assert_validity(&sums, &[true, false, true]); + assert_validity(&counts, &[true, false, true]); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); } @@ -76,35 +79,24 @@ mod tests { let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); - let (sums, counts) = convert_to_avg_state::< - Float64Type, - arrow::datatypes::Int64Type, - >(sums, 1, Some(&filter)); + let (sums, counts) = + convert_to_avg_state::(sums, 1, Some(&filter)); assert_eq!(sums.null_count(), 2); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(sums.is_null(2)); - assert!(!sums.is_null(3)); + assert_validity(&sums, &[true, false, false, true]); assert_eq!(counts.null_count(), 2); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(counts.is_null(2)); - assert!(!counts.is_null(3)); + assert_validity(&counts, &[true, false, false, true]); assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); } #[test] fn convert_to_avg_state_preserves_sum_data_type() { - let sums = Float64Array::from(vec![1.0, 2.0]) - .with_data_type(arrow::datatypes::DataType::Float64); + let sums = Float64Array::from(vec![1.0, 2.0]).with_data_type(DataType::Float64); - let (sums, _counts) = convert_to_avg_state::< - Float64Type, - arrow::datatypes::Int64Type, - >(sums, 1, None); + let (sums, _counts) = + convert_to_avg_state::(sums, 1, None); - assert_eq!(sums.data_type(), &arrow::datatypes::DataType::Float64); + assert_eq!(sums.data_type(), &DataType::Float64); } } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 08ed27f6fb330..a5fc4ec938654 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -973,6 +973,23 @@ mod tests { use arrow::array::{Array, Float64Array}; use arrow::datatypes::{Decimal128Type, DurationNanosecondType}; + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + + fn avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + fn float64_acc() -> AvgGroupsAccumulator Result> { AvgGroupsAccumulator::::new( @@ -987,6 +1004,7 @@ mod tests { AvgGroupsAccumulator::::new( &DataType::Decimal128(10, 2), &DataType::Decimal128(14, 6), + // convert_to_state does not evaluate averages, so avg_fn is unused here. |sum, _count| Ok(sum), ) } @@ -1014,17 +1032,10 @@ mod tests { let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let counts = state[0].as_primitive::(); - let sums = state[1].as_primitive::(); + let (counts, sums) = avg_state::(&state); assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(counts.is_null(2)); - assert!(counts.is_null(3)); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(sums.is_null(2)); - assert!(sums.is_null(3)); + assert_validity(counts, &[true, false, false, false]); + assert_validity(sums, &[true, false, false, false]); } #[test] @@ -1041,16 +1052,11 @@ mod tests { let state = acc.convert_to_state(&values, None).unwrap(); - let counts = state[0].as_primitive::(); - let sums = state[1].as_primitive::(); + let (counts, sums) = avg_state::(&state); assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2)); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(!counts.is_null(2)); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(counts, &[true, false, true]); + assert_validity(sums, &[true, false, true]); } #[test] @@ -1067,13 +1073,10 @@ mod tests { let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let counts = state[0].as_primitive::(); - let sums = state[1].as_primitive::(); + let (counts, sums) = avg_state::(&state); assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond)); assert_eq!(counts.values().as_ref(), &[1, 1]); - assert!(counts.is_null(0)); - assert!(!counts.is_null(1)); - assert!(sums.is_null(0)); - assert!(!sums.is_null(1)); + assert_validity(counts, &[false, true]); + assert_validity(sums, &[false, true]); } } diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 0fc541e1dbaef..90fcc71783eec 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -383,6 +383,13 @@ mod tests { }) } + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + #[test] fn supports_convert_to_state() { assert!(make_acc().supports_convert_to_state()); @@ -418,12 +425,10 @@ mod tests { let sums = state[0].as_primitive::(); let counts = state[1].as_primitive::(); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } @@ -438,12 +443,10 @@ mod tests { let sums = state[0].as_primitive::(); let counts = state[1].as_primitive::(); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } @@ -458,12 +461,10 @@ mod tests { let sums = state[0].as_primitive::(); let counts = state[1].as_primitive::(); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } From 7e8b4fe76f4a3c07fccd7814fe412eb543f5d02c Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 14:31:04 +0800 Subject: [PATCH 04/12] feat(datafusion/spark): enhance Spark Avg merge_batch to honor opt_filter - Implemented logic to skip false and NULL values in merge_batch. - Maintained skipping of null converted state rows. - Added regression test: merge_batch_applies_filter. - Introduced spark_avg_state test helper for better testing. - Refactored code to eliminate repeated state[0]/state[1] decode boilerplate. --- .../spark/src/function/aggregate/avg.rs | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 90fcc71783eec..d288c380a3cf6 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -287,7 +287,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - _opt_filter: Option<&BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); @@ -300,8 +300,12 @@ where for (idx, &group_index) in group_indices.iter().enumerate() { // Skip null state entries emitted by convert_to_state for - // filtered / null input rows. - if partial_counts.is_null(idx) || partial_sums.is_null(idx) { + // filtered / null input rows, and rows filtered during merge. + if partial_counts.is_null(idx) + || partial_sums.is_null(idx) + || opt_filter + .is_some_and(|filter| filter.is_null(idx) || !filter.value(idx)) + { continue; } self.counts[group_index] += partial_counts.value(idx); @@ -390,6 +394,16 @@ mod tests { } } + fn spark_avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + #[test] fn supports_convert_to_state() { assert!(make_acc().supports_convert_to_state()); @@ -402,9 +416,7 @@ mod tests { vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; let state = acc.convert_to_state(&values, None).unwrap(); - assert_eq!(state.len(), 2); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_eq!(sums.values().as_ref(), &[1.0, 2.0, 3.0]); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); @@ -422,8 +434,7 @@ mod tests { ]))]; let state = acc.convert_to_state(&values, None).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_validity(sums, &[true, false, true]); @@ -440,8 +451,7 @@ mod tests { let filter = BooleanArray::from(vec![true, false, true]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_validity(sums, &[true, false, true]); @@ -458,8 +468,7 @@ mod tests { let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_validity(sums, &[true, false, true]); @@ -468,6 +477,22 @@ mod tests { assert_eq!(counts.value(2), 1); } + #[test] + fn merge_batch_applies_filter() { + let mut acc = make_acc(); + let input: Vec = + vec![Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0]))]; + let state = acc.convert_to_state(&input, None).unwrap(); + let filter = BooleanArray::from(vec![Some(true), Some(false), None]); + + acc.merge_batch(&state, &[0, 0, 0], Some(&filter), 1) + .unwrap(); + + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 10.0); + } + #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc(); From 93a14d051183d7f1dd23aef5847fd2d3b06358c8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 16:24:28 +0800 Subject: [PATCH 05/12] feat(aggregate): refactor AVG functionality and integrate Spark with built-in AVG - Removed the common AVG helper and deleted avg.rs - Folded Spark's grouped AVG into the built-in AVG implementation - Generalized AvgGroupsAccumulator to support both count types and state orders - Updated built-in AVG to use UInt64Type with count and sum; Spark now uses Int64Type with sum and count - Updated Spark to call the new spark_avg_groups_accumulator function - Removed duplication of Spark-local grouped accumulator --- .../src/aggregate/groups_accumulator.rs | 1 - .../src/aggregate/groups_accumulator/avg.rs | 102 ------- datafusion/functions-aggregate/src/average.rs | 275 ++++++++++++------ .../spark/src/function/aggregate/avg.rs | 200 +------------ 4 files changed, 194 insertions(+), 384 deletions(-) delete mode 100644 datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 936316e333b6b..ad2a21bb4733c 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -19,7 +19,6 @@ //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] pub mod accumulate; -pub mod avg; pub mod bool_op; pub mod nulls; pub mod prim_op; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs deleted file mode 100644 index cc6f889082d1d..0000000000000 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs +++ /dev/null @@ -1,102 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Shared helpers for average group accumulator state handling. - -use arrow::array::{ArrowNumericType, BooleanArray, PrimitiveArray}; - -use super::nulls::{filtered_null_mask, set_nulls}; - -/// Converts an AVG input value array into nullable per-row state arrays. -/// -/// The returned arrays are `(sum_state, count_state)`. Callers keep control of -/// their aggregate-specific state field order when wrapping these arrays in the -/// final state vector. -/// -/// Rows with NULL input values, `false` filters, or NULL filters are marked NULL -/// in both output arrays so later merge steps can ignore them consistently. -pub fn convert_to_avg_state( - sums: PrimitiveArray, - count_value: CountType::Native, - opt_filter: Option<&BooleanArray>, -) -> (PrimitiveArray, PrimitiveArray) -where - SumType: ArrowNumericType + Send, - CountType: ArrowNumericType + Send, -{ - let counts = PrimitiveArray::::from_value(count_value, sums.len()); - let nulls = filtered_null_mask(opt_filter, &sums); - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); - - (sums, counts) -} - -#[cfg(test)] -mod tests { - use arrow::array::{Array, BooleanArray, Float64Array}; - use arrow::datatypes::{DataType, Float64Type, Int64Type}; - - use super::convert_to_avg_state; - - type CountType = Int64Type; - - fn assert_validity(array: &dyn Array, expected: &[bool]) { - assert_eq!(array.len(), expected.len()); - for (idx, expected_valid) in expected.iter().copied().enumerate() { - assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); - } - } - - #[test] - fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() { - let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); - - let (sums, counts) = - convert_to_avg_state::(sums, 1, None); - - assert_validity(&sums, &[true, false, true]); - assert_validity(&counts, &[true, false, true]); - assert_eq!(counts.values().as_ref(), &[1, 1, 1]); - } - - #[test] - fn convert_to_avg_state_applies_filter_nulls_to_sum_and_count() { - let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); - let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); - - let (sums, counts) = - convert_to_avg_state::(sums, 1, Some(&filter)); - - assert_eq!(sums.null_count(), 2); - assert_validity(&sums, &[true, false, false, true]); - - assert_eq!(counts.null_count(), 2); - assert_validity(&counts, &[true, false, false, true]); - assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); - } - - #[test] - fn convert_to_avg_state_preserves_sum_data_type() { - let sums = Float64Array::from(vec![1.0, 2.0]).with_data_type(DataType::Float64); - - let (sums, _counts) = - convert_to_avg_state::(sums, 1, None); - - assert_eq!(sums.data_type(), &DataType::Float64); - } -} diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index a5fc4ec938654..afe6ceeb1d4eb 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,7 +19,7 @@ use arrow::array::{ Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray, - BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + BooleanArray, PrimitiveArray, PrimitiveBuilder, }; use arrow::compute::sum; @@ -29,7 +29,8 @@ use arrow::datatypes::{ DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256, + DurationSecondType, Field, FieldRef, Float64Type, Int64Type, TimeUnit, UInt64Type, + i256, }; use datafusion_common::types::{NativeType, logical_float64}; use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; @@ -44,7 +45,9 @@ use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, +}; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; @@ -352,13 +355,16 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type match (data_type, args.return_field.data_type()) { - (Float64, Float64) => { - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, - args.return_field.data_type(), - |sum: f64, count: u64| Ok(sum / count as f64), - ))) - } + (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::< + Float64Type, + UInt64Type, + _, + true, + >::new( + data_type, + args.return_field.data_type(), + |sum: f64, count: u64| Ok(sum / count as f64), + ))), ( Decimal32(_sum_precision, sum_scale), Decimal32(target_precision, target_scale), @@ -372,10 +378,13 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, - args.return_field.data_type(), - avg_fn, + Ok(Box::new(AvgGroupsAccumulator::< + Decimal32Type, + UInt64Type, + _, + true, + >::new( + data_type, args.return_field.data_type(), avg_fn ))) } ( @@ -391,10 +400,13 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, - args.return_field.data_type(), - avg_fn, + Ok(Box::new(AvgGroupsAccumulator::< + Decimal64Type, + UInt64Type, + _, + true, + >::new( + data_type, args.return_field.data_type(), avg_fn ))) } ( @@ -410,10 +422,13 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, - args.return_field.data_type(), - avg_fn, + Ok(Box::new(AvgGroupsAccumulator::< + Decimal128Type, + UInt64Type, + _, + true, + >::new( + data_type, args.return_field.data_type(), avg_fn ))) } @@ -431,10 +446,13 @@ impl AggregateUDFImpl for Avg { decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) }; - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, - args.return_field.data_type(), - avg_fn, + Ok(Box::new(AvgGroupsAccumulator::< + Decimal256Type, + UInt64Type, + _, + true, + >::new( + data_type, args.return_field.data_type(), avg_fn ))) } @@ -442,38 +460,46 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); match time_unit { - TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< - DurationSecondType, - _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), - TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< - DurationMillisecondType, - _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), - TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< - DurationMicrosecondType, - _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), - TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< - DurationNanosecondType, - _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), + TimeUnit::Second => { + Ok(Box::new(AvgGroupsAccumulator::< + DurationSecondType, + UInt64Type, + _, + true, + >::new( + data_type, args.return_type(), avg_fn + ))) + } + TimeUnit::Millisecond => { + Ok(Box::new(AvgGroupsAccumulator::< + DurationMillisecondType, + UInt64Type, + _, + true, + >::new( + data_type, args.return_type(), avg_fn + ))) + } + TimeUnit::Microsecond => { + Ok(Box::new(AvgGroupsAccumulator::< + DurationMicrosecondType, + UInt64Type, + _, + true, + >::new( + data_type, args.return_type(), avg_fn + ))) + } + TimeUnit::Nanosecond => { + Ok(Box::new(AvgGroupsAccumulator::< + DurationNanosecondType, + UInt64Type, + _, + true, + >::new( + data_type, args.return_type(), avg_fn + ))) + } } } @@ -762,12 +788,17 @@ impl Accumulator for DurationAvgAccumulator { /// Stores values as native types, and does overflow checking /// /// F: Function that calculates the average value from a sum of -/// T::Native and a total count +/// T::Native and a total count. +/// +/// `COUNT_FIRST` controls the state field order: +/// * `true`: `[count, sum]` for built-in AVG +/// * `false`: `[sum, count]` for Spark AVG #[derive(Debug)] -struct AvgGroupsAccumulator +struct AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + CountType: ArrowPrimitiveType + Send, + F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { /// The type of the internal sum sum_data_type: DataType, @@ -775,8 +806,8 @@ where /// The type of the returned sum return_data_type: DataType, - /// Count per group (use u64 to make UInt64Array) - counts: Vec, + /// Count per group + counts: Vec, /// Sums per group, stored as the native type sums: Vec, @@ -788,10 +819,12 @@ where avg_fn: F, } -impl AvgGroupsAccumulator +impl + AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + CountType: ArrowPrimitiveType + Send, + F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { debug!( @@ -808,12 +841,25 @@ where avg_fn, } } + + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec { + if COUNT_FIRST { + vec![Arc::new(counts) as ArrayRef, Arc::new(sums) as ArrayRef] + } else { + vec![Arc::new(sums) as ArrayRef, Arc::new(counts) as ArrayRef] + } + } } -impl GroupsAccumulator for AvgGroupsAccumulator +impl GroupsAccumulator + for AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + CountType: ArrowPrimitiveType + Send, + F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { fn update_batch( &mut self, @@ -826,7 +872,8 @@ where let values = values[0].as_primitive::(); // increment counts, update sums - self.counts.resize(total_num_groups, 0); + self.counts + .resize(total_num_groups, CountType::Native::usize_as(0)); self.sums.resize(total_num_groups, T::default_value()); self.null_state.accumulate( group_indices, @@ -838,7 +885,8 @@ where let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); - self.counts[group_index] += 1; + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count = count.add_wrapping(CountType::Native::usize_as(1)); }, ); @@ -890,16 +938,13 @@ where let nulls = self.null_state.build(emit_to); let counts = emit_to.take_needed(&mut self.counts); - let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy + let counts = PrimitiveArray::::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums); let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy .with_data_type(self.sum_data_type.clone()); - Ok(vec![ - Arc::new(counts) as ArrayRef, - Arc::new(sums) as ArrayRef, - ]) + Ok(Self::state_arrays(sums, counts)) } fn merge_batch( @@ -910,11 +955,21 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - let partial_sums = values[1].as_primitive::(); + let (partial_counts, partial_sums) = if COUNT_FIRST { + ( + values[0].as_primitive::(), + values[1].as_primitive::(), + ) + } else { + ( + values[1].as_primitive::(), + values[0].as_primitive::(), + ) + }; + // update counts with partial counts - self.counts.resize(total_num_groups, 0); + self.counts + .resize(total_num_groups, CountType::Native::usize_as(0)); self.null_state.accumulate( group_indices, partial_counts, @@ -923,7 +978,7 @@ where |group_index, partial_count| { // SAFETY: group_index is guaranteed to be in bounds let count = unsafe { self.counts.get_unchecked_mut(group_index) }; - *count += partial_count; + *count = count.add_wrapping(partial_count); }, ); @@ -953,9 +1008,16 @@ where .as_primitive::() .clone() .with_data_type(self.sum_data_type.clone()); - let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); + let counts = PrimitiveArray::::from_value( + CountType::Native::usize_as(1), + sums.len(), + ); - Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) + let nulls = filtered_null_mask(opt_filter, &sums); + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + Ok(Self::state_arrays(sums, counts)) } fn supports_convert_to_state(&self) -> bool { @@ -963,10 +1025,28 @@ where } fn size(&self) -> usize { - self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() + self.counts.capacity() * size_of::() + + self.sums.capacity() * size_of::() + + self.null_state.size() } } +/// Creates the Spark AVG grouped accumulator using the built-in AVG state +/// handling implementation while preserving Spark's `[sum, count]` state order +/// and signed count type. +#[doc(hidden)] +pub fn spark_avg_groups_accumulator( + return_data_type: &DataType, +) -> Box { + Box::new( + AvgGroupsAccumulator::::new( + return_data_type, + return_data_type, + |sum: f64, count: i64| Ok(sum / count as f64), + ), + ) +} + #[cfg(test)] mod tests { use super::*; @@ -990,18 +1070,26 @@ mod tests { ) } - fn float64_acc() -> AvgGroupsAccumulator Result> - { - AvgGroupsAccumulator::::new( + fn float64_acc() -> AvgGroupsAccumulator< + Float64Type, + UInt64Type, + impl Fn(f64, u64) -> Result, + true, + > { + AvgGroupsAccumulator::::new( &DataType::Float64, &DataType::Float64, |sum, count| Ok(sum / count as f64), ) } - fn decimal128_acc() - -> AvgGroupsAccumulator Result> { - AvgGroupsAccumulator::::new( + fn decimal128_acc() -> AvgGroupsAccumulator< + Decimal128Type, + UInt64Type, + impl Fn(i128, u64) -> Result, + true, + > { + AvgGroupsAccumulator::::new( &DataType::Decimal128(10, 2), &DataType::Decimal128(14, 6), // convert_to_state does not evaluate averages, so avg_fn is unused here. @@ -1009,10 +1097,13 @@ mod tests { ) } - fn duration_acc() - -> AvgGroupsAccumulator Result> - { - AvgGroupsAccumulator::::new( + fn duration_acc() -> AvgGroupsAccumulator< + DurationNanosecondType, + UInt64Type, + impl Fn(i64, u64) -> Result, + true, + > { + AvgGroupsAccumulator::::new( &DataType::Duration(TimeUnit::Nanosecond), &DataType::Duration(TimeUnit::Nanosecond), |sum, count| Ok(sum / count as i64), diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index d288c380a3cf6..70aa42968fa82 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -16,9 +16,7 @@ // under the License. use arrow::array::{ - Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, BooleanArray, Int64Array, - PrimitiveArray, - builder::PrimitiveBuilder, + Array, ArrayRef, cast::AsArray, types::{Float64Type, Int64Type}, }; @@ -29,10 +27,11 @@ use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF, - Signature, TypeSignatureClass, Volatility, + Accumulator, AggregateUDFImpl, Coercion, GroupsAccumulator, ReversedUDAF, Signature, + TypeSignatureClass, Volatility, }; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; +use datafusion_functions_aggregate::average::spark_avg_groups_accumulator; +use std::mem::size_of_val; use std::sync::Arc; /// AVG aggregate expression @@ -128,10 +127,7 @@ impl AggregateUDFImpl for SparkAvg { // instantiate specialized accumulator based for the type match (&data_type, args.return_type()) { (DataType::Float64, DataType::Float64) => { - Ok(Box::new(AvgGroupsAccumulator::::new( - args.return_field.data_type(), - |sum: f64, count: i64| Ok(sum / count as f64), - ))) + Ok(spark_avg_groups_accumulator(args.return_field.data_type())) } (dt, return_type) => { not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})") @@ -202,189 +198,15 @@ impl Accumulator for AvgAccumulator { } } -/// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types, and does overflow checking -/// -/// F: Function that calculates the average value from a sum of -/// T::Native and a total count -#[derive(Debug)] -struct AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send + 'static, -{ - /// The type of the returned average - return_data_type: DataType, - - /// Count per group (use i64 to make Int64Array) - counts: Vec, - - /// Sums per group, stored as the native type - sums: Vec, - - /// Function that computes the final average (value / count) - avg_fn: F, -} - -impl AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send + 'static, -{ - pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { - Self { - return_data_type: return_data_type.clone(), - counts: vec![], - sums: vec![], - avg_fn, - } - } -} - -impl GroupsAccumulator for AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send + 'static, -{ - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); - let data = values.values(); - - // increment counts, update sums - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); - - let iter = group_indices.iter().zip(data.iter()); - if values.null_count() == 0 { - for (&group_index, &value) in iter { - let sum = &mut self.sums[group_index]; - *sum = (*sum).add_wrapping(value); - self.counts[group_index] += 1; - } - } else { - for (idx, (&group_index, &value)) in iter.enumerate() { - if values.is_null(idx) { - continue; - } - let sum = &mut self.sums[group_index]; - *sum = (*sum).add_wrapping(value); - - self.counts[group_index] += 1; - } - } - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is partial sums, second is counts - let partial_sums = values[0].as_primitive::(); - let partial_counts = values[1].as_primitive::(); - - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); - - for (idx, &group_index) in group_indices.iter().enumerate() { - // Skip null state entries emitted by convert_to_state for - // filtered / null input rows, and rows filtered during merge. - if partial_counts.is_null(idx) - || partial_sums.is_null(idx) - || opt_filter - .is_some_and(|filter| filter.is_null(idx) || !filter.value(idx)) - { - continue; - } - self.counts[group_index] += partial_counts.value(idx); - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(partial_sums.value(idx)); - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - let sums = emit_to.take_needed(&mut self.sums); - let mut builder = PrimitiveBuilder::::with_capacity(sums.len()); - let iter = sums.into_iter().zip(counts); - - for (sum, count) in iter { - if count != 0 { - builder.append_value((self.avg_fn)(sum, count)?) - } else { - builder.append_null(); - } - } - let array: PrimitiveArray = builder.finish(); - - Ok(Arc::new(array)) - } - - // return arrays for sums and counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts = Int64Array::new(counts.into(), None); - - let sums = emit_to.take_needed(&mut self.sums); - let sums = PrimitiveArray::::new(sums.into(), None) - .with_data_type(self.return_data_type.clone()); - - Ok(vec![ - Arc::new(sums) as ArrayRef, - Arc::new(counts) as ArrayRef, - ]) - } - - fn convert_to_state( - &self, - values: &[ArrayRef], - opt_filter: Option<&BooleanArray>, - ) -> Result> { - let sums = values[0] - .as_primitive::() - .clone() - .with_data_type(self.return_data_type.clone()); - let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); - - // [sum, count] - must match state() and merge_batch() - Ok(vec![ - Arc::new(sums) as ArrayRef, - Arc::new(counts) as ArrayRef, - ]) - } - - fn supports_convert_to_state(&self) -> bool { - true - } - - fn size(&self) -> usize { - self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() - } -} - #[cfg(test)] mod tests { use super::*; - use arrow::array::Float64Array; + use arrow::array::types::Int64Type; + use arrow::array::{Array, BooleanArray, Float64Array, PrimitiveArray}; + use datafusion_expr::EmitTo; - fn make_acc() -> AvgGroupsAccumulator Result> { - AvgGroupsAccumulator::::new(&DataType::Float64, |sum, count| { - Ok(sum / count as f64) - }) + fn make_acc() -> Box { + spark_avg_groups_accumulator(&DataType::Float64) } fn assert_validity(array: &dyn Array, expected: &[bool]) { From a525c0a468e8f79762503644e5688028c5cfa7e2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 16:34:23 +0800 Subject: [PATCH 06/12] feat: enhance accumulator functionality - Added private aliases for `BuiltInAvgGroupsAccumulator` and `SparkAvgGroupsAccumulator` - Introduced private constructor helper `new_builtin_avg_groups_accumulator` - Replaced raw true/false generic call-site noise for clarity - Simplified built-in AVG test helper return types to `Box` - Consolidated repeated Spark state assertions for maintainability --- datafusion/functions-aggregate/src/average.rs | 192 ++++++++---------- .../spark/src/function/aggregate/avg.rs | 33 ++- 2 files changed, 94 insertions(+), 131 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index afe6ceeb1d4eb..6b09846749240 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -355,16 +355,13 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type match (data_type, args.return_field.data_type()) { - (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::< - Float64Type, - UInt64Type, - _, - true, - >::new( - data_type, - args.return_field.data_type(), - |sum: f64, count: u64| Ok(sum / count as f64), - ))), + (Float64, Float64) => { + Ok(new_builtin_avg_groups_accumulator::( + data_type, + args.return_field.data_type(), + |sum: f64, count: u64| Ok(sum / count as f64), + )) + } ( Decimal32(_sum_precision, sum_scale), Decimal32(target_precision, target_scale), @@ -378,14 +375,11 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); - Ok(Box::new(AvgGroupsAccumulator::< - Decimal32Type, - UInt64Type, - _, - true, - >::new( - data_type, args.return_field.data_type(), avg_fn - ))) + Ok(new_builtin_avg_groups_accumulator::( + data_type, + args.return_field.data_type(), + avg_fn, + )) } ( Decimal64(_sum_precision, sum_scale), @@ -400,14 +394,11 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); - Ok(Box::new(AvgGroupsAccumulator::< - Decimal64Type, - UInt64Type, - _, - true, - >::new( - data_type, args.return_field.data_type(), avg_fn - ))) + Ok(new_builtin_avg_groups_accumulator::( + data_type, + args.return_field.data_type(), + avg_fn, + )) } ( Decimal128(_sum_precision, sum_scale), @@ -422,14 +413,11 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); - Ok(Box::new(AvgGroupsAccumulator::< - Decimal128Type, - UInt64Type, - _, - true, - >::new( - data_type, args.return_field.data_type(), avg_fn - ))) + Ok(new_builtin_avg_groups_accumulator::( + data_type, + args.return_field.data_type(), + avg_fn, + )) } ( @@ -446,60 +434,41 @@ impl AggregateUDFImpl for Avg { decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) }; - Ok(Box::new(AvgGroupsAccumulator::< - Decimal256Type, - UInt64Type, - _, - true, - >::new( - data_type, args.return_field.data_type(), avg_fn - ))) + Ok(new_builtin_avg_groups_accumulator::( + data_type, + args.return_field.data_type(), + avg_fn, + )) } (Duration(time_unit), Duration(_result_unit)) => { let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); match time_unit { - TimeUnit::Second => { - Ok(Box::new(AvgGroupsAccumulator::< - DurationSecondType, - UInt64Type, - _, - true, - >::new( - data_type, args.return_type(), avg_fn - ))) - } - TimeUnit::Millisecond => { - Ok(Box::new(AvgGroupsAccumulator::< - DurationMillisecondType, - UInt64Type, - _, - true, - >::new( - data_type, args.return_type(), avg_fn - ))) - } - TimeUnit::Microsecond => { - Ok(Box::new(AvgGroupsAccumulator::< - DurationMicrosecondType, - UInt64Type, - _, - true, - >::new( - data_type, args.return_type(), avg_fn - ))) - } - TimeUnit::Nanosecond => { - Ok(Box::new(AvgGroupsAccumulator::< - DurationNanosecondType, - UInt64Type, - _, - true, - >::new( - data_type, args.return_type(), avg_fn - ))) - } + TimeUnit::Second => Ok(new_builtin_avg_groups_accumulator::< + DurationSecondType, + _, + >( + data_type, args.return_type(), avg_fn + )), + TimeUnit::Millisecond => Ok(new_builtin_avg_groups_accumulator::< + DurationMillisecondType, + _, + >( + data_type, args.return_type(), avg_fn + )), + TimeUnit::Microsecond => Ok(new_builtin_avg_groups_accumulator::< + DurationMicrosecondType, + _, + >( + data_type, args.return_type(), avg_fn + )), + TimeUnit::Nanosecond => Ok(new_builtin_avg_groups_accumulator::< + DurationNanosecondType, + _, + >( + data_type, args.return_type(), avg_fn + )), } } @@ -819,6 +788,10 @@ where avg_fn: F, } +type BuiltInAvgGroupsAccumulator = AvgGroupsAccumulator; +type SparkAvgGroupsAccumulator = + AvgGroupsAccumulator; + impl AvgGroupsAccumulator where @@ -854,6 +827,22 @@ where } } +fn new_builtin_avg_groups_accumulator( + sum_data_type: &DataType, + return_data_type: &DataType, + avg_fn: F, +) -> Box +where + T: ArrowNumericType + Send + 'static, + F: Fn(T::Native, u64) -> Result + Send + 'static, +{ + Box::new(BuiltInAvgGroupsAccumulator::::new( + sum_data_type, + return_data_type, + avg_fn, + )) +} + impl GroupsAccumulator for AvgGroupsAccumulator where @@ -1038,13 +1027,11 @@ where pub fn spark_avg_groups_accumulator( return_data_type: &DataType, ) -> Box { - Box::new( - AvgGroupsAccumulator::::new( - return_data_type, - return_data_type, - |sum: f64, count: i64| Ok(sum / count as f64), - ), - ) + Box::new(SparkAvgGroupsAccumulator::<_>::new( + return_data_type, + return_data_type, + |sum: f64, count: i64| Ok(sum / count as f64), + )) } #[cfg(test)] @@ -1070,26 +1057,16 @@ mod tests { ) } - fn float64_acc() -> AvgGroupsAccumulator< - Float64Type, - UInt64Type, - impl Fn(f64, u64) -> Result, - true, - > { - AvgGroupsAccumulator::::new( + fn float64_acc() -> Box { + new_builtin_avg_groups_accumulator::( &DataType::Float64, &DataType::Float64, |sum, count| Ok(sum / count as f64), ) } - fn decimal128_acc() -> AvgGroupsAccumulator< - Decimal128Type, - UInt64Type, - impl Fn(i128, u64) -> Result, - true, - > { - AvgGroupsAccumulator::::new( + fn decimal128_acc() -> Box { + new_builtin_avg_groups_accumulator::( &DataType::Decimal128(10, 2), &DataType::Decimal128(14, 6), // convert_to_state does not evaluate averages, so avg_fn is unused here. @@ -1097,13 +1074,8 @@ mod tests { ) } - fn duration_acc() -> AvgGroupsAccumulator< - DurationNanosecondType, - UInt64Type, - impl Fn(i64, u64) -> Result, - true, - > { - AvgGroupsAccumulator::::new( + fn duration_acc() -> Box { + new_builtin_avg_groups_accumulator::( &DataType::Duration(TimeUnit::Nanosecond), &DataType::Duration(TimeUnit::Nanosecond), |sum, count| Ok(sum / count as i64), diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 70aa42968fa82..07e1d9c3af2df 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -226,6 +226,15 @@ mod tests { ) } + fn assert_state_validity_and_counts(state: &[ArrayRef], expected_validity: &[bool]) { + let (sums, counts) = spark_avg_state(state); + let expected_counts: Vec = vec![1; expected_validity.len()]; + + assert_validity(sums, expected_validity); + assert_eq!(counts.values().as_ref(), expected_counts.as_slice()); + assert_validity(counts, expected_validity); + } + #[test] fn supports_convert_to_state() { assert!(make_acc().supports_convert_to_state()); @@ -256,13 +265,7 @@ mod tests { ]))]; let state = acc.convert_to_state(&values, None).unwrap(); - let (sums, counts) = spark_avg_state(&state); - - assert_validity(sums, &[true, false, true]); - - assert_eq!(counts.value(0), 1); - assert_validity(counts, &[true, false, true]); - assert_eq!(counts.value(2), 1); + assert_state_validity_and_counts(&state, &[true, false, true]); } #[test] @@ -273,13 +276,7 @@ mod tests { let filter = BooleanArray::from(vec![true, false, true]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let (sums, counts) = spark_avg_state(&state); - - assert_validity(sums, &[true, false, true]); - - assert_eq!(counts.value(0), 1); - assert_validity(counts, &[true, false, true]); - assert_eq!(counts.value(2), 1); + assert_state_validity_and_counts(&state, &[true, false, true]); } #[test] @@ -290,13 +287,7 @@ mod tests { let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let (sums, counts) = spark_avg_state(&state); - - assert_validity(sums, &[true, false, true]); - - assert_eq!(counts.value(0), 1); - assert_validity(counts, &[true, false, true]); - assert_eq!(counts.value(2), 1); + assert_state_validity_and_counts(&state, &[true, false, true]); } #[test] From fcd0fbd28b30816cf3c473f6ffbffddde808b868 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 17:02:33 +0800 Subject: [PATCH 07/12] feat: update AvgGroupsAccumulator visibility and enhance Spark testing - Made `AvgGroupsAccumulator::new` private. - Improved Spark test helper to avoid temporary Vec. - Added end-to-end Spark AVG SQL test covering NULL values with FILTER. - Introduced `tokio` development dependency and enabled DataFusion SQL feature for development. - Updated `Cargo.lock` for the new development dependency. --- Cargo.lock | 1 + datafusion/functions-aggregate/src/average.rs | 2 +- datafusion/spark/Cargo.toml | 3 +- .../spark/src/function/aggregate/avg.rs | 29 +++++++++++++++++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c08be6f29ffd7..47dbd663ad27a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2580,6 +2580,7 @@ dependencies = [ "serde_json", "sha1 0.11.0", "sha2", + "tokio", "twox-hash", "url", ] diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 6b09846749240..538598f605b80 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -799,7 +799,7 @@ where CountType: ArrowPrimitiveType + Send, F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { - pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { + fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { debug!( "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}", std::any::type_name::() diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 14f9396d7656e..291b986a6236f 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -71,7 +71,8 @@ url = { workspace = true } arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } # for SessionStateBuilderSpark tests -datafusion = { workspace = true, default-features = false } +datafusion = { workspace = true, default-features = false, features = ["sql"] } +tokio = { workspace = true, features = ["macros", "rt"] } [[bench]] harness = false diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 07e1d9c3af2df..dca711ab84b97 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -228,10 +228,9 @@ mod tests { fn assert_state_validity_and_counts(state: &[ArrayRef], expected_validity: &[bool]) { let (sums, counts) = spark_avg_state(state); - let expected_counts: Vec = vec![1; expected_validity.len()]; assert_validity(sums, expected_validity); - assert_eq!(counts.values().as_ref(), expected_counts.as_slice()); + assert!(counts.values().iter().all(|&count| count == 1)); assert_validity(counts, expected_validity); } @@ -306,6 +305,32 @@ mod tests { assert_eq!(result.value(0), 10.0); } + #[tokio::test] + async fn spark_avg_query_applies_nulls_and_filter() -> Result<()> { + use datafusion::prelude::SessionContext; + + let mut ctx = SessionContext::new(); + crate::register_all(&mut ctx)?; + + let batches = ctx + .sql( + "SELECT avg(v) FILTER (WHERE keep) AS avg_v \ + FROM (VALUES \ + (CAST(1.0 AS DOUBLE), true), \ + (CAST(NULL AS DOUBLE), true), \ + (CAST(3.0 AS DOUBLE), false), \ + (CAST(5.0 AS DOUBLE), CAST(NULL AS BOOLEAN)) \ + ) AS t(v, keep)", + ) + .await? + .collect() + .await?; + + let result = batches[0].column(0).as_primitive::(); + assert_eq!(result.value(0), 1.0); + Ok(()) + } + #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc(); From 5cc3a837d99d04e5c9294747c5777efcd440f525 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 17:17:34 +0800 Subject: [PATCH 08/12] feat(layout): replace COUNT_FIRST with typed layout abstractions for AVG state - Introduced AvgStateLayout, BuiltInAvgLayout, and SparkAvgLayout - Centralized state ordering in layout implementations: - built-in: [count, sum] - Spark: [sum, count] - Added Spark AVG SLT regression tests for grouped AVG with NULL input and FILTER in the test files --- datafusion/functions-aggregate/src/average.rs | 129 +++++++++++++----- .../test_files/spark/aggregate/avg.slt | 21 +++ 2 files changed, 116 insertions(+), 34 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 538598f605b80..a033906e02a24 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -52,6 +52,7 @@ use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; use std::fmt::Debug; +use std::marker::PhantomData; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -759,14 +760,13 @@ impl Accumulator for DurationAvgAccumulator { /// F: Function that calculates the average value from a sum of /// T::Native and a total count. /// -/// `COUNT_FIRST` controls the state field order: -/// * `true`: `[count, sum]` for built-in AVG -/// * `false`: `[sum, count]` for Spark AVG +/// `Layout` controls the state field order. #[derive(Debug)] -struct AvgGroupsAccumulator +struct AvgGroupsAccumulator where T: ArrowNumericType + Send, CountType: ArrowPrimitiveType + Send, + Layout: AvgStateLayout, F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { /// The type of the internal sum @@ -786,17 +786,96 @@ where /// Function that computes the final average (value / count) avg_fn: F, + + /// AVG state layout marker. + layout: PhantomData, +} + +#[derive(Debug)] +struct BuiltInAvgLayout; + +#[derive(Debug)] +struct SparkAvgLayout; + +trait AvgStateLayout: Debug + Send + 'static { + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType; + + fn partial_state<'a, T, CountType>( + values: &'a [ArrayRef], + ) -> (&'a PrimitiveArray, &'a PrimitiveArray) + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType; } -type BuiltInAvgGroupsAccumulator = AvgGroupsAccumulator; +impl AvgStateLayout for BuiltInAvgLayout { + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + vec![Arc::new(counts) as ArrayRef, Arc::new(sums) as ArrayRef] + } + + fn partial_state<'a, T, CountType>( + values: &'a [ArrayRef], + ) -> (&'a PrimitiveArray, &'a PrimitiveArray) + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + ( + values[0].as_primitive::(), + values[1].as_primitive::(), + ) + } +} + +impl AvgStateLayout for SparkAvgLayout { + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + vec![Arc::new(sums) as ArrayRef, Arc::new(counts) as ArrayRef] + } + + fn partial_state<'a, T, CountType>( + values: &'a [ArrayRef], + ) -> (&'a PrimitiveArray, &'a PrimitiveArray) + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + ( + values[1].as_primitive::(), + values[0].as_primitive::(), + ) + } +} + +type BuiltInAvgGroupsAccumulator = + AvgGroupsAccumulator; type SparkAvgGroupsAccumulator = - AvgGroupsAccumulator; + AvgGroupsAccumulator; -impl - AvgGroupsAccumulator +impl AvgGroupsAccumulator where T: ArrowNumericType + Send, CountType: ArrowPrimitiveType + Send, + Layout: AvgStateLayout, F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { @@ -812,17 +891,7 @@ where sums: vec![], null_state: NullState::new(), avg_fn, - } - } - - fn state_arrays( - sums: PrimitiveArray, - counts: PrimitiveArray, - ) -> Vec { - if COUNT_FIRST { - vec![Arc::new(counts) as ArrayRef, Arc::new(sums) as ArrayRef] - } else { - vec![Arc::new(sums) as ArrayRef, Arc::new(counts) as ArrayRef] + layout: PhantomData, } } } @@ -843,11 +912,12 @@ where )) } -impl GroupsAccumulator - for AvgGroupsAccumulator +impl GroupsAccumulator + for AvgGroupsAccumulator where T: ArrowNumericType + Send, CountType: ArrowPrimitiveType + Send, + Layout: AvgStateLayout, F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { fn update_batch( @@ -933,7 +1003,7 @@ where let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy .with_data_type(self.sum_data_type.clone()); - Ok(Self::state_arrays(sums, counts)) + Ok(Layout::state_arrays(sums, counts)) } fn merge_batch( @@ -944,17 +1014,8 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); - let (partial_counts, partial_sums) = if COUNT_FIRST { - ( - values[0].as_primitive::(), - values[1].as_primitive::(), - ) - } else { - ( - values[1].as_primitive::(), - values[0].as_primitive::(), - ) - }; + let (partial_counts, partial_sums) = + Layout::partial_state::(values); // update counts with partial counts self.counts @@ -1006,7 +1067,7 @@ where let counts = set_nulls(counts, nulls.clone()); let sums = set_nulls(sums, nulls); - Ok(Self::state_arrays(sums, counts)) + Ok(Layout::state_arrays(sums, counts)) } fn supports_convert_to_state(&self) -> bool { diff --git a/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt b/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt index 6ae647989aee9..2c7a685193571 100644 --- a/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt +++ b/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt @@ -54,3 +54,24 @@ GROUP BY a ORDER BY a; ---- 0 0 + +# Regression coverage for Spark AVG grouped state conversion with NULL input and FILTER. +# Spark AVG uses state order [sum, count] with Int64 count; this query exercises +# planner/executor partial/final aggregation rather than only accumulator internals. +query IR +SELECT g, avg(v) FILTER (WHERE keep) AS avg_v +FROM (VALUES + (1, CAST(1.0 AS DOUBLE), true), + (1, CAST(NULL AS DOUBLE), true), + (1, CAST(3.0 AS DOUBLE), false), + (1, CAST(5.0 AS DOUBLE), CAST(NULL AS BOOLEAN)), + (2, CAST(10.0 AS DOUBLE), true), + (2, CAST(30.0 AS DOUBLE), true), + (2, CAST(50.0 AS DOUBLE), false), + (2, CAST(NULL AS DOUBLE), true) +) AS t(g, v, keep) +GROUP BY g +ORDER BY g; +---- +1 1 +2 20 From 23c114dbca254acac0ce10a950e531daa02fa3b5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 17:30:17 +0800 Subject: [PATCH 09/12] feat: rename new_builtin_avg_groups_accumulator to create_builtin_avg_groups_accumulator in average.rs - Updated all references in production call sites for create_groups_accumulator - Updated all references in test helper call sites --- datafusion/functions-aggregate/src/average.rs | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index a033906e02a24..fa0058b463967 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -357,7 +357,7 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type match (data_type, args.return_field.data_type()) { (Float64, Float64) => { - Ok(new_builtin_avg_groups_accumulator::( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), @@ -376,7 +376,7 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); - Ok(new_builtin_avg_groups_accumulator::( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, @@ -395,7 +395,7 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); - Ok(new_builtin_avg_groups_accumulator::( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, @@ -414,7 +414,7 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); - Ok(new_builtin_avg_groups_accumulator::( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, @@ -435,7 +435,7 @@ impl AggregateUDFImpl for Avg { decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) }; - Ok(new_builtin_avg_groups_accumulator::( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, @@ -446,25 +446,25 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); match time_unit { - TimeUnit::Second => Ok(new_builtin_avg_groups_accumulator::< + TimeUnit::Second => Ok(create_builtin_avg_groups_accumulator::< DurationSecondType, _, >( data_type, args.return_type(), avg_fn )), - TimeUnit::Millisecond => Ok(new_builtin_avg_groups_accumulator::< + TimeUnit::Millisecond => Ok(create_builtin_avg_groups_accumulator::< DurationMillisecondType, _, >( data_type, args.return_type(), avg_fn )), - TimeUnit::Microsecond => Ok(new_builtin_avg_groups_accumulator::< + TimeUnit::Microsecond => Ok(create_builtin_avg_groups_accumulator::< DurationMicrosecondType, _, >( data_type, args.return_type(), avg_fn )), - TimeUnit::Nanosecond => Ok(new_builtin_avg_groups_accumulator::< + TimeUnit::Nanosecond => Ok(create_builtin_avg_groups_accumulator::< DurationNanosecondType, _, >( @@ -896,7 +896,7 @@ where } } -fn new_builtin_avg_groups_accumulator( +fn create_builtin_avg_groups_accumulator( sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F, @@ -1119,7 +1119,7 @@ mod tests { } fn float64_acc() -> Box { - new_builtin_avg_groups_accumulator::( + create_builtin_avg_groups_accumulator::( &DataType::Float64, &DataType::Float64, |sum, count| Ok(sum / count as f64), @@ -1127,7 +1127,7 @@ mod tests { } fn decimal128_acc() -> Box { - new_builtin_avg_groups_accumulator::( + create_builtin_avg_groups_accumulator::( &DataType::Decimal128(10, 2), &DataType::Decimal128(14, 6), // convert_to_state does not evaluate averages, so avg_fn is unused here. @@ -1136,7 +1136,7 @@ mod tests { } fn duration_acc() -> Box { - new_builtin_avg_groups_accumulator::( + create_builtin_avg_groups_accumulator::( &DataType::Duration(TimeUnit::Nanosecond), &DataType::Duration(TimeUnit::Nanosecond), |sum, count| Ok(sum / count as i64), From 9d46311c7306914c87ae70814a6866184da05073 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 17:50:13 +0800 Subject: [PATCH 10/12] chore: clean up dependencies and update average function - Removed async SQL unit test from avg.rs - Updated Cargo.toml to remove tokio dev-dependency and datafusion dev sql feature - Updated Cargo.lock to remove tokio from datafusion-spark dependencies - Renamed partial_state to decode_partial_state in average.rs --- Cargo.lock | 1 - datafusion/functions-aggregate/src/average.rs | 8 +++--- datafusion/spark/Cargo.toml | 3 +-- .../spark/src/function/aggregate/avg.rs | 26 ------------------- 4 files changed, 5 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 47dbd663ad27a..c08be6f29ffd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2580,7 +2580,6 @@ dependencies = [ "serde_json", "sha1 0.11.0", "sha2", - "tokio", "twox-hash", "url", ] diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index fa0058b463967..6a547fb04f3fe 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -806,7 +806,7 @@ trait AvgStateLayout: Debug + Send + 'static { T: ArrowPrimitiveType, CountType: ArrowPrimitiveType; - fn partial_state<'a, T, CountType>( + fn decode_partial_state<'a, T, CountType>( values: &'a [ArrayRef], ) -> (&'a PrimitiveArray, &'a PrimitiveArray) where @@ -826,7 +826,7 @@ impl AvgStateLayout for BuiltInAvgLayout { vec![Arc::new(counts) as ArrayRef, Arc::new(sums) as ArrayRef] } - fn partial_state<'a, T, CountType>( + fn decode_partial_state<'a, T, CountType>( values: &'a [ArrayRef], ) -> (&'a PrimitiveArray, &'a PrimitiveArray) where @@ -852,7 +852,7 @@ impl AvgStateLayout for SparkAvgLayout { vec![Arc::new(sums) as ArrayRef, Arc::new(counts) as ArrayRef] } - fn partial_state<'a, T, CountType>( + fn decode_partial_state<'a, T, CountType>( values: &'a [ArrayRef], ) -> (&'a PrimitiveArray, &'a PrimitiveArray) where @@ -1015,7 +1015,7 @@ where ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); let (partial_counts, partial_sums) = - Layout::partial_state::(values); + Layout::decode_partial_state::(values); // update counts with partial counts self.counts diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 291b986a6236f..14f9396d7656e 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -71,8 +71,7 @@ url = { workspace = true } arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } # for SessionStateBuilderSpark tests -datafusion = { workspace = true, default-features = false, features = ["sql"] } -tokio = { workspace = true, features = ["macros", "rt"] } +datafusion = { workspace = true, default-features = false } [[bench]] harness = false diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index dca711ab84b97..cab69d6a33a50 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -305,32 +305,6 @@ mod tests { assert_eq!(result.value(0), 10.0); } - #[tokio::test] - async fn spark_avg_query_applies_nulls_and_filter() -> Result<()> { - use datafusion::prelude::SessionContext; - - let mut ctx = SessionContext::new(); - crate::register_all(&mut ctx)?; - - let batches = ctx - .sql( - "SELECT avg(v) FILTER (WHERE keep) AS avg_v \ - FROM (VALUES \ - (CAST(1.0 AS DOUBLE), true), \ - (CAST(NULL AS DOUBLE), true), \ - (CAST(3.0 AS DOUBLE), false), \ - (CAST(5.0 AS DOUBLE), CAST(NULL AS BOOLEAN)) \ - ) AS t(v, keep)", - ) - .await? - .collect() - .await?; - - let result = batches[0].column(0).as_primitive::(); - assert_eq!(result.value(0), 1.0); - Ok(()) - } - #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc(); From 173b54d155d7830b70174b0ea96309a5f7aea791 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 18:26:14 +0800 Subject: [PATCH 11/12] fix: remove unnecessary lifetimes in decode_partial_state for improved Clippy compliance --- datafusion/functions-aggregate/src/average.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 6a547fb04f3fe..2a3be34a16035 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -806,9 +806,9 @@ trait AvgStateLayout: Debug + Send + 'static { T: ArrowPrimitiveType, CountType: ArrowPrimitiveType; - fn decode_partial_state<'a, T, CountType>( - values: &'a [ArrayRef], - ) -> (&'a PrimitiveArray, &'a PrimitiveArray) + fn decode_partial_state( + values: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) where T: ArrowPrimitiveType, CountType: ArrowPrimitiveType; @@ -826,9 +826,9 @@ impl AvgStateLayout for BuiltInAvgLayout { vec![Arc::new(counts) as ArrayRef, Arc::new(sums) as ArrayRef] } - fn decode_partial_state<'a, T, CountType>( - values: &'a [ArrayRef], - ) -> (&'a PrimitiveArray, &'a PrimitiveArray) + fn decode_partial_state( + values: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) where T: ArrowPrimitiveType, CountType: ArrowPrimitiveType, @@ -852,9 +852,9 @@ impl AvgStateLayout for SparkAvgLayout { vec![Arc::new(sums) as ArrayRef, Arc::new(counts) as ArrayRef] } - fn decode_partial_state<'a, T, CountType>( - values: &'a [ArrayRef], - ) -> (&'a PrimitiveArray, &'a PrimitiveArray) + fn decode_partial_state( + values: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) where T: ArrowPrimitiveType, CountType: ArrowPrimitiveType, From 404e29fbe1b9f111cf5bdce1877d6b21a7e65345 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 15 Jun 2026 18:30:41 +0800 Subject: [PATCH 12/12] chore: remove unused datafusion-functions-aggregate-common dependency from Cargo.toml and Cargo.lock --- Cargo.lock | 1 - datafusion/spark/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c08be6f29ffd7..c67a039f014cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2571,7 +2571,6 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "datafusion-functions-aggregate-common", "datafusion-functions-nested", "log", "num-traits", diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 14f9396d7656e..971d557439856 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -55,7 +55,6 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } num-traits = { workspace = true }