Skip to content

File inference_calculator_utils.cc

File List > calculators > tensor > inference_calculator_utils.cc

Go to the documentation of this file

// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/calculators/tensor/inference_calculator_utils.h"

#include <cstdint>
#include <cstring>
#include <ostream>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port.h"  // NOLINT: provides MEDIAPIPE_ANDROID/IOS
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
#include "tensorflow/lite/string_util.h"

#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
#include "mediapipe/util/cpu_util.h"
#endif  // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__

namespace mediapipe {

namespace {

int GetXnnpackDefaultNumThreads() {
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \
    defined(__EMSCRIPTEN_PTHREADS__)
  constexpr int kMinNumThreadsByDefault = 1;
  constexpr int kMaxNumThreadsByDefault = 4;
  return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault,
                    kMaxNumThreadsByDefault);
#else
  return 1;
#endif  // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__
}

// Checks if a MediaPipe Tensor's type matches a TfLite's data type.
bool operator==(Tensor::ElementType tensor_type, TfLiteType tflite_type) {
  switch (tensor_type) {
    // Do these two match?
    case Tensor::ElementType::kNone:
      return tflite_type == TfLiteType::kTfLiteNoType;
    case Tensor::ElementType::kFloat16:
      return tflite_type == TfLiteType::kTfLiteFloat16;
    case Tensor::ElementType::kFloat32:
      return tflite_type == TfLiteType::kTfLiteFloat32;
    case Tensor::ElementType::kUInt8:
      return tflite_type == TfLiteType::kTfLiteUInt8;
    case Tensor::ElementType::kInt8:
      return tflite_type == TfLiteType::kTfLiteInt8;
    case Tensor::ElementType::kInt32:
      return tflite_type == TfLiteType::kTfLiteInt32;
    case Tensor::ElementType::kInt64:
      return tflite_type == TfLiteType::kTfLiteInt64;
    case Tensor::ElementType::kBool:
      return tflite_type == TfLiteType::kTfLiteBool;
    case Tensor::ElementType::kChar:
      return tflite_type == TfLiteType::kTfLiteString;
    default:
      return false;
  }
}

template <typename T>
absl::Status CopyTensorToTfLiteTensor(const Tensor& input_tensor,
                                      TfLiteTensor& tflite_tensor) {
  auto input_tensor_view = input_tensor.GetCpuReadView();
  const T* input_tensor_buffer = input_tensor_view.buffer<T>();
  RET_CHECK(input_tensor_buffer) << "Input tensor buffer is null.";
  RET_CHECK_EQ(tflite_tensor.type, tflite::typeToTfLiteType<T>())
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "Tensor and TfLiteTensor types do not match.";
  void* local_tensor_buffer = tflite_tensor.data.raw;
  RET_CHECK(local_tensor_buffer) << "TfLiteTensor data is null.";
  RET_CHECK_EQ(tflite_tensor.bytes, (size_t)input_tensor.bytes())
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "TfLiteTensor and Tensor sizes do not match.";
  std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
  return absl::OkStatus();
}

template <>
absl::Status CopyTensorToTfLiteTensor<char>(const Tensor& input_tensor,
                                            TfLiteTensor& tflite_tensor) {
  const char* input_tensor_buffer =
      input_tensor.GetCpuReadView().buffer<char>();
  RET_CHECK(input_tensor_buffer) << "Char-typed input tensor buffer is null.";
  RET_CHECK_EQ(tflite_tensor.type, TfLiteType::kTfLiteString)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "TfLiteTensor type is not kTfLiteString while Tensor type is kChar.";
  tflite::DynamicBuffer dynamic_buffer;
  dynamic_buffer.AddString(input_tensor_buffer,
                           input_tensor.shape().num_elements());
  dynamic_buffer.WriteToTensorAsVector(&tflite_tensor);
  return absl::OkStatus();
}

bool operator==(const TfLiteIntArray& lhs, const std::vector<int>& rhs) {
  if ((size_t)lhs.size != rhs.size()) return false;
  for (int i = 0; i < lhs.size; ++i) {
    if (lhs.data[i] != rhs[i]) return false;
  }
  return true;
}

std::ostream& operator<<(std::ostream& os, const TfLiteIntArray& array) {
  return os << '[' << absl::StrJoin(absl::MakeSpan(array.data, array.size), ",")
            << ']';
}

template <typename T>
absl::Status CopyTfLiteTensorToTensor(const TfLiteTensor& tflite_tensor,
                                      Tensor& output_tensor) {
  auto output_tensor_view = output_tensor.GetCpuWriteView();
  T* output_tensor_buffer = output_tensor_view.buffer<T>();
  RET_CHECK(output_tensor_buffer) << "Output tensor buffer is null.";
  RET_CHECK_EQ(tflite_tensor.type, tflite::typeToTfLiteType<T>())
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "TfLite tensor type and requested output type do not match.";
  const Tensor::ElementType output_tensor_type = output_tensor.element_type();
  RET_CHECK(output_tensor_type == tflite_tensor.type)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "Output and TfLiteTensor types do not match";
  const void* local_tensor_buffer = tflite_tensor.data.raw;
  RET_CHECK(local_tensor_buffer) << "TfLiteTensor tensor buffer is null.";
  // Not using RET_CHECK_EQ because the macros triggers array copy. Explicitly
  // use == to compare with const reference.
  RET_CHECK(*tflite_tensor.dims == output_tensor.shape().dims)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "TfLiteTensor and Tensor shape do not match: " << tflite_tensor.dims
      << " vs [" << absl::StrJoin(output_tensor.shape().dims, ",") << ']';
  std::memcpy(output_tensor_buffer, local_tensor_buffer, output_tensor.bytes());
  return absl::OkStatus();
}

template <>
absl::Status CopyTfLiteTensorToTensor<char>(const TfLiteTensor& tflite_tensor,
                                            Tensor& output_tensor) {
  auto output_tensor_view = output_tensor.GetCpuWriteView();
  char* output_tensor_buffer = output_tensor_view.buffer<char>();
  RET_CHECK(output_tensor_buffer) << "Output tensor buffer is null.";
  RET_CHECK_EQ(tflite_tensor.type, kTfLiteString)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "TfLiteTensor type and requested output type do not match.";
  const Tensor::ElementType output_tensor_type = output_tensor.element_type();
  RET_CHECK(output_tensor_type == Tensor::ElementType::kChar)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "Output and TfLiteTensor types do not match";

  // Only one string expected.
  RET_CHECK_EQ(tflite::GetStringCount(&tflite_tensor), 1);
  const tflite::StringRef string_ref = tflite::GetString(&tflite_tensor, 0);
  std::string str(string_ref.str, string_ref.len);
  RET_CHECK(str.size() == (size_t)output_tensor.shape().num_elements())
          .SetCode(absl::StatusCode::kInvalidArgument)
      << absl::StrFormat(
             "TfLiteTensor and Tensor shape do not match: %d vs [%s]",
             str.size(), absl::StrJoin(output_tensor.shape().dims, ","));
  std::memcpy(output_tensor_buffer, str.data(), str.size());
  return absl::OkStatus();
}

}  // namespace

int GetXnnpackNumThreads(
    const bool opts_has_delegate,
    const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
  static constexpr int kDefaultNumThreads = -1;
  if (opts_has_delegate && opts_delegate.has_xnnpack() &&
      opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
    return opts_delegate.xnnpack().num_threads();
  }
  return GetXnnpackDefaultNumThreads();
}

absl::Status CopyCpuInputIntoInterpreterTensor(const Tensor& input_tensor,
                                               tflite::Interpreter& interpreter,
                                               int input_tensor_index) {
  auto* tflite_tensor = interpreter.input_tensor(input_tensor_index);
  RET_CHECK(tflite_tensor);
  MP_RETURN_IF_ERROR(CopyCpuInputIntoTfLiteTensor(input_tensor, *tflite_tensor))
      << " at index " << input_tensor_index;
  return absl::OkStatus();
}

absl::Status CopyCpuInputIntoTfLiteTensor(const Tensor& input_tensor,
                                          TfLiteTensor& tflite_tensor) {
  const TfLiteType interpreter_tensor_type = tflite_tensor.type;
  const Tensor::ElementType input_tensor_type = input_tensor.element_type();
  RET_CHECK(input_tensor_type == interpreter_tensor_type)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "Input and interpreter tensor type do not match.";
  switch (interpreter_tensor_type) {
    case TfLiteType::kTfLiteFloat16:
    case TfLiteType::kTfLiteFloat32: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<float>(input_tensor, tflite_tensor));
      break;
    }
    case TfLiteType::kTfLiteUInt8: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<uint8_t>(input_tensor, tflite_tensor));
      break;
    }
    case TfLiteType::kTfLiteInt8: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<int8_t>(input_tensor, tflite_tensor));
      break;
    }
    case TfLiteType::kTfLiteInt32: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<int32_t>(input_tensor, tflite_tensor));
      break;
    }
    case TfLiteType::kTfLiteInt64: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<int64_t>(input_tensor, tflite_tensor));
      break;
    }
    case TfLiteType::kTfLiteString: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<char>(input_tensor, tflite_tensor));
      break;
    }
    case TfLiteType::kTfLiteBool: {
      MP_RETURN_IF_ERROR(
          CopyTensorToTfLiteTensor<bool>(input_tensor, tflite_tensor));
      break;
    }
    default:
      return absl::InvalidArgumentError(
          absl::StrCat("Unsupported input data type: ", input_tensor_type));
  }
  return absl::OkStatus();
}

absl::Status CopyInterpreterTensorIntoCpuOutput(
    const tflite::Interpreter& interpreter, int output_tensor_index,
    Tensor& output_tensor) {
  const auto* tflite_tensor = interpreter.tensor(output_tensor_index);
  RET_CHECK(tflite_tensor);
  MP_RETURN_IF_ERROR(
      CopyTfLiteTensorIntoCpuOutput(*tflite_tensor, output_tensor))
      << " at index " << output_tensor_index;
  return absl::OkStatus();
}

absl::Status CopyTfLiteTensorIntoCpuOutput(const TfLiteTensor& tflite_tensor,
                                           Tensor& output_tensor) {
  const TfLiteType tflite_tensor_type = tflite_tensor.type;
  switch (tflite_tensor_type) {
    case TfLiteType::kTfLiteFloat16:
    case TfLiteType::kTfLiteFloat32: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<float>(tflite_tensor, output_tensor));
      break;
    }
    case TfLiteType::kTfLiteUInt8: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<uint8_t>(tflite_tensor, output_tensor));
      break;
    }
    case TfLiteType::kTfLiteInt8: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<int8_t>(tflite_tensor, output_tensor));
      break;
    }
    case TfLiteType::kTfLiteInt32: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<int>(tflite_tensor, output_tensor));
      break;
    }
    case TfLiteType::kTfLiteInt64: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<int64_t>(tflite_tensor, output_tensor));
      break;
    }
    case TfLiteType::kTfLiteString: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<char>(tflite_tensor, output_tensor));
      break;
    }
    case TfLiteType::kTfLiteBool: {
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<bool>(tflite_tensor, output_tensor));
      break;
    }
    default:
      return absl::InvalidArgumentError(
          absl::StrCat("Unsupported output data type: ", tflite_tensor_type));
  }
  return absl::OkStatus();
}

absl::StatusOr<Tensor> ConvertTfLiteTensorToTensor(
    const TfLiteTensor& tflite_tensor) {
  Tensor::Shape shape{
      std::vector<int>{tflite_tensor.dims->data,
                       tflite_tensor.dims->data + tflite_tensor.dims->size}};
  switch (tflite_tensor.type) {
    case TfLiteType::kTfLiteFloat16:
    case TfLiteType::kTfLiteFloat32: {
      Tensor output_tensor(Tensor::ElementType::kFloat32, shape);
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<float>(tflite_tensor, output_tensor));
      return output_tensor;
    }
    case TfLiteType::kTfLiteInt32: {
      Tensor output_tensor(Tensor::ElementType::kInt32, shape);
      MP_RETURN_IF_ERROR(
          CopyTfLiteTensorToTensor<int32_t>(tflite_tensor, output_tensor));
      return output_tensor;
    }
    default:
      return absl::InvalidArgumentError(
          absl::StrCat("Unsupported output data type: ", tflite_tensor.type));
  }
}

}  // namespace mediapipe