Skip to content

Commit 2ea3421

Browse files
authored
DPL Analysis: fall back to per-entry reading of VLA branches with >616 entries (#7912)
1 parent fae14b8 commit 2ea3421

1 file changed

Lines changed: 195 additions & 59 deletions

File tree

Framework/Core/src/TableTreeHelpers.cxx

Lines changed: 195 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ std::pair<std::shared_ptr<arrow::ChunkedArray>, std::shared_ptr<arrow::Field>> B
196196
while (readEntries < totalEntries) {
197197
auto readLast = mBranch->GetBulkRead().GetBulkEntries(readEntries, *buffer);
198198
readEntries += readLast;
199-
status &= static_cast<arrow::BooleanBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(buffer->GetCurrent()), readLast * mListSize);
199+
status &= static_cast<arrow::BooleanBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(buffer->GetCurrent()), (int64_t)readLast * (int64_t)mListSize);
200200
}
201201
if (mListSize > 1) {
202202
status &= static_cast<arrow::FixedSizeListBuilder*>(mListBuilder.get())->AppendValues(readEntries);
@@ -214,73 +214,209 @@ std::pair<std::shared_ptr<arrow::ChunkedArray>, std::shared_ptr<arrow::Field>> B
214214
}
215215
} else {
216216
// other types: use serialized read to build arrays directly
217-
auto&& result = arrow::AllocateResizableBuffer(mBranch->GetTotBytes(), mPool);
218-
if (!result.ok()) {
219-
throw runtime_error("Cannot allocate values buffer");
220-
}
221-
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
222-
auto ptr = arrowValuesBuffer->mutable_data();
223-
if (ptr == nullptr) {
224-
throw runtime_error("Invalid buffer");
225-
}
217+
if (mVLA && totalEntries > 616) {
218+
// special case workaround
219+
auto status = arrow::MakeBuilder(mPool, mArrowType->field(0)->type(), &mBuilder);
220+
if (!status.ok()) {
221+
throw runtime_error("Failed to create value builder");
222+
}
223+
mListBuilder = std::make_unique<arrow::ListBuilder>(mPool, std::move(mBuilder));
224+
mValueBuilder = static_cast<arrow::ListBuilder*>(mListBuilder.get())->value_builder();
225+
void* ptr = nullptr;
226+
227+
switch (mType) {
228+
case EDataType::kUChar_t:
229+
ptr = new uint8_t[255];
230+
break;
231+
case EDataType::kUShort_t:
232+
ptr = new uint16_t[255];
233+
break;
234+
case EDataType::kUInt_t:
235+
ptr = new uint32_t[255];
236+
break;
237+
case EDataType::kULong64_t:
238+
ptr = new uint64_t[255];
239+
break;
240+
case EDataType::kChar_t:
241+
ptr = new int8_t[255];
242+
break;
243+
case EDataType::kShort_t:
244+
ptr = new int16_t[255];
245+
break;
246+
case EDataType::kInt_t:
247+
ptr = new int32_t[255];
248+
break;
249+
case EDataType::kLong64_t:
250+
ptr = new int64_t[255];
251+
break;
252+
case EDataType::kFloat_t:
253+
ptr = new float[255];
254+
break;
255+
case EDataType::kDouble_t:
256+
ptr = new double[255];
257+
break;
258+
default:
259+
throw runtime_error("Unsupported branch type");
260+
}
226261

227-
auto typeSize = TDataType::GetDataType(mType)->Size();
228-
std::unique_ptr<TBufferFile> offsetBuffer;
229-
230-
uint32_t offset = 0;
231-
uint32_t lastOffset;
232-
int count = 0;
233-
std::shared_ptr<arrow::Buffer> arrowOffsetBuffer;
234-
gsl::span<int> offsets;
235-
int size = 0;
236-
uint32_t totalSize = 0;
237-
if (mVLA) {
238-
offsetBuffer.reset(new TBufferFile{TBuffer::EMode::kWrite, 4 * 1024 * 1024});
239-
result = arrow::AllocateResizableBuffer((totalEntries + 1) * sizeof(int), mPool);
262+
int sz;
263+
auto* mSizeBranch = mBranch->GetTree()->GetBranch((std::string{mBranch->GetName()} + TableTreeHelpers::sizeBranchSuffix).c_str());
264+
mSizeBranch->SetAddress(&sz);
265+
mBranch->SetAddress(ptr);
266+
std::vector<int> offsets;
267+
268+
offsets.push_back(0);
269+
for (auto entry = 0; entry < totalEntries; ++entry) {
270+
mBranch->GetEntry(entry);
271+
mSizeBranch->GetEntry(entry);
272+
offsets.push_back(sz + offsets.back());
273+
arrow::Status status;
274+
switch (mType) {
275+
case EDataType::kUChar_t:
276+
status = static_cast<arrow::UInt8Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(ptr), sz);
277+
break;
278+
case EDataType::kUShort_t:
279+
status = static_cast<arrow::UInt16Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint16_t const*>(ptr), sz);
280+
break;
281+
case EDataType::kUInt_t:
282+
status = static_cast<arrow::UInt32Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint32_t const*>(ptr), sz);
283+
break;
284+
case EDataType::kULong64_t:
285+
status = static_cast<arrow::UInt64Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint64_t const*>(ptr), sz);
286+
break;
287+
case EDataType::kChar_t:
288+
status = static_cast<arrow::Int8Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int8_t const*>(ptr), sz);
289+
break;
290+
case EDataType::kShort_t:
291+
status = static_cast<arrow::Int16Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int16_t const*>(ptr), sz);
292+
break;
293+
case EDataType::kInt_t:
294+
status = static_cast<arrow::Int32Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int32_t const*>(ptr), sz);
295+
break;
296+
case EDataType::kLong64_t:
297+
status = static_cast<arrow::Int64Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int64_t const*>(ptr), sz);
298+
break;
299+
case EDataType::kFloat_t:
300+
status = static_cast<arrow::FloatBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<float const*>(ptr), sz);
301+
break;
302+
case EDataType::kDouble_t:
303+
status = static_cast<arrow::DoubleBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<double const*>(ptr), sz);
304+
break;
305+
default:
306+
throw runtime_error("Unsupported branch type");
307+
}
308+
}
309+
status &= static_cast<arrow::ListBuilder*>(mListBuilder.get())->AppendValues(offsets.data(), totalEntries);
310+
status &= static_cast<arrow::ListBuilder*>(mListBuilder.get())->Finish(&array);
311+
312+
mSizeBranch->SetStatus(false);
313+
mSizeBranch->DropBaskets("all");
314+
mSizeBranch->Reset();
315+
mSizeBranch->GetTransientBuffer(0)->Expand(0);
316+
317+
switch (mType) {
318+
case EDataType::kUChar_t:
319+
delete[] static_cast<uint8_t*>(ptr);
320+
break;
321+
case EDataType::kUShort_t:
322+
delete[] static_cast<uint16_t*>(ptr);
323+
break;
324+
case EDataType::kUInt_t:
325+
delete[] static_cast<uint32_t*>(ptr);
326+
break;
327+
case EDataType::kULong64_t:
328+
delete[] static_cast<uint64_t*>(ptr);
329+
break;
330+
case EDataType::kChar_t:
331+
delete[] static_cast<int8_t*>(ptr);
332+
break;
333+
case EDataType::kShort_t:
334+
delete[] static_cast<int16_t*>(ptr);
335+
break;
336+
case EDataType::kInt_t:
337+
delete[] static_cast<int32_t*>(ptr);
338+
break;
339+
case EDataType::kLong64_t:
340+
delete[] static_cast<int64_t*>(ptr);
341+
break;
342+
case EDataType::kFloat_t:
343+
delete[] static_cast<float*>(ptr);
344+
break;
345+
case EDataType::kDouble_t:
346+
delete[] static_cast<double*>(ptr);
347+
break;
348+
default:
349+
throw runtime_error("Unsupported branch type");
350+
}
351+
} else {
352+
auto&& result = arrow::AllocateResizableBuffer(mBranch->GetTotBytes(), mPool);
240353
if (!result.ok()) {
241-
throw runtime_error("Cannot allocate offset buffer");
354+
throw runtime_error("Cannot allocate values buffer");
355+
}
356+
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
357+
auto ptr = arrowValuesBuffer->mutable_data();
358+
if (ptr == nullptr) {
359+
throw runtime_error("Invalid buffer");
242360
}
243-
arrowOffsetBuffer = std::move(result).ValueUnsafe();
244-
unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data();
245-
auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
246-
offsets = gsl::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
247-
}
248361

249-
while (readEntries < totalEntries) {
250-
auto readLast = mBranch->GetBulkRead().GetEntriesSerialized(readEntries, *buffer, offsetBuffer.get());
251-
readEntries += readLast;
362+
auto typeSize = TDataType::GetDataType(mType)->Size();
363+
std::unique_ptr<TBufferFile> offsetBuffer;
252364

365+
uint32_t offset = 0;
366+
uint32_t lastOffset;
367+
int count = 0;
368+
std::shared_ptr<arrow::Buffer> arrowOffsetBuffer;
369+
gsl::span<int> offsets;
370+
int size = 0;
371+
uint32_t totalSize = 0;
253372
if (mVLA) {
254-
lastOffset = offset;
255-
for (auto i = 0; i < readLast; ++i) {
256-
offsets[count++] = (int)offset;
257-
offset += swap32_(reinterpret_cast<uint32_t*>(offsetBuffer->GetCurrent())[i]);
373+
offsetBuffer = std::make_unique<TBufferFile>(TBuffer::EMode::kWrite, 4 * 1024 * 1024);
374+
result = arrow::AllocateResizableBuffer((int64_t)(sizeof(int) * (totalEntries + 1)), mPool);
375+
if (!result.ok()) {
376+
throw runtime_error("Cannot allocate offset buffer");
377+
}
378+
arrowOffsetBuffer = std::move(result).ValueUnsafe();
379+
unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data();
380+
auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
381+
offsets = gsl::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
382+
}
383+
384+
while (readEntries < totalEntries) {
385+
auto readLast = mBranch->GetBulkRead().GetEntriesSerialized(readEntries, *buffer, offsetBuffer.get());
386+
readEntries += readLast;
387+
388+
if (mVLA) {
389+
lastOffset = offset;
390+
for (auto i = 0; i < readLast; ++i) {
391+
offsets[count++] = (int)offset;
392+
offset += swap32_(reinterpret_cast<uint32_t*>(offsetBuffer->GetCurrent())[i]);
393+
}
394+
size = (int)(offset - lastOffset);
395+
} else {
396+
size = readLast * mListSize;
258397
}
259-
size = offset - lastOffset;
398+
swapCopy(ptr, buffer->GetCurrent(), size, typeSize);
399+
ptr += (ptrdiff_t)(size * typeSize);
400+
}
401+
if (mVLA) {
402+
offsets[count] = (int)offset;
403+
totalSize = offset;
260404
} else {
261-
size = readLast * mListSize;
405+
totalSize = readEntries * mListSize;
406+
}
407+
std::shared_ptr<arrow::PrimitiveArray> varray;
408+
switch (mListSize) {
409+
case -1:
410+
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
411+
array = std::make_shared<arrow::ListArray>(mArrowType, readEntries, arrowOffsetBuffer, varray);
412+
break;
413+
case 1:
414+
array = std::make_shared<arrow::PrimitiveArray>(mArrowType, readEntries, arrowValuesBuffer);
415+
break;
416+
default:
417+
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
418+
array = std::make_shared<arrow::FixedSizeListArray>(mArrowType, readEntries, varray);
262419
}
263-
swapCopy(ptr, buffer->GetCurrent(), size, typeSize);
264-
ptr += size * typeSize;
265-
}
266-
if (mVLA) {
267-
offsets[count] = offset;
268-
totalSize = offset;
269-
} else {
270-
totalSize = readEntries * mListSize;
271-
}
272-
std::shared_ptr<arrow::PrimitiveArray> varray;
273-
switch (mListSize) {
274-
case -1:
275-
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
276-
array = std::make_shared<arrow::ListArray>(mArrowType, readEntries, arrowOffsetBuffer, varray);
277-
break;
278-
case 1:
279-
array = std::make_shared<arrow::PrimitiveArray>(mArrowType, readEntries, arrowValuesBuffer);
280-
break;
281-
default:
282-
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
283-
array = std::make_shared<arrow::FixedSizeListArray>(mArrowType, readEntries, varray);
284420
}
285421
}
286422

0 commit comments

Comments
 (0)