Skip to content

File template_expander.cc

File List > framework > tool > template_expander.cc

Go to the documentation of this file

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/framework/tool/template_expander.h"

#include <algorithm>
#include <memory>
#include <string>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/port/numbers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/calculator_graph_template.pb.h"
#include "mediapipe/framework/tool/proto_util_lite.h"

namespace mediapipe {

namespace tool {

using mediapipe::proto_ns::MessageLite;
using mediapipe::tool::ProtoUtilLite;
using WireFormatLite = ProtoUtilLite::WireFormatLite;
using FieldValue = ProtoUtilLite::FieldValue;
using FieldType = ProtoUtilLite::FieldType;
using ProtoPath = ProtoUtilLite::ProtoPath;
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;

namespace {

// Returns a template argument by name.
TemplateArgument* GetItem(TemplateDict* args, const std::string& name) {
  for (TemplateDict::Parameter& arg : *args->mutable_arg()) {
    if (arg.key() == name) {
      return arg.mutable_value();
    }
  }
  return nullptr;
}

// Sets the template argument for a param name.
void PutItem(TemplateDict* args, const std::string& name,
             const TemplateArgument* value) {
  for (int i = args->arg_size() - 1; i >= 0; --i) {
    if (args->arg()[i].key() == name) {
      if (value != nullptr) {
        *args->mutable_arg(i)->mutable_value() = *value;
      } else {
        args->mutable_arg()->erase(args->mutable_arg()->begin() + i);
      }
      return;
    }
  }
  if (value != nullptr) {
    TemplateDict::Parameter* arg = args->add_arg();
    *arg->mutable_key() = name;
    *arg->mutable_value() = *value;
  }
}

// Creates a deep copy of a message.
std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
  std::unique_ptr<MessageLite> result(message.New());
  result->CheckTypeAndMergeFrom(message);
  return result;
}

// Parses one ProtoPathEntry.
// The parsed entry is appended to `result` and removed from `path`.
// ProtoPathEntry::key_value stores map key text.  Use SetMapKeyTypes
// to serialize the key text to protobuf wire format.
absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) {
  bool ok = true;
  size_t sb = path.find('[');
  size_t eb = path.find(']');
  int field_id = -1;
  ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id);
  auto selector = path.substr(sb + 1, eb - 1 - sb);
  if (absl::StartsWith(selector, "@")) {
    size_t eq = selector.find('=');
    int key_id = -1;
    ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id);
    auto key_text = selector.substr(eq + 1);
    FieldType key_type = FieldType::TYPE_STRING;
    result->push_back({field_id, key_id, key_type, std::string(key_text)});
  } else {
    int index = 0;
    ok &= absl::SimpleAtoi(selector, &index);
    result->push_back({field_id, index});
  }
  size_t end = path.find('/', eb);
  if (end == std::string::npos) {
    path = "";
  } else {
    path = path.substr(end + 1);
  }
  return ok ? absl::OkStatus()
            : absl::InvalidArgumentError(
                  absl::StrCat("Failed to parse ProtoPath entry: ", path));
}

// Specifies the FieldTypes for protobuf map keys in a ProtoPath.
// Each ProtoPathEntry::key_value is converted from text to the protobuf
// wire format for its key type.
absl::Status SetMapKeyTypes(const std::vector<FieldType>& key_types,
                            ProtoPath* result) {
  int i = 0;
  for (ProtoPathEntry& entry : *result) {
    if (entry.map_id >= 0) {
      FieldType key_type = key_types[i++];
      std::vector<FieldValue> key_value;
      MP_RETURN_IF_ERROR(
          ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value));
      entry.key_type = key_type;
      entry.key_value = key_value.front();
    }
  }
  return absl::OkStatus();
}

// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]",
// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]".
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
  result->clear();
  absl::string_view rest = path;
  if (absl::StartsWith(rest, "/")) {
    rest = rest.substr(1);
  }
  while (!rest.empty()) {
    MP_RETURN_IF_ERROR(ParseEntry(rest, result));
  }
  return absl::OkStatus();
}

// Parse the TemplateExpression.path field into a ProtoPath struct.
absl::Status ParseProtoPath(const TemplateExpression& rule,
                            std::string base_path, ProtoPath* result) {
  ProtoPath base_entries;
  MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries));
  MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result));
  std::vector<FieldType> key_types;
  for (int type : rule.key_type()) {
    key_types.push_back(static_cast<FieldType>(type));
  }
  MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result));
  result->erase(result->begin(), result->begin() + (int)base_entries.size());
  return absl::OkStatus();
}

// Returns true if one proto path is prefix by another.
bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
  return absl::StartsWith(path, prefix);
}

// Returns the target ProtoUtilLite::FieldType of a rule.
FieldType GetFieldType(const TemplateExpression& rule) {
  return static_cast<FieldType>(rule.field_type());
}

// Returns the count of field values at a ProtoPath.
int FieldCount(const FieldValue& base, ProtoPath field_path,
               FieldType field_type) {
  int result = 0;
  ABSL_CHECK_OK(
      ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result));
  return result;
}

}  // namespace

// The default implementation for the mediapipe template rule interpreter.
class TemplateExpanderImpl {
 public:
  explicit TemplateExpanderImpl(std::vector<absl::Status>* errors)
      : errors_(errors) {}

  // Applies the rules specified in a CalculatorGraphTemplate to a
  // CalculatorGraphConfig.  Each rule references a nested field-value or
  // message and defines zero or more replacement values for it.
  bool ExpandTemplates(const TemplateDict& args,
                       const CalculatorGraphTemplate& templ,
                       CalculatorGraphConfig* output) {
    // Extract the serialized CalculatorGraphConfig.
    FieldValue base_value;
    if (!templ.config().SerializeToString(&base_value)) {
      return false;
    }

    // Extract the CalculatorGraphTemplate rules.
    template_rules_ = templ;
    template_rules_.clear_config();

    // Invoke recursive rule expansion.
    environment_ = args;
    std::vector<FieldValue> result;
    if (!ExpandNestedRules(0, "", base_value, &result)) {
      return false;
    }
    return output->ParseFromString(result[0]);
  }

 private:
  // Expands a template rule of a specific type.
  // Modifies a base message to produce one or more expanded messages.
  // Ownership of the result messages is transferred to the caller.
  bool ExpandTemplateRule(int base_index, const FieldValue& base_message,
                          std::vector<FieldValue>* result) {
    // Exapand a template rule of a specific type.
    const TemplateExpression& rule = template_rules_.rule().Get(base_index);
    if (rule.op() == "for") {
      ExpandIterationRule(base_index, base_message, result);
    } else if (rule.op() == "if") {
      ExpandConditionalRule(base_index, base_message, result);
    } else if (rule.op() == "param") {
      ExpandDeclaration(base_index, base_message, result);
    } else {
      ExpandExpressionRule(base_index, result);
    }
    return true;
  }

  // Apply any remaining rules on the current field.
  // If the next rule also applies to the current field, apply it.
  // Otherwise, apply rules for nested fields.
  bool ExpandPeerRules(int base_index, const FieldValue& base_message,
                       std::vector<FieldValue>* result) {
    // If the next rule applies to the same message, apply it now.
    auto& base_rule = template_rules_.rule().Get(base_index);
    int next_index = base_index + 1;
    if (next_index < template_rules_.rule().size()) {
      auto& next_rule = template_rules_.rule().Get(next_index);
      if (next_rule.path() == base_rule.path()) {
        return ExpandTemplateRule(next_index, base_message, result);
      }
    }

    // Otheriwse, apply rules for nested fields.
    return ExpandNestedRules(next_index, base_rule.path(), base_message,
                             result);
  }

  // Return the field values addressed by a template rule.
  absl::Status GetBaseValue(const std::string& base_path,
                            const TemplateExpression& rule,
                            const FieldValue& output,
                            std::vector<FieldValue>* base) {
    if (!rule.has_path()) {
      base->push_back(output);
      return absl::OkStatus();
    }
    if (rule.has_field_value()) {
      // For a non-repeated field, the field value is stored only in the rule.
      base->push_back(rule.field_value());
      return absl::OkStatus();
    }
    ProtoPath field_path;
    MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
    return ProtoUtilLite::GetFieldRange(output, field_path, 1,
                                        GetFieldType(rule), base);
  }

  // Replace the field values addressed by a template rule.
  absl::Status ReplaceBaseValue(const std::string& base_path,
                                const TemplateExpression& rule,
                                const std::vector<FieldValue>& field_values,
                                FieldValue* output) {
    if (!rule.has_path()) {
      if (!field_values.empty()) {
        *output = field_values[0];
      }
      return absl::OkStatus();
    }
    ProtoPath field_path;
    MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
    int field_count = 1;
    if (rule.has_field_value()) {
      // For a non-repeated field, only one value can be specified.
      if (!field_values.empty() &&
          FieldCount(*output, field_path, GetFieldType(rule)) > 0) {
        return absl::InvalidArgumentError(absl::StrCat(
            "Multiple values specified for non-repeated field: ", rule.path()));
      }
      // For a non-repeated field, the field value is stored only in the rule.
      field_path[field_path.size() - 1].index = 0;
      field_count = 0;
    }
    return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,
                                            GetFieldType(rule), field_values);
  }

  // Replaces nested fields by following nested template rules.
  bool ExpandNestedRules(int base_index, const std::string& base_path,
                         const FieldValue& base_message,
                         std::vector<FieldValue>* result) {
    absl::Status status;
    FieldValue output = base_message;

    // Evaluate the rules nested below base_path in lexical order.
    std::vector<int> rules = GetNestedRules(base_index, base_path);
    std::vector<std::vector<FieldValue>> edits;
    for (int i = 0; i < (int)rules.size(); ++i) {
      const auto& rule = template_rules_.rule().Get(rules[i]);
      std::vector<FieldValue> base;
      status = GetBaseValue(base_path, rule, output, &base);
      if (!status.ok()) break;
      std::vector<FieldValue> values;
      if (!ExpandTemplateRule(rules[i], base[0], &values)) {
        status = absl::InternalError("ExpandTemplateRule failed");
        break;
      }
      edits.push_back(values);
    }
    if (!status.ok()) {
      RecordError(status);
      return false;
    }
    // Replace base field values with the evaluated results.
    // Edits are applied in reverse order since later indices are invalidated.
    for (int i = (int)edits.size() - 1; i >= 0; --i) {
      const auto& rule = template_rules_.rule().Get(rules[i]);
      status = ReplaceBaseValue(base_path, rule, edits[i], &output);
      if (!status.ok()) break;
    }
    if (!status.ok()) {
      RecordError(status);
      return false;
    }
    result->push_back(output);
    return true;
  }

  // Returns indexes of the rules directly nested within a certain rule.
  std::vector<int> GetNestedRules(int rule_index,
                                  const std::string& rule_path) {
    std::vector<int> result;
    std::string prev_path = "-1[-1]";
    for (int i = rule_index; i < template_rules_.rule().size(); ++i) {
      auto& rule = template_rules_.rule().Get(i);
      if (!ProtoPathStartsWith(rule.path(), rule_path)) {
        break;
      }
      if (!ProtoPathStartsWith(rule.path(), prev_path)) {
        result.push_back(i);
        prev_path = rule.path();
      }
    }
    return result;
  }

  // Apply a "for" operation to a base message.
  // Expands nested rules once for each iteration range value.
  bool ExpandIterationRule(int base_index, const FieldValue& base_message,
                           std::vector<FieldValue>* result) {
    // Retrieve the var param and the range expression.
    const TemplateExpression& rule = template_rules_.rule().Get(base_index);
    std::string var_param = rule.arg().Get(0).param();
    const TemplateExpression& range_expr = rule.arg().Get(1);
    TemplateArgument range = EvalExpression(range_expr);

    // For each value of the range param, expand all nested rules.
    TemplateArgument* shadow_item = GetItem(&environment_, var_param);
    for (const TemplateArgument& item : range.element()) {
      PutItem(&environment_, var_param, &item);
      ExpandPeerRules(base_index, base_message, result);
    }
    PutItem(&environment_, var_param, shadow_item);
    return true;
  }

  // Initializes a parameter in the parameter environment.
  bool ExpandDeclaration(int base_index, const FieldValue& base_message,
                         std::vector<FieldValue>* result) {
    // Retrieve the var param and the range expression.
    const TemplateExpression& rule = template_rules_.rule().Get(base_index);
    if (rule.arg().empty() || rule.arg().size() > 2) {
      RecordError(absl::InvalidArgumentError(
          "Param declaration must specify a parameter name and "
          "may specify a single default value."));
    }
    // TODO: Validate that all params are declared or none.
    // Delarations for required params will have no default value.
    if (rule.arg().size() == 2) {
      std::string var_param = rule.arg().Get(0).param();
      const TemplateExpression& item_expr = rule.arg().Get(1);
      TemplateArgument item = EvalExpression(item_expr);
      // The parameter default value is used if no other value is specified.
      if (GetItem(&environment_, var_param) == nullptr) {
        PutItem(&environment_, var_param, &item);
      }
    }
    ExpandPeerRules(base_index, base_message, result);
    return true;
  }

  // Applies an "if" operation to a base message.
  // Expands nested rules zero or more times.
  bool ExpandConditionalRule(int base_index, const FieldValue& base_message,
                             std::vector<FieldValue>* result) {
    // Retrieve the condition expression.
    const TemplateExpression& rule = template_rules_.rule().Get(base_index);
    // Expand this template zero or one times.
    bool condition = AsBool(EvalExpression(rule.arg(0)));
    if (condition) {
      ExpandPeerRules(base_index, base_message, result);
    }
    return true;
  }

  // A self-contained expression just defines a single result value.
  bool ExpandExpressionRule(int base_index, std::vector<FieldValue>* result) {
    const TemplateExpression& rule = template_rules_.rule().Get(base_index);
    TemplateArgument item = EvalExpression(rule);
    std::vector<FieldValue> values;
    absl::Status status = AsFieldValues(std::vector<TemplateArgument>{item},
                                        GetFieldType(rule), &values);
    if (!status.ok()) {
      RecordError(status);
      return false;
    }
    result->push_back(values[0]);
    return true;
  }

  // The "param" operation does variable environment lookup.
  TemplateArgument EvalParam(const TemplateExpression& expr) {
    TemplateArgument* result = GetItem(&environment_, expr.param());
    if (result == nullptr) {
      RecordError(absl::NotFoundError(absl::StrCat("param: ", expr.param())));
      return AsArgument(0.0);
    }
    return *result;
  }

  // The "." operator does template dict lookup.
  TemplateArgument EvalDot(const TemplateExpression& expr) {
    TemplateArgument lhs = EvalExpression(expr.arg(0));
    TemplateArgument* result = GetItem(lhs.mutable_dict(), expr.arg(1).param());
    if (result == nullptr) {
      RecordError(absl::NotFoundError(
          absl::StrCat("param field: ", expr.arg(1).param())));
      return AsArgument(0.0);
    }
    return *result;
  }

  // Converts a TemplateArgument to double.
  double AsNum(const TemplateArgument& value) {
    double result = 0;
    if (value.has_num()) {
      result = value.num();
    }
    if (value.has_str()) {
      if (!absl::SimpleAtod(value.str(), &result)) {
        RecordError(absl::InvalidArgumentError(value.str()));
      }
    }
    return result;
  }

  // Converts a TemplateArgument to string.
  std::string AsString(const TemplateArgument& value) {
    std::string result;
    if (value.has_num()) {
      result = absl::StrCat(value.num());
    }
    if (value.has_str()) {
      result = value.str();
    }
    return result;
  }

  // Converts a TemplateArgument to bool.
  bool AsBool(const TemplateArgument& value) {
    bool result = false;
    if (value.has_num()) {
      return value.num() != 0;
    } else if (value.has_str()) {
      if (!absl::SimpleAtob(value.str(), &result)) {
        RecordError(absl::InvalidArgumentError(value.str()));
      }
    }
    return result;
  }

  // Converts a vector of TemplateArguments to a dict TemplateArgument.
  TemplateArgument AsDict(const std::vector<TemplateArgument>& args) {
    TemplateArgument result;
    if (args.size() % 2 != 0) {
      RecordError(absl::InvalidArgumentError(absl::StrCat(
          "Dict requires an even number of arguments, got: ", args.size())));
      return result;
    }
    TemplateDict* dict = result.mutable_dict();
    for (int i = 0; i < (int)args.size(); i += 2) {
      TemplateDict::Parameter* p = dict->add_arg();
      *p->mutable_key() = AsString(args[i]);
      *p->mutable_value() = args[i + 1];
    }
    return result;
  }

  // Converts a vector of TemplateArguments to a list TemplateArgument.
  TemplateArgument AsList(const std::vector<TemplateArgument>& args) {
    TemplateArgument result;
    auto list = result.mutable_element();
    for (int i = 0; i < (int)args.size(); ++i) {
      *list->Add() = args[i];
    }
    return result;
  }

  // Evaluate each of the sub-expressions of a TemplateExpression.
  void EvalNestedExpressions(const TemplateExpression& expr,
                             std::vector<TemplateArgument>* result) {
    for (const TemplateExpression& e : expr.arg()) {
      result->push_back(EvalExpression(e));
    }
  }

  // Returns true if a TemplateArgument represents a number.
  bool IsNum(const TemplateArgument& value) {
    double r = 0;
    return value.has_num() || absl::SimpleAtod(value.str(), &r);
  }

  // Returns 0 if v1 == v1, positive if v1 > v2, negative if v1 < v2.
  int CompareArgs(const TemplateArgument& v1, const TemplateArgument& v2) {
    if (IsNum(v1) && IsNum(v2)) {
      double d = AsNum(v1) - AsNum(v2);
      return (d < 0) ? -1 : (d > 0) ? 1 : 0;
    } else {
      return AsString(v1).compare(AsString(v2));
    }
  }

  // Evaluates a TemplateExpression to produce a template argument.
  TemplateArgument EvalExpression(const TemplateExpression& expr) {
    if (expr.op() == "literal") {
      return AsArgument(expr.param());
    } else if (expr.op() == ".") {
      return EvalDot(expr);
    } else if (expr.has_param()) {
      return EvalParam(expr);
    }
    std::vector<TemplateArgument> args;
    EvalNestedExpressions(expr, &args);
    TemplateArgument result;
    if (expr.op() == "paren") {
      result = args[0];
    } else if (expr.op() == "+") {
      if (IsNum(args[0]) && IsNum(args[1])) {
        result = AsArgument(AsNum(args[0]) + AsNum(args[1]));
      } else {
        result = AsArgument(AsString(args[0]) + AsString(args[1]));
      }
    } else if (expr.op() == "-") {
      result = AsArgument(AsNum(args[0]) - AsNum(args[1]));
    } else if (expr.op() == "*") {
      result = AsArgument(AsNum(args[0]) * AsNum(args[1]));
    } else if (expr.op() == "/") {
      result = AsArgument(AsNum(args[0]) / AsNum(args[1]));
    } else if (expr.op() == ">") {
      result = AsArgument(CompareArgs(args[0], args[1]) > 0);
    } else if (expr.op() == "<") {
      result = AsArgument(CompareArgs(args[0], args[1]) < 0);
    } else if (expr.op() == ">=") {
      result = AsArgument(CompareArgs(args[0], args[1]) >= 0);
    } else if (expr.op() == "<=") {
      result = AsArgument(CompareArgs(args[0], args[1]) <= 0);
    } else if (expr.op() == "==") {
      result = AsArgument(CompareArgs(args[0], args[1]) == 0);
    } else if (expr.op() == "!=") {
      result = AsArgument(CompareArgs(args[0], args[1]) != 0);
    } else if (expr.op() == "&&") {
      result = AsArgument(AsBool(args[0]) && AsBool(args[1]));
    } else if (expr.op() == "||") {
      result = AsArgument(AsBool(args[0]) || AsBool(args[1]));
    } else if (expr.op() == "!") {
      result = AsArgument(!(AsBool(args[0])));
    } else if (expr.op() == "min") {
      result = AsArgument(std::min(AsNum(args[0]), AsNum(args[1])));
    } else if (expr.op() == "max") {
      result = AsArgument(std::max(AsNum(args[0]), AsNum(args[1])));
    } else if (expr.op() == "concat") {
      result = AsArgument(AsString(args[0]) + AsString(args[1]));
    } else if (expr.op() == "lowercase") {
      result = AsArgument(absl::AsciiStrToLower(AsString(args[0])));
    } else if (expr.op() == "uppercase") {
      result = AsArgument(absl::AsciiStrToUpper(AsString(args[0])));
    } else if (expr.op() == "dict") {
      result = AsDict(args);
    } else if (expr.op() == "list") {
      result = AsList(args);
    } else if (expr.op() == "size") {
      return AsArgument(static_cast<double>(
          args[0].has_dict() ? args[0].mutable_dict()->arg_size()
                             : args[0].mutable_element()->size()));
    }
    return result;
  }

  // Converts a simple value to a template argument for further processing.
  TemplateArgument AsArgument(const std::string& value) {
    TemplateArgument result;
    result.set_str(value);
    return result;
  }

  // Converts a simple value to a template argument for further processing.
  TemplateArgument AsArgument(double value) {
    TemplateArgument result;
    result.set_num(value);
    return result;
  }

  // Converts a boolean result into a template argument for further processing.
  TemplateArgument AsArgument(bool b) {
    return AsArgument(static_cast<double>(b));
  }

  // Convert between a proto field value and a template argument.
  absl::Status AsFieldValues(const std::vector<TemplateArgument>& args,
                             FieldType field_type,
                             std::vector<FieldValue>* result) {
    for (int i = 0; i < (int)args.size(); ++i) {
      if (args[i].has_dict()) {
        FieldValue dict_bytes;
        ABSL_CHECK(args[i].dict().SerializePartialToString(&dict_bytes));
        result->push_back(dict_bytes);
      } else if (args[i].has_num() || args[i].has_str()) {
        std::string text_value = args[i].has_num()
                                     ? mediapipe::SimpleDtoa(args[i].num())
                                     : args[i].str();
        std::vector<FieldValue> r;
        MP_RETURN_IF_ERROR(
            ProtoUtilLite::Serialize({text_value}, field_type, &r));
        result->push_back(r[0]);
      }
    }
    return absl::OkStatus();
  }

  // Record a Status if it indicates an error.
  void RecordError(const absl::Status& status) {
    if (!status.ok()) {
      errors_->push_back(status);
    }
  }

 private:
  // The list of template rules.
  mediapipe::CalculatorGraphTemplate template_rules_;

  // The template variable environment.
  TemplateDict environment_;

  // List of errors found in template parameters.
  std::vector<absl::Status>* errors_;
};

TemplateExpander::TemplateExpander() {}

// Expands template rules within a proto message.
// Replaces template rules with expanded sub-messages.
absl::Status TemplateExpander::ExpandTemplates(
    const TemplateDict& args, const CalculatorGraphTemplate& templ,
    CalculatorGraphConfig* output) {
  errors_.clear();
  TemplateExpanderImpl expander(&errors_);
  if (!expander.ExpandTemplates(args, templ, output)) {
    errors_.push_back(absl::InternalError("ExpandTemplates failed"));
  }
  absl::Status status;
  for (const absl::Status& error : errors_) {
    ABSL_LOG(ERROR) << error;
    status.Update(error);
  }
  return status;
}

}  // namespace tool
}  // namespace mediapipe