Skip to content

Commit ff21c1d

Browse files
authored
DPL Analysis: disentangle slicing kernels (#8005)
* DPL Analysis: disentangle slicing kernels * Remove unnecessary templates * Move code to .cxx * Consolidate usage of arrow value_counts kernel * Remove unnecessary includes * fix erroneous unassigned sortedness check
1 parent 1974652 commit ff21c1d

8 files changed

Lines changed: 240 additions & 186 deletions

File tree

Framework/Core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ o2_add_library(Framework
8080
src/InputSpan.cxx
8181
src/InputSpec.cxx
8282
src/OutputSpec.cxx
83+
src/Kernels.cxx
8384
src/LifetimeHelpers.cxx
8485
src/LocalRootFileService.cxx
8586
src/RootConfigParamHelpers.cxx

Framework/Core/include/Framework/ASoA.h

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
#include "Framework/Expressions.h"
2121
#include "Framework/ArrowTypes.h"
2222
#include "Framework/RuntimeError.h"
23+
#include "Framework/Kernels.h"
2324
#include <arrow/table.h>
2425
#include <arrow/array.h>
2526
#include <arrow/util/variant.h>
26-
#include <arrow/compute/kernel.h>
27-
#include <arrow/compute/api_aggregate.h>
2827
#include <gandiva/selection_vector.h>
2928
#include <cassert>
3029
#include <fmt/format.h>
@@ -936,14 +935,12 @@ auto select(T const& t, framework::expressions::Filter const& f)
936935
return Filtered<T>({t.asArrowTable()}, selectionToVector(framework::expressions::createSelection(t.asArrowTable(), f)));
937936
}
938937

939-
arrow::Status getSliceFor(int value, char const* key, std::shared_ptr<arrow::Table> const& input, std::shared_ptr<arrow::Table>& output, uint64_t& offset);
940-
941938
template <typename T>
942939
auto sliceBy(T const& t, framework::expressions::BindingNode const& node, int value)
943940
{
944941
uint64_t offset = 0;
945942
std::shared_ptr<arrow::Table> result = nullptr;
946-
auto status = getSliceFor(value, node.name.c_str(), t.asArrowTable(), result, offset);
943+
auto status = o2::framework::getSliceFor(value, node.name.c_str(), t.asArrowTable(), result, offset);
947944
if (status.ok()) {
948945
return T({result}, offset);
949946
}
@@ -1260,15 +1257,7 @@ class Table
12601257
arrow::Status initializeSliceCaches(char const* key)
12611258
{
12621259
mCurrentKey = key;
1263-
arrow::Datum value_counts;
1264-
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
1265-
ARROW_ASSIGN_OR_RAISE(value_counts,
1266-
arrow::compute::CallFunction("value_counts", {mTable->GetColumnByName(key)},
1267-
&options));
1268-
auto pair = static_cast<arrow::StructArray>(value_counts.array());
1269-
mValues = std::make_shared<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
1270-
mCounts = std::make_shared<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
1271-
return arrow::Status::OK();
1260+
return o2::framework::getSlices(key, mTable, mValues, mCounts);
12721261
}
12731262

12741263
public:

Framework/Core/include/Framework/ASoAHelpers.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#define O2_FRAMEWORK_ASOAHELPERS_H_
1414

1515
#include "Framework/ASoA.h"
16-
#include "Framework/Kernels.h"
1716
#include "Framework/RuntimeError.h"
1817
#include <arrow/table.h>
1918

Framework/Core/include/Framework/AnalysisManagers.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#define FRAMEWORK_ANALYSISMANAGERS_H
1414
#include "Framework/AnalysisHelpers.h"
1515
#include "Framework/GroupedCombinations.h"
16-
#include "Framework/Kernels.h"
1716
#include "Framework/ASoA.h"
1817
#include "Framework/ProcessingContext.h"
1918
#include "Framework/EndOfStreamContext.h"

Framework/Core/include/Framework/Kernels.h

Lines changed: 25 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -12,168 +12,47 @@
1212
#ifndef O2_FRAMEWORK_KERNELS_H_
1313
#define O2_FRAMEWORK_KERNELS_H_
1414

15-
#include "Framework/BasicOps.h"
16-
#include "Framework/TableBuilder.h"
17-
18-
#include <arrow/compute/kernel.h>
19-
#include <arrow/status.h>
20-
#include <arrow/util/visibility.h>
21-
#include <arrow/util/variant.h>
22-
#include <arrow/util/config.h>
23-
24-
#include <string>
15+
#include <arrow/table.h>
2516

2617
namespace o2::framework
2718
{
2819
using ListVector = std::vector<std::vector<int64_t>>;
29-
template <typename T>
30-
auto sliceByColumnGeneric(
20+
/// Slice a given table uncheked, filling slice caches
21+
arrow::Status getSlices(
22+
const char* key,
23+
std::shared_ptr<arrow::Table> const& input,
24+
std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>& values,
25+
std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>& counts);
26+
27+
/// Slice a given table unchecked
28+
arrow::Status getSliceFor(
29+
int value,
30+
char const* key,
31+
std::shared_ptr<arrow::Table> const& input,
32+
std::shared_ptr<arrow::Table>& output,
33+
uint64_t& offset);
34+
35+
/// Slice a given table checked, for grouping association
36+
void sliceByColumnGeneric(
3137
char const* key,
3238
char const* target,
3339
std::shared_ptr<arrow::Table> const& input,
34-
T fullSize,
40+
int32_t fullSize,
3541
ListVector* groups,
36-
ListVector* unassigned = nullptr)
37-
{
38-
groups->resize(fullSize);
39-
auto column = input->GetColumnByName(key);
40-
int64_t row = 0;
41-
for (auto iChunk = 0; iChunk < column->num_chunks(); ++iChunk) {
42-
auto chunk = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(column->chunk(iChunk)->data());
43-
for (auto iElement = 0; iElement < chunk.length(); ++iElement) {
44-
auto v = chunk.Value(iElement);
45-
if (v >= 0) {
46-
if (v >= groups->size()) {
47-
throw runtime_error_f("Table %s has an entry with index (%d) that is larger than the grouping table size (%d)", target, v, fullSize);
48-
}
49-
(*groups)[v].push_back(row);
50-
} else if (unassigned != nullptr) {
51-
auto av = std::abs(v);
52-
if (unassigned->size() < av + 1) {
53-
unassigned->resize(av + 1);
54-
}
55-
(*unassigned)[av].push_back(row);
56-
}
57-
++row;
58-
}
59-
}
60-
}
42+
ListVector* unassigned = nullptr);
6143

62-
/// Slice a given table in a vector of tables each containing a slice.
63-
/// @a slices the arrow tables in which the original @a input
64-
/// is split into.
65-
/// @a offset the offset in the original table at which the corresponding
66-
/// slice was split.
67-
template <typename T>
68-
auto sliceByColumn(
44+
/// Slice a given table checked, fast, for grouping association assuming
45+
/// the index is properly sorted
46+
arrow::Status sliceByColumn(
6947
char const* key,
7048
char const* target,
7149
std::shared_ptr<arrow::Table> const& input,
72-
T fullSize,
50+
int32_t fullSize,
7351
std::vector<arrow::Datum>* slices,
7452
std::vector<uint64_t>* offsets = nullptr,
7553
std::vector<int>* sizes = nullptr,
7654
std::vector<arrow::Datum>* unassignedSlices = nullptr,
77-
std::vector<uint64_t>* unassignedOffsets = nullptr)
78-
{
79-
arrow::Datum value_counts;
80-
auto column = input->GetColumnByName(key);
81-
for (auto i = 0; i < column->num_chunks(); ++i) {
82-
T prev = 0;
83-
T cur = 0;
84-
T lastNeg = 0;
85-
T lastPos = 0;
86-
87-
auto array = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(column->chunk(i)->data());
88-
for (auto e = 0; e < array.length(); ++e) {
89-
prev = cur;
90-
if (prev >= 0) {
91-
lastPos = prev;
92-
} else {
93-
lastNeg = prev;
94-
}
95-
cur = array.Value(e);
96-
if (cur >= 0) {
97-
if (lastPos > cur) {
98-
throw runtime_error_f("Table %s index %s is not sorted: next value %d < previous value %d!", target, key, cur, lastPos);
99-
} else if (lastPos == cur && prev < 0) {
100-
throw runtime_error_f("Table %s index %s has a group with index %d that is split by %d", target, key, cur, prev);
101-
}
102-
} else {
103-
if (lastNeg < cur) {
104-
throw runtime_error_f("Table %s index %s is not sorted: next negative value %d > previous negative value %d!", target, key, cur, lastNeg);
105-
} else if (lastNeg == cur && prev >= 0) {
106-
throw runtime_error_f("Table %s index %s has a group with index %d that is split by %d", target, key, cur, prev);
107-
}
108-
}
109-
}
110-
}
111-
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
112-
ARROW_ASSIGN_OR_RAISE(value_counts,
113-
arrow::compute::CallFunction("value_counts", {column},
114-
&options));
115-
auto pair = static_cast<arrow::StructArray>(value_counts.array());
116-
auto values = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(pair.field(0)->data());
117-
auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
118-
119-
// create slices and offsets
120-
uint64_t offset = 0;
121-
auto count = 0;
122-
auto size = values.length();
123-
124-
auto makeSlice = [&](uint64_t offset_, T count_) {
125-
slices->emplace_back(arrow::Datum{input->Slice(offset_, count_)});
126-
if (offsets) {
127-
offsets->emplace_back(offset_);
128-
}
129-
if (sizes) {
130-
sizes->emplace_back(count_);
131-
}
132-
};
133-
134-
auto makeUnassignedSlice = [&](uint64_t offset_, T count_) {
135-
if (unassignedSlices) {
136-
unassignedSlices->emplace_back(arrow::Datum{input->Slice(offset_, count_)});
137-
}
138-
if (unassignedOffsets) {
139-
unassignedOffsets->emplace_back(offset_);
140-
}
141-
};
142-
143-
auto v = 0;
144-
auto vprev = v;
145-
auto nzeros = 0;
146-
147-
for (auto i = 0; i < size; ++i) {
148-
count = counts.Value(i);
149-
if (v >= 0) {
150-
vprev = v;
151-
}
152-
v = values.Value(i);
153-
if (v < 0) {
154-
makeUnassignedSlice(offset, count);
155-
offset += count;
156-
continue;
157-
}
158-
nzeros = v - vprev - ((i == 0 || slices->empty() == true) ? 0 : 1);
159-
for (auto z = 0; z < nzeros; ++z) {
160-
makeSlice(offset, 0);
161-
}
162-
makeSlice(offset, count);
163-
offset += count;
164-
}
165-
v = values.Value(size - 1);
166-
if (v >= 0) {
167-
vprev = v;
168-
}
169-
if (vprev < fullSize - 1) {
170-
for (auto v = vprev + 1; v < fullSize; ++v) {
171-
makeSlice(offset, 0);
172-
}
173-
}
174-
175-
return arrow::Status::OK();
176-
}
55+
std::vector<uint64_t>* unassignedOffsets = nullptr);
17756
} // namespace o2::framework
17857

17958
#endif // O2_FRAMEWORK_KERNELS_H_

Framework/Core/src/ASoA.cxx

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// or submit itself to any jurisdiction.
1111

1212
#include "Framework/ASoA.h"
13+
#include "Framework/Kernels.h"
1314
#include "ArrowDebugHelpers.h"
1415
#include "Framework/RuntimeError.h"
1516
#include <arrow/util/key_value_metadata.h>
@@ -109,26 +110,4 @@ arrow::ChunkedArray* getIndexFromLabel(arrow::Table* table, const char* label)
109110
return table->column(index[0]).get();
110111
}
111112

112-
arrow::Status getSliceFor(int value, char const* key, std::shared_ptr<arrow::Table> const& input, std::shared_ptr<arrow::Table>& output, uint64_t& offset)
113-
{
114-
arrow::Datum value_counts;
115-
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
116-
ARROW_ASSIGN_OR_RAISE(value_counts,
117-
arrow::compute::CallFunction("value_counts", {input->GetColumnByName(key)},
118-
&options));
119-
auto pair = static_cast<arrow::StructArray>(value_counts.array());
120-
auto values = static_cast<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
121-
auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
122-
123-
for (auto slice = 0; slice < values.length(); ++slice) {
124-
if (values.Value(slice) == value) {
125-
output = input->Slice(offset, counts.Value(slice));
126-
return arrow::Status::OK();
127-
}
128-
offset += counts.Value(slice);
129-
}
130-
output = input->Slice(offset, 0);
131-
return arrow::Status::OK();
132-
}
133-
134113
} // namespace o2::soa

0 commit comments

Comments
 (0)