Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Framework/Core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ o2_add_library(Framework
src/InputSpan.cxx
src/InputSpec.cxx
src/OutputSpec.cxx
src/Kernels.cxx
src/LifetimeHelpers.cxx
src/LocalRootFileService.cxx
src/RootConfigParamHelpers.cxx
Expand Down
17 changes: 3 additions & 14 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
#include "Framework/Expressions.h"
#include "Framework/ArrowTypes.h"
#include "Framework/RuntimeError.h"
#include "Framework/Kernels.h"
#include <arrow/table.h>
#include <arrow/array.h>
#include <arrow/util/variant.h>
#include <arrow/compute/kernel.h>
#include <arrow/compute/api_aggregate.h>
#include <gandiva/selection_vector.h>
#include <cassert>
#include <fmt/format.h>
Expand Down Expand Up @@ -936,14 +935,12 @@ auto select(T const& t, framework::expressions::Filter const& f)
return Filtered<T>({t.asArrowTable()}, selectionToVector(framework::expressions::createSelection(t.asArrowTable(), f)));
}

arrow::Status getSliceFor(int value, char const* key, std::shared_ptr<arrow::Table> const& input, std::shared_ptr<arrow::Table>& output, uint64_t& offset);

template <typename T>
auto sliceBy(T const& t, framework::expressions::BindingNode const& node, int value)
{
uint64_t offset = 0;
std::shared_ptr<arrow::Table> result = nullptr;
auto status = getSliceFor(value, node.name.c_str(), t.asArrowTable(), result, offset);
auto status = o2::framework::getSliceFor(value, node.name.c_str(), t.asArrowTable(), result, offset);
if (status.ok()) {
return T({result}, offset);
}
Expand Down Expand Up @@ -1260,15 +1257,7 @@ class Table
arrow::Status initializeSliceCaches(char const* key)
{
mCurrentKey = key;
arrow::Datum value_counts;
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
ARROW_ASSIGN_OR_RAISE(value_counts,
arrow::compute::CallFunction("value_counts", {mTable->GetColumnByName(key)},
&options));
auto pair = static_cast<arrow::StructArray>(value_counts.array());
mValues = std::make_shared<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
mCounts = std::make_shared<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
return arrow::Status::OK();
return o2::framework::getSlices(key, mTable, mValues, mCounts);
}

public:
Expand Down
1 change: 0 additions & 1 deletion Framework/Core/include/Framework/ASoAHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define O2_FRAMEWORK_ASOAHELPERS_H_

#include "Framework/ASoA.h"
#include "Framework/Kernels.h"
#include "Framework/RuntimeError.h"
#include <arrow/table.h>

Expand Down
1 change: 0 additions & 1 deletion Framework/Core/include/Framework/AnalysisManagers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define FRAMEWORK_ANALYSISMANAGERS_H
#include "Framework/AnalysisHelpers.h"
#include "Framework/GroupedCombinations.h"
#include "Framework/Kernels.h"
#include "Framework/ASoA.h"
#include "Framework/ProcessingContext.h"
#include "Framework/EndOfStreamContext.h"
Expand Down
171 changes: 25 additions & 146 deletions Framework/Core/include/Framework/Kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,168 +12,47 @@
#ifndef O2_FRAMEWORK_KERNELS_H_
#define O2_FRAMEWORK_KERNELS_H_

#include "Framework/BasicOps.h"
#include "Framework/TableBuilder.h"

#include <arrow/compute/kernel.h>
#include <arrow/status.h>
#include <arrow/util/visibility.h>
#include <arrow/util/variant.h>
#include <arrow/util/config.h>

#include <string>
#include <arrow/table.h>

namespace o2::framework
{
using ListVector = std::vector<std::vector<int64_t>>;
template <typename T>
auto sliceByColumnGeneric(
/// Slice a given table uncheked, filling slice caches
arrow::Status getSlices(
const char* key,
std::shared_ptr<arrow::Table> const& input,
std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>& values,
std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>& counts);

/// Slice a given table unchecked
arrow::Status getSliceFor(
int value,
char const* key,
std::shared_ptr<arrow::Table> const& input,
std::shared_ptr<arrow::Table>& output,
uint64_t& offset);

/// Slice a given table checked, for grouping association
void sliceByColumnGeneric(
char const* key,
char const* target,
std::shared_ptr<arrow::Table> const& input,
T fullSize,
int32_t fullSize,
ListVector* groups,
ListVector* unassigned = nullptr)
{
groups->resize(fullSize);
auto column = input->GetColumnByName(key);
int64_t row = 0;
for (auto iChunk = 0; iChunk < column->num_chunks(); ++iChunk) {
auto chunk = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(column->chunk(iChunk)->data());
for (auto iElement = 0; iElement < chunk.length(); ++iElement) {
auto v = chunk.Value(iElement);
if (v >= 0) {
if (v >= groups->size()) {
throw runtime_error_f("Table %s has an entry with index (%d) that is larger than the grouping table size (%d)", target, v, fullSize);
}
(*groups)[v].push_back(row);
} else if (unassigned != nullptr) {
auto av = std::abs(v);
if (unassigned->size() < av + 1) {
unassigned->resize(av + 1);
}
(*unassigned)[av].push_back(row);
}
++row;
}
}
}
ListVector* unassigned = nullptr);

/// Slice a given table in a vector of tables each containing a slice.
/// @a slices the arrow tables in which the original @a input
/// is split into.
/// @a offset the offset in the original table at which the corresponding
/// slice was split.
template <typename T>
auto sliceByColumn(
/// Slice a given table checked, fast, for grouping association assuming
/// the index is properly sorted
arrow::Status sliceByColumn(
char const* key,
char const* target,
std::shared_ptr<arrow::Table> const& input,
T fullSize,
int32_t fullSize,
std::vector<arrow::Datum>* slices,
std::vector<uint64_t>* offsets = nullptr,
std::vector<int>* sizes = nullptr,
std::vector<arrow::Datum>* unassignedSlices = nullptr,
std::vector<uint64_t>* unassignedOffsets = nullptr)
{
arrow::Datum value_counts;
auto column = input->GetColumnByName(key);
for (auto i = 0; i < column->num_chunks(); ++i) {
T prev = 0;
T cur = 0;
T lastNeg = 0;
T lastPos = 0;

auto array = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(column->chunk(i)->data());
for (auto e = 0; e < array.length(); ++e) {
prev = cur;
if (prev >= 0) {
lastPos = prev;
} else {
lastNeg = prev;
}
cur = array.Value(e);
if (cur >= 0) {
if (lastPos > cur) {
throw runtime_error_f("Table %s index %s is not sorted: next value %d < previous value %d!", target, key, cur, lastPos);
} else if (lastPos == cur && prev < 0) {
throw runtime_error_f("Table %s index %s has a group with index %d that is split by %d", target, key, cur, prev);
}
} else {
if (lastNeg < cur) {
throw runtime_error_f("Table %s index %s is not sorted: next negative value %d > previous negative value %d!", target, key, cur, lastNeg);
} else if (lastNeg == cur && prev >= 0) {
throw runtime_error_f("Table %s index %s has a group with index %d that is split by %d", target, key, cur, prev);
}
}
}
}
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
ARROW_ASSIGN_OR_RAISE(value_counts,
arrow::compute::CallFunction("value_counts", {column},
&options));
auto pair = static_cast<arrow::StructArray>(value_counts.array());
auto values = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(pair.field(0)->data());
auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());

// create slices and offsets
uint64_t offset = 0;
auto count = 0;
auto size = values.length();

auto makeSlice = [&](uint64_t offset_, T count_) {
slices->emplace_back(arrow::Datum{input->Slice(offset_, count_)});
if (offsets) {
offsets->emplace_back(offset_);
}
if (sizes) {
sizes->emplace_back(count_);
}
};

auto makeUnassignedSlice = [&](uint64_t offset_, T count_) {
if (unassignedSlices) {
unassignedSlices->emplace_back(arrow::Datum{input->Slice(offset_, count_)});
}
if (unassignedOffsets) {
unassignedOffsets->emplace_back(offset_);
}
};

auto v = 0;
auto vprev = v;
auto nzeros = 0;

for (auto i = 0; i < size; ++i) {
count = counts.Value(i);
if (v >= 0) {
vprev = v;
}
v = values.Value(i);
if (v < 0) {
makeUnassignedSlice(offset, count);
offset += count;
continue;
}
nzeros = v - vprev - ((i == 0 || slices->empty() == true) ? 0 : 1);
for (auto z = 0; z < nzeros; ++z) {
makeSlice(offset, 0);
}
makeSlice(offset, count);
offset += count;
}
v = values.Value(size - 1);
if (v >= 0) {
vprev = v;
}
if (vprev < fullSize - 1) {
for (auto v = vprev + 1; v < fullSize; ++v) {
makeSlice(offset, 0);
}
}

return arrow::Status::OK();
}
std::vector<uint64_t>* unassignedOffsets = nullptr);
} // namespace o2::framework

#endif // O2_FRAMEWORK_KERNELS_H_
23 changes: 1 addition & 22 deletions Framework/Core/src/ASoA.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// or submit itself to any jurisdiction.

#include "Framework/ASoA.h"
#include "Framework/Kernels.h"
#include "ArrowDebugHelpers.h"
#include "Framework/RuntimeError.h"
#include <arrow/util/key_value_metadata.h>
Expand Down Expand Up @@ -109,26 +110,4 @@ arrow::ChunkedArray* getIndexFromLabel(arrow::Table* table, const char* label)
return table->column(index[0]).get();
}

arrow::Status getSliceFor(int value, char const* key, std::shared_ptr<arrow::Table> const& input, std::shared_ptr<arrow::Table>& output, uint64_t& offset)
{
arrow::Datum value_counts;
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
ARROW_ASSIGN_OR_RAISE(value_counts,
arrow::compute::CallFunction("value_counts", {input->GetColumnByName(key)},
&options));
auto pair = static_cast<arrow::StructArray>(value_counts.array());
auto values = static_cast<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());

for (auto slice = 0; slice < values.length(); ++slice) {
if (values.Value(slice) == value) {
output = input->Slice(offset, counts.Value(slice));
return arrow::Status::OK();
}
offset += counts.Value(slice);
}
output = input->Slice(offset, 0);
return arrow::Status::OK();
}

} // namespace o2::soa
Loading