Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 53 additions & 39 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,56 @@

_INDENTATION = " "

_UTILS_H_CONTENT = """#ifndef INFINI_OPS_BINDINGS_UTILS_H_
#define INFINI_OPS_BINDINGS_UTILS_H_

#include <string>
#include <unordered_map>

namespace infini::ops {

inline DataType DataTypeFromString(const std::string& name) {
return kStringToDataType.at(name);
}

inline Device::Type DeviceTypeFromString(const std::string& name) {
static const std::unordered_map<std::string, Device::Type> kTorchNameToTypes{
{"cpu", Device::Type::kCpu},
#ifdef WITH_NVIDIA
{"cuda", Device::Type::kNvidia},
#endif
#ifdef WITH_METAX
{"cuda", Device::Type::kMetax},
#endif
#ifdef WITH_ILUVATAR
{"cuda", Device::Type::kIluvatar},
#endif
#ifdef WITH_KUNLUN
{"cuda", Device::Type::kKunlun},
#endif
#ifdef WITH_HYGON
{"cuda", Device::Type::kHygon},
#endif
#ifdef WITH_QY
{"cuda", Device::Type::kQy},
#endif
{"mlu", Device::Type::kCambricon}, {"npu", Device::Type::kAscend},
{"musa", Device::Type::kMoore}};

auto it{kTorchNameToTypes.find(name)};

if (it != kTorchNameToTypes.cend()) {
return it->second;
}

return Device::TypeFromString(name);
}

} // namespace infini::ops

#endif
"""


class _OperatorExtractor:
def __call__(self, op_name):
Expand Down Expand Up @@ -132,51 +182,13 @@ def _generate_call(op_name, call, method=True):
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <unordered_map>

#include "base/{op_name.lower()}.h"
#include "utils.h"

namespace py = pybind11;

namespace infini::ops {{

inline DataType DataTypeFromString(const std::string& name) {{
return kStringToDataType.at(name);
}}

inline Device::Type DeviceTypeFromString(const std::string& name) {{
static const std::unordered_map<std::string, Device::Type> kTorchNameToTypes{{
{{"cpu", Device::Type::kCpu}},
#ifdef WITH_NVIDIA
{{"cuda", Device::Type::kNvidia}},
#endif
#ifdef WITH_METAX
{{"cuda", Device::Type::kMetax}},
#endif
#ifdef WITH_ILUVATAR
{{"cuda", Device::Type::kIluvatar}},
#endif
#ifdef WITH_KUNLUN
{{"cuda", Device::Type::kKunlun}},
#endif
#ifdef WITH_HYGON
{{"cuda", Device::Type::kHygon}},
#endif
#ifdef WITH_QY
{{"cuda", Device::Type::kQy}},
#endif
{{"mlu", Device::Type::kCambricon}}, {{"npu", Device::Type::kAscend}},
{{"musa", Device::Type::kMoore}}}};

auto it{{kTorchNameToTypes.find(name)}};

if (it != kTorchNameToTypes.cend()) {{
return it->second;
}}

return Device::TypeFromString(name);
}}

void Bind{op_name}(py::module& m) {{
using Self = {op_name};

Expand Down Expand Up @@ -415,6 +427,8 @@ def _get_all_ops(devices):
header_paths = []
bind_func_names = []

(_BINDINGS_DIR / "utils.h").write_text(_UTILS_H_CONTENT)

for op_name, impl_paths in ops.items():
extractor = _OperatorExtractor()
operator = extractor(op_name)
Expand Down
15 changes: 11 additions & 4 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ if(WITH_CPU)

target_compile_definitions(infiniops PUBLIC WITH_CPU=1)

# Reserve for OpenMP.
# find_package(OpenMP REQUIRED)
# target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX)
find_package(OpenMP REQUIRED)
target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX)

list(APPEND DEVICE_LIST "cpu")
endif()
Expand Down Expand Up @@ -59,6 +58,7 @@ if(WITH_METAX)
set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX)

target_compile_definitions(infiniops PRIVATE WITH_METAX=1)
target_compile_options(infiniops PUBLIC "-x" "maca")
target_sources(infiniops PRIVATE ${METAX_SOURCES})

target_include_directories(infiniops PUBLIC "${MACA_PATH}/include")
Expand Down Expand Up @@ -86,10 +86,17 @@ if(GENERATE_PYTHON_BINDINGS)
message(STATUS "Generating wrappers - done")
endif()

set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")

# TODO: There might be a better solution.
if(WITH_NVIDIA)
set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA)
endif()

find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)

pybind11_add_module(ops "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")
pybind11_add_module(ops ${PYBIND11_SOURCES})

target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR})
target_link_libraries(ops PRIVATE infiniops)
Expand Down
71 changes: 71 additions & 0 deletions src/base/add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef INFINI_OPS_BASE_ADD_H_
#define INFINI_OPS_BASE_ADD_H_

#include <optional>

#include "operator.h"

namespace infini::ops {

class Add : public Operator<Add> {
public:
Add(const Tensor input, const Tensor other, Tensor out)
: ndim_{out.ndim()},
output_size_{out.numel()},
input_type_{input.dtype()},
other_type_{other.dtype()},
out_type_{out.dtype()},
input_shape_{input.shape()},
other_shape_{other.shape()},
out_shape_{out.shape()},
input_strides_{input.strides()},
other_strides_{other.strides()},
out_strides_{out.strides()},
is_input_contiguous_{input.IsContiguous()},
is_other_contiguous_{other.IsContiguous()},
is_out_contiguous_{out.IsContiguous()} {
assert(!out.HasBroadcastDim() &&
"The output of `Add` should NOT have broadcasted dim!");
// TODO(lzm): support mix-precision later using the generic elementwise
// framework.
assert(input_type_ == other_type_ && other_type_ == out_type_ &&
"Operator `Add` requires all input and output Tensors to have the "
"same dtype");
}

virtual void operator()(void* stream, const Tensor input, const Tensor other,
Tensor out) const = 0;

protected:
Tensor::Size ndim_{0};

Tensor::Size output_size_{0};

const DataType input_type_;

const DataType other_type_;

const DataType out_type_;

Tensor::Shape input_shape_;

Tensor::Shape other_shape_;

Tensor::Shape out_shape_;

Tensor::Strides input_strides_;

Tensor::Strides other_strides_;

Tensor::Strides out_strides_;

bool is_input_contiguous_{false};

bool is_other_contiguous_{false};

bool is_out_contiguous_{false};
};

} // namespace infini::ops

#endif
25 changes: 25 additions & 0 deletions src/common/cuda/kernel_commons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_

#ifdef WITH_NVIDIA
#include <cuda_runtime.h>
#elif WITH_METAX
#include <mcr/mc_runtime.h>
#endif

namespace infini::ops {

__forceinline__ __device__ __host__ size_t
indexToOffset(size_t flat_index, size_t ndim, const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}

} // namespace infini::ops

#endif
26 changes: 26 additions & 0 deletions src/common/generic_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef INFINI_OPS_COMMON_GENERIC_UTILS_H_
#define INFINI_OPS_COMMON_GENERIC_UTILS_H_

#include <cstddef>

namespace infini::ops::utils {

std::size_t indexToOffset(std::size_t flat_index, std::size_t ndim,
const std::size_t* shape,
const std::ptrdiff_t* strides) {
std::size_t res = 0;
for (std::size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}

template <typename Tx, typename Ty>
constexpr auto CeilDiv(const Tx& x, const Ty& y) {
return (x + y - 1) / y;
}

} // namespace infini::ops::utils

#endif
55 changes: 55 additions & 0 deletions src/cpu/add/add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef INFINI_OPS_CPU_ADD_ADD_H_
#define INFINI_OPS_CPU_ADD_ADD_H_

#include <utility>

#include "base/add.h"
#include "common/generic_utils.h"

namespace infini::ops {

template <>
class Operator<Add, Device::Type::kCpu> : public Add {
public:
Operator(const Tensor input, const Tensor other, Tensor out)
: Add{input, other, out} {
// TODO: Check constraints.
}

void operator()(void* stream, const Tensor input, const Tensor other,
Tensor out) const override {
DispatchFunc<ConcatType<FloatTypes, AllIntTypes>>(
out_type_, [&]<typename T>() { compute<T>(stream, input, other, out); },
"Operator<Add, Device::Type::kCpu>::operator()");
}

private:
template <typename T>
void compute(void* stream, const Tensor input, const Tensor other,
Tensor out) const {
const auto* input_ptr = static_cast<const T*>(input.data());
const auto* other_ptr = static_cast<const T*>(other.data());
auto* out_ptr = static_cast<T*>(out.data());

auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape,
const auto* strides) {
return is_contig ? i : utils::indexToOffset(i, ndim_, shape, strides);
};

#pragma omp parallel for
for (Tensor::Size i = 0; i < output_size_; ++i) {
auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(),
input_strides_.data());
auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(),
other_strides_.data());
auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(),
out_strides_.data());

out_ptr[out_idx] = input_ptr[input_idx] + other_ptr[other_idx];
}
}
};

} // namespace infini::ops

#endif
Loading