Skip to content

File inference_feedback_manager.h

File List > calculators > tensor > inference_feedback_manager.h

Go to the documentation of this file

#ifndef MEDIAPIPE_CALCULATORS_TENSOR_Inference_feedback_manager_H_
#define MEDIAPIPE_CALCULATORS_TENSOR_Inference_feedback_manager_H_

#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/calculators/tensor/inference_io_mapper.h"
#include "tensorflow/lite/interpreter.h"

namespace mediapipe {

// Feedback tensors are pairs of model output / input tensors where the
// model output is used as model input in the next model invocation. This allows
// to manage a notion of temporal state by continuously feeding from the model's
// output to the model's input during each inference step. The
// InferenceFeedbackManager initializes the feedback input tensors with zeros
// and efficiently swaps them from output to input with zero copies.
class InferenceFeedbackManager {
 public:
  // Initializes the feedback tensors with zeros and generates
  // feedback_tensor_indices_links_. The provided interpreter must outlive the
  // InferenceFeedbackManager instance.
  absl::Status Init(
      const mediapipe::InferenceCalculatorOptions::InputOutputConfig& io_config,
      const InputOutputTensorNames& input_output_tensor_names_map,
      tflite::Interpreter* interpreter);

  // Swaps the feedback tensors from output to input.
  void SwapFeedbackTensors();

  // Returns the number of expected non-feedback tensors. This can be used to
  // confirm the number of input tensors to the InferenceRunner implementation.
  int GetNumberOfNonFeedbackInputTensors() const;

  //  Returns the number of feedback tensor pairs.
  int GetNumberOfFeedbackTensors() const;

  // Since feedback tensors are excluded from InferenceRunner input, This method
  // maps the tensor index from the InferenceRunner input to the TfLite model
  // tensor input.
  absl::StatusOr<int> MapInputTensorToModelIndex(int input_idx) const;

  // Returns true if the tensor at the given index is a feedback input tensor.
  bool IsFeedbackInputTensorAtIndex(int idx) const;

  // Returns true if the tensor at the given index is a feedback output tensor.
  bool IsFeedbackOutputTensorAtIndex(int idx) const;

 private:
  // Links between feedback tensors defined by model tensor indices.
  struct TensorFeedbackIndicesLink {
    int from_idx;
    int to_idx;
  };

  // Translates the tensor names from the input/output config into the
  // corresponding TfLite tensor indices.
  static absl::StatusOr<std::vector<TensorFeedbackIndicesLink>>
  ConvertSignatureTensorNamesToModelIndices(
      const mediapipe::InferenceCalculatorOptions::InputOutputConfig& io_config,
      const InputOutputTensorNames& input_output_tensor_names_map);

  // Non-owning reference to the TfLite interpreter.
  tflite::Interpreter* interpreter_ = nullptr;

  // List of tensor feedback pairs defined by model tensor indices.
  std::vector<TensorFeedbackIndicesLink> feedback_tensor_indices_links_;

  // Maps InferenceRunner input indices to TfLiteModel input indices.
  std::vector<int> input_tensor_to_model_indices_;

  // Set of feedback input model tensor indices.
  absl::flat_hash_set<int> feedback_input_indices_;

  // Set of feedback output model tensor indices.
  absl::flat_hash_set<int> feedback_output_indices_;
};
}  // namespace mediapipe

#endif  // MEDIAPIPE_CALCULATORS_TENSOR_Inference_feedback_manager_H_