saluki_io/net/server/
multiplex_service.rs

1use std::{
2    convert::Infallible,
3    task::{Context, Poll},
4};
5
6use axum::{
7    http::header::CONTENT_TYPE,
8    response::{IntoResponse, Response},
9};
10use futures::{future::BoxFuture, ready};
11use http::Request;
12use hyper::body::Incoming;
13use tower::Service;
14
15/// A [`Service`] that multiplexes requests between two underlying REST-ful and gRPC services.
16///
17/// In some scenarios, it can be useful to expose gRPC services on the same existing REST-ful HTTP endpoints as gRPC
18/// defaults to using HTTP/2, and doing so allows the use of a single exposed endpoint/port, leading to reduced
19/// configuration and complexity.
20///
21/// This service takes two services -- one for REST-ful requests and one for gRPC requests -- and multiplexes incoming
22/// requests between the two by inspecting the request headers, specifically the `Content-Type` header. If the header
23/// starts with `application/grpc`, the request is assumed to be a gRPC request and is forwarded to the gRPC service.
24/// Otherwise, the request is assumed to be a REST-ful request and is forwarded to the REST-ful service.
25pub struct MultiplexService<A, B> {
26    rest: A,
27    rest_ready: bool,
28    grpc: B,
29    grpc_ready: bool,
30}
31
32impl<A, B> MultiplexService<A, B> {
33    /// Creates a new `MultiplexService` from the given REST-ful and gRPC services.
34    pub fn new(rest: A, grpc: B) -> Self {
35        Self {
36            rest,
37            rest_ready: false,
38            grpc,
39            grpc_ready: false,
40        }
41    }
42}
43
44impl<A, B> Clone for MultiplexService<A, B>
45where
46    A: Clone,
47    B: Clone,
48{
49    fn clone(&self) -> Self {
50        Self {
51            rest: self.rest.clone(),
52            grpc: self.grpc.clone(),
53            // Don't assume either service is ready when cloning.
54            rest_ready: false,
55            grpc_ready: false,
56        }
57    }
58}
59
60impl<A, B> Service<Request<Incoming>> for MultiplexService<A, B>
61where
62    A: Service<Request<Incoming>, Error = Infallible>,
63    A::Response: IntoResponse,
64    A::Future: Send + 'static,
65    B: Service<Request<Incoming>>,
66    B::Response: IntoResponse,
67    B::Future: Send + 'static,
68{
69    type Response = Response;
70    type Error = B::Error;
71    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
72
73    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74        // We drive both services to readiness as we need to ensure either service can handle a request before we can
75        // actually accept a call, given that we don't know what type of request we're about to get.
76        loop {
77            match (self.rest_ready, self.grpc_ready) {
78                (true, true) => {
79                    return Ok(()).into();
80                }
81                (false, _) => {
82                    ready!(self.rest.poll_ready(cx)).map_err(|err| match err {})?;
83                    self.rest_ready = true;
84                }
85                (_, false) => {
86                    ready!(self.grpc.poll_ready(cx))?;
87                    self.grpc_ready = true;
88                }
89            }
90        }
91    }
92
93    fn call(&mut self, req: Request<Incoming>) -> Self::Future {
94        assert!(
95            self.grpc_ready,
96            "grpc service not ready. Did you forget to call `poll_ready`?"
97        );
98        assert!(
99            self.rest_ready,
100            "rest service not ready. Did you forget to call `poll_ready`?"
101        );
102
103        // Figure out which service this request should go to, reset that service's readiness, and call it.
104        if is_grpc_request(&req) {
105            self.grpc_ready = false;
106            let future = self.grpc.call(req);
107            Box::pin(async move {
108                let res = future.await?;
109                Ok(res.into_response())
110            })
111        } else {
112            self.rest_ready = false;
113            let future = self.rest.call(req);
114            Box::pin(async move {
115                let res = future.await.map_err(|err| match err {})?;
116                Ok(res.into_response())
117            })
118        }
119    }
120}
121
122fn is_grpc_request<B>(req: &Request<B>) -> bool {
123    // We specifically check if the header value _starts_ with `application/grpc` as the gRPC spec allows for additional
124    // suffixes to describe how the payload is encoded (i.e. `application/grpc+proto` when encoded via Protocol Buffers
125    // vs `application/grpc+json` when encoded via JSON for gRPC-Web).
126    req.headers()
127        .get(CONTENT_TYPE)
128        .map(|content_type| content_type.as_bytes())
129        .filter(|content_type| content_type.starts_with(b"application/grpc"))
130        .is_some()
131}