diff --git a/Cargo.lock b/Cargo.lock index c08be6f29ffd7..c67a039f014cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2571,7 +2571,6 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "datafusion-functions-aggregate-common", "datafusion-functions-nested", "log", "num-traits", diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddeb9b0870a16..2a3be34a16035 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,7 +19,7 @@ use arrow::array::{ Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray, - BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + BooleanArray, PrimitiveArray, PrimitiveBuilder, }; use arrow::compute::sum; @@ -29,7 +29,8 @@ use arrow::datatypes::{ DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256, + DurationSecondType, Field, FieldRef, Float64Type, Int64Type, TimeUnit, UInt64Type, + i256, }; use datafusion_common::types::{NativeType, logical_float64}; use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; @@ -51,6 +52,7 @@ use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; use std::fmt::Debug; +use std::marker::PhantomData; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -355,11 +357,11 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type match (data_type, args.return_field.data_type()) { (Float64, Float64) => { - Ok(Box::new(AvgGroupsAccumulator::::new( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), - ))) + )) } ( Decimal32(_sum_precision, sum_scale), @@ -374,11 +376,11 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); - Ok(Box::new(AvgGroupsAccumulator::::new( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, - ))) + )) } ( Decimal64(_sum_precision, sum_scale), @@ -393,11 +395,11 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); - Ok(Box::new(AvgGroupsAccumulator::::new( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, - ))) + )) } ( Decimal128(_sum_precision, sum_scale), @@ -412,11 +414,11 @@ impl AggregateUDFImpl for Avg { let avg_fn = move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); - Ok(Box::new(AvgGroupsAccumulator::::new( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, - ))) + )) } ( @@ -433,49 +435,41 @@ impl AggregateUDFImpl for Avg { decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) }; - Ok(Box::new(AvgGroupsAccumulator::::new( + Ok(create_builtin_avg_groups_accumulator::( data_type, args.return_field.data_type(), avg_fn, - ))) + )) } (Duration(time_unit), Duration(_result_unit)) => { let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); match time_unit { - TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< + TimeUnit::Second => Ok(create_builtin_avg_groups_accumulator::< DurationSecondType, _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), - TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< + >( + data_type, args.return_type(), avg_fn + )), + TimeUnit::Millisecond => Ok(create_builtin_avg_groups_accumulator::< DurationMillisecondType, _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), - TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< + >( + data_type, args.return_type(), avg_fn + )), + TimeUnit::Microsecond => Ok(create_builtin_avg_groups_accumulator::< DurationMicrosecondType, _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), - TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< + >( + data_type, args.return_type(), avg_fn + )), + TimeUnit::Nanosecond => Ok(create_builtin_avg_groups_accumulator::< DurationNanosecondType, _, - >::new( - data_type, - args.return_type(), - avg_fn, - ))), + >( + data_type, args.return_type(), avg_fn + )), } } @@ -764,12 +758,16 @@ impl Accumulator for DurationAvgAccumulator { /// Stores values as native types, and does overflow checking /// /// F: Function that calculates the average value from a sum of -/// T::Native and a total count +/// T::Native and a total count. +/// +/// `Layout` controls the state field order. #[derive(Debug)] -struct AvgGroupsAccumulator +struct AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + CountType: ArrowPrimitiveType + Send, + Layout: AvgStateLayout, + F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { /// The type of the internal sum sum_data_type: DataType, @@ -777,8 +775,8 @@ where /// The type of the returned sum return_data_type: DataType, - /// Count per group (use u64 to make UInt64Array) - counts: Vec, + /// Count per group + counts: Vec, /// Sums per group, stored as the native type sums: Vec, @@ -788,14 +786,99 @@ where /// Function that computes the final average (value / count) avg_fn: F, + + /// AVG state layout marker. + layout: PhantomData, +} + +#[derive(Debug)] +struct BuiltInAvgLayout; + +#[derive(Debug)] +struct SparkAvgLayout; + +trait AvgStateLayout: Debug + Send + 'static { + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType; + + fn decode_partial_state( + values: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType; +} + +impl AvgStateLayout for BuiltInAvgLayout { + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + vec![Arc::new(counts) as ArrayRef, Arc::new(sums) as ArrayRef] + } + + fn decode_partial_state( + values: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + ( + values[0].as_primitive::(), + values[1].as_primitive::(), + ) + } +} + +impl AvgStateLayout for SparkAvgLayout { + fn state_arrays( + sums: PrimitiveArray, + counts: PrimitiveArray, + ) -> Vec + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + vec![Arc::new(sums) as ArrayRef, Arc::new(counts) as ArrayRef] + } + + fn decode_partial_state( + values: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) + where + T: ArrowPrimitiveType, + CountType: ArrowPrimitiveType, + { + ( + values[1].as_primitive::(), + values[0].as_primitive::(), + ) + } } -impl AvgGroupsAccumulator +type BuiltInAvgGroupsAccumulator = + AvgGroupsAccumulator; +type SparkAvgGroupsAccumulator = + AvgGroupsAccumulator; + +impl AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + CountType: ArrowPrimitiveType + Send, + Layout: AvgStateLayout, + F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { - pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { + fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { debug!( "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}", std::any::type_name::() @@ -808,14 +891,34 @@ where sums: vec![], null_state: NullState::new(), avg_fn, + layout: PhantomData, } } } -impl GroupsAccumulator for AvgGroupsAccumulator +fn create_builtin_avg_groups_accumulator( + sum_data_type: &DataType, + return_data_type: &DataType, + avg_fn: F, +) -> Box where - T: ArrowNumericType + Send, + T: ArrowNumericType + Send + 'static, F: Fn(T::Native, u64) -> Result + Send + 'static, +{ + Box::new(BuiltInAvgGroupsAccumulator::::new( + sum_data_type, + return_data_type, + avg_fn, + )) +} + +impl GroupsAccumulator + for AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + CountType: ArrowPrimitiveType + Send, + Layout: AvgStateLayout, + F: Fn(T::Native, CountType::Native) -> Result + Send + 'static, { fn update_batch( &mut self, @@ -828,7 +931,8 @@ where let values = values[0].as_primitive::(); // increment counts, update sums - self.counts.resize(total_num_groups, 0); + self.counts + .resize(total_num_groups, CountType::Native::usize_as(0)); self.sums.resize(total_num_groups, T::default_value()); self.null_state.accumulate( group_indices, @@ -840,7 +944,8 @@ where let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); - self.counts[group_index] += 1; + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count = count.add_wrapping(CountType::Native::usize_as(1)); }, ); @@ -892,16 +997,13 @@ where let nulls = self.null_state.build(emit_to); let counts = emit_to.take_needed(&mut self.counts); - let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy + let counts = PrimitiveArray::::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums); let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy .with_data_type(self.sum_data_type.clone()); - Ok(vec![ - Arc::new(counts) as ArrayRef, - Arc::new(sums) as ArrayRef, - ]) + Ok(Layout::state_arrays(sums, counts)) } fn merge_batch( @@ -912,11 +1014,12 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - let partial_sums = values[1].as_primitive::(); + let (partial_counts, partial_sums) = + Layout::decode_partial_state::(values); + // update counts with partial counts - self.counts.resize(total_num_groups, 0); + self.counts + .resize(total_num_groups, CountType::Native::usize_as(0)); self.null_state.accumulate( group_indices, partial_counts, @@ -925,7 +1028,7 @@ where |group_index, partial_count| { // SAFETY: group_index is guaranteed to be in bounds let count = unsafe { self.counts.get_unchecked_mut(group_index) }; - *count += partial_count; + *count = count.add_wrapping(partial_count); }, ); @@ -955,15 +1058,16 @@ where .as_primitive::() .clone() .with_data_type(self.sum_data_type.clone()); - let counts = UInt64Array::from_value(1, sums.len()); + let counts = PrimitiveArray::::from_value( + CountType::Native::usize_as(1), + sums.len(), + ); let nulls = filtered_null_mask(opt_filter, &sums); - - // set nulls on the arrays let counts = set_nulls(counts, nulls.clone()); let sums = set_nulls(sums, nulls); - Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) + Ok(Layout::state_arrays(sums, counts)) } fn supports_convert_to_state(&self) -> bool { @@ -971,6 +1075,132 @@ where } fn size(&self) -> usize { - self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() + self.counts.capacity() * size_of::() + + self.sums.capacity() * size_of::() + + self.null_state.size() + } +} + +/// Creates the Spark AVG grouped accumulator using the built-in AVG state +/// handling implementation while preserving Spark's `[sum, count]` state order +/// and signed count type. +#[doc(hidden)] +pub fn spark_avg_groups_accumulator( + return_data_type: &DataType, +) -> Box { + Box::new(SparkAvgGroupsAccumulator::<_>::new( + return_data_type, + return_data_type, + |sum: f64, count: i64| Ok(sum / count as f64), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::{Decimal128Type, DurationNanosecondType}; + + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + + fn avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + + fn float64_acc() -> Box { + create_builtin_avg_groups_accumulator::( + &DataType::Float64, + &DataType::Float64, + |sum, count| Ok(sum / count as f64), + ) + } + + fn decimal128_acc() -> Box { + create_builtin_avg_groups_accumulator::( + &DataType::Decimal128(10, 2), + &DataType::Decimal128(14, 6), + // convert_to_state does not evaluate averages, so avg_fn is unused here. + |sum, _count| Ok(sum), + ) + } + + fn duration_acc() -> Box { + create_builtin_avg_groups_accumulator::( + &DataType::Duration(TimeUnit::Nanosecond), + &DataType::Duration(TimeUnit::Nanosecond), + |sum, count| Ok(sum / count as i64), + ) + } + + #[test] + fn float64_convert_to_state_uses_count_sum_order_and_null_filter() { + let acc = float64_acc(); + let values: Vec = vec![Arc::new(Float64Array::from(vec![ + Some(10.0), + Some(20.0), + None, + Some(40.0), + ]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let (counts, sums) = avg_state::(&state); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + assert_validity(counts, &[true, false, false, false]); + assert_validity(sums, &[true, false, false, false]); + } + + #[test] + fn decimal_convert_to_state_preserves_sum_type_and_nulls() { + let acc = decimal128_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(100_i128), + None, + Some(300_i128), + ]) + .with_data_type(DataType::Decimal128(10, 2)), + )]; + + let state = acc.convert_to_state(&values, None).unwrap(); + + let (counts, sums) = avg_state::(&state); + assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + assert_validity(counts, &[true, false, true]); + assert_validity(sums, &[true, false, true]); + } + + #[test] + fn duration_convert_to_state_preserves_sum_type_and_applies_filter() { + let acc = duration_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(10_i64), + Some(20_i64), + ]) + .with_data_type(DataType::Duration(TimeUnit::Nanosecond)), + )]; + let filter = BooleanArray::from(vec![Some(false), Some(true)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let (counts, sums) = avg_state::(&state); + assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(counts.values().as_ref(), &[1, 1]); + assert_validity(counts, &[false, true]); + assert_validity(sums, &[false, true]); } } diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 14f9396d7656e..971d557439856 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -55,7 +55,6 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } num-traits = { workspace = true } diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 5f4d2c253a2dc..cab69d6a33a50 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -16,9 +16,7 @@ // under the License. use arrow::array::{ - Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, BooleanArray, Int64Array, - PrimitiveArray, - builder::PrimitiveBuilder, + Array, ArrayRef, cast::AsArray, types::{Float64Type, Int64Type}, }; @@ -29,12 +27,11 @@ use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF, - Signature, TypeSignatureClass, Volatility, -}; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, + Accumulator, AggregateUDFImpl, Coercion, GroupsAccumulator, ReversedUDAF, Signature, + TypeSignatureClass, Volatility, }; +use datafusion_functions_aggregate::average::spark_avg_groups_accumulator; +use std::mem::size_of_val; use std::sync::Arc; /// AVG aggregate expression @@ -130,10 +127,7 @@ impl AggregateUDFImpl for SparkAvg { // instantiate specialized accumulator based for the type match (&data_type, args.return_type()) { (DataType::Float64, DataType::Float64) => { - Ok(Box::new(AvgGroupsAccumulator::::new( - args.return_field.data_type(), - |sum: f64, count: i64| Ok(sum / count as f64), - ))) + Ok(spark_avg_groups_accumulator(args.return_field.data_type())) } (dt, return_type) => { not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})") @@ -204,189 +198,40 @@ impl Accumulator for AvgAccumulator { } } -/// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types, and does overflow checking -/// -/// F: Function that calculates the average value from a sum of -/// T::Native and a total count -#[derive(Debug)] -struct AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send + 'static, -{ - /// The type of the returned average - return_data_type: DataType, - - /// Count per group (use i64 to make Int64Array) - counts: Vec, - - /// Sums per group, stored as the native type - sums: Vec, - - /// Function that computes the final average (value / count) - avg_fn: F, -} - -impl AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send + 'static, -{ - pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { - Self { - return_data_type: return_data_type.clone(), - counts: vec![], - sums: vec![], - avg_fn, - } - } -} - -impl GroupsAccumulator for AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send + 'static, -{ - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); - let data = values.values(); - - // increment counts, update sums - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); - - let iter = group_indices.iter().zip(data.iter()); - if values.null_count() == 0 { - for (&group_index, &value) in iter { - let sum = &mut self.sums[group_index]; - *sum = (*sum).add_wrapping(value); - self.counts[group_index] += 1; - } - } else { - for (idx, (&group_index, &value)) in iter.enumerate() { - if values.is_null(idx) { - continue; - } - let sum = &mut self.sums[group_index]; - *sum = (*sum).add_wrapping(value); - - self.counts[group_index] += 1; - } - } - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is partial sums, second is counts - let partial_sums = values[0].as_primitive::(); - let partial_counts = values[1].as_primitive::(); - - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); - - for (idx, &group_index) in group_indices.iter().enumerate() { - // Skip null state entries emitted by convert_to_state for - // filtered / null input rows. - if partial_counts.is_null(idx) || partial_sums.is_null(idx) { - continue; - } - self.counts[group_index] += partial_counts.value(idx); - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(partial_sums.value(idx)); - } +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::types::Int64Type; + use arrow::array::{Array, BooleanArray, Float64Array, PrimitiveArray}; + use datafusion_expr::EmitTo; - Ok(()) + fn make_acc() -> Box { + spark_avg_groups_accumulator(&DataType::Float64) } - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - let sums = emit_to.take_needed(&mut self.sums); - let mut builder = PrimitiveBuilder::::with_capacity(sums.len()); - let iter = sums.into_iter().zip(counts); - - for (sum, count) in iter { - if count != 0 { - builder.append_value((self.avg_fn)(sum, count)?) - } else { - builder.append_null(); - } + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); } - let array: PrimitiveArray = builder.finish(); - - Ok(Arc::new(array)) - } - - // return arrays for sums and counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts = Int64Array::new(counts.into(), None); - - let sums = emit_to.take_needed(&mut self.sums); - let sums = PrimitiveArray::::new(sums.into(), None) - .with_data_type(self.return_data_type.clone()); - - Ok(vec![ - Arc::new(sums) as ArrayRef, - Arc::new(counts) as ArrayRef, - ]) - } - - fn convert_to_state( - &self, - values: &[ArrayRef], - opt_filter: Option<&BooleanArray>, - ) -> Result> { - let sums = values[0] - .as_primitive::() - .clone() - .with_data_type(self.return_data_type.clone()); - let counts = Int64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); - - // [sum, count] - must match state() and merge_batch() - Ok(vec![ - Arc::new(sums) as ArrayRef, - Arc::new(counts) as ArrayRef, - ]) - } - - fn supports_convert_to_state(&self) -> bool { - true } - fn size(&self) -> usize { - self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() + fn spark_avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) } -} -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::Float64Array; + fn assert_state_validity_and_counts(state: &[ArrayRef], expected_validity: &[bool]) { + let (sums, counts) = spark_avg_state(state); - fn make_acc() -> AvgGroupsAccumulator Result> { - AvgGroupsAccumulator::::new(&DataType::Float64, |sum, count| { - Ok(sum / count as f64) - }) + assert_validity(sums, expected_validity); + assert!(counts.values().iter().all(|&count| count == 1)); + assert_validity(counts, expected_validity); } #[test] @@ -401,9 +246,7 @@ mod tests { vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; let state = acc.convert_to_state(&values, None).unwrap(); - assert_eq!(state.len(), 2); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_eq!(sums.values().as_ref(), &[1.0, 2.0, 3.0]); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); @@ -421,16 +264,7 @@ mod tests { ]))]; let state = acc.convert_to_state(&values, None).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); - - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); - - assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); - assert_eq!(counts.value(2), 1); + assert_state_validity_and_counts(&state, &[true, false, true]); } #[test] @@ -441,16 +275,34 @@ mod tests { let filter = BooleanArray::from(vec![true, false, true]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + assert_state_validity_and_counts(&state, &[true, false, true]); + } + + #[test] + fn convert_to_state_with_null_filter() { + let acc = make_acc(); + let values: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + assert_state_validity_and_counts(&state, &[true, false, true]); + } + + #[test] + fn merge_batch_applies_filter() { + let mut acc = make_acc(); + let input: Vec = + vec![Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0]))]; + let state = acc.convert_to_state(&input, None).unwrap(); + let filter = BooleanArray::from(vec![Some(true), Some(false), None]); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + acc.merge_batch(&state, &[0, 0, 0], Some(&filter), 1) + .unwrap(); - assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); - assert_eq!(counts.value(2), 1); + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 10.0); } #[test] diff --git a/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt b/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt index 6ae647989aee9..2c7a685193571 100644 --- a/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt +++ b/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt @@ -54,3 +54,24 @@ GROUP BY a ORDER BY a; ---- 0 0 + +# Regression coverage for Spark AVG grouped state conversion with NULL input and FILTER. +# Spark AVG uses state order [sum, count] with Int64 count; this query exercises +# planner/executor partial/final aggregation rather than only accumulator internals. +query IR +SELECT g, avg(v) FILTER (WHERE keep) AS avg_v +FROM (VALUES + (1, CAST(1.0 AS DOUBLE), true), + (1, CAST(NULL AS DOUBLE), true), + (1, CAST(3.0 AS DOUBLE), false), + (1, CAST(5.0 AS DOUBLE), CAST(NULL AS BOOLEAN)), + (2, CAST(10.0 AS DOUBLE), true), + (2, CAST(30.0 AS DOUBLE), true), + (2, CAST(50.0 AS DOUBLE), false), + (2, CAST(NULL AS DOUBLE), true) +) AS t(g, v, keep) +GROUP BY g +ORDER BY g; +---- +1 1 +2 20