File tflite_signature_reader.cc
File List > mediapipe > util > tflite > tflite_signature_reader.cc
Go to the documentation of this file
#include "mediapipe/util/tflite/tflite_signature_reader.h"
#include <algorithm>
#include <cstdint>
#include <map>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
namespace {
// Flips the key-value pairs in a map.
absl::flat_hash_map<uint32_t, std::string> FlipKVInMap(
const std::map<std::string, uint32_t>& map) {
absl::flat_hash_map<uint32_t, std::string> flipped;
for (const auto& kv : map) {
flipped[kv.second] = kv.first;
}
return flipped;
}
} // namespace
absl::StatusOr<SignatureInputOutputTensorNames>
TfLiteSignatureReader::GetInputOutputTensorNamesFromTfliteSignature(
const tflite::Interpreter& interpreter, const std::string* signature_key) {
std::vector<const std::string*> model_signature_keys =
interpreter.signature_keys();
if (model_signature_keys.empty()) {
return absl::InvalidArgumentError("No signatures found.");
}
if (signature_key == nullptr && model_signature_keys.size() > 1) {
std::vector<std::string> available_signature_keys;
available_signature_keys.reserve(model_signature_keys.size());
for (const std::string* signature_key : model_signature_keys) {
available_signature_keys.push_back(*signature_key);
}
return absl::InvalidArgumentError(
absl::StrCat("Model contains multiple signatures but no signature key "
"specified. Available signature keys: ",
absl::StrJoin(available_signature_keys, ", ")));
}
const std::string* signature_key_str = nullptr;
if (signature_key != nullptr) {
RET_CHECK(std::find_if(model_signature_keys.begin(),
model_signature_keys.end(),
[&](const std::string* model_signature_key) {
return *signature_key == *model_signature_key;
}) != model_signature_keys.end())
<< "Signature key not found in model.";
signature_key_str = signature_key;
} else {
signature_key_str = model_signature_keys[0];
}
const absl::flat_hash_map<uint32_t, std::string>
model_input_tensor_id_to_name_map =
FlipKVInMap(interpreter.signature_inputs(signature_key_str->c_str()));
const absl::flat_hash_map<uint32_t, std::string>
model_output_tensor_id_to_name_map = FlipKVInMap(
interpreter.signature_outputs(signature_key_str->c_str()));
// Maps the model input and outputs to internal model tensor ids.
const std::vector<int>& model_input_tensor_ids = interpreter.inputs();
const std::vector<int>& model_output_tensor_ids = interpreter.outputs();
SignatureInputOutputTensorNames input_output_tensor_names;
auto& input_names = input_output_tensor_names.input_tensor_names;
auto& output_names = input_output_tensor_names.output_tensor_names;
input_names.reserve(model_input_tensor_ids.size());
for (int i = 0; i < (int)model_input_tensor_ids.size(); ++i) {
const auto it =
model_input_tensor_id_to_name_map.find(model_input_tensor_ids[i]);
if (it == model_input_tensor_id_to_name_map.end()) {
return absl::InternalError(absl::StrCat("Input tensor id ",
model_input_tensor_ids[i],
" not found in signature."));
}
input_names.push_back(it->second);
}
output_names.reserve(model_output_tensor_ids.size());
for (int i = 0; i < (int)model_output_tensor_ids.size(); ++i) {
const auto it =
model_output_tensor_id_to_name_map.find(model_output_tensor_ids[i]);
if (it == model_output_tensor_id_to_name_map.end()) {
return absl::InternalError(absl::StrCat("Output tensor id ",
model_output_tensor_ids[i],
" not found in signature."));
}
output_names.push_back(it->second);
}
return input_output_tensor_names;
}
absl::StatusOr<
absl::flat_hash_map<SignatureName, SignatureInputOutputTensorNames>>
TfLiteSignatureReader::GetInputOutputTensorNamesFromAllTfliteSignatures(
const tflite::Interpreter& interpreter) {
absl::flat_hash_map<SignatureName, SignatureInputOutputTensorNames> result;
std::vector<const std::string*> model_signature_keys =
interpreter.signature_keys();
for (const std::string* signature_key : model_signature_keys) {
MP_ASSIGN_OR_RETURN(
SignatureInputOutputTensorNames input_output_tensor_names,
GetInputOutputTensorNamesFromTfliteSignature(interpreter,
signature_key));
auto [unused_iter, was_inserted] =
result.insert({*signature_key, std::move(input_output_tensor_names)});
RET_CHECK(was_inserted) << "Duplicate signature key: " << *signature_key
<< ". Available signature keys: "
<< absl::StrJoin(model_signature_keys, ", ");
}
return result;
}
} // namespace mediapipe