From fe3baa6dfd5c18bfcfac0a048265c4f4340738df Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 23 Oct 2021 15:21:15 -0400 Subject: [PATCH 1/2] Fix blas usage with vector of stride <= 0 Stride == 0 is unsuppored for vector increments; Stride < 0 would be supported but the code needs to be adapted to pass the right pointer for this case (lowest in memory pointer). Co-authored-by: bluss --- src/linalg/impl_linalg.rs | 19 +++++++++++++++---- xtest-blas/tests/oper.rs | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 61b91eaed..d851f2f08 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -8,6 +8,8 @@ use crate::imp_prelude::*; use crate::numeric_util; +#[cfg(feature = "blas")] +use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; use crate::{LinalgScalar, Zip}; @@ -649,6 +651,12 @@ unsafe fn general_mat_vec_mul_impl( } }; + // Low addr in memory pointers required for x, y + let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); + let x_ptr = x.ptr.as_ptr().sub(x_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); + let y_ptr = y.ptr.as_ptr().sub(y_offset); + let x_stride = x.strides()[0] as blas_index; let y_stride = y.strides()[0] as blas_index; @@ -660,10 +668,10 @@ unsafe fn general_mat_vec_mul_impl( cast_as(&alpha), // alpha a.ptr.as_ptr() as *const _, // a a_stride, // lda - x.ptr.as_ptr() as *const _, // x + x_ptr as *const _, // x x_stride, - cast_as(&beta), // beta - y.ptr.as_ptr() as *mut _, // x + cast_as(&beta), // beta + y_ptr as *mut _, // y y_stride, ); return; @@ -719,7 +727,10 @@ where return false; } let stride = a.strides()[0]; - if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize { + if stride == 0 + || stride > blas_index::max_value() as isize + || stride < blas_index::min_value() as isize + { return false; } true diff --git a/xtest-blas/tests/oper.rs b/xtest-blas/tests/oper.rs index 0aeb47680..95ab4ebb2 100644 --- a/xtest-blas/tests/oper.rs +++ b/xtest-blas/tests/oper.rs @@ -21,6 +21,25 @@ fn mat_vec_product_1d() { assert_eq!(a.t().dot(&b), ans); } +#[test] +fn mat_vec_product_1d_broadcast() { + let a = arr2(&[[1.], [2.], [3.]]); + let b = arr1(&[1.]); + let b = b.broadcast(3).unwrap(); + let ans = arr1(&[6.]); + assert_eq!(a.t().dot(&b), ans); +} + +#[test] +fn mat_vec_product_1d_inverted_axis() { + let a = arr2(&[[1.], [2.], [3.]]); + let mut b = arr1(&[1., 2., 3.]); + b.invert_axis(Axis(0)); + + let ans = arr1(&[3. + 4. + 3.]); + assert_eq!(a.t().dot(&b), ans); +} + fn range_mat(m: Ix, n: Ix) -> Array2 { Array::linspace(0., (m * n) as f32 - 1., m * n) .into_shape((m, n)) From b954a380c1f5d0fdfff02ee8b45e06cf6f27859d Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 24 Oct 2021 14:03:51 +0200 Subject: [PATCH 2/2] MAINT: Silence clippy lint about if_then_panic (style issue) --- src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 63e6316c5..1448b89a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,8 @@ clippy::manual_map, // is not an error clippy::while_let_on_iterator, // is not an error clippy::from_iter_instead_of_collect, // using from_iter is good style + clippy::if_then_panic, // is not an error + clippy::redundant_closure, // false positives clippy #7812 )] #![doc(test(attr(deny(warnings))))] #![doc(test(attr(allow(unused_variables))))]