diff --git a/Framework/Core/include/Framework/MessageSet.h b/Framework/Core/include/Framework/MessageSet.h index e651c7d5960ff..7093486c6c888 100644 --- a/Framework/Core/include/Framework/MessageSet.h +++ b/Framework/Core/include/Framework/MessageSet.h @@ -14,6 +14,7 @@ #include "Framework/PartRef.h" #include #include +#include namespace o2 { @@ -24,16 +25,98 @@ namespace framework struct MessageSet { std::vector parts; + MessageSet() + : parts() + { + } + + template + MessageSet(F&& getter, size_t size) + : parts() + { + add(std::forward(getter), size); + } + + MessageSet(MessageSet&& other) + : parts(std::move(other.parts)) + { + other.clear(); + } + + MessageSet& operator=(MessageSet&& other) + { + if (&other == this) { + return *this; + } + parts = std::move(other.parts); + other.clear(); + return *this; + } + size_t size() const { return parts.size(); } + size_t getNumberOfPayloads(size_t part) const + { + // this is for upcoming change of message store + return 1; + } + void clear() { parts.clear(); } + // this is more or less legacy + void reset(PartRef&& ref) + { + clear(); + add(std::move(ref)); + } + + void add(PartRef&& ref) + { + parts.emplace_back(std::move(ref)); + } + + template + void add(F getter, size_t size) + { + for (size_t i = 0; i < size; ++i) { + PartRef ref{std::move(getter(i)), std::move(getter(i + 1))}; + parts.emplace_back(std::move(ref)); + ++i; + } + } + + FairMQMessagePtr& header(size_t partIndex) + { + assert(partIndex < parts.size()); + return parts[partIndex].header; + } + + FairMQMessagePtr& payload(size_t partIndex, size_t payloadIndex = 0) + { + assert(partIndex < parts.size()); + // payload index will be supported in linear message store + assert(payloadIndex == 0); + return parts[partIndex].payload; + } + + FairMQMessagePtr const& header(size_t partIndex) const + { + assert(partIndex < parts.size()); + return parts[partIndex].header; + } + + FairMQMessagePtr const& payload(size_t partIndex) const + { + assert(partIndex < parts.size()); + return parts[partIndex].payload; + } + PartRef& operator[](size_t index) { return parts[index]; diff --git a/Framework/Core/src/DataProcessingDevice.cxx b/Framework/Core/src/DataProcessingDevice.cxx index 4e3e4b6bd0804..c9d8a57ca2000 100644 --- a/Framework/Core/src/DataProcessingDevice.cxx +++ b/Framework/Core/src/DataProcessingDevice.cxx @@ -886,38 +886,54 @@ void DataProcessingDevice::handleData(DataProcessorContext& context, InputChanne SourceInfo }; + struct InputInfo { + InputInfo(size_t p, size_t s, InputType t) + : position(p), size(s), type(t) + { + } + size_t position; + size_t size; + InputType type; + }; + // This is how we validate inputs. I.e. we try to enforce the O2 Data model // and we do a few stats. We bind parts as a lambda captured variable, rather // than an input, because we do not want the outer loop actually be exposed // to the implementation details of the messaging layer. auto getInputTypes = [&stats = context.registry->get(), - &info, &context]() -> std::optional> { + &info, &context]() -> std::optional> { auto& parts = info.parts; stats.inputParts = parts.Size(); TracyPlot("messages received", (int64_t)parts.Size()); - if (parts.Size() % 2) { - return std::nullopt; - } - std::vector results(parts.Size() / 2, InputType::Invalid); + std::vector results; + // we can reserve the upper limit + results.reserve(parts.Size() / 2); + size_t nTotalPayloads = 0; + + auto insertInputInfo = [&results, &nTotalPayloads](size_t position, size_t length, InputType type) { + results.emplace_back(position, length, type); + if (type != InputType::Invalid && length > 1) { + nTotalPayloads += length - 1; + } + }; - for (size_t hi = 0; hi < parts.Size() / 2; ++hi) { - auto pi = hi * 2; + for (size_t pi = 0; pi < parts.Size(); pi += 2) { auto sih = o2::header::get(parts.At(pi)->GetData()); if (sih) { info.state = sih->state; - results[hi] = InputType::SourceInfo; + insertInputInfo(pi, 2, InputType::SourceInfo); *context.wasActive = true; continue; } auto dh = o2::header::get(parts.At(pi)->GetData()); if (!dh) { - results[hi] = InputType::Invalid; + insertInputInfo(pi, 0, InputType::Invalid); LOGP(error, "Header is not a DataHeader?"); continue; } if (dh->payloadSize != parts.At(pi + 1)->GetSize()) { - results[hi] = InputType::Invalid; + insertInputInfo(pi, 0, InputType::Invalid); LOGP(error, "DataHeader payloadSize mismatch"); continue; } @@ -925,24 +941,31 @@ void DataProcessingDevice::handleData(DataProcessorContext& context, InputChanne auto dph = o2::header::get(parts.At(pi)->GetData()); TracyAlloc(parts.At(pi + 1)->GetData(), parts.At(pi + 1)->GetSize()); if (!dph) { - results[hi] = InputType::Invalid; + insertInputInfo(pi, 2, InputType::Invalid); LOGP(error, "Header stack does not contain DataProcessingHeader"); continue; } - // We can set the type for the next splitPayloadParts - // because we are guaranteed they are all the same. - // If splitPayloadParts = 0, we assume that means there is only one (header, payload) - // pair. - size_t finalSplitPayloadIndex = hi + (dh->splitPayloadParts > 0 ? dh->splitPayloadParts : 1); - if (finalSplitPayloadIndex > results.size()) { - LOGP(error, "DataHeader::splitPayloadParts invalid"); - results[hi] = InputType::Invalid; - continue; - } - for (; hi < finalSplitPayloadIndex; ++hi) { - results[hi] = InputType::Data; + { + // We can set the type for the next splitPayloadParts + // because we are guaranteed they are all the same. + // If splitPayloadParts = 0, we assume that means there is only one (header, payload) + // pair. + size_t finalSplitPayloadIndex = pi + (dh->splitPayloadParts > 0 ? dh->splitPayloadParts : 1) * 2; + if (finalSplitPayloadIndex > parts.Size()) { + LOGP(error, "DataHeader::splitPayloadParts invalid"); + insertInputInfo(pi, 0, InputType::Invalid); + continue; + } + insertInputInfo(pi, 2, InputType::Data); + for (; pi + 2 < finalSplitPayloadIndex; pi += 2) { + insertInputInfo(pi + 2, 2, InputType::Data); + } } - hi = finalSplitPayloadIndex - 1; + } + assert(std::accumulate(results.begin(), results.end(), 0, [](size_t const& count, auto const& element) -> size_t { return count + element.size; })); + if (results.size() + nTotalPayloads != parts.Size()) { + LOG(ERROR) << "inconsistent number of inputs extracted"; + return std::nullopt; } return results; }; @@ -951,21 +974,22 @@ void DataProcessingDevice::handleData(DataProcessorContext& context, InputChanne registry.get().errorCount++; }; - auto handleValidMessages = [&info, &context = context, &relayer = *context.relayer, &reportError](std::vector const& types) { + auto handleValidMessages = [&info, &context = context, &relayer = *context.relayer, &reportError](std::vector const& inputInfos) { static WaitBackpressurePolicy policy; auto& parts = info.parts; // We relay execution to make sure we have a complete set of parts // available. - for (size_t pi = 0; pi < (parts.Size() / 2); ++pi) { - switch (types[pi]) { + for (auto ii = 0; ii < inputInfos.size(); ++ii) { + auto const& input = inputInfos[ii]; + switch (input.type) { case InputType::Data: { - auto headerIndex = 2 * pi; - auto payloadIndex = 2 * pi + 1; + auto headerIndex = input.position; + auto payloadIndex = headerIndex + 1; assert(payloadIndex < parts.Size()); auto dh = o2::header::get(parts.At(headerIndex)->GetData()); auto relayed = relayer.relay(parts.At(headerIndex), &parts.At(payloadIndex), dh->splitPayloadParts > 0 ? dh->splitPayloadParts * 2 - 1 : 0); - pi += dh->splitPayloadParts > 0 ? dh->splitPayloadParts - 1 : 0; + ii += dh->splitPayloadParts > 0 ? dh->splitPayloadParts - 1 : 0; switch (relayed) { case DataRelayer::Backpressured: if (info.normalOpsNotified == true && info.backpressureNotified == false) { @@ -988,8 +1012,8 @@ void DataProcessingDevice::handleData(DataProcessorContext& context, InputChanne } break; case InputType::SourceInfo: { *context.wasActive = true; - auto headerIndex = 2 * pi; - auto payloadIndex = 2 * pi + 1; + auto headerIndex = input.position; + auto payloadIndex = input.position + 1; assert(payloadIndex < parts.Size()); auto dh = o2::header::get(parts.At(headerIndex)->GetData()); // FIXME: the message with the end of stream cannot contain