|
12 | 12 | #ifndef O2_FRAMEWORK_KERNELS_H_ |
13 | 13 | #define O2_FRAMEWORK_KERNELS_H_ |
14 | 14 |
|
15 | | -#include "Framework/BasicOps.h" |
16 | | -#include "Framework/TableBuilder.h" |
17 | | - |
18 | | -#include <arrow/compute/kernel.h> |
19 | | -#include <arrow/status.h> |
20 | | -#include <arrow/util/visibility.h> |
21 | | -#include <arrow/util/variant.h> |
22 | | -#include <arrow/util/config.h> |
23 | | - |
24 | | -#include <string> |
| 15 | +#include <arrow/table.h> |
25 | 16 |
|
26 | 17 | namespace o2::framework |
27 | 18 | { |
28 | 19 | using ListVector = std::vector<std::vector<int64_t>>; |
29 | | -template <typename T> |
30 | | -auto sliceByColumnGeneric( |
| 20 | +/// Slice a given table uncheked, filling slice caches |
| 21 | +arrow::Status getSlices( |
| 22 | + const char* key, |
| 23 | + std::shared_ptr<arrow::Table> const& input, |
| 24 | + std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>& values, |
| 25 | + std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>& counts); |
| 26 | + |
| 27 | +/// Slice a given table unchecked |
| 28 | +arrow::Status getSliceFor( |
| 29 | + int value, |
| 30 | + char const* key, |
| 31 | + std::shared_ptr<arrow::Table> const& input, |
| 32 | + std::shared_ptr<arrow::Table>& output, |
| 33 | + uint64_t& offset); |
| 34 | + |
| 35 | +/// Slice a given table checked, for grouping association |
| 36 | +void sliceByColumnGeneric( |
31 | 37 | char const* key, |
32 | 38 | char const* target, |
33 | 39 | std::shared_ptr<arrow::Table> const& input, |
34 | | - T fullSize, |
| 40 | + int32_t fullSize, |
35 | 41 | ListVector* groups, |
36 | | - ListVector* unassigned = nullptr) |
37 | | -{ |
38 | | - groups->resize(fullSize); |
39 | | - auto column = input->GetColumnByName(key); |
40 | | - int64_t row = 0; |
41 | | - for (auto iChunk = 0; iChunk < column->num_chunks(); ++iChunk) { |
42 | | - auto chunk = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(column->chunk(iChunk)->data()); |
43 | | - for (auto iElement = 0; iElement < chunk.length(); ++iElement) { |
44 | | - auto v = chunk.Value(iElement); |
45 | | - if (v >= 0) { |
46 | | - if (v >= groups->size()) { |
47 | | - throw runtime_error_f("Table %s has an entry with index (%d) that is larger than the grouping table size (%d)", target, v, fullSize); |
48 | | - } |
49 | | - (*groups)[v].push_back(row); |
50 | | - } else if (unassigned != nullptr) { |
51 | | - auto av = std::abs(v); |
52 | | - if (unassigned->size() < av + 1) { |
53 | | - unassigned->resize(av + 1); |
54 | | - } |
55 | | - (*unassigned)[av].push_back(row); |
56 | | - } |
57 | | - ++row; |
58 | | - } |
59 | | - } |
60 | | -} |
| 42 | + ListVector* unassigned = nullptr); |
61 | 43 |
|
62 | | -/// Slice a given table in a vector of tables each containing a slice. |
63 | | -/// @a slices the arrow tables in which the original @a input |
64 | | -/// is split into. |
65 | | -/// @a offset the offset in the original table at which the corresponding |
66 | | -/// slice was split. |
67 | | -template <typename T> |
68 | | -auto sliceByColumn( |
| 44 | +/// Slice a given table checked, fast, for grouping association assuming |
| 45 | +/// the index is properly sorted |
| 46 | +arrow::Status sliceByColumn( |
69 | 47 | char const* key, |
70 | 48 | char const* target, |
71 | 49 | std::shared_ptr<arrow::Table> const& input, |
72 | | - T fullSize, |
| 50 | + int32_t fullSize, |
73 | 51 | std::vector<arrow::Datum>* slices, |
74 | 52 | std::vector<uint64_t>* offsets = nullptr, |
75 | 53 | std::vector<int>* sizes = nullptr, |
76 | 54 | std::vector<arrow::Datum>* unassignedSlices = nullptr, |
77 | | - std::vector<uint64_t>* unassignedOffsets = nullptr) |
78 | | -{ |
79 | | - arrow::Datum value_counts; |
80 | | - auto column = input->GetColumnByName(key); |
81 | | - for (auto i = 0; i < column->num_chunks(); ++i) { |
82 | | - T prev = 0; |
83 | | - T cur = 0; |
84 | | - T lastNeg = 0; |
85 | | - T lastPos = 0; |
86 | | - |
87 | | - auto array = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(column->chunk(i)->data()); |
88 | | - for (auto e = 0; e < array.length(); ++e) { |
89 | | - prev = cur; |
90 | | - if (prev >= 0) { |
91 | | - lastPos = prev; |
92 | | - } else { |
93 | | - lastNeg = prev; |
94 | | - } |
95 | | - cur = array.Value(e); |
96 | | - if (cur >= 0) { |
97 | | - if (lastPos > cur) { |
98 | | - throw runtime_error_f("Table %s index %s is not sorted: next value %d < previous value %d!", target, key, cur, lastPos); |
99 | | - } else if (lastPos == cur && prev < 0) { |
100 | | - throw runtime_error_f("Table %s index %s has a group with index %d that is split by %d", target, key, cur, prev); |
101 | | - } |
102 | | - } else { |
103 | | - if (lastNeg < cur) { |
104 | | - throw runtime_error_f("Table %s index %s is not sorted: next negative value %d > previous negative value %d!", target, key, cur, lastNeg); |
105 | | - } else if (lastNeg == cur && prev >= 0) { |
106 | | - throw runtime_error_f("Table %s index %s has a group with index %d that is split by %d", target, key, cur, prev); |
107 | | - } |
108 | | - } |
109 | | - } |
110 | | - } |
111 | | - auto options = arrow::compute::ScalarAggregateOptions::Defaults(); |
112 | | - ARROW_ASSIGN_OR_RAISE(value_counts, |
113 | | - arrow::compute::CallFunction("value_counts", {column}, |
114 | | - &options)); |
115 | | - auto pair = static_cast<arrow::StructArray>(value_counts.array()); |
116 | | - auto values = static_cast<arrow::NumericArray<typename detail::ConversionTraits<T>::ArrowType>>(pair.field(0)->data()); |
117 | | - auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data()); |
118 | | - |
119 | | - // create slices and offsets |
120 | | - uint64_t offset = 0; |
121 | | - auto count = 0; |
122 | | - auto size = values.length(); |
123 | | - |
124 | | - auto makeSlice = [&](uint64_t offset_, T count_) { |
125 | | - slices->emplace_back(arrow::Datum{input->Slice(offset_, count_)}); |
126 | | - if (offsets) { |
127 | | - offsets->emplace_back(offset_); |
128 | | - } |
129 | | - if (sizes) { |
130 | | - sizes->emplace_back(count_); |
131 | | - } |
132 | | - }; |
133 | | - |
134 | | - auto makeUnassignedSlice = [&](uint64_t offset_, T count_) { |
135 | | - if (unassignedSlices) { |
136 | | - unassignedSlices->emplace_back(arrow::Datum{input->Slice(offset_, count_)}); |
137 | | - } |
138 | | - if (unassignedOffsets) { |
139 | | - unassignedOffsets->emplace_back(offset_); |
140 | | - } |
141 | | - }; |
142 | | - |
143 | | - auto v = 0; |
144 | | - auto vprev = v; |
145 | | - auto nzeros = 0; |
146 | | - |
147 | | - for (auto i = 0; i < size; ++i) { |
148 | | - count = counts.Value(i); |
149 | | - if (v >= 0) { |
150 | | - vprev = v; |
151 | | - } |
152 | | - v = values.Value(i); |
153 | | - if (v < 0) { |
154 | | - makeUnassignedSlice(offset, count); |
155 | | - offset += count; |
156 | | - continue; |
157 | | - } |
158 | | - nzeros = v - vprev - ((i == 0 || slices->empty() == true) ? 0 : 1); |
159 | | - for (auto z = 0; z < nzeros; ++z) { |
160 | | - makeSlice(offset, 0); |
161 | | - } |
162 | | - makeSlice(offset, count); |
163 | | - offset += count; |
164 | | - } |
165 | | - v = values.Value(size - 1); |
166 | | - if (v >= 0) { |
167 | | - vprev = v; |
168 | | - } |
169 | | - if (vprev < fullSize - 1) { |
170 | | - for (auto v = vprev + 1; v < fullSize; ++v) { |
171 | | - makeSlice(offset, 0); |
172 | | - } |
173 | | - } |
174 | | - |
175 | | - return arrow::Status::OK(); |
176 | | -} |
| 55 | + std::vector<uint64_t>* unassignedOffsets = nullptr); |
177 | 56 | } // namespace o2::framework |
178 | 57 |
|
179 | 58 | #endif // O2_FRAMEWORK_KERNELS_H_ |
0 commit comments