Skip to content

File packet_generator_graph.cc

File List > framework > packet_generator_graph.cc

Go to the documentation of this file

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/framework/packet_generator_graph.h"

#include <deque>
#include <functional>
#include <memory>
#include <utility>

#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/delegating_executor.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/thread_pool_executor.h"
#include "mediapipe/framework/tool/status_util.h"

namespace mediapipe {

namespace {

// Create the input side packet set for a generator (provided by
// index in the canonical config).  unrunnable is set to true if the
// generator cannot be run given the currently available side packets
// (and false otherwise).  If an error occurs then unrunnable and
// input_side_packet_set are undefined.
absl::Status CreateInputsForGenerator(
    const ValidatedGraphConfig& validated_graph, int generator_index,
    const std::map<std::string, Packet>& side_packets,
    PacketSet* input_side_packet_set, bool* unrunnable) {
  const NodeTypeInfo& node_type_info =
      validated_graph.GeneratorInfos()[generator_index];
  const auto& generator_name = validated_graph.Config()
                                   .packet_generator(generator_index)
                                   .packet_generator();
  // Fill the PacketSet (if possible).
  *unrunnable = false;
  std::vector<absl::Status> statuses;
  for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId();
       id < node_type_info.InputSidePacketTypes().EndId(); ++id) {
    const std::string& name =
        node_type_info.InputSidePacketTypes().TagMap()->Names()[id.value()];

    std::map<std::string, Packet>::const_iterator it = side_packets.find(name);
    if (it == side_packets.end()) {
      *unrunnable = true;
      continue;
    }
    input_side_packet_set->Get(id) = it->second;
    absl::Status status =
        node_type_info.InputSidePacketTypes().Get(id).Validate(
            input_side_packet_set->Get(id));
    if (!status.ok()) {
      statuses.push_back(tool::AddStatusPrefix(
          absl::StrCat("Input side packet \"", name,
                       "\" for PacketGenerator \"", generator_name,
                       "\" is not of the correct type: "),
          status));
    }
  }
  if (!statuses.empty()) {
    return tool::CombinedStatus(
        absl::StrCat(generator_name, " had invalid configuration."), statuses);
  }
  return absl::OkStatus();
}

// Generate the packets from a PacketGenerator, place them in
// output_side_packet_set, and validate their types.
absl::Status Generate(const ValidatedGraphConfig& validated_graph,
                      int generator_index,
                      const PacketSet& input_side_packet_set,
                      PacketSet* output_side_packet_set) {
  const NodeTypeInfo& node_type_info =
      validated_graph.GeneratorInfos()[generator_index];
  const PacketGeneratorConfig& generator_config =
      validated_graph.Config().packet_generator(generator_index);
  const auto& generator_name = generator_config.packet_generator();

  MP_ASSIGN_OR_RETURN(
      auto static_access,
      internal::StaticAccessToGeneratorRegistry::CreateByNameInNamespace(
          validated_graph.Package(), generator_name),
      _ << generator_name << " is not a valid PacketGenerator.");
  MP_RETURN_IF_ERROR(static_access->Generate(generator_config.options(),
                                             input_side_packet_set,
                                             output_side_packet_set))
          .SetPrepend()
      << generator_name << "::Generate() failed. ";

  MP_RETURN_IF_ERROR(ValidatePacketSet(node_type_info.OutputSidePacketTypes(),
                                       *output_side_packet_set))
          .SetPrepend()
      << generator_name
      << "::Generate() output packets were of incorrect type: ";
  return absl::OkStatus();
}

// GeneratorScheduler schedules the packet generators in a validated graph for
// execution on an executor.
class GeneratorScheduler {
 public:
  // If "executor" is null, a DelegatingExecutor will be created internally.
  // "initial" must be set to true for the first pass and false for subsequent
  // passes. If "initial" is false, non_base_generators contains the non-base
  // PacketGenerators (those not run at initialize time due to missing
  // dependencies).
  GeneratorScheduler(const ValidatedGraphConfig* validated_graph,
                     mediapipe::Executor* executor,
                     const std::vector<int>& non_base_generators, bool initial);

  // Run a PacketGenerator on a given executor on the provided input
  // side packets.  After running the generator, schedule any generators
  // which became runnable.
  void GenerateAndScheduleNext(int generator_index,
                               std::map<std::string, Packet>* side_packets,
                               std::unique_ptr<PacketSet> input_side_packet_set)
      ABSL_LOCKS_EXCLUDED(mutex_);

  // Iterate through all generators in the config, scheduling any that
  // are runnable (and haven't been scheduled yet).
  void ScheduleAllRunnableGenerators(
      std::map<std::string, Packet>* side_packets) ABSL_LOCKS_EXCLUDED(mutex_);

  // Waits until there are no pending tasks.
  void WaitUntilIdle() ABSL_LOCKS_EXCLUDED(mutex_);

  // Stores the indexes of the packet generators that were not scheduled (or
  // rather, not executed) in non_scheduled_generators. Returns the combined
  // error status if there were errors while running the packet generators.
  // NOTE: This method should only be called when there are no pending tasks.
  absl::Status GetNonScheduledGenerators(
      std::vector<int>* non_scheduled_generators) const;

 private:
  // Called by delegating_executor_ to add a task.
  void AddApplicationThreadTask(std::function<void()> task);

  // Run all the application thread tasks (which are kept track of in
  // app_thread_tasks_).
  void RunApplicationThreadTasks() ABSL_LOCKS_EXCLUDED(app_thread_mutex_);

  const ValidatedGraphConfig* const validated_graph_;
  mediapipe::Executor* executor_;

  mutable absl::Mutex mutex_;
  // The number of pending tasks.
  int num_tasks_ ABSL_GUARDED_BY(mutex_) = 0;
  // This condition variable is signaled when num_tasks_ becomes 0.
  absl::CondVar idle_condvar_;
  // Accumulates the error statuses while running the packet generators.
  std::vector<absl::Status> statuses_ ABSL_GUARDED_BY(mutex_);
  // scheduled_generators_[i] is true if the packet generator with index i was
  // scheduled (or rather, executed).
  std::vector<bool> scheduled_generators_ ABSL_GUARDED_BY(mutex_);

  absl::Mutex app_thread_mutex_;
  // Tasks to be executed on the application thread.
  std::deque<std::function<void()>> app_thread_tasks_
      ABSL_GUARDED_BY(app_thread_mutex_);
  std::unique_ptr<internal::DelegatingExecutor> delegating_executor_;
};

GeneratorScheduler::GeneratorScheduler(
    const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor,
    const std::vector<int>& non_base_generators, bool initial)
    : validated_graph_(validated_graph),
      executor_(executor),
      scheduled_generators_(validated_graph_->Config().packet_generator_size(),
                            !initial) {
  if (!executor_) {
    // Run on the application thread.
    delegating_executor_ = absl::make_unique<internal::DelegatingExecutor>(
        std::bind(&GeneratorScheduler::AddApplicationThreadTask, this,
                  std::placeholders::_1));
    executor_ = delegating_executor_.get();
  }

  if (!initial) {
    // Only schedule the non-base generators.
    for (int generator_index : non_base_generators) {
      scheduled_generators_[generator_index] = false;
    }
  }
}

void GeneratorScheduler::GenerateAndScheduleNext(
    int generator_index, std::map<std::string, Packet>* side_packets,
    std::unique_ptr<PacketSet> input_side_packet_set) {
  {
    absl::MutexLock lock(&mutex_);
    if (!statuses_.empty()) {
      // Return early, don't run the generator if we already have errors.
      return;
    }
  }
  PacketSet output_side_packet_set(
      validated_graph_->GeneratorInfos()[generator_index]
          .OutputSidePacketTypes()
          .TagMap());
  VLOG(1) << "Running generator " << generator_index;
  absl::Status status =
      Generate(*validated_graph_, generator_index, *input_side_packet_set,
               &output_side_packet_set);

  {
    absl::MutexLock lock(&mutex_);
    if (!status.ok()) {
      statuses_.push_back(std::move(status));
      return;
    }
    // Add packets to side_packets .
    for (CollectionItemId id = output_side_packet_set.BeginId();
         id < output_side_packet_set.EndId(); ++id) {
      const auto& name = output_side_packet_set.TagMap()->Names()[id.value()];
      auto item = side_packets->emplace(name, output_side_packet_set.Get(id));
      if (!item.second) {
        statuses_.push_back(absl::AlreadyExistsError(
            absl::StrCat("Side packet \"", name, "\" was defined twice.")));
      }
    }
    if (!statuses_.empty()) {
      return;
    }
  }

  // Check all generators and schedule any that have become runnable.
  // TODO Instead of checking all of them, only check ones
  // that have input side packets which we have just produced.
  ScheduleAllRunnableGenerators(side_packets);
}

void GeneratorScheduler::ScheduleAllRunnableGenerators(
    std::map<std::string, Packet>* side_packets) {
  absl::MutexLock lock(&mutex_);
  const auto& generators = validated_graph_->Config().packet_generator();

  for (int index = 0; index < generators.size(); ++index) {
    if (scheduled_generators_[index]) {
      continue;
    }
    bool is_unrunnable = false;
    // TODO Input side packet set should only be created once.
    auto input_side_packet_set =
        absl::make_unique<PacketSet>(validated_graph_->GeneratorInfos()[index]
                                         .InputSidePacketTypes()
                                         .TagMap());

    absl::Status status =
        CreateInputsForGenerator(*validated_graph_, index, *side_packets,
                                 input_side_packet_set.get(), &is_unrunnable);
    if (!status.ok()) {
      statuses_.push_back(std::move(status));
      continue;
    }
    if (is_unrunnable) {
      continue;
    }
    // The Generator is runnable, schedule a callback to run it.
    scheduled_generators_[index] = true;
    VLOG(1) << "Scheduling generator " << index;
    // Get around the fact that we can't capture a unique_ptr (this
    // means a memory leak will result if the lambda is not run).
    PacketSet* input_side_packet_set_ptr = input_side_packet_set.release();
    ++num_tasks_;
    mutex_.Unlock();
    executor_->Schedule(
        [this, index, side_packets, input_side_packet_set_ptr]() {
          GenerateAndScheduleNext(
              index, side_packets,
              std::unique_ptr<PacketSet>(input_side_packet_set_ptr));
          {
            absl::MutexLock lock(&mutex_);
            --num_tasks_;
            if (num_tasks_ == 0) {
              idle_condvar_.Signal();
            }
          }
        });
    mutex_.Lock();
  }
}

void GeneratorScheduler::WaitUntilIdle() {
  if (executor_ == delegating_executor_.get()) {
    // Run the tasks on the application thread.
    RunApplicationThreadTasks();
  } else {
    absl::MutexLock lock(&mutex_);
    while (num_tasks_ != 0) {
      idle_condvar_.Wait(&mutex_);
    }
  }
}

absl::Status GeneratorScheduler::GetNonScheduledGenerators(
    std::vector<int>* non_scheduled_generators) const {
  non_scheduled_generators->clear();

  absl::MutexLock lock(&mutex_);
  if (!statuses_.empty()) {
    return tool::CombinedStatus("PacketGeneratorGraph failed.", statuses_);
  }
  for (int i = 0; i < (int)scheduled_generators_.size(); ++i) {
    if (!scheduled_generators_[i]) {
      non_scheduled_generators->push_back(i);
    }
  }
  return absl::OkStatus();
}

void GeneratorScheduler::AddApplicationThreadTask(std::function<void()> task) {
  absl::MutexLock lock(&app_thread_mutex_);
  app_thread_tasks_.push_back(std::move(task));
}

void GeneratorScheduler::RunApplicationThreadTasks() {
  while (true) {
    std::function<void()> task_callback;
    {
      // Get the next task.
      absl::MutexLock lock(&app_thread_mutex_);
      if (app_thread_tasks_.empty()) {
        break;
      }
      task_callback = std::move(app_thread_tasks_.front());
      app_thread_tasks_.pop_front();
    }
    // Run the next task.  Don't hold any lock, since this task could
    // schedule further tasks to be run on the application thread.
    task_callback();
  }
}

}  // namespace

PacketGeneratorGraph::~PacketGeneratorGraph() {}

absl::Status PacketGeneratorGraph::Initialize(
    const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor,
    const std::map<std::string, Packet>& input_side_packets) {
  validated_graph_ = validated_graph;
  executor_ = executor;
  base_packets_ = input_side_packets;
  MP_RETURN_IF_ERROR(
      validated_graph_->CanAcceptSidePackets(input_side_packets));
  return ExecuteGenerators(&base_packets_, &non_base_generators_,
                           /*initial=*/true);
}

absl::Status PacketGeneratorGraph::RunGraphSetup(
    const std::map<std::string, Packet>& input_side_packets,
    std::map<std::string, Packet>* output_side_packets,
    std::vector<int>* non_scheduled_generators) const {
  *output_side_packets = base_packets_;
  for (const std::pair<const std::string, Packet>& item : input_side_packets) {
    auto iter = output_side_packets->find(item.first);
    if (iter != output_side_packets->end()) {
      return absl::AlreadyExistsError(
          absl::StrCat("Side packet \"", iter->first, "\" was defined twice."));
    }
    output_side_packets->insert(iter, item);
  }
  std::vector<int> non_scheduled_generators_local;
  if (!non_scheduled_generators)
    non_scheduled_generators = &non_scheduled_generators_local;

  MP_RETURN_IF_ERROR(
      validated_graph_->CanAcceptSidePackets(input_side_packets));
  // This type check on the required side packets is redundant with
  // error checking in ExecuteGenerators, but we do it now to fail early.
  MP_RETURN_IF_ERROR(
      validated_graph_->ValidateRequiredSidePackets(*output_side_packets));
  MP_RETURN_IF_ERROR(ExecuteGenerators(
      output_side_packets, non_scheduled_generators, /*initial=*/false));
  return absl::OkStatus();
}

absl::Status PacketGeneratorGraph::ExecuteGenerators(
    std::map<std::string, Packet>* output_side_packets,
    std::vector<int>* non_scheduled_generators, bool initial) const {
  VLOG(1) << "ExecuteGenerators initial == " << initial;

  // Iterate through the generators and produce as many output
  // side packets as we can. The generators that don't have all the
  // required input side packets are put into non_scheduled_generators.
  // The ValidatedGraphConfig object is expected to already have sorted
  // generators in topological order.
  GeneratorScheduler scheduler(validated_graph_, executor_,
                               non_base_generators_, initial);
  scheduler.ScheduleAllRunnableGenerators(output_side_packets);
  // Do not return early if scheduler encountered an error.  The lambdas
  // in the executor must run in order to free resources.

  scheduler.WaitUntilIdle();

  // It is safe to return now, since all the tasks have run.
  return scheduler.GetNonScheduledGenerators(non_scheduled_generators);
}

}  // namespace mediapipe