Ray_workflow Note

本文记录 ray 框架的 workflow

WorkFlow

sending task

1
2
3
4
# actor.py 949
object_refs = worker.core_worker.submit_actor_task(
    self._ray_actor_language, self._ray_actor_id, function_descriptor,
    list_args, name, num_returns, self._ray_actor_method_cpus)

此处调用的是core_worker.ccSubmitActorTask 函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// core_worker.cc 1951
std::vector<rpc::ObjectReference> CoreWorker::SubmitActorTask(
    const ActorID &actor_id, const RayFunction &function,
    const std::vector<std::unique_ptr<TaskArg>> &args, const TaskOptions &task_options) {
  auto actor_handle = actor_manager_->GetActorHandle(actor_id);

  // Add one for actor cursor object id for tasks.
  const int num_returns = task_options.num_returns + 1;

  // Build common task spec.
  TaskSpecBuilder builder;
  const auto next_task_index = worker_context_.GetNextTaskIndex();
  const TaskID actor_task_id = TaskID::ForActorTask(
      worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
      next_task_index, actor_handle->GetActorID());
  const std::unordered_map<std::string, double> required_resources;
  const auto task_name = task_options.name.empty()
                             ? function.GetFunctionDescriptor()->DefaultTaskName()
                             : task_options.name;
  BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name,
                      worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
                      rpc_address_, function, args, num_returns, task_options.resources,
                      required_resources, std::make_pair(PlacementGroupID::Nil(), -1),
                      true, /* placement_group_capture_child_tasks */
                      "",   /* debugger_breakpoint */
                      "{}", /* serialized_runtime_env */
                      {},   /* runtime_env_uris */
                      task_options.concurrency_group_name);
  // NOTE: placement_group_capture_child_tasks and runtime_env will
  // be ignored in the actor because we should always follow the actor's option.

  // TODO(swang): Do we actually need to set this ObjectID?
  const ObjectID new_cursor = ObjectID::FromIndex(actor_task_id, num_returns);
  actor_handle->SetActorTaskSpec(builder, new_cursor);

  // Submit task.
  TaskSpecification task_spec = builder.Build();
  std::vector<rpc::ObjectReference> returned_refs;
  if (options_.is_local_mode) {
    returned_refs = ExecuteTaskLocalMode(task_spec, actor_id);
  } else {
    returned_refs = task_manager_->AddPendingTask(
        rpc_address_, task_spec, CurrentCallSite(), actor_handle->MaxTaskRetries());
    io_service_.post(
        [this, task_spec]() {
          RAY_UNUSED(direct_actor_submitter_->SubmitTask(task_spec));
        },
        "CoreWorker.SubmitActorTask");
  }
  return returned_refs;
}

SubmitActorTask 中的 RayFunction 是一个类 FunctionDescriptorInterfact

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// function_descriptor.h 26
class FunctionDescriptorInterface : public MessageWrapper<rpc::FunctionDescriptor> {
 public:
  virtual ~FunctionDescriptorInterface() {}

  /// Construct an empty FunctionDescriptor.
  FunctionDescriptorInterface() : MessageWrapper() {}

  /// Construct from a protobuf message object.
  /// The input message will be **copied** into this object.
  ///
  /// \param message The protobuf message.
  FunctionDescriptorInterface(rpc::FunctionDescriptor message)
      : MessageWrapper(std::move(message)) {}

  ray::FunctionDescriptorType Type() const {
    return message_->function_descriptor_case();
  }

  virtual size_t Hash() const = 0;

  // DO NOT define operator==() or operator!=() in the base class.
  // Let the derived classes define and implement.
  // This is to avoid unexpected behaviors when comparing function descriptors of
  // different declard types, as in this case, the base class version is invoked.

  virtual std::string ToString() const = 0;

  // A one-word summary of the function call site (e.g., __main__.foo).
  virtual std::string CallSiteString() const { return CallString(); }

  // The function or method call, e.g. "foo()" or "Bar.foo()". This does not include the
  // module/library.
  virtual std::string CallString() const = 0;

  // The default name for a task that executes this function.
  virtual std::string DefaultTaskName() const { return CallString() + "()"; }

  template <typename Subtype>
  Subtype *As() {
    return reinterpret_cast<Subtype *>(this);
  }
};

运行 plot_pong_example.py 可得 RayFunction

ToSting: {type=PythonFunctionDescriptor, module_name=plot_pong_example, class_name=RolloutWorker, function_name=compute_gradient, function_hash=}
CallSiteSting: plot_pong_example.RolloutWorker.compute_gradient
CallSting: RolloutWorker.compute_gradient
DefaultTaskName: RolloutWorker.compute_gradient()

SubmitActorTask 中最后 调用的 io_service.post()

1
2
3
4
5
 io_service_.post(
        [this, task_spec]() {
          RAY_UNUSED(direct_actor_submitter_->SubmitTask(task_spec));
        },
        "CoreWorker.SubmitActorTask");

是一个继承 boost::asio::io_context 的类, 调用 post 后会回调传入的 lambda 函数

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
// instrumented_io_context.h 97
class instrumented_io_context : public boost::asio::io_context {
 public:
  /// Initializes the global stats struct after calling the base contructor.
  instrumented_io_context() : global_stats_(std::make_shared<GuardedGlobalStats>()) {}

  /// A proxy post function that collects count, queueing, and execution statistics for
  /// the given handler.
  ///
  /// \param handler The handler to be posted to the event loop.
  /// \param name A human-readable name for the handler, to be used for viewing stats
  /// for the provided handler. Defaults to UNKNOWN.
  void post(std::function<void()> handler, const std::string name = "UNKNOWN")
      LOCKS_EXCLUDED(mutex_);

  /// A proxy post function where the operation start is manually recorded. For example,
  /// this is useful for tracking the number of active outbound RPC calls.
  ///
  /// \param handler The handler to be posted to the event loop.
  /// \param handle The stats handle returned by RecordStart() previously.
  void post(std::function<void()> handler, std::shared_ptr<StatsHandle> handle)
      LOCKS_EXCLUDED(mutex_);

  /// A proxy post function that collects count, queueing, and execution statistics for
  /// the given handler.
  ///
  /// \param handler The handler to be posted to the event loop.
  /// \param name A human-readable name for the handler, to be used for viewing stats
  /// for the provided handler. Defaults to UNKNOWN.
  void dispatch(std::function<void()> handler, const std::string name = "UNKNOWN")
      LOCKS_EXCLUDED(mutex_);

  /// Sets the queueing start time, increments the current and cumulative counts and
  /// returns an opaque handle for these stats. This is used in conjunction with
  /// RecordExecution() to manually instrument an event loop handler that doesn't call
  /// .post().
  ///
  /// The returned opaque stats handle should be given to a subsequent RecordExecution()
  /// call.
  ///
  /// \param name A human-readable name to which collected stats will be associated.
  /// \param expected_queueing_delay_ns How much to pad the observed queueing start time,
  ///  in nanoseconds.
  /// \return An opaque stats handle, to be given to RecordExecution().
  std::shared_ptr<StatsHandle> RecordStart(const std::string &name,
                                           int64_t expected_queueing_delay_ns = 0);

  /// Records stats about the provided function's execution. This is used in conjunction
  /// with RecordStart() to manually instrument an event loop handler that doesn't call
  /// .post().
  ///
  /// \param fn The function to execute and instrument.
  /// \param handle An opaque stats handle returned by RecordStart().
  static void RecordExecution(const std::function<void()> &fn,
                              std::shared_ptr<StatsHandle> handle);

  /// Returns a snapshot view of the global count, queueing, and execution statistics
  /// across all handlers.
  ///
  /// \return A snapshot view of the global handler stats.
  GlobalStats get_global_stats() const;

  /// Returns a snapshot view of the count, queueing, and execution statistics for the
  /// provided handler.
  ///
  /// \param handler_name The name of the handler whose stats should be returned.
  /// \return A snapshot view of the handler's stats.
  absl::optional<HandlerStats> get_handler_stats(const std::string &handler_name) const
      LOCKS_EXCLUDED(mutex_);

  /// Returns snapshot views of the count, queueing, and execution statistics for all
  /// handlers.
  ///
  /// \return A vector containing snapshot views of stats for all handlers.
  std::vector<std::pair<std::string, HandlerStats>> get_handler_stats() const
      LOCKS_EXCLUDED(mutex_);

  /// Builds and returns a statistics summary string. Used by the DebugString() of
  /// objects that used this io_context wrapper, such as the raylet and the core worker.
  ///
  /// \return A stats summary string, suitable for inclusion in an object's
  /// DebugString().
  std::string StatsString() const LOCKS_EXCLUDED(mutex_);

 private:
  using HandlerStatsTable =
      absl::flat_hash_map<std::string, std::shared_ptr<GuardedHandlerStats>>;
  /// Get the mutex-guarded stats for this handler if it exists, otherwise create the
  /// stats for this handler and return an iterator pointing to it.
  ///
  /// \param name A human-readable name for the handler, to be used for viewing stats
  /// for the provided handler.
  std::shared_ptr<GuardedHandlerStats> GetOrCreate(const std::string &name);

  /// Global stats, across all handlers.
  std::shared_ptr<GuardedGlobalStats> global_stats_;

  /// Table of per-handler post stats.
  /// We use a std::shared_ptr value in order to ensure pointer stability.
  HandlerStatsTable post_handler_stats_ GUARDED_BY(mutex_);

  /// Protects access to the per-handler post stats table.
  mutable absl::Mutex mutex_;
};

post 内回调的函数 direct_actor_submitter_->SubmitTask(task_spec) , 其中 direct_actor_submitter_ 是类 CoreWorkerDirectActorTaskSubmitter (direct_actor_transport.h 66). 最后在函数void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks()内,通过调用 direct_actor_transport.h 内的 PushActorTask 函数,在函数内进行的 RPC 调用传递 task

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// direct_actor_transport.cc 268
void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks(const ActorID &actor_id) {
  auto it = client_queues_.find(actor_id);
  RAY_CHECK(it != client_queues_.end());
  if (!it->second.rpc_client) {
    return;
  }
  auto &client_queue = it->second;

  // Check if there is a pending force kill. If there is, send it and disconnect the
  // client.
  if (client_queue.pending_force_kill) {
    RAY_LOG(INFO) << "Sending KillActor request to actor " << actor_id;
    // It's okay if this fails because this means the worker is already dead.
    client_queue.rpc_client->KillActor(*client_queue.pending_force_kill, nullptr);
    client_queue.pending_force_kill.reset();
  }

  // Submit all pending requests.
  auto &requests = client_queue.requests;
  auto head = requests.begin();
  while (head != requests.end() &&
         (/*seqno*/ head->first <= client_queue.next_send_position) &&
         (/*dependencies_resolved*/ head->second.second)) {
    // If the task has been sent before, skip the other tasks in the send
    // queue.
    bool skip_queue = head->first < client_queue.next_send_position;
    auto task_spec = std::move(head->second.first);
    head = requests.erase(head);

    RAY_CHECK(!client_queue.worker_id.empty());
    PushActorTask(client_queue, task_spec, skip_queue);
    client_queue.next_send_position++;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
// direct_actor_transport.h 228
/// Push a task to a remote actor via the given client.
  /// Note, this function doesn't return any error status code. If an error occurs while
  /// sending the request, this task will be treated as failed.
  ///
  /// \param[in] queue The actor queue. Contains the RPC client state.
  /// \param[in] task_spec The task to send.
  /// \param[in] skip_queue Whether to skip the task queue. This will send the
  /// task for execution immediately.
  /// \return Void.
  void PushActorTask(const ClientQueue &queue, const TaskSpecification &task_spec,
                     bool skip_queue) EXCLUSIVE_LOCKS_REQUIRED(mu_);
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
// direct_actor_transport.cc 322
void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue,
                                                       const TaskSpecification &task_spec,
                                                       bool skip_queue) {
  auto request = std::make_unique<rpc::PushTaskRequest>();
  // NOTE(swang): CopyFrom is needed because if we use Swap here and the task
  // fails, then the task data will be gone when the TaskManager attempts to
  // access the task.
  request->mutable_task_spec()->CopyFrom(task_spec.GetMessage());

  request->set_intended_worker_id(queue.worker_id);
  RAY_CHECK(task_spec.ActorCounter() >= queue.caller_starts_at)
      << "actor counter " << task_spec.ActorCounter() << " " << queue.caller_starts_at;
  request->set_sequence_number(task_spec.ActorCounter() - queue.caller_starts_at);

  const auto task_id = task_spec.TaskId();
  const auto actor_id = task_spec.ActorId();
  const auto actor_counter = task_spec.ActorCounter();
  const auto task_skipped = task_spec.GetMessage().skip_execution();
  const auto num_queued =
      request->sequence_number() - queue.rpc_client->ClientProcessedUpToSeqno();
  RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id
                 << " actor counter " << actor_counter << " seq no "
                 << request->sequence_number() << " num queued " << num_queued;
  if (num_queued >= next_queueing_warn_threshold_) {
    // TODO(ekl) add more debug info about the actor name, etc.
    warn_excess_queueing_(actor_id, num_queued);
    next_queueing_warn_threshold_ *= 2;
  }

  rpc::Address addr(queue.rpc_client->Addr());
  queue.rpc_client->PushActorTask(
      std::move(request), skip_queue,
      [this, addr, task_id, actor_id, actor_counter, task_spec, task_skipped](
          Status status, const rpc::PushTaskReply &reply) {
        bool increment_completed_tasks = true;

        if (task_skipped) {
          // NOTE(simon):Increment the task counter regardless of the status because the
          // reply for a previously completed task. We are not calling CompletePendingTask
          // because the tasks are pushed directly to the actor, not placed on any queues
          // in task_finisher_.
        } else if (status.ok()) {
          task_finisher_.CompletePendingTask(task_id, reply, addr);
        } else {
          // push task failed due to network error. For example, actor is dead
          // and no process response for the push task.
          absl::MutexLock lock(&mu_);
          auto queue_pair = client_queues_.find(actor_id);
          RAY_CHECK(queue_pair != client_queues_.end());
          auto &queue = queue_pair->second;

          bool immediately_mark_object_fail = (queue.state == rpc::ActorTableData::DEAD);
          bool will_retry = task_finisher_.PendingTaskFailed(
              task_id, rpc::ErrorType::ACTOR_DIED, &status, queue.creation_task_exception,
              immediately_mark_object_fail);
          if (will_retry) {
            increment_completed_tasks = false;
          } else if (!immediately_mark_object_fail) {
            // put it to wait_for_death_info_tasks and wait for Death info
            int64_t death_info_timeout_ts =
                current_time_ms() +
                RayConfig::instance().timeout_ms_task_wait_for_death_info();
            queue.wait_for_death_info_tasks.emplace_back(death_info_timeout_ts,
                                                         task_spec);
            RAY_LOG(INFO)
                << "PushActorTask failed because of network error, this task "
                   "will be stashed away and waiting for Death info from GCS, task_id="
                << task_spec.TaskId()
                << ", wait queue size=" << queue.wait_for_death_info_tasks.size();
          }
        }

        if (increment_completed_tasks) {
          absl::MutexLock lock(&mu_);
          auto queue_pair = client_queues_.find(actor_id);
          RAY_CHECK(queue_pair != client_queues_.end());
          auto &queue = queue_pair->second;

          // Try to increment queue.next_task_reply_position consecutively until we
          // cannot. In the case of tasks not received in order, the following block
          // ensure queue.next_task_reply_position are incremented to the max possible
          // value.
          queue.out_of_order_completed_tasks.insert({actor_counter, task_spec});
          auto min_completed_task = queue.out_of_order_completed_tasks.begin();
          while (min_completed_task != queue.out_of_order_completed_tasks.end()) {
            if (min_completed_task->first == queue.next_task_reply_position) {
              queue.next_task_reply_position++;
              // increment the iterator and erase the old value
              queue.out_of_order_completed_tasks.erase(min_completed_task++);
            } else {
              break;
            }
          }

          RAY_LOG(DEBUG) << "Got PushTaskReply for actor " << actor_id
                         << " with actor_counter " << actor_counter
                         << " new queue.next_task_reply_position is "
                         << queue.next_task_reply_position
                         << " and size of out_of_order_tasks set is "
                         << queue.out_of_order_completed_tasks.size();
        }
      });
}

此处的关键调用为 queue.rpc_client->PushActorTask , 为 CoreWorkerClient 类的方法,定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
// core_worker_client.h 266
  void PushActorTask(std::unique_ptr<PushTaskRequest> request, bool skip_queue,
                     const ClientCallback<PushTaskReply> &callback) override {
    if (skip_queue) {
      // Set this value so that the actor does not skip any tasks when
      // processing this request. We could also set it to max_finished_seq_no_,
      // but we just set it to the default of -1 to avoid taking the lock.
      request->set_client_processed_up_to(-1);
      INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, callback, grpc_client_);
      return;
    }

    {
      absl::MutexLock lock(&mutex_);
      send_queue_.push_back(std::make_pair(
          std::move(request),
          std::move(const_cast<ClientCallback<PushTaskReply> &>(callback))));
    }
    SendRequests();
  }

其中继续调用 SendRequests 方法,定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// core_worker_client.h 303
 /// Send as many pending tasks as possible. This method is thread-safe.
  ///
  /// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being
  /// sent at once. This prevents the server scheduling queue from being overwhelmed.
  /// See direct_actor.proto for a description of the ordering protocol.
  void SendRequests() {
    absl::MutexLock lock(&mutex_);
    auto this_ptr = this->shared_from_this();

    while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) {
      auto pair = std::move(*send_queue_.begin());
      send_queue_.pop_front();

      auto request = std::move(pair.first);
      int64_t task_size = RequestSizeInBytes(*request);
      int64_t seq_no = request->sequence_number();
      request->set_client_processed_up_to(max_finished_seq_no_);
      rpc_bytes_in_flight_ += task_size;

      auto rpc_callback = [this, this_ptr, seq_no, task_size,
                           callback = std::move(pair.second)](
                              Status status, const rpc::PushTaskReply &reply) {
        {
          absl::MutexLock lock(&mutex_);
          if (seq_no > max_finished_seq_no_) {
            max_finished_seq_no_ = seq_no;
          }
          rpc_bytes_in_flight_ -= task_size;
          RAY_CHECK(rpc_bytes_in_flight_ >= 0);
        }
        SendRequests();
        callback(status, reply);
      };

      RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request,
                                 std::move(rpc_callback), grpc_client_));
    }

    if (!send_queue_.empty()) {
      RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size();
    }
  }

其中的关键调用 RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, std::move(rpc_callback), grpc_client_));, 进行了宏 的调用,如下:

1
2
3
4
5
6
7
// grpc_client.h 30
// This macro wraps the logic to call a specific RPC method of a service,
// to make it easier to implement a new RPC client.
#define INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client) \
  (rpc_client->CallMethod<METHOD##Request, METHOD##Reply>(              \
      &SERVICE::Stub::PrepareAsync##METHOD, request, callback,          \
      #SERVICE ".grpc_client." #METHOD))

宏定义中调用了 GrpcClient 类的 CallMethod 方法, 定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// grpc_client.h 78
  /// Create a new `ClientCall` and send request.
  ///
  /// \tparam Request Type of the request message.
  /// \tparam Reply Type of the reply message.
  ///
  /// \param[in] prepare_async_function Pointer to the gRPC-generated
  /// `FooService::Stub::PrepareAsyncBar` function.
  /// \param[in] request The request message.
  /// \param[in] callback The callback function that handles reply.
  ///
  /// \return Status.
  template <class Request, class Reply>
  void CallMethod(
      const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
      const Request &request, const ClientCallback<Reply> &callback,
      std::string call_name = "UNKNOWN_RPC") {
    auto call = client_call_manager_.CreateCall<GrpcService, Request, Reply>(
        *stub_, prepare_async_function, request, callback, std::move(call_name));
    RAY_CHECK(call != nullptr);
  }

其中调用了 ClientCallManager 类的 CreateCall 方法, 如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// client_call.h 213
  /// Create a new `ClientCall` and send request.
  ///
  /// \tparam GrpcService Type of the gRPC-generated service class.
  /// \tparam Request Type of the request message.
  /// \tparam Reply Type of the reply message.
  ///
  /// \param[in] stub The gRPC-generated stub.
  /// \param[in] prepare_async_function Pointer to the gRPC-generated
  /// `FooService::Stub::PrepareAsyncBar` function.
  /// \param[in] request The request message.
  /// \param[in] callback The callback function that handles reply.
  ///
  /// \return A `ClientCall` representing the request that was just sent.
  template <class GrpcService, class Request, class Reply>
  std::shared_ptr<ClientCall> CreateCall(
      typename GrpcService::Stub &stub,
      const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
      const Request &request, const ClientCallback<Reply> &callback,
      std::string call_name) {
    auto stats_handle = main_service_.RecordStart(call_name);
    auto call = std::make_shared<ClientCallImpl<Reply>>(callback, std::move(stats_handle),
                                                        call_timeout_ms_);
    // Send request.
    // Find the next completion queue to wait for response.
    call->response_reader_ = (stub.*prepare_async_function)(
        &call->context_, request, cqs_[rr_index_++ % num_threads_].get());
    call->response_reader_->StartCall();
    // Create a new tag object. This object will eventually be deleted in the
    // `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
    //
    // NOTE(chen): Unlike `ServerCall`, we can't directly use `ClientCall` as the tag.
    // Because this function must return a `shared_ptr` to make sure the returned
    // `ClientCall` is safe to use. But `response_reader_->Finish` only accepts a raw
    // pointer.
    auto tag = new ClientCallTag(call);
    call->response_reader_->Finish(&call->reply_, &call->status_, (void *)tag);
    return call;
  }

对 RPC 的异步调用进行了封装:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// client_call.h 163
/// Represents the generic signature of a `FooService::Stub::PrepareAsyncBar`
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class GrpcService, class Request, class Reply>
using PrepareAsyncFunction = std::unique_ptr<grpc::ClientAsyncResponseReader<Reply>> (
    GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request,
                          grpc::CompletionQueue *cq);

Protocol define

在 core_worker.proto 中有一个 PushTaskRequet message

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
// core_worker.proto 86
message PushTaskRequest {
 // The ID of the worker this message is intended for.
 bytes intended_worker_id = 1;
 // The task to be pushed.
 TaskSpec task_spec = 2;
 // The sequence number of the task for this client. This must increase
 // sequentially starting from zero for each actor handle. The server
 // will guarantee tasks execute in this sequence, waiting for any
 // out-of-order request messages to arrive as necessary.
 // If set to -1, ordering is disabled and the task executes immediately.
 int64 sequence_number = 3;
 // The max sequence number the client has processed responses for. This
 // is a performance optimization that allows the client to tell the server
 // to cancel any PushTaskRequests with seqno <= this value, rather than
 // waiting for the server to time out waiting for missing messages.
 int64 client_processed_up_to = 4;
 // Resource mapping ids assigned to the worker executing the task.
 repeated ResourceMapEntry resource_mapping = 5;
}

其中重点的 TaskSpec message 定义如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// common.proto 170
/// The task specification encapsulates all immutable information about the
/// task. These fields are determined at submission time, converse to the
/// `TaskExecutionSpec` may change at execution time.
message TaskSpec {
 // Type of this task.
 TaskType type = 1;
 // Name of this task.
 string name = 2;
 // Language of this task.
 Language language = 3;
 // Function descriptor of this task uniquely describe the function to execute.
 FunctionDescriptor function_descriptor = 4;
 // ID of the job that this task belongs to.
 bytes job_id = 5;
 // Task ID of the task.
 bytes task_id = 6;
 // Task ID of the parent task.
 bytes parent_task_id = 7;
 // A count of the number of tasks submitted by the parent task before this one.
 uint64 parent_counter = 8;
 // Task ID of the caller. This is the same as parent_task_id for non-actors.
 // This is the actor ID (embedded in a nil task ID) for actors.
 bytes caller_id = 9;
 /// Address of the caller.
 Address caller_address = 10;
 // Task arguments.
 repeated TaskArg args = 11;
 // Number of return objects.
 uint64 num_returns = 12;
 // Quantities of the different resources required by this task.
 map<string, double> required_resources = 13;
 // The resources required for placing this task on a node. If this is empty,
 // then the placement resources are equal to the required_resources.
 map<string, double> required_placement_resources = 14;
 // Task specification for an actor creation task.
 // This field is only valid when `type == ACTOR_CREATION_TASK`.
 ActorCreationTaskSpec actor_creation_task_spec = 15;
 // Task specification for an actor task.
 // This field is only valid when `type == ACTOR_TASK`.
 ActorTaskSpec actor_task_spec = 16;
 // number of times this task may be retried on worker failure.
 int32 max_retries = 17;
 // placement group that is associated with this task.
 bytes placement_group_id = 18;
 // placement group bundle that is associated with this task.
 int64 placement_group_bundle_index = 19;
 // whether or not this task should capture parent's placement group automatically.
 bool placement_group_capture_child_tasks = 20;
 // whether or not to skip the execution of this task. when it's true,
 // the receiver will not execute the task. this field is used by async actors
 // to guarantee task submission order after restart.
 bool skip_execution = 21;
 // breakpoint if this task should drop into the debugger when it starts executing
 // and "" if the task should not drop into the debugger.
 bytes debugger_breakpoint = 22;
 // runtime environment for this task.
 runtimeenv runtime_env = 23;
 // the concurrency group name in which this task will be performed.
 string concurrency_group_name = 24;
 // whether application-level errors (exceptions) should be retried.
 bool retry_exceptions = 25;
}

其中关键的 FunctionDescriptor 定义如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// common.proto 81
message FunctionDescriptor {
 oneof function_descriptor {
   JavaFunctionDescriptor java_function_descriptor = 1;
   PythonFunctionDescriptor python_function_descriptor = 2;
   CppFunctionDescriptor cpp_function_descriptor = 3;
 }
}

/// Function descriptor for Java.
message JavaFunctionDescriptor {
 string class_name = 1;
 string function_name = 2;
 string signature = 3;
}

/// Function descriptor for Python.
message PythonFunctionDescriptor {
 string module_name = 1;
 string class_name = 2;
 string function_name = 3;
 string function_hash = 4;
}

/// Function descriptor for C/C++.
message CppFunctionDescriptor {
 /// Remote function name.
 string function_name = 1;
}

receiving task

在 raylet 进程中先通过调用 fork() 系统调用开启子进程,之后子进程通过执行 execvpe 命令执行 python default_worker.py ... , 在执行 default_worker.py 时就已经创建好了多个进程来执行了。

在 default_worker.py

1
2
3
4
5
6
7
8
# default_worker.py 219
if mode == ray.WORKER_MODE:
    ray.worker.global_worker.main_loop()
elif mode in [ray.RESTORE_WORKER_MODE, ray.SPILL_WORKER_MODE]:
    # It is handled by another thread in the C++ core worker.
    # We just need to keep the worker alive.
    while True:
        time.sleep(100000)

在 main_loop 中, 调用 c++ 的实现 CoreWorkerProcess.RunTaskExecutionLoop()

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def main_loop(self):
    """The main loop a worker runs to receive and execute tasks."""

    def sigterm_handler(signum, frame):
        shutdown(True)
        sys.exit(1)

    ray._private.utils.set_sigterm_handler(sigterm_handler)
    # liudy: locate in _raylet.pyx
    self.core_worker.run_task_loop()
    sys.exit(0)
1
2
3
4
# _raylet.pyx 1108
    def run_task_loop(self):
        with nogil:
            CCoreWorkerProcess.RunTaskExecutionLoop()

CoreWorkerProcess::RunTaskExecutionLoop() 的实现如下, 其中再创建 worker (CoreWorker), 调用 CoreWorker 的 RunTaskExecutionLoop

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// core_worker.cc 370
void CoreWorkerProcess::RunTaskExecutionLoop() {
  EnsureInitialized(/*quick_exit*/ false);
  RAY_CHECK(core_worker_process->options_.worker_type == WorkerType::WORKER);
  if (core_worker_process->options_.num_workers == 1) {
    // Run the task loop in the current thread only if the number of workers is 1.
    auto worker = core_worker_process->GetGlobalWorker();
    if (!worker) {
      worker = core_worker_process->CreateWorker();
    }
    worker->RunTaskExecutionLoop();
    RAY_LOG(DEBUG) << "Task execution loop terminated. Removing the global worker.";
    core_worker_process->RemoveWorker(worker);
  } else {
    std::vector<std::thread> worker_threads;
    for (int i = 0; i < core_worker_process->options_.num_workers; i++) {
      worker_threads.emplace_back([i] {
        SetThreadName("worker.task" + std::to_string(i));
        auto worker = core_worker_process->CreateWorker();
        worker->RunTaskExecutionLoop();
        RAY_LOG(INFO) << "Task execution loop terminated for a thread "
                      << std::to_string(i) << ". Removing a worker.";
        core_worker_process->RemoveWorker(worker);
      });
    }
    for (auto &thread : worker_threads) {
      thread.join();
    }
  }

  core_worker_process.reset();
}

CoreWorker 的 RunTaskExecutionLoop , 其中的 task_execution_service 是类 instrumented_io_conext

1
2
// core_worker.cc 2190
void CoreWorker::RunTaskExecutionLoop() { task_execution_service_.run(); }

其中 大部分工作都在 CoreWorker 的构造函数中进行的

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
 // core_worker.cc 402
CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id)
    : options_(options),
      get_call_site_(RayConfig::instance().record_ref_creation_sites()
                         ? options_.get_lang_stack
                         : nullptr),
      worker_context_(options_.worker_type, worker_id, GetProcessJobID(options_)),
      io_work_(io_service_),
      client_call_manager_(new rpc::ClientCallManager(io_service_)),
      periodical_runner_(io_service_),
      task_queue_length_(0),
      num_executed_tasks_(0),
      resource_ids_(new ResourceMappingType()),
      grpc_service_(io_service_, *this),
      task_execution_service_work_(task_execution_service_) {
  RAY_LOG(DEBUG) << "Constructing CoreWorker, worker_id: " << worker_id;

  // Initialize task receivers.
  if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) {
    RAY_CHECK(options_.task_execution_callback != nullptr);
    auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
                                  std::placeholders::_2, std::placeholders::_3,
                                  std::placeholders::_4, std::placeholders::_5);
    direct_task_receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>(
        worker_context_, task_execution_service_, execute_task,
        [this] { return local_raylet_client_->TaskDone(); });
  }

  // Initialize raylet client.
  // NOTE(edoakes): the core_worker_server_ must be running before registering with
  // the raylet, as the raylet will start sending some RPC messages immediately.
  // TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot
  // connect to Raylet after a number of retries, this can be changed later
  // so that the worker (java/python .etc) can retrieve and handle the error
  // instead of crashing.
  // 通过 RPC 与 NodeManager 进行连接
  auto grpc_client = rpc::NodeManagerWorkerClient::make(
      options_.raylet_ip_address, options_.node_manager_port, *client_call_manager_);
  Status raylet_client_status;
  NodeID local_raylet_id;
  int assigned_port;
  std::string serialized_job_config = options_.serialized_job_config;

  // 构建一个连接 raylet 的 client
  local_raylet_client_ = std::make_shared<raylet::RayletClient>(
      io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(),
      options_.worker_type, worker_context_.GetCurrentJobID(), options_.runtime_env_hash,
      options_.language, options_.node_ip_address, &raylet_client_status,
      &local_raylet_id, &assigned_port, &serialized_job_config, options_.worker_shim_pid,
      options_.startup_token);

  if (!raylet_client_status.ok()) {
    // Avoid using FATAL log or RAY_CHECK here because they may create a core dump file.
    RAY_LOG(ERROR) << "Failed to register worker " << worker_id << " to Raylet. "
                   << raylet_client_status;
    // Quit the process immediately.
    QuickExit();
  }

  connected_ = true;

  RAY_CHECK(assigned_port >= 0);

  // Parse job config from serialized string.
  job_config_.reset(new rpc::JobConfig());
  job_config_->ParseFromString(serialized_job_config);

  // Start RPC server after all the task receivers are properly initialized and we have
  // our assigned port from the raylet.
   // 启动 RPC server 用于接受任务
  core_worker_server_ = std::make_unique<rpc::GrpcServer>(
      WorkerTypeString(options_.worker_type), assigned_port,
      options_.node_ip_address == "127.0.0.1");
  core_worker_server_->RegisterService(grpc_service_);
  core_worker_server_->Run();

  // Set our own address.
  RAY_CHECK(!local_raylet_id.IsNil());
  rpc_address_.set_ip_address(options_.node_ip_address);
  rpc_address_.set_port(core_worker_server_->GetPort());
  rpc_address_.set_raylet_id(local_raylet_id.Binary());
  rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary());
  RAY_LOG(INFO) << "Initializing worker at address: " << rpc_address_.ip_address() << ":"
                << rpc_address_.port() << ", worker ID " << worker_context_.GetWorkerID()
                << ", raylet " << local_raylet_id;

  // Begin to get gcs server address from raylet.
  gcs_server_address_updater_ = std::make_unique<GcsServerAddressUpdater>(
      options_.raylet_ip_address, options_.node_manager_port,
      [this](std::string ip, int port) {
        absl::MutexLock lock(&gcs_server_address_mutex_);
        gcs_server_address_.first = ip;
        gcs_server_address_.second = port;
      });

  // Initialize gcs client.
  // As the synchronous and the asynchronous context of redis client is not used in this
  // gcs client. We would not open connection for it by setting `enable_sync_conn` and
  // `enable_async_conn` as false.
  gcs::GcsClientOptions gcs_options = gcs::GcsClientOptions(
      options_.gcs_options.server_ip_, options_.gcs_options.server_port_,
      options_.gcs_options.password_,
      /*enable_sync_conn=*/false, /*enable_async_conn=*/false,
      /*enable_subscribe_conn=*/true);
  gcs_client_ = std::make_shared<gcs::GcsClient>(
      gcs_options, [this](std::pair<std::string, int> *address) {
        absl::MutexLock lock(&gcs_server_address_mutex_);
        if (gcs_server_address_.second != 0) {
          address->first = gcs_server_address_.first;
          address->second = gcs_server_address_.second;
          return true;
        }
        return false;
      });

  RAY_CHECK_OK(gcs_client_->Connect(io_service_));
  RegisterToGcs();

  // Register a callback to monitor removed nodes.
  auto on_node_change = [this](const NodeID &node_id, const rpc::GcsNodeInfo &data) {
    if (data.state() == rpc::GcsNodeInfo::DEAD) {
      OnNodeRemoved(node_id);
    }
  };
  RAY_CHECK_OK(gcs_client_->Nodes().AsyncSubscribeToNodeChange(on_node_change, nullptr));

  // Initialize profiler.
  profiler_ = std::make_shared<worker::Profiler>(
      worker_context_, options_.node_ip_address, io_service_, gcs_client_);

  core_worker_client_pool_ =
      std::make_shared<rpc::CoreWorkerClientPool>(*client_call_manager_);

  object_info_publisher_ = std::make_unique<pubsub::Publisher>(
      /*channels=*/std::vector<
          rpc::ChannelType>{rpc::ChannelType::WORKER_OBJECT_EVICTION,
                            rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL,
                            rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL},
      /*periodical_runner=*/&periodical_runner_,
      /*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
      /*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
      /*publish_batch_size_=*/RayConfig::instance().publish_batch_size());
  object_info_subscriber_ = std::make_unique<pubsub::Subscriber>(
      /*subscriber_id=*/GetWorkerID(),
      /*channels=*/
      std::vector<rpc::ChannelType>{rpc::ChannelType::WORKER_OBJECT_EVICTION,
                                    rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL,
                                    rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL},
      /*max_command_batch_size*/ RayConfig::instance().max_command_batch_size(),
      /*get_client=*/
      [this](const rpc::Address &address) {
        return core_worker_client_pool_->GetOrConnect(address);
      },
      /*callback_service*/ &io_service_);

  reference_counter_ = std::make_shared<ReferenceCounter>(
      rpc_address_,
      /*object_info_publisher=*/object_info_publisher_.get(),
      /*object_info_subscriber=*/object_info_subscriber_.get(),
      RayConfig::instance().lineage_pinning_enabled(), [this](const rpc::Address &addr) {
        return std::shared_ptr<rpc::CoreWorkerClient>(
            new rpc::CoreWorkerClient(addr, *client_call_manager_));
      });

  if (options_.worker_type == WorkerType::WORKER) {
    periodical_runner_.RunFnPeriodically(
        [this] { CheckForRayletFailure(); },
        RayConfig::instance().raylet_death_check_interval_milliseconds());
  }

  plasma_store_provider_.reset(new CoreWorkerPlasmaStoreProvider(
      options_.store_socket, local_raylet_client_, reference_counter_,
      options_.check_signals,
      /*warmup=*/
      (options_.worker_type != WorkerType::SPILL_WORKER &&
       options_.worker_type != WorkerType::RESTORE_WORKER),
      /*get_current_call_site=*/boost::bind(&CoreWorker::CurrentCallSite, this)));
  memory_store_.reset(new CoreWorkerMemoryStore(
      reference_counter_, local_raylet_client_, options_.check_signals,
      [this](const RayObject &obj) {
        // Run this on the event loop to avoid calling back into the language runtime
        // from the middle of user operations.
        io_service_.post(
            [this, obj]() {
              if (options_.unhandled_exception_handler != nullptr) {
                options_.unhandled_exception_handler(obj);
              }
            },
            "CoreWorker.HandleException");
      }));

  periodical_runner_.RunFnPeriodically([this] { InternalHeartbeat(); },
                                       kInternalHeartbeatMillis);

  auto check_node_alive_fn = [this](const NodeID &node_id) {
    auto node = gcs_client_->Nodes().Get(node_id);
    return node != nullptr;
  };
  auto reconstruct_object_callback = [this](const ObjectID &object_id) {
    io_service_.post(
        [this, object_id]() {
          RAY_CHECK(object_recovery_manager_->RecoverObject(object_id));
        },
        "CoreWorker.ReconstructObject");
  };
  auto push_error_callback = [this](const JobID &job_id, const std::string &type,
                                    const std::string &error_message, double timestamp) {
    return PushError(job_id, type, error_message, timestamp);
  };
  task_manager_.reset(new TaskManager(
      memory_store_, reference_counter_,
      /*put_in_local_plasma_callback=*/
      [this](const RayObject &object, const ObjectID &object_id) {
        RAY_CHECK_OK(PutInLocalPlasmaStore(object, object_id, /*pin_object=*/true));
      },
      /* retry_task_callback= */
      [this](TaskSpecification &spec, bool delay) {
        if (delay) {
          // Retry after a delay to emulate the existing Raylet reconstruction
          // behaviour. TODO(ekl) backoff exponentially.
          uint32_t delay = RayConfig::instance().task_retry_delay_ms();
          RAY_LOG(INFO) << "Will resubmit task after a " << delay
                        << "ms delay: " << spec.DebugString();
          absl::MutexLock lock(&mutex_);
          to_resubmit_.push_back(std::make_pair(current_time_ms() + delay, spec));
        } else {
          RAY_LOG(INFO) << "Resubmitting task that produced lost plasma object: "
                        << spec.DebugString();
          if (spec.IsActorTask()) {
            auto actor_handle = actor_manager_->GetActorHandle(spec.ActorId());
            actor_handle->SetResubmittedActorTaskSpec(spec, spec.ActorDummyObject());
            RAY_CHECK_OK(direct_actor_submitter_->SubmitTask(spec));
          } else {
            RAY_CHECK_OK(direct_task_submitter_->SubmitTask(spec));
          }
        }
      },
      check_node_alive_fn, reconstruct_object_callback, push_error_callback));

  // Create an entry for the driver task in the task table. This task is
  // added immediately with status RUNNING. This allows us to push errors
  // related to this driver task back to the driver. For example, if the
  // driver creates an object that is later evicted, we should notify the
  // user that we're unable to reconstruct the object, since we cannot
  // rerun the driver.
  if (options_.worker_type == WorkerType::DRIVER) {
    TaskSpecBuilder builder;
    const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
    builder.SetDriverTaskSpec(task_id, options_.language,
                              worker_context_.GetCurrentJobID(),
                              TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
                              GetCallerId(), rpc_address_);

    std::shared_ptr<rpc::TaskTableData> data = std::make_shared<rpc::TaskTableData>();
    data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
    SetCurrentTaskId(task_id);
  }

  auto raylet_client_factory = [this](const std::string ip_address, int port) {
    auto grpc_client =
        rpc::NodeManagerWorkerClient::make(ip_address, port, *client_call_manager_);
    return std::shared_ptr<raylet::RayletClient>(
        new raylet::RayletClient(std::move(grpc_client)));
  };

  auto on_excess_queueing = [this](const ActorID &actor_id, int64_t num_queued) {
    auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
                         std::chrono::system_clock::now().time_since_epoch())
                         .count();
    std::ostringstream stream;
    stream << "Warning: More than " << num_queued
           << " tasks are pending submission to actor " << actor_id
           << ". To reduce memory usage, wait for these tasks to finish before sending "
              "more.";
    RAY_CHECK_OK(
        PushError(options_.job_id, "excess_queueing_warning", stream.str(), timestamp));
  };

  actor_creator_ = std::make_shared<DefaultActorCreator>(gcs_client_);

  direct_actor_submitter_ = std::shared_ptr<CoreWorkerDirectActorTaskSubmitter>(
      new CoreWorkerDirectActorTaskSubmitter(*core_worker_client_pool_, *memory_store_,
                                             *task_manager_, *actor_creator_,
                                             on_excess_queueing));

  auto node_addr_factory = [this](const NodeID &node_id) {
    absl::optional<rpc::Address> addr;
    if (auto node_info = gcs_client_->Nodes().Get(node_id)) {
      rpc::Address address;
      address.set_raylet_id(node_info->node_id());
      address.set_ip_address(node_info->node_manager_address());
      address.set_port(node_info->node_manager_port());
      addr = address;
    }
    return addr;
  };
  auto lease_policy = RayConfig::instance().locality_aware_leasing_enabled()
                          ? std::shared_ptr<LeasePolicyInterface>(
                                std::make_shared<LocalityAwareLeasePolicy>(
                                    reference_counter_, node_addr_factory, rpc_address_))
                          : std::shared_ptr<LeasePolicyInterface>(
                                std::make_shared<LocalLeasePolicy>(rpc_address_));

  direct_task_submitter_ = std::make_unique<CoreWorkerDirectTaskSubmitter>(
      rpc_address_, local_raylet_client_, core_worker_client_pool_, raylet_client_factory,
      std::move(lease_policy), memory_store_, task_manager_, local_raylet_id,
      RayConfig::instance().worker_lease_timeout_milliseconds(), actor_creator_,
      RayConfig::instance().max_tasks_in_flight_per_worker(),
      boost::asio::steady_timer(io_service_),
      RayConfig::instance().max_pending_lease_requests_per_scheduling_category());
  auto report_locality_data_callback =
      [this](const ObjectID &object_id, const absl::flat_hash_set<NodeID> &locations,
             uint64_t object_size) {
        reference_counter_->ReportLocalityData(object_id, locations, object_size);
      };
  future_resolver_.reset(new FutureResolver(memory_store_, reference_counter_,
                                            std::move(report_locality_data_callback),
                                            core_worker_client_pool_, rpc_address_));

  // Unfortunately the raylet client has to be constructed after the receivers.
  if (direct_task_receiver_ != nullptr) {
    task_argument_waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_));
    direct_task_receiver_->Init(core_worker_client_pool_, rpc_address_,
                                task_argument_waiter_);
  }

  actor_manager_ = std::make_unique<ActorManager>(gcs_client_, direct_actor_submitter_,
                                                  reference_counter_);

  std::function<Status(const ObjectID &object_id, const ObjectLookupCallback &callback)>
      object_lookup_fn;

  object_lookup_fn = [this, node_addr_factory](const ObjectID &object_id,
                                               const ObjectLookupCallback &callback) {
    std::vector<rpc::Address> locations;
    const absl::optional<absl::flat_hash_set<NodeID>> object_locations =
        reference_counter_->GetObjectLocations(object_id);
    if (object_locations.has_value()) {
      locations.reserve(object_locations.value().size());
      for (const auto &node_id : object_locations.value()) {
        absl::optional<rpc::Address> addr = node_addr_factory(node_id);
        if (addr.has_value()) {
          locations.push_back(addr.value());
        } else {
          // We're getting potentially stale locations directly from the reference
          // counter, so the location might be a dead node.
          RAY_LOG(DEBUG) << "Location " << node_id
                         << " is dead, not using it in the recovery of object "
                         << object_id;
        }
      }
    }
    callback(object_id, locations);
    return Status::OK();
  };
  object_recovery_manager_ = std::make_unique<ObjectRecoveryManager>(
      rpc_address_, raylet_client_factory, local_raylet_client_, object_lookup_fn,
      task_manager_, reference_counter_, memory_store_,
      [this](const ObjectID &object_id, rpc::ErrorType reason, bool pin_object) {
        RAY_LOG(DEBUG) << "Failed to recover object " << object_id << " due to "
                       << rpc::ErrorType_Name(reason);
        RAY_CHECK_OK(Put(RayObject(reason),
                         /*contained_object_ids=*/{}, object_id,
                         /*pin_object=*/pin_object));
      });

  // Start the IO thread after all other members have been initialized, in case
  // the thread calls back into any of our members.
  io_thread_ = std::thread([this]() { RunIOService(); });
  // Tell the raylet the port that we are listening on.
  // NOTE: This also marks the worker as available in Raylet. We do this at the
  // very end in case there is a problem during construction.
  if (options.connect_on_start) {
    RAY_CHECK_OK(
        local_raylet_client_->AnnounceWorkerPort(core_worker_server_->GetPort()));
  }
  // Used to detect if the object is in the plasma store.
  max_direct_call_object_size_ = RayConfig::instance().max_direct_call_object_size();

  /// If periodic asio stats print is enabled, it will print it.
  const auto event_stats_print_interval_ms =
      RayConfig::instance().event_stats_print_interval_ms();
  if (event_stats_print_interval_ms != -1 && RayConfig::instance().event_stats()) {
    periodical_runner_.RunFnPeriodically(
        [this] {
          RAY_LOG(INFO) << "Event stats:\n\n" << io_service_.StatsString() << "\n\n";
        },
        event_stats_print_interval_ms);
  }

  // Set event context for current core worker thread.
  RayEventContext::Instance().SetEventContext(
      ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER,
      {{"worker_id", worker_id.Hex()}});
}

通过使用 Boost::Asio 进行异步编程通过 post 回调函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// core_worker.cc 2527
void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
                                rpc::PushTaskReply *reply,
                                rpc::SendReplyCallback send_reply_callback) {
  if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()),
                           send_reply_callback)) {
    return;
  }

  // Increment the task_queue_length and per function counter.
  task_queue_length_ += 1;
  std::string func_name =
      FunctionDescriptorBuilder::FromProto(request.task_spec().function_descriptor())
          ->CallString();
  {
    absl::MutexLock l(&task_counter_.tasks_counter_mutex_);
    task_counter_.Add(TaskCounter::kPending, func_name, 1);
  }

  // For actor tasks, we just need to post a HandleActorTask instance to the task
  // execution service.
  if (request.task_spec().type() == TaskType::ACTOR_TASK) {
    task_execution_service_.post(
        [this, request, reply, send_reply_callback = std::move(send_reply_callback)] {
          // We have posted an exit task onto the main event loop,
          // so shouldn't bother executing any further work.
          if (exiting_) return;
          direct_task_receiver_->HandleTask(request, reply, send_reply_callback);
        },
        "CoreWorker.HandlePushTaskActor");
  } else {
    // Normal tasks are enqueued here, and we post a RunNormalTasksFromQueue instance to
    // the task execution service.
    direct_task_receiver_->HandleTask(request, reply, send_reply_callback);
    task_execution_service_.post(
        [=] {
          // We have posted an exit task onto the main event loop,
          // so shouldn't bother executing any further work.
          if (exiting_) return;
          direct_task_receiver_->RunNormalTasksFromQueue();
        },
        "CoreWorker.HandlePushTask");
  }
}

在回调函数 HandlePushTask 中调用 task_execution_service.post. task_execution_service_ 是类 instrumented_io_context. 会执行回调函数 direct_task_receiver_->HandleTask 进行处理. direct_task_receiver_ 定义如下:

1
2
3
   // core_worker.h 1411
  // Interface that receives tasks from direct actor calls.
  std::unique_ptr<CoreWorkerDirectTaskReceiver> direct_task_receiver_;

HandleTask 实现如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// direct_actor_transport.cc 444
void CoreWorkerDirectTaskReceiver::HandleTask(
    const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
    rpc::SendReplyCallback send_reply_callback) {
  RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
  // Use `mutable_task_spec()` here as `task_spec()` returns a const reference
  // which doesn't work with std::move.
  TaskSpecification task_spec(
      std::move(*(const_cast<rpc::PushTaskRequest &>(request).mutable_task_spec())));

  // If GCS server is restarted after sending an actor creation task to this core worker,
  // the restarted GCS server will send the same actor creation task to the core worker
  // again. We just need to ignore it and reply ok.
  if (task_spec.IsActorCreationTask() &&
      worker_context_.GetCurrentActorID() == task_spec.ActorCreationId()) {
    send_reply_callback(Status::OK(), nullptr, nullptr);
    RAY_LOG(INFO) << "Ignoring duplicate actor creation task for actor "
                  << task_spec.ActorCreationId()
                  << ". This is likely due to a GCS server restart.";
    return;
  }

  if (task_spec.IsActorCreationTask()) {
    worker_context_.SetCurrentActorId(task_spec.ActorCreationId());
    SetMaxActorConcurrency(task_spec.IsAsyncioActor(), task_spec.MaxActorConcurrency());
  }

  // Only assign resources for non-actor tasks. Actor tasks inherit the resources
  // assigned at initial actor creation time.
  std::shared_ptr<ResourceMappingType> resource_ids;
  if (!task_spec.IsActorTask()) {
    resource_ids.reset(new ResourceMappingType());
    for (const auto &mapping : request.resource_mapping()) {
      std::vector<std::pair<int64_t, double>> rids;
      for (const auto &ids : mapping.resource_ids()) {
        rids.push_back(std::make_pair(ids.index(), ids.quantity()));
      }
      (*resource_ids)[mapping.name()] = rids;
    }
  }

  auto accept_callback = [this, reply, task_spec,
                          resource_ids](rpc::SendReplyCallback send_reply_callback) {
    if (task_spec.GetMessage().skip_execution()) {
      send_reply_callback(Status::OK(), nullptr, nullptr);
      return;
    }

    auto num_returns = task_spec.NumReturns();
    if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
      // Decrease to account for the dummy object id.
      num_returns--;
    }
    RAY_CHECK(num_returns >= 0);

    std::vector<std::shared_ptr<RayObject>> return_objects;
    bool is_application_level_error = false;

    // 任务执行
    auto status =
        task_handler_(task_spec, resource_ids, &return_objects,
                      reply->mutable_borrowed_refs(), &is_application_level_error);
    reply->set_is_application_level_error(is_application_level_error);

    bool objects_valid = return_objects.size() == num_returns;
    if (objects_valid) {
      for (size_t i = 0; i < return_objects.size(); i++) {
        auto return_object = reply->add_return_objects();
        ObjectID id = ObjectID::FromIndex(task_spec.TaskId(), /*index=*/i + 1);
        return_object->set_object_id(id.Binary());

        // The object is nullptr if it already existed in the object store.
        const auto &result = return_objects[i];
        return_object->set_size(result->GetSize());
        if (result->GetData() != nullptr && result->GetData()->IsPlasmaBuffer()) {
          return_object->set_in_plasma(true);
        } else {
          if (result->GetData() != nullptr) {
            return_object->set_data(result->GetData()->Data(), result->GetData()->Size());
          }
          if (result->GetMetadata() != nullptr) {
            return_object->set_metadata(result->GetMetadata()->Data(),
                                        result->GetMetadata()->Size());
          }
        }
        for (const auto &nested_ref : result->GetNestedRefs()) {
          return_object->add_nested_inlined_refs()->CopyFrom(nested_ref);
        }
      }

      if (task_spec.IsActorCreationTask()) {
        /// The default max concurrency for creating PoolManager should
        /// be 0 if this is an asyncio actor.
        const int default_max_concurrency =
            task_spec.IsAsyncioActor() ? 0 : task_spec.MaxActorConcurrency();
        pool_manager_ = std::make_shared<PoolManager>(task_spec.ConcurrencyGroups(),
                                                      default_max_concurrency);
        concurrency_groups_cache_[task_spec.TaskId().ActorId()] =
            task_spec.ConcurrencyGroups();
        RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId()
                      << ", actor_id: " << task_spec.ActorCreationId();
        // Tell raylet that an actor creation task has finished execution, so that
        // raylet can publish actor creation event to GCS, and mark this worker as
        // actor, thus if this worker dies later raylet will restart the actor.
        RAY_CHECK_OK(task_done_());
      }
    }
    if (status.ShouldExitWorker()) {
      // Don't allow the worker to be reused, even though the reply status is OK.
      // The worker will be shutting down shortly.
      reply->set_worker_exiting(true);
      if (objects_valid) {
        // This happens when max_calls is hit. We still need to return the objects.
        send_reply_callback(Status::OK(), nullptr, nullptr);
      } else {
        send_reply_callback(status, nullptr, nullptr);
      }
    } else {
      RAY_CHECK(objects_valid) << return_objects.size() << "  " << num_returns;
      send_reply_callback(status, nullptr, nullptr);
    }
  };

  auto reject_callback = [](rpc::SendReplyCallback send_reply_callback) {
    send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
  };

  auto steal_callback = [this, task_spec,
                         reply](rpc::SendReplyCallback send_reply_callback) {
    RAY_LOG(DEBUG) << "Task " << task_spec.TaskId() << " was stolen from "
                   << worker_context_.GetWorkerID()
                   << "'s non_actor_task_queue_! Setting reply->set_task_stolen(true)!";
    reply->set_task_stolen(true);
    send_reply_callback(Status::OK(), nullptr, nullptr);
  };

  auto dependencies = task_spec.GetDependencies(false);

  if (task_spec.IsActorTask()) {
    auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
    if (it == actor_scheduling_queues_.end()) {
      auto cg_it = concurrency_groups_cache_.find(task_spec.ActorId());
      RAY_CHECK(cg_it != concurrency_groups_cache_.end());
      auto result = actor_scheduling_queues_.emplace(
          task_spec.CallerWorkerId(),
          std::unique_ptr<SchedulingQueue>(new ActorSchedulingQueue(
              task_main_io_service_, *waiter_, pool_manager_, is_asyncio_,
              fiber_max_concurrency_, cg_it->second)));
      it = result.first;
    }

    it->second->Add(request.sequence_number(), request.client_processed_up_to(),
                    std::move(accept_callback), std::move(reject_callback),
                    std::move(send_reply_callback), task_spec.ConcurrencyGroupName(),
                    task_spec.FunctionDescriptor(), nullptr, task_spec.TaskId(),
                    dependencies);
  } else {
    // Add the normal task's callbacks to the non-actor scheduling queue.
    normal_scheduling_queue_->Add(
        request.sequence_number(), request.client_processed_up_to(),
        std::move(accept_callback), std::move(reject_callback),
        std::move(send_reply_callback), "", task_spec.FunctionDescriptor(),
        std::move(steal_callback), task_spec.TaskId(), dependencies);
  }
}

其中任务执行的关键代码 task_handler_ 是一个 TaskHandler 类, 定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
   // direct_actor_transport.h 944
  /// The callback function to process a task.
  TaskHandler task_handler_;

   / direct_actor_transport.h 887
    using TaskHandler =
      std::function<Status(const TaskSpecification &task_spec,
                           const std::shared_ptr<ResourceMappingType> resource_ids,
                           std::vector<std::shared_ptr<RayObject>> *return_objects,
                           ReferenceCounter::ReferenceTableProto *borrower_refs,
                           bool *is_application_level_error)>;

对于 task_handler_ 的赋值位于 CoreWorkerDirectTaskReceiver 的构造函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
  // direct_actor_transport.h 896
  CoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
                               instrumented_io_context &main_io_service,
                               const TaskHandler &task_handler,
                               const OnTaskDone &task_done)
      : worker_context_(worker_context),
        task_handler_(task_handler),
        task_main_io_service_(main_io_service),
        task_done_(task_done),
        pool_manager_(std::make_shared<PoolManager>()) {}

构造 CoreWorkerDirectTaskReceiver 的实例,位于 CoreWorker 的构造函数内

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
  // core_worker.cc 419
  if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) {
    RAY_CHECK(options_.task_execution_callback != nullptr);
    auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
                                  std::placeholders::_2, std::placeholders::_3,
                                  std::placeholders::_4, std::placeholders::_5);
    direct_task_receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>(
        worker_context_, task_execution_service_, execute_task,
        [this] { return local_raylet_client_->TaskDone(); });
  }

所以 task_handler_ 就是 execute_task , 而其是对 CoreWorker::ExecuteTask 函数的一个绑定

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// core_worker.ccc 2238
Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
                               const std::shared_ptr<ResourceMappingType> &resource_ids,
                               std::vector<std::shared_ptr<RayObject>> *return_objects,
                               ReferenceCounter::ReferenceTableProto *borrowed_refs,
                               bool *is_application_level_error) {
  RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString();
  task_queue_length_ -= 1;
  num_executed_tasks_ += 1;

  // Modify the worker's per function counters.
  std::string func_name = task_spec.FunctionDescriptor()->CallString();
  {
    absl::MutexLock l(&task_counter_.tasks_counter_mutex_);
    task_counter_.Add(TaskCounter::kPending, func_name, -1);
    task_counter_.Add(TaskCounter::kRunning, func_name, 1);
  }

  if (!options_.is_local_mode) {
    worker_context_.SetCurrentTask(task_spec);
    SetCurrentTaskId(task_spec.TaskId());
  }
  {
    absl::MutexLock lock(&mutex_);
    current_task_ = task_spec;
    if (resource_ids) {
      resource_ids_ = resource_ids;
    }
  }

  RayFunction func{task_spec.GetLanguage(), task_spec.FunctionDescriptor()};

  std::vector<std::shared_ptr<RayObject>> args;
  std::vector<rpc::ObjectReference> arg_refs;
  // This includes all IDs that were passed by reference and any IDs that were
  // inlined in the task spec. These references will be pinned during the task
  // execution and unpinned once the task completes. We will notify the caller
  // about any IDs that we are still borrowing by the time the task completes.
  std::vector<ObjectID> borrowed_ids;
  RAY_CHECK_OK(GetAndPinArgsForExecutor(task_spec, &args, &arg_refs, &borrowed_ids));

  std::vector<ObjectID> return_ids;
  for (size_t i = 0; i < task_spec.NumReturns(); i++) {
    return_ids.push_back(task_spec.ReturnId(i));
  }

  Status status;
  TaskType task_type = TaskType::NORMAL_TASK;
  if (task_spec.IsActorCreationTask()) {
    RAY_CHECK(return_ids.size() > 0);
    return_ids.pop_back();
    task_type = TaskType::ACTOR_CREATION_TASK;
    SetActorId(task_spec.ActorCreationId());
    {
      std::unique_ptr<ActorHandle> self_actor_handle(
          new ActorHandle(task_spec.GetSerializedActorHandle()));
      // Register the handle to the current actor itself.
      actor_manager_->RegisterActorHandle(std::move(self_actor_handle), ObjectID::Nil(),
                                          CurrentCallSite(), rpc_address_,
                                          /*is_self=*/true);
    }
    RAY_LOG(INFO) << "Creating actor: " << task_spec.ActorCreationId();
  } else if (task_spec.IsActorTask()) {
    RAY_CHECK(return_ids.size() > 0);
    return_ids.pop_back();
    task_type = TaskType::ACTOR_TASK;
  }

  // Because we support concurrent actor calls, we need to update the
  // worker ID for the current thread.
  CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());

  std::shared_ptr<LocalMemoryBuffer> creation_task_exception_pb_bytes = nullptr;

  std::vector<ConcurrencyGroup> defined_concurrency_groups = {};
  std::string name_of_concurrency_group_to_execute;
  if (task_spec.IsActorCreationTask()) {
    defined_concurrency_groups = task_spec.ConcurrencyGroups();
  } else if (task_spec.IsActorTask()) {
    name_of_concurrency_group_to_execute = task_spec.ConcurrencyGroupName();
  }

  // 此处调用进行任务的处理
  status = options_.task_execution_callback(
      task_type, task_spec.GetName(), func,
      task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs,
      return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
      creation_task_exception_pb_bytes, is_application_level_error,
      defined_concurrency_groups, name_of_concurrency_group_to_execute);

  // Get the reference counts for any IDs that we borrowed during this task,
  // remove the local reference for these IDs, and return the ref count info to
  // the caller. This will notify the caller of any IDs that we (or a nested
  // task) are still borrowing. It will also notify the caller of any new IDs
  // that were contained in a borrowed ID that we (or a nested task) are now
  // borrowing.
  std::vector<ObjectID> deleted;
  if (!borrowed_ids.empty()) {
    reference_counter_->PopAndClearLocalBorrowers(borrowed_ids, borrowed_refs, &deleted);
  }
  memory_store_->Delete(deleted);

  if (task_spec.IsNormalTask() && reference_counter_->NumObjectIDsInScope() != 0) {
    RAY_LOG(DEBUG)
        << "There were " << reference_counter_->NumObjectIDsInScope()
        << " ObjectIDs left in scope after executing task " << task_spec.TaskId()
        << ". This is either caused by keeping references to ObjectIDs in Python "
           "between "
           "tasks (e.g., in global variables) or indicates a problem with Ray's "
           "reference counting, and may cause problems in the object store.";
  }

  if (!options_.is_local_mode) {
    SetCurrentTaskId(TaskID::Nil());
    worker_context_.ResetCurrentTask();
  }
  {
    absl::MutexLock lock(&mutex_);
    current_task_ = TaskSpecification();
    if (task_spec.IsNormalTask()) {
      resource_ids_.reset(new ResourceMappingType());
    }
  }

  // Modify the worker's per function counters.
  {
    absl::MutexLock l(&task_counter_.tasks_counter_mutex_);
    task_counter_.Add(TaskCounter::kRunning, func_name, -1);
    task_counter_.Add(TaskCounter::kFinished, func_name, 1);
  }

  RAY_LOG(DEBUG) << "Finished executing task " << task_spec.TaskId()
                 << ", status=" << status;
  if (status.IsCreationTaskError()) {
    Exit(rpc::WorkerExitType::CREATION_TASK_ERROR, creation_task_exception_pb_bytes);
  } else if (status.IsIntentionalSystemExit()) {
    Exit(rpc::WorkerExitType::INTENDED_EXIT, creation_task_exception_pb_bytes);
  } else if (status.IsUnexpectedSystemExit()) {
    Exit(rpc::WorkerExitType::SYSTEM_ERROR_EXIT, creation_task_exception_pb_bytes);
  } else if (!status.ok()) {
    RAY_LOG(FATAL) << "Unexpected task status type : " << status;
  }

  return status;
}

以上对任务处理的关键代码为;

1
2
3
4
5
6
7
// core_worker.cc 2318
  status = options_.task_execution_callback(
      task_type, task_spec.GetName(), func,
      task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs,
      return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
      creation_task_exception_pb_bytes, is_application_level_error,
      defined_concurrency_groups, name_of_concurrency_group_to_execute);

其中 options_ 是结构体 CoreWorkerOptions, 定义如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// If you change this options's definition, you must change the options used in
// other files. Please take a global search and modify them !!!
struct CoreWorkerOptions {
  // Callback that must be implemented and provided by the language-specific worker
  // frontend to execute tasks and return their results.
  using TaskExecutionCallback = std::function<Status(
      TaskType task_type, const std::string task_name, const RayFunction &ray_function,
      const std::unordered_map<std::string, double> &required_resources,
      const std::vector<std::shared_ptr<RayObject>> &args,
      const std::vector<rpc::ObjectReference> &arg_refs,
      const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
      std::vector<std::shared_ptr<RayObject>> *results,
      std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
      bool *is_application_level_error,
      // The following 2 parameters `defined_concurrency_groups` and
      // `name_of_concurrency_group_to_execute` are used for Python
      // asyncio actor only.
      //
      // Defined concurrency groups of this actor. Note this is only
      // used for actor creation task.
      const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
      const std::string name_of_concurrency_group_to_execute)>;

  CoreWorkerOptions()
      : store_socket(""),
        raylet_socket(""),
        enable_logging(false),
        log_dir(""),
        install_failure_signal_handler(false),
        interactive(false),
        node_ip_address(""),
        node_manager_port(0),
        raylet_ip_address(""),
        driver_name(""),
        stdout_file(""),
        stderr_file(""),
        task_execution_callback(nullptr),
        check_signals(nullptr),
        gc_collect(nullptr),
        spill_objects(nullptr),
        restore_spilled_objects(nullptr),
        delete_spilled_objects(nullptr),
        unhandled_exception_handler(nullptr),
        get_lang_stack(nullptr),
        kill_main(nullptr),
        is_local_mode(false),
        num_workers(0),
        terminate_asyncio_thread(nullptr),
        serialized_job_config(""),
        metrics_agent_port(-1),
        connect_on_start(true),
        runtime_env_hash(0),
        worker_shim_pid(0) {}

  /// Type of this worker (i.e., DRIVER or WORKER).
  WorkerType worker_type;
  /// Application language of this worker (i.e., PYTHON or JAVA).
  Language language;
  /// Object store socket to connect to.
  std::string store_socket;
  /// Raylet socket to connect to.
  std::string raylet_socket;
  /// Job ID of this worker.
  JobID job_id;
  /// Options for the GCS client.
  gcs::GcsClientOptions gcs_options;
  /// Initialize logging if true. Otherwise, it must be initialized and cleaned up by the
  /// caller.
  bool enable_logging;
  /// Directory to write logs to. If this is empty, logs won't be written to a file.
  std::string log_dir;
  /// If false, will not call `RayLog::InstallFailureSignalHandler()`.
  bool install_failure_signal_handler;
  /// Whether this worker is running in a tty.
  bool interactive;
  /// IP address of the node.
  std::string node_ip_address;
  /// Port of the local raylet.
  int node_manager_port;
  /// IP address of the raylet.
  std::string raylet_ip_address;
  /// The name of the driver.
  std::string driver_name;
  /// The stdout file of this process.
  std::string stdout_file;
  /// The stderr file of this process.
  std::string stderr_file;
  /// Language worker callback to execute tasks.
  TaskExecutionCallback task_execution_callback;
  /// The callback to be called when shutting down a `CoreWorker` instance.
  std::function<void(const WorkerID &)> on_worker_shutdown;
  /// Application-language callback to check for signals that have been received
  /// since calling into C++. This will be called periodically (at least every
  /// 1s) during long-running operations. If the function returns anything but StatusOK,
  /// any long-running operations in the core worker will short circuit and return that
  /// status.
  std::function<Status()> check_signals;
  /// Application-language callback to trigger garbage collection in the language
  /// runtime. This is required to free distributed references that may otherwise
  /// be held up in garbage objects.
  std::function<void()> gc_collect;
  /// Application-language callback to spill objects to external storage.
  std::function<std::vector<std::string>(const std::vector<rpc::ObjectReference> &)>
      spill_objects;
  /// Application-language callback to restore objects from external storage.
  std::function<int64_t(const std::vector<rpc::ObjectReference> &,
                        const std::vector<std::string> &)>
      restore_spilled_objects;
  /// Application-language callback to delete objects from external storage.
  std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
      delete_spilled_objects;
  /// Function to call on error objects never retrieved.
  std::function<void(const RayObject &error)> unhandled_exception_handler;
  /// Language worker callback to get the current call stack.
  std::function<void(std::string *)> get_lang_stack;
  // Function that tries to interrupt the currently running Python thread.
  std::function<bool()> kill_main;
  /// Is local mode being used.
  bool is_local_mode;
  /// The number of workers to be started in the current process.
  int num_workers;
  /// The function to destroy asyncio event and loops.
  std::function<void()> terminate_asyncio_thread;
  /// Serialized representation of JobConfig.
  std::string serialized_job_config;
  /// The port number of a metrics agent that imports metrics from core workers.
  /// -1 means there's no such agent.
  int metrics_agent_port;
  /// If false, the constructor won't connect and notify raylets that it is
  /// ready. It should be explicitly startd by a caller using CoreWorker::Start.
  /// TODO(sang): Use this method for Java and cpp frontend too.
  bool connect_on_start;
  /// The hash of the runtime env for this worker.
  int runtime_env_hash;
  /// The PID of the process for setup worker runtime env.
  pid_t worker_shim_pid;
  /// The startup token of the process assigned to it
  /// during startup via command line arguments.
  /// This is needed because the actual core worker process
  /// may not have the same pid as the process the worker pool
  /// starts (due to shim processes).
  StartupToken startup_token{0};
};

由上可看出, task_execution_callback 是一个函数绑定 TaskExecutionCallback. 由此回溯查找对 options_ 的实例赋值,找到是在 CoreWorker 的构造函数中

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// core_worker.cc 402
WoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id)
    : options_(options),
      get_call_site_(RayConfig::instance().record_ref_creation_sites()
                         ? options_.get_lang_stack
                         : nullptr),
      worker_context_(options_.worker_type, worker_id, GetProcessJobID(options_)),
      io_work_(io_service_),
      client_call_manager_(new rpc::ClientCallManager(io_service_)),
      periodical_runner_(io_service_),
      task_queue_length_(0),
      num_executed_tasks_(0),
      resource_ids_(new ResourceMappingType()),
      grpc_service_(io_service_, *this),
      task_execution_service_work_(task_execution_service_)

继续回溯查找对 CoreWorker 实例化的调用, 可知位于 CoreWorkerProcess::RunTaskExecutionLoop 中, 其中调用了 CreateWoker 函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
// core_worker.cc 336
std::shared_ptr<CoreWorker> CoreWorkerProcess::CreateWorker() {
  auto worker = std::make_shared<CoreWorker>(
      options_,
      global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom());
  RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created.";
  absl::WriterMutexLock lock(&mutex_);
  if (options_.num_workers == 1) {
    global_worker_ = worker;
  }
  current_core_worker_ = worker;

  workers_.emplace(worker->GetWorkerID(), worker);
  RAY_CHECK(workers_.size() <= static_cast<size_t>(options_.num_workers));
  return worker;
}

其中使用的是 CoreWorkerProcessoptions_, 继续在CoreWorkerProcess查找, 发现对 options_ 的赋值位于 CoreWorkerProcess 的构造函数

1
2
3
4
5
6
7
// core_worker.cc 95
CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options)
    : options_(options),
      global_worker_id_(
          options.worker_type == WorkerType::DRIVER
              ? ComputeDriverIdFromJob(options_.job_id)
              : (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil()))

所以需要回到在 default_worker.py 文件中查看对于 CoreWorkerProcess 是如何实例化的。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    # default_worker.py 203
    # Add code search path to sys.path, set load_code_from_local.
    core_worker = ray.worker.global_worker.core_worker
    code_search_path = core_worker.get_job_config().code_search_path


    load_code_from_local = False
    if code_search_path:
        load_code_from_local = True
        for p in code_search_path:
            if os.path.isfile(p):
                p = os.path.dirname(p)
            sys.path.insert(0, p)
    ray.worker.global_worker.set_load_code_from_local(load_code_from_local)

    # Setup log file.
    out_file, err_file = node.get_log_file_handles(
        get_worker_log_file_name(args.worker_type))
    configure_log_file(out_file, err_file)

    if mode == ray.WORKER_MODE:
        ray.worker.global_worker.main_loop()
    elif mode in [ray.RESTORE_WORKER_MODE, ray.SPILL_WORKER_MODE]:
        # It is handled by another thread in the C++ core worker.
        # We just need to keep the worker alive.
        while True:
            time.sleep(100000)
    else:
        raise ValueError(f"Unexcepted worker mode: {mode}")

其中 core_worker = ray.worker.global_worker.core_worker 的 global_worker 是一个全局变量来至:

1
2
3
4
5
6
7
# worker.py 566
global_worker = Worker()
"""Worker: The global Worker object for this worker process.

We use a global Worker object to ensure that there is a single worker object
per worker process.
"""

然后在类 Worker 的实现中并没有发现 core_worker 的声明只有使用,所以猜测可能是在实例化 global_worker 后添加了 core_worker 这个成员变量。成功的找到了添加该成员的位置:

1
2
3
4
5
6
7
8
# worker.py 1418
    worker.core_worker = ray._raylet.CoreWorker(
        mode, node.plasma_store_socket_name, node.raylet_socket_name, job_id,
        gcs_options, node.get_logs_dir_path(), node.node_ip_address,
        node.node_manager_port, node.raylet_ip_address, (mode == LOCAL_MODE),
        driver_name, log_stdout_file_path, log_stderr_file_path,
        serialized_job_config, node.metrics_agent_port, runtime_env_hash,
        worker_shim_pid, startup_token)

此处调用的应该是 _raylet.pyx 内的 CoreWorker 的定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# _raylet.pyx 1025_
cdef class CoreWorker:

    def __cinit__(self, worker_type, store_socket, raylet_socket,
                  JobID job_id, GcsClientOptions gcs_options, log_dir,
                  node_ip_address, node_manager_port, raylet_ip_address,
                  local_mode, driver_name, stdout_file, stderr_file,
                  serialized_job_config, metrics_agent_port, runtime_env_hash,
                  worker_shim_pid, startup_token):
        self.is_local_mode = local_mode

        cdef CCoreWorkerOptions options = CCoreWorkerOptions()
        if worker_type in (ray.LOCAL_MODE, ray.SCRIPT_MODE):
            self.is_driver = True
            options.worker_type = WORKER_TYPE_DRIVER
        elif worker_type == ray.WORKER_MODE:
            self.is_driver = False
            options.worker_type = WORKER_TYPE_WORKER
        elif worker_type == ray.SPILL_WORKER_MODE:
            self.is_driver = False
            options.worker_type = WORKER_TYPE_SPILL_WORKER
        elif worker_type == ray.RESTORE_WORKER_MODE:
            self.is_driver = False
            options.worker_type = WORKER_TYPE_RESTORE_WORKER
        else:
            raise ValueError(f"Unknown worker type: {worker_type}")
        options.language = LANGUAGE_PYTHON
        options.store_socket = store_socket.encode("ascii")
        options.raylet_socket = raylet_socket.encode("ascii")
        options.job_id = job_id.native()
        options.gcs_options = gcs_options.native()[0]
        options.enable_logging = True
        options.log_dir = log_dir.encode("utf-8")
        options.install_failure_signal_handler = True
        # https://stackoverflow.com/questions/2356399/tell-if-python-is-in-interactive-mode
        options.interactive = hasattr(sys, "ps1")
        options.node_ip_address = node_ip_address.encode("utf-8")
        options.node_manager_port = node_manager_port
        options.raylet_ip_address = raylet_ip_address.encode("utf-8")
        options.driver_name = driver_name
        options.stdout_file = stdout_file
        options.stderr_file = stderr_file
        options.task_execution_callback = task_execution_handler
        options.check_signals = check_signals
        options.gc_collect = gc_collect
        options.spill_objects = spill_objects_handler
        options.restore_spilled_objects = restore_spilled_objects_handler
        options.delete_spilled_objects = delete_spilled_objects_handler
        options.unhandled_exception_handler = unhandled_exception_handler
        options.get_lang_stack = get_py_stack
        options.is_local_mode = local_mode
        options.num_workers = 1
        options.kill_main = kill_main_task
        options.terminate_asyncio_thread = terminate_asyncio_thread
        options.serialized_job_config = serialized_job_config
        options.metrics_agent_port = metrics_agent_port
        options.connect_on_start = False
        options.runtime_env_hash = runtime_env_hash
        options.worker_shim_pid = worker_shim_pid
        options.startup_token = startup_token
        CCoreWorkerProcess.Initialize(options)

        self.cgname_to_eventloop_dict = None
        self.fd_to_cgname_dict = None
        self.eventloop_for_default_cg = None

可知传入 CoreWorkerProcess 构造函数的 options 为 CCoreWorkerOptions, 且进行了大量的成员变量的赋值操作。其中 CCoreWorkerOptions 就是通过 C++ 实现的 CoreWorkerOptions, 如下:

1
2
# libcoreworker.pxd 259
cdef cppclass CCoreWorkerOptions "ray::core::CoreWorkerOptions":

在上面给 options 赋值操作中,对 task_execution_callback 的赋值如下:

1
2
# _raylet.pyx 1066_
options.task_execution_callback = task_execution_handler

其中 task_execution_handler 为:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# _raylet.pyx 738_
cdef CRayStatus task_execution_handler(
        CTaskType task_type,
        const c_string task_name,
        const CRayFunction &ray_function,
        const unordered_map[c_string, double] &c_resources,
        const c_vector[shared_ptr[CRayObject]] &c_args,
        const c_vector[CObjectReference] &c_arg_refs,
        const c_vector[CObjectID] &c_return_ids,
        const c_string debugger_breakpoint,
        c_vector[shared_ptr[CRayObject]] *returns,
        shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
        c_bool *is_application_level_error,
        const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
        const c_string name_of_concurrency_group_to_execute) nogil:
    with gil, disable_client_hook():
        try:
            try:
                # The call to execute_task should never raise an exception. If
                # it does, that indicates that there was an internal error.
                execute_task(task_type, task_name, ray_function, c_resources,
                             c_args, c_arg_refs, c_return_ids,
                             debugger_breakpoint, returns,
                             is_application_level_error,
                             defined_concurrency_groups,
                             name_of_concurrency_group_to_execute)
            except Exception as e:
                sys_exit = SystemExit()
                if isinstance(e, RayActorError) and \
                   e.has_creation_task_error():
                    traceback_str = str(e)
                    logger.error("Exception raised "
                                 f"in creation task: {traceback_str}")
                    # Cython's bug that doesn't allow reference assignment,
                    # this is a workaroud.
                    # See https://github.com/cython/cython/issues/1863
                    (&creation_task_exception_pb_bytes)[0] = (
                        ray_error_to_memory_buf(e))
                    sys_exit.is_creation_task_error = True
                else:
                    traceback_str = traceback.format_exc() + (
                        "An unexpected internal error "
                        "occurred while the worker "
                        "was executing a task.")
                    ray._private.utils.push_error_to_driver(
                        ray.worker.global_worker,
                        "worker_crash",
                        traceback_str,
                        job_id=None)
                raise sys_exit
        except SystemExit as e:
            # Tell the core worker to exit as soon as the result objects
            # are processed.
            if hasattr(e, "is_ray_terminate"):
                return CRayStatus.IntentionalSystemExit()
            elif hasattr(e, "is_creation_task_error"):
                return CRayStatus.CreationTaskError()
            elif e.code and e.code == 0:
                # This means the system exit was
                # normal based on the python convention.
                # https://docs.python.org/3/library/sys.html#sys.exit
                return CRayStatus.IntentionalSystemExit()
            else:
                logger.exception("SystemExit was raised from the worker")
                return CRayStatus.UnexpectedSystemExit()

    return CRayStatus.OK()

此处 task_execution_handler 内的实现调用了 execute_task 函数, 如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# _raylet.pyx 444_
cdef execute_task(
        CTaskType task_type,
        const c_string name,
        const CRayFunction &ray_function,
        const unordered_map[c_string, double] &c_resources,
        const c_vector[shared_ptr[CRayObject]] &c_args,
        const c_vector[CObjectReference] &c_arg_refs,
        const c_vector[CObjectID] &c_return_ids,
        const c_string debugger_breakpoint,
        c_vector[shared_ptr[CRayObject]] *returns,
        c_bool *is_application_level_error,
        # This parameter is only used for actor creation task to define
        # the concurrency groups of this actor.
        const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups,
        const c_string c_name_of_concurrency_group_to_execute):

        #...
        # 488
        function_descriptor = CFunctionDescriptorToPython(
        ray_function.GetFunctionDescriptor())
        #...
        # 514 此处的execution_info 在第一次调用时,获取的是 NoneType, 第二次调用即为<class 'ray._private.function_manager.FunctionExecutionInfo'>
        execution_info = execution_infos.get(function_descriptor)
        if not execution_info:
          execution_info = manager.get_execution_info(
            job_id, function_descriptor)
          execution_infos[function_descriptor] = execution_info
        #...

对类 FunctionExecutionInfo 进行分析,可得其只是一个元组,如下:

1
2
3
4
# function_manager.py 33
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
                                   ["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""

worker server

在 CoreWorker 的构造函数中启动了 core_worker_server.

1
2
3
4
5
6
7
8
// core_worker.cc 465
  // Start RPC server after all the task receivers are properly initialized and we have
  // our assigned port from the raylet.
  core_worker_server_ = std::make_unique<rpc::GrpcServer>(
      WorkerTypeString(options_.worker_type), assigned_port,
      options_.node_ip_address == "127.0.0.1");
  core_worker_server_->RegisterService(grpc_service_);
  core_worker_server_->Run();

我们进一步分析其所注册的 service , 可知 grpc_service_ 声明如下:

1
2
3
// core_worker.h 1403
  /// Common rpc service for all worker modules.
  rpc::CoreWorkerGrpcService grpc_service_;

通过分析 RPC 相关的代码,这里便是注册了 CoreWorkerService 服务,定义位于 core_worker.proto

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
service CoreWorkerService {
  // Push a task directly to this worker from another.
  rpc PushTask(PushTaskRequest) returns (PushTaskReply);
  // Steal tasks from a worker if it has a surplus of work
  rpc StealTasks(StealTasksRequest) returns (StealTasksReply);
  // Reply from raylet that wait for direct actor call args has completed.
  rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
      returns (DirectActorCallArgWaitCompleteReply);
  // Ask the object's owner about the object's current status.
  rpc GetObjectStatus(GetObjectStatusRequest) returns (GetObjectStatusReply);
  // Wait for the actor's owner to decide that the actor has gone out of scope.
  // Replying to this message indicates that the client should force-kill the
  // actor process, if still alive.
  rpc WaitForActorOutOfScope(WaitForActorOutOfScopeRequest)
      returns (WaitForActorOutOfScopeReply);
  /// The long polling request sent to the core worker for pubsub operations.
  /// It is replied once there are batch of objects that need to be published to
  /// the caller (subscriber).
  rpc PubsubLongPolling(PubsubLongPollingRequest) returns (PubsubLongPollingReply);
  /// The pubsub command batch request used by the subscriber.
  rpc PubsubCommandBatch(PubsubCommandBatchRequest) returns (PubsubCommandBatchReply);
  // Update the batched object location information to the ownership-based object
  // directory.
  rpc UpdateObjectLocationBatch(UpdateObjectLocationBatchRequest)
      returns (UpdateObjectLocationBatchReply);
  // Get object locations from the ownership-based object directory.
  rpc GetObjectLocationsOwner(GetObjectLocationsOwnerRequest)
      returns (GetObjectLocationsOwnerReply);
  // Request that the worker shut down without completing outstanding work.
  rpc KillActor(KillActorRequest) returns (KillActorReply);
  // Request that a worker cancels a task.
  rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply);
  // Request for a worker to issue a cancelation.
  rpc RemoteCancelTask(RemoteCancelTaskRequest) returns (RemoteCancelTaskReply);
  // Get metrics from core workers.
  rpc GetCoreWorkerStats(GetCoreWorkerStatsRequest) returns (GetCoreWorkerStatsReply);
  // Trigger local GC on the worker.
  rpc LocalGC(LocalGCRequest) returns (LocalGCReply);
  // Spill objects to external storage. Caller: raylet; callee: I/O worker.
  rpc SpillObjects(SpillObjectsRequest) returns (SpillObjectsReply);
  // Restore spilled objects from external storage. Caller: raylet; callee: I/O worker.
  rpc RestoreSpilledObjects(RestoreSpilledObjectsRequest)
      returns (RestoreSpilledObjectsReply);
  // Delete spilled objects from external storage. Caller: raylet; callee: I/O worker.
  rpc DeleteSpilledObjects(DeleteSpilledObjectsRequest)
      returns (DeleteSpilledObjectsReply);
  // Add spilled URL, spilled node ID, and update object size for owned object.
  // Caller: raylet; callee: owner worker.
  rpc AddSpilledUrl(AddSpilledUrlRequest) returns (AddSpilledUrlReply);
  // Notification from raylet that an object ID is available in local plasma.
  rpc PlasmaObjectReady(PlasmaObjectReadyRequest) returns (PlasmaObjectReadyReply);
  // Request for a worker to exit.
  rpc Exit(ExitRequest) returns (ExitReply);
  // Assign the owner of an object to the intended worker.
  rpc AssignObjectOwner(AssignObjectOwnerRequest) returns (AssignObjectOwnerReply);
}

通过 RPC 模块,将 CoreWorkerService 内包含的函数的实现都变成了 Handle* 例如, HandlePushTask.

所以我们查找 RPC 函数 PushTask 在服务端的实现便是 HandlePushTask, 实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// core_worker.cc 2527
void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
                                rpc::PushTaskReply *reply,
                                rpc::SendReplyCallback send_reply_callback) {
  if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()),
                           send_reply_callback)) {
    return;
  }

  // Increment the task_queue_length and per function counter.
  task_queue_length_ += 1;
  std::string func_name =
      FunctionDescriptorBuilder::FromProto(request.task_spec().function_descriptor())
          ->CallString();
  {
    absl::MutexLock l(&task_counter_.tasks_counter_mutex_);
    task_counter_.Add(TaskCounter::kPending, func_name, 1);
  }

  // For actor tasks, we just need to post a HandleActorTask instance to the task
  // execution service.
  if (request.task_spec().type() == TaskType::ACTOR_TASK) {
    task_execution_service_.post(
        [this, request, reply, send_reply_callback = std::move(send_reply_callback)] {
          // We have posted an exit task onto the main event loop,
          // so shouldn't bother executing any further work.
          if (exiting_) return;
          direct_task_receiver_->HandleTask(request, reply, send_reply_callback);
        },
        "CoreWorker.HandlePushTaskActor");
  } else {
    // Normal tasks are enqueued here, and we post a RunNormalTasksFromQueue instance to
    // the task execution service.
    direct_task_receiver_->HandleTask(request, reply, send_reply_callback);
    task_execution_service_.post(
        [=] {
          // We have posted an exit task onto the main event loop,
          // so shouldn't bother executing any further work.
          if (exiting_) return;
          direct_task_receiver_->RunNormalTasksFromQueue();
        },
        "CoreWorker.HandlePushTask");
  }
}

FunctionActorManager

在 function_manager.py 内存在类 FunctionActorManager 查看描述应为关键类,如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# function_manager.py 40
class FunctionActorManager:
    """A class used to export/load remote functions and actors.
    Attributes:
        _worker: The associated worker that this manager related.
        _functions_to_export: The remote functions to export when
            the worker gets connected.
        _actors_to_export: The actors to export when the worker gets
            connected.
        _function_execution_info: The function_id
            and execution_info.
        _num_task_executions: The function
            execution times.
        imported_actor_classes: The set of actor classes keys (format:
            ActorClass:function_id) that are already in GCS.
    """

    def __init__(self, worker):
        self._worker = worker
        self._functions_to_export = []
        self._actors_to_export = []
        # This field is a dictionary that maps function IDs
        # to a FunctionExecutionInfo object. This should only be used on
        # workers that execute remote functions.
        self._function_execution_info = defaultdict(lambda: {})
        self._num_task_executions = defaultdict(lambda: {})
        # A set of all of the actor class keys that have been imported by the
        # import thread. It is safe to convert this worker into an actor of
        # these types.
        self.imported_actor_classes = set()
        self._loaded_actor_classes = {}
        # Deserialize an ActorHandle will call load_actor_class(). If a
        # function closure captured an ActorHandle, the deserialization of the
        # function will be:
        #     import_thread.py
        #         -> fetch_and_register_remote_function (acquire lock)
        #         -> _load_actor_class_from_gcs (acquire lock, too)
        # So, the lock should be a reentrant lock.
        self.lock = threading.RLock()
        self.cv = threading.Condition(lock=self.lock)
        self.execution_infos = {}

在类 FunctionActorManager 中有个函数 export, 定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# function_manager.py 130
    def export(self, remote_function):
        """Pickle a remote function and export it to redis.
        Args:
            remote_function: the RemoteFunction object.
        """
        if self._worker.load_code_from_local:
            function_descriptor = remote_function._function_descriptor
            module_name, function_name = (
                function_descriptor.module_name,
                function_descriptor.function_name,
            )
            # If the function is dynamic, we still export it to GCS
            # even if load_code_from_local is set True.
            if self.load_function_or_class_from_local(
                    module_name, function_name) is not None:
                return
        function = remote_function._function
        pickled_function = remote_function._pickled_function

        check_oversized_function(pickled_function,
                                 remote_function._function_name,
                                 "remote function", self._worker)
        key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
               + remote_function._function_descriptor.function_id.binary())
        if self._worker.redis_client.exists(key) == 1:
            return
        self._worker.redis_client.hset(
            key,
            mapping={
                "job_id": self._worker.current_job_id.binary(),
                "function_id": remote_function._function_descriptor.
                function_id.binary(),
                "function_name": remote_function._function_name,
                "module": function.__module__,
                "function": pickled_function,
                "collision_identifier": self.compute_collision_identifier(
                    function),
                "max_calls": remote_function._max_calls
            })
        self._worker.redis_client.rpush("Exports", key)

在单节点运行 plot_pong.py 时,该函数没有被调用。可能时在跨节点的时候才会调用该函数(待经一步考证)

在类 FunctionActorManager 中有个函数 get_execution_info, 定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# function_manager.py #248
   def get_execution_info(self, job_id, function_descriptor):
        """Get the FunctionExecutionInfo of a remote function.
        Args:
            job_id: ID of the job that the function belongs to.
            function_descriptor: The FunctionDescriptor of the function to get.
        Returns:
            A FunctionExecutionInfo object.
        """
        function_id = function_descriptor.function_id
        # If the function has already been loaded,
        # There's no need to load again
        if function_id in self._function_execution_info:
            return self._function_execution_info[function_id]
        if self._worker.load_code_from_local:
            # Load function from local code.
            if not function_descriptor.is_actor_method():
                # If the function is not able to be loaded,
                # try to load it from GCS,
                # even if load_code_from_local is set True
                if self._load_function_from_local(function_descriptor) is True:
                    return self._function_execution_info[function_id]
        # Load function from GCS.
        # Wait until the function to be executed has actually been
        # registered on this worker. We will push warnings to the user if
        # we spend too long in this loop.
        # The driver function may not be found in sys.path. Try to load
        # the function from GCS.
        with profiling.profile("wait_for_function"):
            self._wait_for_function(function_descriptor, job_id)
        try:
            function_id = function_descriptor.function_id
            info = self._function_execution_info[function_id]
        except KeyError as e:
            message = ("Error occurs in get_execution_info: "
                       "job_id: %s, function_descriptor: %s. Message: %s" %
                       (job_id, function_descriptor, e))
            raise KeyError(message)
        return info

其中 function_descriptor 的类型是 <class 'ray._raylet.PythonFunctionDescriptor'>

上函数还调用了 _load_function_from_local 函数,定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# function_manager.py 287
    def _load_function_from_local(self, function_descriptor):
        assert not function_descriptor.is_actor_method()
        function_id = function_descriptor.function_id

        module_name, function_name = (
            function_descriptor.module_name,
            function_descriptor.function_name,
        )

        object = self.load_function_or_class_from_local(
            module_name, function_name)
        if object is not None:
            function = object._function
            self._function_execution_info[function_id] = (
                FunctionExecutionInfo(
                    function=function,
                    function_name=function_name,
                    max_calls=0,
                ))
            self._num_task_executions[function_id] = 0
            return True
        else:
            return False

# function_manager.py 117
    def load_function_or_class_from_local(self, module_name,
                                          function_or_class_name):
        """Try to load a function or class in the module from local."""
        module = importlib.import_module(module_name)
        parts = [part for part in function_or_class_name.split(".") if part]
        object = module
        try:
            for part in parts:
                object = getattr(object, part)
            return object
        except Exception:
            return None

针对于 Actor 任务,有函数 load_actor_class 如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# function_manager.py 416
    def load_actor_class(self, job_id, actor_creation_function_descriptor):
        """Load the actor class.
        Args:
            job_id: job ID of the actor.
            actor_creation_function_descriptor: Function descriptor of
                the actor constructor.
        Returns:
            The actor class.
        """
        function_id = actor_creation_function_descriptor.function_id
        # Check if the actor class already exists in the cache.
        actor_class = self._loaded_actor_classes.get(function_id, None)
        if actor_class is None:
            # Load actor class.
            if self._worker.load_code_from_local:
                # Load actor class from local code first.
                actor_class = self._load_actor_class_from_local(
                    actor_creation_function_descriptor)
                # If the actor is unable to be loaded
                # from local, try to load it
                # from GCS even if load_code_from_local is set True
                if actor_class is None:
                    actor_class = self._load_actor_class_from_gcs(
                        job_id, actor_creation_function_descriptor)

            else:
                # Load actor class from GCS.
                actor_class = self._load_actor_class_from_gcs(
                    job_id, actor_creation_function_descriptor)
            # Save the loaded actor class in cache.
            self._loaded_actor_classes[function_id] = actor_class

            # Generate execution info for the methods of this actor class.
            module_name = actor_creation_function_descriptor.module_name
            actor_class_name = actor_creation_function_descriptor.class_name
            actor_methods = inspect.getmembers(
                actor_class, predicate=is_function_or_method)
            for actor_method_name, actor_method in actor_methods:
                # Actor creation function descriptor use a unique function
                # hash to solve actor name conflict. When constructing an
                # actor, the actor creation function descriptor will be the
                # key to find __init__ method execution info. So, here we
                # use actor creation function descriptor as method descriptor
                # for generating __init__ method execution info.
                if actor_method_name == "__init__":
                    method_descriptor = actor_creation_function_descriptor
                else:
                    method_descriptor = PythonFunctionDescriptor(
                        module_name, actor_method_name, actor_class_name)
                method_id = method_descriptor.function_id
                executor = self._make_actor_method_executor(
                    actor_method_name,
                    actor_method,
                    actor_imported=True,
                )
                self._function_execution_info[method_id] = (
                    FunctionExecutionInfo(
                        function=executor,
                        function_name=actor_method_name,
                        max_calls=0,
                    ))
                self._num_task_executions[method_id] = 0
            self._num_task_executions[function_id] = 0
        return actor_class

    def _load_actor_class_from_local(self, actor_creation_function_descriptor):
        """Load actor class from local code."""
        module_name, class_name = (
            actor_creation_function_descriptor.module_name,
            actor_creation_function_descriptor.class_name)

        object = self.load_function_or_class_from_local(
            module_name, class_name)

        if object is not None:
            if isinstance(object, ray.actor.ActorClass):
                return object.__ray_metadata__.modified_class
            else:
                return object
        else:
            return None

通过 cloudpickle 来将 python 代码进行序列化,和反序列化,进行类和函数的发送。

updatedupdated2021-12-212021-12-21