diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 51c1403a5..e597881d7 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1452,7 +1452,7 @@ where /// Make the array unshared. /// /// This method is mostly only useful with unsafe code. - fn ensure_unique(&mut self) + pub(crate) fn ensure_unique(&mut self) where S: DataMut, { diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 2ac5c08c7..b12b2a727 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -1,3 +1,4 @@ +use num_complex::Complex; use std::mem; use std::ptr::NonNull; @@ -149,6 +150,73 @@ where } } +impl RawArrayView, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + pub fn split_re_im(self) -> Complex> { + // Check that the size and alignment of `Complex` are as expected. + // These assertions should always pass, for arbitrary `T`. + assert_eq!( + mem::size_of::>(), + mem::size_of::().checked_mul(2).unwrap() + ); + assert_eq!(mem::align_of::>(), mem::align_of::()); + + let dim = self.dim.clone(); + + // Double the strides. In the zero-sized element case and for axes of + // length <= 1, we leave the strides as-is to avoid possible overflow. + let mut strides = self.strides.clone(); + if mem::size_of::() != 0 { + for ax in 0..strides.ndim() { + if dim[ax] > 1 { + strides[ax] *= 2; + } + } + } + + let ptr_re: *mut T = self.ptr.as_ptr().cast(); + let ptr_im: *mut T = if self.is_empty() { + // In the empty case, we can just reuse the existing pointer since + // it won't be dereferenced anyway. It is not safe to offset by + // one, since the allocation may be empty. + ptr_re + } else { + // In the nonempty case, we can safely offset into the first + // (complex) element. + unsafe { ptr_re.add(1) } + }; + + // `Complex` is `repr(C)` with only fields `re: T` and `im: T`. So, the + // real components of the elements start at the same pointer, and the + // imaginary components start at the pointer offset by one, with + // exactly double the strides. The new, doubled strides still meet the + // overflow constraints: + // + // - For the zero-sized element case, the strides are unchanged in + // units of bytes and in units of the element type. + // + // - For the nonzero-sized element case: + // + // - In units of bytes, the strides are unchanged. The only exception + // is axes of length <= 1, but those strides are irrelevant anyway. + // + // - Since `Complex` for nonzero `T` is always at least 2 bytes, + // and the original strides did not overflow in units of bytes, we + // know that the new, doubled strides will not overflow in units of + // `T`. + unsafe { + Complex { + re: RawArrayView::new_(ptr_re, dim.clone(), strides.clone()), + im: RawArrayView::new_(ptr_im, dim, strides), + } + } + } +} + impl RawArrayViewMut where D: Dimension, @@ -300,3 +368,20 @@ where unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) } } } + +impl RawArrayViewMut, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + pub fn split_re_im(self) -> Complex> { + let Complex { re, im } = self.into_raw_view().split_re_im(); + unsafe { + Complex { + re: RawArrayViewMut::new(re.ptr, re.dim, re.strides), + im: RawArrayViewMut::new(im.ptr, im.dim, im.strides), + } + } + } +} diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index a36ae4ddb..5ea554d8a 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -8,6 +8,7 @@ use crate::imp_prelude::*; use crate::slice::MultiSliceArg; +use num_complex::Complex; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -95,6 +96,37 @@ where } } +impl<'a, T, D> ArrayView<'a, Complex, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + /// + /// ``` + /// use ndarray::prelude::*; + /// use num_complex::{Complex, Complex64}; + /// + /// let arr = array![ + /// [Complex64::new(1., 2.), Complex64::new(3., 4.)], + /// [Complex64::new(5., 6.), Complex64::new(7., 8.)], + /// [Complex64::new(9., 10.), Complex64::new(11., 12.)], + /// ]; + /// let Complex { re, im } = arr.view().split_re_im(); + /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); + /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); + /// ``` + pub fn split_re_im(self) -> Complex> { + unsafe { + let Complex { re, im } = self.into_raw_view().split_re_im(); + Complex { + re: re.deref_into_view(), + im: im.deref_into_view(), + } + } + } +} + /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> where @@ -135,3 +167,41 @@ where info.multi_slice_move(self) } } + +impl<'a, T, D> ArrayViewMut<'a, Complex, D> +where + D: Dimension, +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + /// + /// ``` + /// use ndarray::prelude::*; + /// use num_complex::{Complex, Complex64}; + /// + /// let mut arr = array![ + /// [Complex64::new(1., 2.), Complex64::new(3., 4.)], + /// [Complex64::new(5., 6.), Complex64::new(7., 8.)], + /// [Complex64::new(9., 10.), Complex64::new(11., 12.)], + /// ]; + /// + /// let Complex { mut re, mut im } = arr.view_mut().split_re_im(); + /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); + /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); + /// + /// re[[0, 1]] = 13.; + /// im[[2, 0]] = 14.; + /// + /// assert_eq!(arr[[0, 1]], Complex64::new(13., 4.)); + /// assert_eq!(arr[[2, 0]], Complex64::new(9., 14.)); + /// ``` + pub fn split_re_im(self) -> Complex> { + unsafe { + let Complex { re, im } = self.into_raw_view_mut().split_re_im(); + Complex { + re: re.deref_into_view_mut(), + im: im.deref_into_view_mut(), + } + } + } +}