Ray_RPC Note

本文介绍 ray 框架关于 RPC 的使用

grpc_server

宏定义

首先定义了两种宏

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// core_worker_server.h 30
#define RAY_CORE_WORKER_RPC_HANDLERS                                         \
  RPC_SERVICE_HANDLER(CoreWorkerService, PushTask, -1)                       \
  RPC_SERVICE_HANDLER(CoreWorkerService, StealTasks, -1)                     \
  RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete, -1) \
  RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus, -1)                \
  RPC_SERVICE_HANDLER(CoreWorkerService, WaitForActorOutOfScope, -1)         \
  RPC_SERVICE_HANDLER(CoreWorkerService, PubsubLongPolling, -1)              \
  RPC_SERVICE_HANDLER(CoreWorkerService, PubsubCommandBatch, -1)             \
  RPC_SERVICE_HANDLER(CoreWorkerService, UpdateObjectLocationBatch, -1)      \
  RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectLocationsOwner, -1)        \
  RPC_SERVICE_HANDLER(CoreWorkerService, KillActor, -1)                      \
  RPC_SERVICE_HANDLER(CoreWorkerService, CancelTask, -1)                     \
  RPC_SERVICE_HANDLER(CoreWorkerService, RemoteCancelTask, -1)               \
  RPC_SERVICE_HANDLER(CoreWorkerService, GetCoreWorkerStats, -1)             \
  RPC_SERVICE_HANDLER(CoreWorkerService, LocalGC, -1)                        \
  RPC_SERVICE_HANDLER(CoreWorkerService, SpillObjects, -1)                   \
  RPC_SERVICE_HANDLER(CoreWorkerService, RestoreSpilledObjects, -1)          \
  RPC_SERVICE_HANDLER(CoreWorkerService, DeleteSpilledObjects, -1)           \
  RPC_SERVICE_HANDLER(CoreWorkerService, AddSpilledUrl, -1)                  \
  RPC_SERVICE_HANDLER(CoreWorkerService, PlasmaObjectReady, -1)              \
  RPC_SERVICE_HANDLER(CoreWorkerService, Exit, -1)                           \
  RPC_SERVICE_HANDLER(CoreWorkerService, AssignObjectOwner, -1)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// core_worker_server.h 53
#define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS                              \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask)                       \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(StealTasks)                     \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus)                \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForActorOutOfScope)         \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubLongPolling)              \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubCommandBatch)             \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(UpdateObjectLocationBatch)      \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectLocationsOwner)        \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor)                      \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask)                     \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RemoteCancelTask)               \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetCoreWorkerStats)             \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(LocalGC)                        \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(SpillObjects)                   \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RestoreSpilledObjects)          \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DeleteSpilledObjects)           \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(AddSpilledUrl)                  \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PlasmaObjectReady)              \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(Exit)                           \
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(AssignObjectOwner)

其中嵌套的宏定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
/// \param MAX_ACTIVE_RPCS Maximum number of RPCs to handle at the same time. -1 means no
/// limit.
#define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS)                  \
  std::unique_ptr<ServerCallFactory> HANDLER##_call_factory(                    \
      new ServerCallFactoryImpl<SERVICE, SERVICE##Handler, HANDLER##Request,    \
                                HANDLER##Reply>(                                \
          service_, &SERVICE::AsyncService::Request##HANDLER, service_handler_, \
          &SERVICE##Handler::Handle##HANDLER, cq, main_service_,                \
          #SERVICE ".grpc_server." #HANDLER, MAX_ACTIVE_RPCS));                 \
  server_call_factories->emplace_back(std::move(HANDLER##_call_factory));

// Define a void RPC client method.
#define DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(METHOD)                   \
  virtual void Handle##METHOD(const ::ray::rpc::METHOD##Request &request, \
                              ::ray::rpc::METHOD##Reply *reply,           \
                              ::ray::rpc::SendReplyCallback send_reply_callback) = 0;

其中的 ServerCallFactory 是一个虚类

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// server_call.h 100
/// The factory that creates a particular kind of `ServerCall` objects.
class ServerCallFactory {
 public:
  /// Create a new `ServerCall` and request gRPC runtime to start accepting the
  /// corresponding type of requests.
  virtual void CreateCall() const = 0;

  /// Get the maximum request number to handle at the same time. -1 means no limit.
  virtual int64_t GetMaxActiveRPCs() const = 0;

  virtual ~ServerCallFactory() = default;
};

实际使用的实现类是 ServerCallFactoryImpl, 如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// server_call.h 295
/// Implementation of `ServerCallFactory`
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam ServiceHandler Type of the handler that handles the request.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class GrpcService, class ServiceHandler, class Request, class Reply>
class ServerCallFactoryImpl : public ServerCallFactory {...}

// server_call.h 306
  /// Constructor.
  ///
  /// \param[in] service The gRPC-generated `AsyncService`.
  /// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod`
  //  function.
  /// \param[in] service_handler The service handler that handles the request.
  /// \param[in] handle_request_function Pointer to the service handler function.
  /// \param[in] cq The `CompletionQueue`.
  /// \param[in] io_service The event loop.
  /// \param[in] max_active_rpcs Maximum request number to handle at the same time. -1
  /// means no limit.
  ServerCallFactoryImpl(
      AsyncService &service,
      RequestCallFunction<GrpcService, Request, Reply> request_call_function,
      ServiceHandler &service_handler,
      HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
      const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
      instrumented_io_context &io_service, std::string call_name, int64_t max_active_rpcs)
      : service_(service),
        request_call_function_(request_call_function),
        service_handler_(service_handler),
        handle_request_function_(handle_request_function),
        cq_(cq),
        io_service_(io_service),
        call_name_(std::move(call_name)),
        max_active_rpcs_(max_active_rpcs) {}

所以我们可以得到,比如 RPC_SERVICE_HANDLER(CoreWorkerService, PushTask, -1) 经过宏定义后解析的代码为:

1
2
3
4
5
6
7
std::unique_ptr<ServerCallFactory> PushTask_call_factory(                    \
      new ServerCallFactoryImpl<CoreWorkerService, CoreWorkerServiceHandler, PushTaskRequest,    \
                                PushTaskReply>(                                \
          service_, &CoreWorkerService::AsyncService::RequestPushTask, service_handler_, \
          &CoreWorkerServiceHandler::HandlePushTask, cq, main_service_,                \
           "CoreWorkerService" ".grpc_server." "PushTask", MAX_ACTIVE_RPCS));                 \
  server_call_factories->emplace_back(std::move(PushTask_call_factory));

RAY_CORE_WORKER_RPC_HANDLERS 的使用在类 CoreWorkerGrpcService 中

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// core_worker_server.h 93
/// The `GrpcServer` for `CoreWorkerService`.
class CoreWorkerGrpcService : public GrpcService {
 public:
  /// Constructor.
  ///
  /// \param[in] main_service See super class.
  /// \param[in] handler The service handler that actually handle the requests.
  CoreWorkerGrpcService(instrumented_io_context &main_service,
                        CoreWorkerServiceHandler &service_handler)
      : GrpcService(main_service), service_handler_(service_handler) {}

 protected:
  grpc::Service &GetGrpcService() override { return service_; }

  void InitServerCallFactories(
      const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
      std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
    RAY_CORE_WORKER_RPC_HANDLERS
  }

 private:
  /// The grpc async service object.
  CoreWorkerService::AsyncService service_;

  /// The service handler that actually handles the requests.
  CoreWorkerServiceHandler &service_handler_;
};

在我们解析后的代码中,出现的 service_, 定义位于类 CoreWorkerService 中。其实解析代码里面就是做了,声明一个 ServerCallFactoryImpl 类的实例,然后将该实例的 智能指针传递给容器 server_call_factories (std::vector<std::unique_ptr>).

其中构建ServerCallFactoryImpl 类的实例传入的参数,&CoreWorkerService::AsyncService::RequestPushTask , &CoreWorkerServiceHandler::HandlePushTask 分别来至:

CoreWorkerService::AsyncService::RequestPushTask 来自 protocol 编译生成的文件 core_worker.grpc.pb.h

1
2
3
void RequestPushTask(::grpc::ServerContext* context, ::ray::rpc::PushTaskRequest* request, ::grpc::ServerAsyncResponseWriter< ::ray::rpc::PushTaskReply>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) {
  ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag);
}

&CoreWorkerServiceHandler::HandlePushTask 如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
// core_worker_server.h 53
#define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS
  DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask)

  // grpc_server.h 40
  // Define a void RPC client method.
#define DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(METHOD)                   \
  virtual void Handle##METHOD(const ::ray::rpc::METHOD##Request &request, \
                              ::ray::rpc::METHOD##Reply *reply,           \
                              ::ray::rpc::SendReplyCallback send_reply_callback) = 0;
  //宏解析后的代码为:
  virtual void HandlePushTask(const ::ray::rpc::PushTaskRequest &request, \
                              ::ray::rpc::PushTaskReply *reply,           \
                              ::ray::rpc::SendReplyCallback send_reply_callback) = 0;

GrpcServer

GrpcServer 类的定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
// grpc_server.h 48
/// Class that represents an gRPC server.
///
/// A `GrpcServer` listens on a specific port. It owns
/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC,
/// 2) and a thread that polls events from the `ServerCompletionQueue`.
///
/// Subclasses can register one or multiple services to a `GrpcServer`, see
/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide
/// which kinds of requests this server should accept.
class GrpcServer {
 public:
  /// Construct a gRPC server that listens on a TCP port.
  ///
  /// \param[in] name Name of this server, used for logging and debugging purpose.
  /// \param[in] port The port to bind this server to. If it's 0, a random available port
  ///  will be chosen.
  GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only,
             int num_threads = 1,
             int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/);

在类 GrpcServer 的 Run 函数内关于 service 注册关键代码:

1
2
3
4
// grpc_server.cc 91
  for (auto &entry : services_) {
    builder.RegisterService(&entry.get());
  }

builder 为 grpc::ServerBuilder, RegisterService 注册 service implementation, 一般都是继承 *::Service* 实现类的实例。 其中services_ 的声明如下:

1
2
3
// grpc_server.h 117
  /// The `grpc::Service` objects which should be registered to `ServerBuilder`.
  std::vector<std::reference_wrapper<grpc::Service>> services_;

通过 GrpcServer 里的函数 RegisterService 往 services_ 容器内添加 service.

1
2
3
4
5
6
7
8
// grpc_server.cc 137
void GrpcServer::RegisterService(GrpcService &service) {
  services_.emplace_back(service.GetGrpcService());

  for (int i = 0; i < num_threads_; i++) {
    service.InitServerCallFactories(cqs_[i], &server_call_factories_);
  }
}

通过 RPC 模块,将 CoreWorkerService 内包含的函数的实现都变成了 Handle* 例如, HandlePushTask.

updatedupdated2021-12-162021-12-16