Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 195 additions & 59 deletions Framework/Core/src/TableTreeHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ std::pair<std::shared_ptr<arrow::ChunkedArray>, std::shared_ptr<arrow::Field>> B
while (readEntries < totalEntries) {
auto readLast = mBranch->GetBulkRead().GetBulkEntries(readEntries, *buffer);
readEntries += readLast;
status &= static_cast<arrow::BooleanBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(buffer->GetCurrent()), readLast * mListSize);
status &= static_cast<arrow::BooleanBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(buffer->GetCurrent()), (int64_t)readLast * (int64_t)mListSize);
}
if (mListSize > 1) {
status &= static_cast<arrow::FixedSizeListBuilder*>(mListBuilder.get())->AppendValues(readEntries);
Expand All @@ -214,73 +214,209 @@ std::pair<std::shared_ptr<arrow::ChunkedArray>, std::shared_ptr<arrow::Field>> B
}
} else {
// other types: use serialized read to build arrays directly
auto&& result = arrow::AllocateResizableBuffer(mBranch->GetTotBytes(), mPool);
if (!result.ok()) {
throw runtime_error("Cannot allocate values buffer");
}
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
auto ptr = arrowValuesBuffer->mutable_data();
if (ptr == nullptr) {
throw runtime_error("Invalid buffer");
}
if (mVLA && totalEntries > 616) {
// special case workaround
auto status = arrow::MakeBuilder(mPool, mArrowType->field(0)->type(), &mBuilder);
if (!status.ok()) {
throw runtime_error("Failed to create value builder");
}
mListBuilder = std::make_unique<arrow::ListBuilder>(mPool, std::move(mBuilder));
mValueBuilder = static_cast<arrow::ListBuilder*>(mListBuilder.get())->value_builder();
void* ptr = nullptr;

switch (mType) {
case EDataType::kUChar_t:
ptr = new uint8_t[255];
break;
case EDataType::kUShort_t:
ptr = new uint16_t[255];
break;
case EDataType::kUInt_t:
ptr = new uint32_t[255];
break;
case EDataType::kULong64_t:
ptr = new uint64_t[255];
break;
case EDataType::kChar_t:
ptr = new int8_t[255];
break;
case EDataType::kShort_t:
ptr = new int16_t[255];
break;
case EDataType::kInt_t:
ptr = new int32_t[255];
break;
case EDataType::kLong64_t:
ptr = new int64_t[255];
break;
case EDataType::kFloat_t:
ptr = new float[255];
break;
case EDataType::kDouble_t:
ptr = new double[255];
break;
default:
throw runtime_error("Unsupported branch type");
}

auto typeSize = TDataType::GetDataType(mType)->Size();
std::unique_ptr<TBufferFile> offsetBuffer;

uint32_t offset = 0;
uint32_t lastOffset;
int count = 0;
std::shared_ptr<arrow::Buffer> arrowOffsetBuffer;
gsl::span<int> offsets;
int size = 0;
uint32_t totalSize = 0;
if (mVLA) {
offsetBuffer.reset(new TBufferFile{TBuffer::EMode::kWrite, 4 * 1024 * 1024});
result = arrow::AllocateResizableBuffer((totalEntries + 1) * sizeof(int), mPool);
int sz;
auto* mSizeBranch = mBranch->GetTree()->GetBranch((std::string{mBranch->GetName()} + TableTreeHelpers::sizeBranchSuffix).c_str());
mSizeBranch->SetAddress(&sz);
mBranch->SetAddress(ptr);
std::vector<int> offsets;

offsets.push_back(0);
for (auto entry = 0; entry < totalEntries; ++entry) {
mBranch->GetEntry(entry);
mSizeBranch->GetEntry(entry);
offsets.push_back(sz + offsets.back());
arrow::Status status;
switch (mType) {
case EDataType::kUChar_t:
status = static_cast<arrow::UInt8Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint8_t const*>(ptr), sz);
break;
case EDataType::kUShort_t:
status = static_cast<arrow::UInt16Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint16_t const*>(ptr), sz);
break;
case EDataType::kUInt_t:
status = static_cast<arrow::UInt32Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint32_t const*>(ptr), sz);
break;
case EDataType::kULong64_t:
status = static_cast<arrow::UInt64Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<uint64_t const*>(ptr), sz);
break;
case EDataType::kChar_t:
status = static_cast<arrow::Int8Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int8_t const*>(ptr), sz);
break;
case EDataType::kShort_t:
status = static_cast<arrow::Int16Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int16_t const*>(ptr), sz);
break;
case EDataType::kInt_t:
status = static_cast<arrow::Int32Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int32_t const*>(ptr), sz);
break;
case EDataType::kLong64_t:
status = static_cast<arrow::Int64Builder*>(mValueBuilder)->AppendValues(reinterpret_cast<int64_t const*>(ptr), sz);
break;
case EDataType::kFloat_t:
status = static_cast<arrow::FloatBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<float const*>(ptr), sz);
break;
case EDataType::kDouble_t:
status = static_cast<arrow::DoubleBuilder*>(mValueBuilder)->AppendValues(reinterpret_cast<double const*>(ptr), sz);
break;
default:
throw runtime_error("Unsupported branch type");
}
}
status &= static_cast<arrow::ListBuilder*>(mListBuilder.get())->AppendValues(offsets.data(), totalEntries);
status &= static_cast<arrow::ListBuilder*>(mListBuilder.get())->Finish(&array);

mSizeBranch->SetStatus(false);
mSizeBranch->DropBaskets("all");
mSizeBranch->Reset();
mSizeBranch->GetTransientBuffer(0)->Expand(0);

switch (mType) {
case EDataType::kUChar_t:
delete[] static_cast<uint8_t*>(ptr);
break;
case EDataType::kUShort_t:
delete[] static_cast<uint16_t*>(ptr);
break;
case EDataType::kUInt_t:
delete[] static_cast<uint32_t*>(ptr);
break;
case EDataType::kULong64_t:
delete[] static_cast<uint64_t*>(ptr);
break;
case EDataType::kChar_t:
delete[] static_cast<int8_t*>(ptr);
break;
case EDataType::kShort_t:
delete[] static_cast<int16_t*>(ptr);
break;
case EDataType::kInt_t:
delete[] static_cast<int32_t*>(ptr);
break;
case EDataType::kLong64_t:
delete[] static_cast<int64_t*>(ptr);
break;
case EDataType::kFloat_t:
delete[] static_cast<float*>(ptr);
break;
case EDataType::kDouble_t:
delete[] static_cast<double*>(ptr);
break;
default:
throw runtime_error("Unsupported branch type");
}
} else {
auto&& result = arrow::AllocateResizableBuffer(mBranch->GetTotBytes(), mPool);
if (!result.ok()) {
throw runtime_error("Cannot allocate offset buffer");
throw runtime_error("Cannot allocate values buffer");
}
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
auto ptr = arrowValuesBuffer->mutable_data();
if (ptr == nullptr) {
throw runtime_error("Invalid buffer");
}
arrowOffsetBuffer = std::move(result).ValueUnsafe();
unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data();
auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
offsets = gsl::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
}

while (readEntries < totalEntries) {
auto readLast = mBranch->GetBulkRead().GetEntriesSerialized(readEntries, *buffer, offsetBuffer.get());
readEntries += readLast;
auto typeSize = TDataType::GetDataType(mType)->Size();
std::unique_ptr<TBufferFile> offsetBuffer;

uint32_t offset = 0;
uint32_t lastOffset;
int count = 0;
std::shared_ptr<arrow::Buffer> arrowOffsetBuffer;
gsl::span<int> offsets;
int size = 0;
uint32_t totalSize = 0;
if (mVLA) {
lastOffset = offset;
for (auto i = 0; i < readLast; ++i) {
offsets[count++] = (int)offset;
offset += swap32_(reinterpret_cast<uint32_t*>(offsetBuffer->GetCurrent())[i]);
offsetBuffer = std::make_unique<TBufferFile>(TBuffer::EMode::kWrite, 4 * 1024 * 1024);
result = arrow::AllocateResizableBuffer((int64_t)(sizeof(int) * (totalEntries + 1)), mPool);
if (!result.ok()) {
throw runtime_error("Cannot allocate offset buffer");
}
arrowOffsetBuffer = std::move(result).ValueUnsafe();
unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data();
auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
offsets = gsl::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
}

while (readEntries < totalEntries) {
auto readLast = mBranch->GetBulkRead().GetEntriesSerialized(readEntries, *buffer, offsetBuffer.get());
readEntries += readLast;

if (mVLA) {
lastOffset = offset;
for (auto i = 0; i < readLast; ++i) {
offsets[count++] = (int)offset;
offset += swap32_(reinterpret_cast<uint32_t*>(offsetBuffer->GetCurrent())[i]);
}
size = (int)(offset - lastOffset);
} else {
size = readLast * mListSize;
}
size = offset - lastOffset;
swapCopy(ptr, buffer->GetCurrent(), size, typeSize);
ptr += (ptrdiff_t)(size * typeSize);
}
if (mVLA) {
offsets[count] = (int)offset;
totalSize = offset;
} else {
size = readLast * mListSize;
totalSize = readEntries * mListSize;
}
std::shared_ptr<arrow::PrimitiveArray> varray;
switch (mListSize) {
case -1:
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
array = std::make_shared<arrow::ListArray>(mArrowType, readEntries, arrowOffsetBuffer, varray);
break;
case 1:
array = std::make_shared<arrow::PrimitiveArray>(mArrowType, readEntries, arrowValuesBuffer);
break;
default:
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
array = std::make_shared<arrow::FixedSizeListArray>(mArrowType, readEntries, varray);
}
swapCopy(ptr, buffer->GetCurrent(), size, typeSize);
ptr += size * typeSize;
}
if (mVLA) {
offsets[count] = offset;
totalSize = offset;
} else {
totalSize = readEntries * mListSize;
}
std::shared_ptr<arrow::PrimitiveArray> varray;
switch (mListSize) {
case -1:
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
array = std::make_shared<arrow::ListArray>(mArrowType, readEntries, arrowOffsetBuffer, varray);
break;
case 1:
array = std::make_shared<arrow::PrimitiveArray>(mArrowType, readEntries, arrowValuesBuffer);
break;
default:
varray = std::make_shared<arrow::PrimitiveArray>(mArrowType->field(0)->type(), totalSize, arrowValuesBuffer);
array = std::make_shared<arrow::FixedSizeListArray>(mArrowType, readEntries, varray);
}
}

Expand Down