diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index c0a7a310c..2a18752ec 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -21,6 +21,56 @@ _INDENTATION = " " +_UTILS_H_CONTENT = """#ifndef INFINI_OPS_BINDINGS_UTILS_H_ +#define INFINI_OPS_BINDINGS_UTILS_H_ + +#include +#include + +namespace infini::ops { + +inline DataType DataTypeFromString(const std::string& name) { + return kStringToDataType.at(name); +} + +inline Device::Type DeviceTypeFromString(const std::string& name) { + static const std::unordered_map kTorchNameToTypes{ + {"cpu", Device::Type::kCpu}, +#ifdef WITH_NVIDIA + {"cuda", Device::Type::kNvidia}, +#endif +#ifdef WITH_METAX + {"cuda", Device::Type::kMetax}, +#endif +#ifdef WITH_ILUVATAR + {"cuda", Device::Type::kIluvatar}, +#endif +#ifdef WITH_KUNLUN + {"cuda", Device::Type::kKunlun}, +#endif +#ifdef WITH_HYGON + {"cuda", Device::Type::kHygon}, +#endif +#ifdef WITH_QY + {"cuda", Device::Type::kQy}, +#endif + {"mlu", Device::Type::kCambricon}, {"npu", Device::Type::kAscend}, + {"musa", Device::Type::kMoore}}; + + auto it{kTorchNameToTypes.find(name)}; + + if (it != kTorchNameToTypes.cend()) { + return it->second; + } + + return Device::TypeFromString(name); +} + +} // namespace infini::ops + +#endif +""" + class _OperatorExtractor: def __call__(self, op_name): @@ -132,51 +182,13 @@ def _generate_call(op_name, call, method=True): #include #include -#include - #include "base/{op_name.lower()}.h" +#include "utils.h" namespace py = pybind11; namespace infini::ops {{ -inline DataType DataTypeFromString(const std::string& name) {{ - return kStringToDataType.at(name); -}} - -inline Device::Type DeviceTypeFromString(const std::string& name) {{ - static const std::unordered_map kTorchNameToTypes{{ - {{"cpu", Device::Type::kCpu}}, -#ifdef WITH_NVIDIA - {{"cuda", Device::Type::kNvidia}}, -#endif -#ifdef WITH_METAX - {{"cuda", Device::Type::kMetax}}, -#endif -#ifdef WITH_ILUVATAR - {{"cuda", Device::Type::kIluvatar}}, -#endif -#ifdef WITH_KUNLUN - {{"cuda", Device::Type::kKunlun}}, -#endif -#ifdef WITH_HYGON - {{"cuda", Device::Type::kHygon}}, -#endif -#ifdef WITH_QY - {{"cuda", Device::Type::kQy}}, -#endif - {{"mlu", Device::Type::kCambricon}}, {{"npu", Device::Type::kAscend}}, - {{"musa", Device::Type::kMoore}}}}; - - auto it{{kTorchNameToTypes.find(name)}}; - - if (it != kTorchNameToTypes.cend()) {{ - return it->second; - }} - - return Device::TypeFromString(name); -}} - void Bind{op_name}(py::module& m) {{ using Self = {op_name}; @@ -415,6 +427,8 @@ def _get_all_ops(devices): header_paths = [] bind_func_names = [] + (_BINDINGS_DIR / "utils.h").write_text(_UTILS_H_CONTENT) + for op_name, impl_paths in ops.items(): extractor = _OperatorExtractor() operator = extractor(op_name) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 16059c700..2f1f5cc57 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,9 +16,8 @@ if(WITH_CPU) target_compile_definitions(infiniops PUBLIC WITH_CPU=1) - # Reserve for OpenMP. - # find_package(OpenMP REQUIRED) - # target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) + find_package(OpenMP REQUIRED) + target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) list(APPEND DEVICE_LIST "cpu") endif() @@ -59,6 +58,7 @@ if(WITH_METAX) set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX) target_compile_definitions(infiniops PRIVATE WITH_METAX=1) + target_compile_options(infiniops PUBLIC "-x" "maca") target_sources(infiniops PRIVATE ${METAX_SOURCES}) target_include_directories(infiniops PUBLIC "${MACA_PATH}/include") @@ -86,10 +86,17 @@ if(GENERATE_PYTHON_BINDINGS) message(STATUS "Generating wrappers - done") endif() + set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + + # TODO: There might be a better solution. + if(WITH_NVIDIA) + set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA) + endif() + find_package(Python COMPONENTS Interpreter Development) find_package(pybind11 CONFIG) - pybind11_add_module(ops "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + pybind11_add_module(ops ${PYBIND11_SOURCES}) target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(ops PRIVATE infiniops) diff --git a/src/base/add.h b/src/base/add.h new file mode 100644 index 000000000..a2da9ef51 --- /dev/null +++ b/src/base/add.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_BASE_ADD_H_ +#define INFINI_OPS_BASE_ADD_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Add : public Operator { + public: + Add(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "The output of `Add` should NOT have broadcasted dim!"); + // TODO(lzm): support mix-precision later using the generic elementwise + // framework. + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "Operator `Add` requires all input and output Tensors to have the " + "same dtype"); + } + + virtual void operator()(void* stream, const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h new file mode 100644 index 000000000..4d92e0082 --- /dev/null +++ b/src/common/cuda/kernel_commons.h @@ -0,0 +1,25 @@ +#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ +#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ + +#ifdef WITH_NVIDIA +#include +#elif WITH_METAX +#include +#endif + +namespace infini::ops { + +__forceinline__ __device__ __host__ size_t +indexToOffset(size_t flat_index, size_t ndim, const size_t *shape, + const ptrdiff_t *strides) { + size_t res = 0; + for (size_t i = ndim; i-- > 0;) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +} // namespace infini::ops + +#endif diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h new file mode 100644 index 000000000..6c82f4972 --- /dev/null +++ b/src/common/generic_utils.h @@ -0,0 +1,26 @@ +#ifndef INFINI_OPS_COMMON_GENERIC_UTILS_H_ +#define INFINI_OPS_COMMON_GENERIC_UTILS_H_ + +#include + +namespace infini::ops::utils { + +std::size_t indexToOffset(std::size_t flat_index, std::size_t ndim, + const std::size_t* shape, + const std::ptrdiff_t* strides) { + std::size_t res = 0; + for (std::size_t i = ndim; i-- > 0;) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +template +constexpr auto CeilDiv(const Tx& x, const Ty& y) { + return (x + y - 1) / y; +} + +} // namespace infini::ops::utils + +#endif diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h new file mode 100644 index 000000000..d9a456a4c --- /dev/null +++ b/src/cpu/add/add.h @@ -0,0 +1,55 @@ +#ifndef INFINI_OPS_CPU_ADD_ADD_H_ +#define INFINI_OPS_CPU_ADD_ADD_H_ + +#include + +#include "base/add.h" +#include "common/generic_utils.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + // TODO: Check constraints. + } + + void operator()(void* stream, const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc>( + out_type_, [&]() { compute(stream, input, other, out); }, + "Operator::operator()"); + } + + private: + template + void compute(void* stream, const Tensor input, const Tensor other, + Tensor out) const { + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::indexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = input_ptr[input_idx] + other_ptr[other_idx]; + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h new file mode 100644 index 000000000..664bcee51 --- /dev/null +++ b/src/cuda/add/kernel.h @@ -0,0 +1,136 @@ +#ifndef INFINI_OPS_CUDA_ADD_KERNEL_H_ +#define INFINI_OPS_CUDA_ADD_KERNEL_H_ + +#include + +#include "base/add.h" +#include "common/cuda/kernel_commons.h" +#include "common/generic_utils.h" + +namespace infini::ops { + +typedef struct AddOp { + public: + static constexpr std::size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T& input, + const T& other) const { + if constexpr (std::is_same_v) { + return __hadd2(input, other); + } else if constexpr (std::is_same_v || + std::is_same_v>) { + return __hadd(input, other); + } else if constexpr (std::is_same_v) { + return __fadd_rn(input, other); + } else { + return input + other; + } + } +} AddOp; + +template +__global__ void AddKernel( + T* out, const T* input, const T* other, const Tensor::Size* out_shape, + const Tensor::Size* input_shape, const Tensor::Size* other_shape, + const Tensor::Stride* out_strides, const Tensor::Stride* input_strides, + const Tensor::Stride* other_strides, size_t output_size, size_t ndim, + size_t offset, bool out_contiguous, bool input_contiguous, + bool other_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < output_size) { + Tensor::Size out_idx = + out_contiguous ? idx : indexToOffset(idx, ndim, out_shape, out_strides); + Tensor::Size input_idx = + input_contiguous ? idx + : indexToOffset(idx, ndim, input_shape, input_strides); + Tensor::Size other_idx = + other_contiguous ? idx + : indexToOffset(idx, ndim, other_shape, other_strides); + + out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); + } +} + +template +class CudaAdd : public Add { + public: + CudaAdd(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + size_t shape_size = ndim_ * sizeof(*d_input_shape_); + size_t strides_size = ndim_ * sizeof(*d_input_strides_); + + Backend::malloc((void**)&d_input_shape_, shape_size); + Backend::malloc((void**)&d_other_shape_, shape_size); + Backend::malloc((void**)&d_out_shape_, shape_size); + Backend::malloc((void**)&d_input_strides_, strides_size); + Backend::malloc((void**)&d_other_strides_, strides_size); + Backend::malloc((void**)&d_out_strides_, strides_size); + + Backend::memcpy(d_input_shape_, input_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_other_shape_, other_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_out_shape_, out_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_input_strides_, input_strides_.data(), strides_size, + Backend::memcpyH2D); + Backend::memcpy(d_other_strides_, other_strides_.data(), strides_size, + Backend::memcpyH2D); + Backend::memcpy(d_out_strides_, out_strides_.data(), strides_size, + Backend::memcpyH2D); + } + + ~CudaAdd() { + Backend::free(d_input_shape_); + Backend::free(d_other_shape_); + Backend::free(d_out_shape_); + Backend::free(d_input_strides_); + Backend::free(d_other_strides_); + Backend::free(d_out_strides_); + } + + void operator()(void* stream, const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&]() { + // TODO(lzm): currently hard-code block_size to be 256. + dim3 blockDims( + std::min(static_cast(256), output_size_)); + dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); + size_t step = gridDims.x * blockDims.x; + + T* d_out = reinterpret_cast(out.data()); + const T* d_input = reinterpret_cast(input.data()); + const T* d_other = reinterpret_cast(other.data()); + + for (size_t i = 0; i < output_size_; i += step) { + AddKernel<<(stream)>>>( + d_out, d_input, d_other, d_out_shape_, d_input_shape_, + d_other_shape_, d_out_strides_, d_input_strides_, + d_other_strides_, output_size_, ndim_, i, is_out_contiguous_, + is_input_contiguous_, is_other_contiguous_); + } + }, + "CudaAdd::operator()"); + } + + private: + Tensor::Size* d_input_shape_{nullptr}; + + Tensor::Size* d_other_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Stride* d_other_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/data_type.h b/src/data_type.h index b9e2feadf..567a07654 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -4,6 +4,14 @@ #include #include +#ifdef WITH_NVIDIA +#include +#include +#elif WITH_METAX +#include +#include +#endif + #include "common/constexpr_map.h" #include "common/traits.h" @@ -102,7 +110,27 @@ DEFINE_DATA_TYPE_MAPPING(kUInt64, uint64_t) DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) -// TODO(lzm): Support fp16 and bf16. + +#ifdef WITH_NVIDIA +DEFINE_DATA_TYPE_MAPPING(kFloat16, half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) +#elif WITH_METAX +DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) +#else +// TODO(lzm): currently there's an ambiguity of uint16_t mapping to both kUInt16 +// and kFloat16/kBFloat16 for CPU. When CPU custom bfloat16/float16 types are +// defined, this should be replaced. +template <> +struct TypeMap { + using type = uint16_t; +}; +template <> +struct TypeMap { + using type = uint16_t; +}; +#endif +#undef DEFINE_DATA_TYPE_MAPPING // Defines the common categories of data types using List. using FloatTypes = List; diff --git a/src/metax/add/kernel.h b/src/metax/add/kernel.h new file mode 100644 index 000000000..ce9ec0165 --- /dev/null +++ b/src/metax/add/kernel.h @@ -0,0 +1,38 @@ +#ifndef INFINI_OPS_METAX_ADD_KERNEL_H_ +#define INFINI_OPS_METAX_ADD_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/add/kernel.h" + +namespace infini::ops { + +namespace add { + +struct MetaxBackend { + using stream_t = mcStream_t; + + static constexpr auto malloc = mcMalloc; + + static constexpr auto memcpy = mcMemcpy; + + static constexpr auto free = mcFree; + + static constexpr auto memcpyH2D = mcMemcpyHostToDevice; +}; + +} // namespace add + +template <> +class Operator : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h index 10bef3b0f..1fd6f2f11 100644 --- a/src/metax/gemm/mcblas.h +++ b/src/metax/gemm/mcblas.h @@ -11,19 +11,28 @@ namespace infini::ops { +namespace gemm { + struct MetaxBackend { using blasHandle_t = mcblasHandle_t; + using stream_t = mcStream_t; static constexpr auto BLAS_OP_N = MCBLAS_OP_N; + static constexpr auto BLAS_OP_T = MCBLAS_OP_T; + static constexpr auto R_32F = MACA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = MCBLAS_COMPUTE_32F_FAST_TF32; + static constexpr auto BLAS_GEMM_DEFAULT = MCBLAS_GEMM_DEFAULT; static constexpr auto blasCreate = mcblasCreate; + static constexpr auto blasSetStream = mcblasSetStream; + static constexpr auto blasDestroy = mcblasDestroy; static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { @@ -31,10 +40,12 @@ struct MetaxBackend { }; }; +} // namespace gemm + template <> -class Operator : public Blas { +class Operator : public Blas { public: - using Blas::Blas; + using Blas::Blas; }; } // namespace infini::ops diff --git a/src/nvidia/add/kernel.h b/src/nvidia/add/kernel.h new file mode 100644 index 000000000..7e6c3e57f --- /dev/null +++ b/src/nvidia/add/kernel.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_NVIDIA_ADD_KERNEL_H_ +#define INFINI_OPS_NVIDIA_ADD_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/add/kernel.h" + +namespace infini::ops { + +namespace add { + +struct NvidiaBackend { + using stream_t = cudaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = cudaMemcpy; + + static constexpr auto free = cudaFree; + + static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; +}; + +} // namespace add + +template <> +class Operator + : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index d4e4b78e2..16c1b7ac9 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -11,19 +11,28 @@ namespace infini::ops { +namespace gemm { + struct NvidiaBackend { using blasHandle_t = cublasHandle_t; + using stream_t = cudaStream_t; static constexpr auto BLAS_OP_N = CUBLAS_OP_N; + static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + static constexpr auto R_32F = CUDA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F_FAST_TF32; + static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT; static constexpr auto blasCreate = cublasCreate; + static constexpr auto blasSetStream = cublasSetStream; + static constexpr auto blasDestroy = cublasDestroy; static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { @@ -31,10 +40,12 @@ struct NvidiaBackend { }; }; +} // namespace gemm + template <> -class Operator : public Blas { +class Operator : public Blas { public: - using Blas::Blas; + using Blas::Blas; }; } // namespace infini::ops diff --git a/src/tensor.cc b/src/tensor.cc index 5203ac8b5..8746c60e3 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -1,6 +1,8 @@ #include "tensor.h" +#include #include +#include #include "dispatcher.h" @@ -49,6 +51,12 @@ Tensor::Size Tensor::ndim() const { return shape_.size(); } Tensor::Size Tensor::element_size() const { return kDataTypeToSize.at(dtype_); } +Tensor::Size Tensor::numel() const { + return std::accumulate(shape_.begin(), shape_.end(), + static_cast(1), + [](Tensor::Size a, Tensor::Size b) { return a * b; }); +} + Tensor Tensor::T() const { return {data_, {shape_[1], shape_[0]}, @@ -63,6 +71,25 @@ std::string Tensor::ToString() const { device_.ToString() + "')"; } +bool Tensor::HasBroadcastDim() const { + return std::any_of(shape_.begin(), shape_.end(), + [&, i = 0](const auto&) mutable { + return shape_[i] != 1 && strides_[i++] == 0; + }); +} + +bool Tensor::IsContiguous() const { + if (ndim() == 0) { + return true; + } + + if (!IsMergeable(0, ndim() - 1)) { + return false; + } + + return stride(ndim() - 1) == 1; +} + const DataType Tensor::DefaultDataType() { return DataType::kFloat32; } Device Tensor::DefaultDevice() { return Device{Device::Type::kCpu}; } @@ -85,10 +112,10 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { std::string Tensor::ToStringHelper() const { if (ndim() == 0) { - return DispatchFunc( + return DispatchFunc>( dtype_, [&]() { return std::to_string(*static_cast(data_)); }, - "ToStringHelper"); + "Tensor::ToStringHelper()"); } std::string result{"["}; @@ -103,4 +130,21 @@ std::string Tensor::ToStringHelper() const { return result; } +bool Tensor::IsMergeable(Tensor::Size dim_start, Tensor::Size dim_end) const { + if (dim_start == dim_end) { + return true; + } + + for (Tensor::Size i = dim_start; i < dim_end; ++i) { + if (size(i) == 1 && stride(i) == 0) { + return false; + } + if (stride(i) != size(i + 1) * stride(i + 1)) { + return false; + } + } + + return true; +} + } // namespace infini::ops diff --git a/src/tensor.h b/src/tensor.h index 0feb4c437..39d4f98d5 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -12,9 +12,9 @@ namespace infini::ops { class Tensor { public: - using Size = std::uint64_t; + using Size = std::size_t; - using Stride = std::int64_t; + using Stride = std::ptrdiff_t; using Index = Stride; @@ -89,10 +89,16 @@ class Tensor { Size element_size() const; + Size numel() const; + Tensor T() const; std::string ToString() const; + bool HasBroadcastDim() const; + + bool IsContiguous() const; + private: static const DataType DefaultDataType(); @@ -102,6 +108,8 @@ class Tensor { std::string ToStringHelper() const; + bool IsMergeable(Size dim_start, Size dim_end) const; + void* data_{nullptr}; Shape shape_; diff --git a/tests/test_add.py b/tests/test_add.py new file mode 100644 index 000000000..7a1351f6d --- /dev/null +++ b/tests/test_add.py @@ -0,0 +1,45 @@ +import infini.ops +import pytest +import torch + +from tests.utils import empty_strided, get_available_devices, randn_strided + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "dtype, rtol, atol", + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-3, 1e-3), + ), +) +@pytest.mark.parametrize( + "shape, a_strides, b_strides, c_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +def test_add(shape, a_strides, b_strides, c_strides, dtype, device, rtol, atol): + a = randn_strided(shape, a_strides, dtype=dtype, device=device) + b = randn_strided(shape, b_strides, dtype=dtype, device=device) + + output = empty_strided(shape, c_strides, dtype=dtype, device=device) + expected = output.clone() + + # TODO: Add keyword argument support. + infini.ops.add(a, b, output) + torch.add(a, b, out=expected) + + assert torch.allclose(output, expected, rtol=rtol, atol=atol) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index a169dcb8d..da97f531b 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import empty_strided, get_available_devices +from tests.utils import empty_strided, get_available_devices, randn_strided @pytest.mark.parametrize("device", get_available_devices()) @@ -38,8 +38,8 @@ def test_gemm( rtol, atol, ): - a = empty_strided(a_shape, a_strides, dtype=dtype, device=device) - b = empty_strided(b_shape, b_strides, dtype=dtype, device=device) + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) + b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) if trans_a: a = a.transpose(-2, -1) @@ -50,9 +50,6 @@ def test_gemm( output = empty_strided(c_shape, c_strides, dtype=dtype, device=device) expected = output.clone() - a.normal_() - b.normal_() - # TODO: Add keyword argument support. infini.ops.gemm(a, b, alpha, beta, trans_a, trans_b, output) _torch_gemm( diff --git a/tests/utils.py b/tests/utils.py index 3a6ca17e3..c9b800628 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,3 +24,13 @@ def empty_strided(shape, strides, *, dtype=None, device=None): return torch.empty(shape, dtype=dtype, device=device) return torch.empty_strided(shape, strides, dtype=dtype, device=device) + + +def randn_strided(shape, strides, *, dtype=None, device=None): + output = empty_strided(shape, strides, dtype=dtype, device=device) + + output.as_strided( + (output.untyped_storage().size() // output.element_size(),), (1,) + ).normal_() + + return output