Skip to content

File inference_feedback_manager.cc

File List > calculators > tensor > inference_feedback_manager.cc

Go to the documentation of this file

#include "mediapipe/calculators/tensor/inference_feedback_manager.h"

#include <cstring>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/calculators/tensor/inference_io_mapper.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/tflite/tflite_signature_reader.h"
#include "mediapipe/util/tflite/utils.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"

namespace mediapipe {

namespace {

bool TfLiteTensorSpecEqual(const TfLiteTensor& a, const TfLiteTensor& b) {
  return a.type == b.type && TfLiteIntArrayEqual(a.dims, b.dims) &&
         a.params.scale == b.params.scale &&
         a.params.zero_point == b.params.zero_point &&
         a.allocation_type == b.allocation_type && a.bytes == b.bytes;
}

absl::flat_hash_map<std::string, int> CreateNameToIndexMap(
    const std::vector<std::string>& names) {
  absl::flat_hash_map<std::string, int> name_to_index_map;
  for (int i = 0; i < (int)names.size(); ++i) {
    name_to_index_map[names[i]] = i;
  }
  return name_to_index_map;
}

}  // namespace

absl::Status InferenceFeedbackManager::Init(
    const InferenceCalculatorOptions::InputOutputConfig& io_config,
    const InputOutputTensorNames& input_output_tensor_names,
    tflite::Interpreter* interpreter) {
  interpreter_ = interpreter;
  MP_ASSIGN_OR_RETURN(feedback_tensor_indices_links_,
                      ConvertSignatureTensorNamesToModelIndices(
                          io_config, input_output_tensor_names));

  for (const auto& link : feedback_tensor_indices_links_) {
    const auto [output_unused_iter, output_was_inserted] =
        feedback_output_indices_.insert(link.from_idx);
    RET_CHECK(output_was_inserted) << "Feedback output tensors must be unique.";
    TfLiteTensor* from_tensor =
        interpreter_->tensor(interpreter->outputs()[link.from_idx]);
    RET_CHECK(!util::tflite::IsDynamicTensor(*from_tensor))
        << "Feedback output tensors must not be dynamic.";
    const auto [input_unused_iter, input_was_inserted] =
        feedback_input_indices_.insert(link.to_idx);
    RET_CHECK(input_was_inserted) << "Feedback input tensors must be unique.";
    TfLiteTensor* to_tensor =
        interpreter_->tensor(interpreter->inputs()[link.to_idx]);
    RET_CHECK(!util::tflite::IsDynamicTensor(*to_tensor))
        << "Feedback input tensors must not be dynamic.";
    RET_CHECK(TfLiteTensorSpecEqual(*from_tensor, *to_tensor))
        << "Feedback tensors must have the same spec.";
    // Since the TfLite API isn't specific about the initialization of newly
    // allocated Tensor memory, we initialize the input to_tensor tensor with
    // zeros.
    memset(to_tensor->data.raw, 0, to_tensor->bytes);
  }

  // Populate input_tensor_to_model_indices_ which maps InferenceRunner input
  // tensors indices to the model input indices.
  input_tensor_to_model_indices_.reserve(interpreter_->inputs().size());
  for (int i = 0; i < (int)interpreter_->inputs().size(); ++i) {
    if (!feedback_input_indices_.contains(i)) {
      input_tensor_to_model_indices_.push_back(i);
    }
  }
  return absl::OkStatus();
}

void InferenceFeedbackManager::SwapFeedbackTensors() {
  for (const auto& link : feedback_tensor_indices_links_) {
    TfLiteTensor* from_tensor =
        interpreter_->tensor(interpreter_->outputs()[link.from_idx]);
    TfLiteTensor* to_tensor =
        interpreter_->tensor(interpreter_->inputs()[link.to_idx]);
    {
      using std::swap;
      // TODO b/338023494 - Use TfLite CustomAllocator to manage memory of
      // feedback tensors (replace std::swap)
      swap(*from_tensor, *to_tensor);
    }
  }
}

// static
absl::StatusOr<std::vector<InferenceFeedbackManager::TensorFeedbackIndicesLink>>
InferenceFeedbackManager::ConvertSignatureTensorNamesToModelIndices(
    const InferenceCalculatorOptions::InputOutputConfig& io_config,
    const InputOutputTensorNames& input_output_tensor_names_map) {
  std::vector<TensorFeedbackIndicesLink> indices_links;
  if (input_output_tensor_names_map.empty() ||
      input_output_tensor_names_map.size() > 1) {
    // Fail gracefully by returning an empty TensorFeedbackIndicesLink list if
    // SignatureDef is not available or not supported.
    ABSL_LOG(WARNING)
        << "Feedback manager requires a model with a single signature "
           "inference. Disabling support for feedback tensors.";
    return indices_links;
  }
  // Obtain reference to single-signature in input_output_tensor_names_map.
  const auto& input_output_tensor_names =
      input_output_tensor_names_map.begin()->second;

  const auto input_name_to_index_map =
      CreateNameToIndexMap(input_output_tensor_names.input_tensor_names);
  const auto output_name_to_index_map =
      CreateNameToIndexMap(input_output_tensor_names.output_tensor_names);

  // Create a set of all input/output tensor names used for InferenceCalculator
  // I/O mapping.
  absl::flat_hash_set<std::string> input_output_mapping_tensor_names;
  for (const auto& name : io_config.input_tensor_names_map().tensor_names()) {
    input_output_mapping_tensor_names.insert(name);
  }
  for (const auto& name : io_config.output_tensor_names_map().tensor_names()) {
    input_output_mapping_tensor_names.insert(name);
  }

  for (const auto& link : io_config.feedback_tensor_links()) {
    RET_CHECK(!input_output_mapping_tensor_names.contains(
        link.from_output_tensor_name()))
        << absl::StrFormat(
               "Feedback output tensor [%s] cannot be used for input/output "
               "mapping. Input/output mapping tensor names: [%s]",
               link.from_output_tensor_name(),
               absl::StrJoin(input_output_mapping_tensor_names, ", "));
    RET_CHECK(!input_output_mapping_tensor_names.contains(
        link.to_input_tensor_name()))
        << absl::StrFormat(
               "Feedback input tensor [%s] cannot be used for input/output "
               "mapping. Input/output mapping tensor names: [%s]",
               link.to_input_tensor_name(),
               absl::StrJoin(input_output_mapping_tensor_names, ", "));
    TensorFeedbackIndicesLink indices_link;
    auto from_it =
        output_name_to_index_map.find(link.from_output_tensor_name());
    RET_CHECK(from_it != output_name_to_index_map.end())
        << "Output tensor name not found: " << link.from_output_tensor_name();
    auto to_it = input_name_to_index_map.find(link.to_input_tensor_name());
    RET_CHECK(to_it != input_name_to_index_map.end())
        << "Input tensor name not found: " << link.to_input_tensor_name();
    indices_link.from_idx = from_it->second;
    indices_link.to_idx = to_it->second;
    indices_links.push_back(indices_link);
  }
  return indices_links;
}

bool InferenceFeedbackManager::IsFeedbackInputTensorAtIndex(int idx) const {
  return feedback_input_indices_.contains(idx);
}

bool InferenceFeedbackManager::IsFeedbackOutputTensorAtIndex(int idx) const {
  return feedback_output_indices_.contains(idx);
}

absl::StatusOr<int> InferenceFeedbackManager::MapInputTensorToModelIndex(
    int input_idx) const {
  RET_CHECK(input_idx >= 0 &&
            (size_t)input_idx <= input_tensor_to_model_indices_.size())
      << "Invalid input tensor index: " << input_idx;
  return input_tensor_to_model_indices_[input_idx];
}

int InferenceFeedbackManager::GetNumberOfNonFeedbackInputTensors() const {
  return (int)input_tensor_to_model_indices_.size();
}

int InferenceFeedbackManager::GetNumberOfFeedbackTensors() const {
  return (int)feedback_tensor_indices_links_.size();
}
}  // namespace mediapipe