File inference_io_mapper.cc
File List > calculators > tensor > inference_io_mapper.cc
Go to the documentation of this file
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensor/inference_io_mapper.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensor_span.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/tflite/tflite_signature_reader.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model_builder.h"
namespace mediapipe {
namespace {
using ::tflite::FlatBufferModel;
using ::tflite::Interpreter;
using ::tflite::InterpreterBuilder;
using ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates;
// Checks for duplicate indices in a TensorIndicesMap.
absl::StatusOr<std::vector<int>> GenerateAndValidateTensorList(
const InferenceCalculatorOptions::InputOutputConfig::TensorIndicesMap&
tensor_indices_list) {
absl::flat_hash_set<int> indices_set;
std::vector<int> result;
for (const int index : tensor_indices_list.model_tensor_indices()) {
RET_CHECK(indices_set.insert(index).second)
<< "Indices in TensorIndicesMap are not unique.";
result.push_back(index);
}
return result;
}
absl::StatusOr<absl::flat_hash_map<std::string, int>> CreateNameToIndexMap(
const std::vector<std::string>& names) {
absl::flat_hash_map<std::string, int> name_to_index_map;
for (int i = 0; i < (int)names.size(); ++i) {
auto [unused_iter, was_inserted] = name_to_index_map.insert({names[i], i});
RET_CHECK(was_inserted)
<< "Duplicate tensor names found in model signatures: "
<< absl::StrJoin(names, ", ");
}
return name_to_index_map;
}
template <typename T>
static bool ContainsDuplicates(const std::vector<T>& input) {
absl::flat_hash_set<T> set;
for (const auto& item : input) {
if (!set.insert(item).second) {
return true;
}
}
return false;
}
static absl::StatusOr<std::vector<int>> MapTensorNamesToIndices(
const std::vector<std::string>& signature_tensor_names,
const InferenceCalculatorOptions::InputOutputConfig::TensorNamesMap&
config_tensor_names) {
std::vector<int> result;
result.reserve(signature_tensor_names.size());
MP_ASSIGN_OR_RETURN(const auto input_name_to_index_map,
CreateNameToIndexMap(signature_tensor_names));
for (const auto& tensor_name : config_tensor_names.tensor_names()) {
const auto it = input_name_to_index_map.find(tensor_name);
RET_CHECK(it != input_name_to_index_map.end())
<< "Tensor name " << tensor_name
<< " not found in model signatures. Model tensor names: "
<< absl::StrJoin(signature_tensor_names, ", ");
result.push_back(it->second);
}
RET_CHECK(!ContainsDuplicates(result))
<< "Duplicate tensor names found in TensorNamesMap: "
<< absl::StrJoin(config_tensor_names.tensor_names(), ", ");
return result;
};
// Feedback tensors are excluded from the InferenceRunner input and output
// accordingly (since they are class-internally handled by the
// InferenceFeedbackManager). This means that the input and output Tensor orders
// of the InferenceRunner don't match the model I/O tensors anymore and
// therefore tensor I/O indices need to be adjusted accordingly.
absl::Status ExcludeFeedbackTensorsFromRemappingIndicesVector(
const InferenceCalculatorOptions::InputOutputConfig& io_config,
const std::vector<std::string>& model_tensor_names,
std::vector<int>& remapping_tensor_indices) {
// Create set of all feedback tensor names.
absl::flat_hash_set<std::string> feedback_tensor_names;
for (const auto& link : io_config.feedback_tensor_links()) {
{
// No need to check for name collisions. Inference feedback manager
// confirms validity of feedback tensor names.
feedback_tensor_names.insert(link.from_output_tensor_name());
feedback_tensor_names.insert(link.to_input_tensor_name());
}
}
// Built model index translation vector which maps InferenceRunner I/O tensor
// indices to InferenceRunner I/O indices with excluded feedback tensors.
std::vector<int> indices_translation(model_tensor_names.size(), -1);
int model_output_idx = 0;
for (int i = 0; i < (int)model_tensor_names.size(); ++i) {
if (!feedback_tensor_names.contains(model_tensor_names[i])) {
indices_translation[i] = model_output_idx;
++model_output_idx;
}
}
// Adjust remapping_tensor_indices.
for (int i = 0; i < (int)remapping_tensor_indices.size(); ++i) {
const int model_index = remapping_tensor_indices[i];
RET_CHECK(model_index >= 0 && (size_t)model_index < indices_translation.size())
<< "Index " << model_index << " out of range.";
remapping_tensor_indices[i] =
indices_translation[remapping_tensor_indices[i]];
}
return absl::OkStatus();
}
} // namespace
// static
absl::StatusOr<InputOutputTensorNames>
InferenceIoMapper::GetInputOutputTensorNamesFromInterpreter(
const tflite::Interpreter& interpreter) {
auto input_output_tensor_names =
TfLiteSignatureReader::GetInputOutputTensorNamesFromAllTfliteSignatures(
interpreter);
if (!input_output_tensor_names.ok()) {
// TODO b/336260063 - remove this warning once the bug is fixed.
ABSL_LOG_FIRST_N(WARNING, 1)
<< "Unable to extract TfLite model's tensor names from "
"TfliteSignature. Disabling tensor name-based I/O mapping.";
return InputOutputTensorNames();
}
return *input_output_tensor_names;
}
// static
absl::StatusOr<InputOutputTensorNames>
InferenceIoMapper::GetInputOutputTensorNamesFromModel(
const tflite::FlatBufferModel& flatbuffer,
const tflite::OpResolver& op_resolver) {
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
ABSL_LOG_EVERY_N(WARNING, 1)
<< "Extracting input output tensor names from TfliteSignature failed: "
"Unable to prepare interpreter. Ignoring tensor name-based I/O "
"mapping.";
return InputOutputTensorNames();
}
return GetInputOutputTensorNamesFromInterpreter(*interpreter);
}
absl::Status InferenceIoMapper::UpdateIoMap(
const InferenceCalculatorOptions::InputOutputConfig& io_config,
const InputOutputTensorNames& input_output_tensor_names) {
num_feedback_tensors_ = io_config.feedback_tensor_links().size();
if ((io_config.has_input_tensor_indices_map() ||
io_config.has_output_tensor_indices_map()) &&
num_feedback_tensors_ > 0) {
// TODO b/336767692 - remove this check once indices-based feedback
// tensors are supported.
return absl::FailedPreconditionError(
"Feedback tensors are not supported with tensor index-based I/O "
"mapping.");
}
input_tensor_indices_.clear();
output_tensor_indices_.clear();
if (io_config.has_input_tensor_indices_map()) {
input_tensor_indices_.reserve(
io_config.input_tensor_indices_map().model_tensor_indices().size());
MP_ASSIGN_OR_RETURN(
input_tensor_indices_,
GenerateAndValidateTensorList(io_config.input_tensor_indices_map()));
}
if (io_config.has_output_tensor_indices_map()) {
output_tensor_indices_.reserve(
io_config.output_tensor_indices_map().model_tensor_indices().size());
MP_ASSIGN_OR_RETURN(
output_tensor_indices_,
GenerateAndValidateTensorList(io_config.output_tensor_indices_map()));
}
if (!io_config.has_input_tensor_names_map() &&
!io_config.has_output_tensor_names_map()) {
// No tensor name mapping is provided.
return absl::OkStatus();
}
if (input_output_tensor_names.empty()) {
return absl::FailedPreconditionError(
"Tensor name-based mapping requires a model with one signature.");
}
if (input_output_tensor_names.size() > 1) {
return absl::FailedPreconditionError(
"Tensor name-based mapping is not supported with multi-signature "
"models.");
}
// Use tensor names of default signature.
const auto input_output_tensor_names_default_signature =
input_output_tensor_names.begin()->second;
if (io_config.has_input_tensor_names_map()) {
// Read number of model inputs directly from the signature.
const int num_model_input_tensors =
(int)input_output_tensor_names_default_signature.input_tensor_names.size();
input_tensor_indices_.reserve(
io_config.input_tensor_names_map().tensor_names().size());
MP_ASSIGN_OR_RETURN(
input_tensor_indices_,
MapTensorNamesToIndices(
input_output_tensor_names_default_signature.input_tensor_names,
io_config.input_tensor_names_map()));
if (num_feedback_tensors_ > 0) {
MP_RETURN_IF_ERROR(ExcludeFeedbackTensorsFromRemappingIndicesVector(
io_config,
input_output_tensor_names_default_signature.input_tensor_names,
input_tensor_indices_));
}
// Feedback tensors are excluded from the input_tensor_indices_.
RET_CHECK_EQ(input_tensor_indices_.size() + num_feedback_tensors_,
(size_t)num_model_input_tensors)
<< "Unexpected number of input tensors.";
}
if (io_config.has_output_tensor_names_map()) {
const int num_model_output_tensors =
(int)input_output_tensor_names_default_signature.output_tensor_names.size();
output_tensor_indices_.reserve(num_model_output_tensors);
MP_ASSIGN_OR_RETURN(
output_tensor_indices_,
MapTensorNamesToIndices(
input_output_tensor_names_default_signature.output_tensor_names,
io_config.output_tensor_names_map()));
if (num_feedback_tensors_ > 0) {
MP_RETURN_IF_ERROR(ExcludeFeedbackTensorsFromRemappingIndicesVector(
io_config,
input_output_tensor_names_default_signature.output_tensor_names,
output_tensor_indices_));
}
// Feedback tensors are excluded from the output_tensor_indices_.
RET_CHECK_EQ(output_tensor_indices_.size() + num_feedback_tensors_,
(size_t)num_model_output_tensors)
<< "Unexpected number of output tensors.";
}
return absl::OkStatus();
}
absl::StatusOr<TensorSpan> InferenceIoMapper::RemapInputTensors(
const TensorSpan& unmapped_tensors) {
if (input_tensor_indices_.empty()) {
return unmapped_tensors;
}
RET_CHECK_EQ((size_t)unmapped_tensors.size(), input_tensor_indices_.size())
<< "Unexpected number of input tensors.";
std::vector<const Tensor*> mapped_tensors(unmapped_tensors.size());
for (int i = 0; i < unmapped_tensors.size(); ++i) {
const int index = input_tensor_indices_[i];
RET_CHECK(index < unmapped_tensors.size())
<< "Index " << index << " out of range"
<< ". Size of TensorIndicesMap: " << unmapped_tensors.size() << ".";
mapped_tensors[index] = &unmapped_tensors[i];
}
return TensorSpan(std::move(mapped_tensors));
}
absl::StatusOr<std::vector<Tensor>> InferenceIoMapper::RemapOutputTensors(
std::vector<Tensor>&& unmapped_tensors) {
if (output_tensor_indices_.empty()) {
return std::move(unmapped_tensors);
}
RET_CHECK_EQ(unmapped_tensors.size(), output_tensor_indices_.size())
<< "Unexpected number of output tensors.";
std::vector<Tensor> mapped_tensors;
mapped_tensors.reserve(unmapped_tensors.size());
for (int i = 0; i < (int)unmapped_tensors.size(); ++i) {
const int index = output_tensor_indices_[i];
RET_CHECK(index < (int)unmapped_tensors.size())
<< "Index " << index << " out of range"
<< ". Size of TensorIndicesMap: " << unmapped_tensors.size() << ".";
mapped_tensors.emplace_back(std::move(unmapped_tensors[index]));
}
return mapped_tensors;
}
} // namespace mediapipe