Skip to main content

saluki_io/net/server/
http.rs

1//! Basic HTTP server.
2
3use std::{
4    future::Future,
5    pin::Pin,
6    sync::Arc,
7    task::{ready, Context, Poll},
8};
9
10use http::{Request, Response};
11use http_body::Body;
12use hyper::{body::Incoming, service::Service};
13use hyper_util::{
14    rt::{TokioExecutor, TokioIo},
15    server::conn::auto::Builder,
16};
17use rustls::ServerConfig;
18use saluki_common::{
19    sync::shutdown::{ShutdownCoordinator, ShutdownHandle},
20    task::{spawn_traced_named, HandleExt as _},
21};
22use saluki_error::GenericError;
23use tokio::{pin, runtime::Handle, select, sync::oneshot};
24use tokio_rustls::TlsAcceptor;
25use tracing::{debug, error, info};
26
27use crate::net::listener::ConnectionOrientedListener;
28
29/// An HTTP server.
30pub struct HttpServer<S> {
31    executor: Handle,
32    listener: ConnectionOrientedListener,
33    conn_builder: Builder<TokioExecutor>,
34    service: S,
35    tls_config: Option<ServerConfig>,
36}
37
38impl<S> HttpServer<S> {
39    /// Creates a new `HttpServer` from the given listener and service.
40    ///
41    /// # Panics
42    ///
43    /// This will panic if called outside the context of a Tokio runtime.
44    pub fn from_listener(listener: ConnectionOrientedListener, service: S) -> Self {
45        Self {
46            executor: Handle::current(),
47            listener,
48            conn_builder: Builder::new(TokioExecutor::new()),
49            service,
50            tls_config: None,
51        }
52    }
53
54    /// Sets the TLS configuration for the server.
55    ///
56    /// This will enable TLS for the server, and the server will only accept connections that are encrypted with TLS.
57    ///
58    /// Defaults to TLS being disabled.
59    pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
60        self.tls_config = Some(config);
61        self
62    }
63
64    /// Sets the executor for the server.
65    ///
66    /// This executor will be used for spawning tasks to handle incoming connections, but _not_ for the spawn that accepts
67    /// new connections.
68    ///
69    /// Defaults to the current Tokio runtime at the time `HttpServer::new` is called.
70    pub fn with_executor(mut self, executor: Handle) -> Self {
71        self.executor = executor;
72        self
73    }
74}
75
76impl<S, B> HttpServer<S>
77where
78    S: Service<Request<Incoming>, Response = Response<B>> + Send + Clone + 'static,
79    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
80    S::Future: Send + 'static,
81    B: Body + Send + 'static,
82    B::Data: Send,
83    B::Error: std::error::Error + Send + Sync,
84{
85    /// Starts the server and listens for incoming connections.
86    ///
87    /// Returns two handles: one for shutting down the server, and one for receiving any errors that occur while the
88    /// server is running.
89    pub fn listen(self) -> (ShutdownCoordinator, ErrorHandle) {
90        let (shutdown_coordinator, shutdown) = ShutdownHandle::paired();
91        let (error_tx, error_rx) = oneshot::channel();
92
93        let Self {
94            executor,
95            mut listener,
96            conn_builder,
97            service,
98            tls_config,
99            ..
100        } = self;
101
102        spawn_traced_named("http-server-acceptor", async move {
103            let tls_enabled = tls_config.is_some();
104            let maybe_tls_acceptor = tls_config.map(|mut config| {
105                // Allow for HTTP/1.1 and HTTP/2.
106                config.alpn_protocols.push(b"h2".to_vec());
107                config.alpn_protocols.push(b"http/1.1".to_vec());
108                TlsAcceptor::from(Arc::new(config))
109            });
110
111            info!(listen_addr = %listener.listen_address(), tls_enabled, "HTTP server started.");
112
113            pin!(shutdown);
114
115            loop {
116                select! {
117                    result = listener.accept() => match result {
118                        Ok(stream) => {
119                            let service = service.clone();
120                            let conn_builder = conn_builder.clone();
121                            let listen_addr = listener.listen_address().clone();
122                            match &maybe_tls_acceptor {
123                                Some(acceptor) => {
124                                    let tls_stream = match acceptor.accept(stream).await {
125                                        Ok(stream) => stream,
126                                        Err(e) => {
127                                            error!(%listen_addr, error = %e, "Failed to complete TLS handshake.");
128                                            continue
129                                        },
130                                    };
131
132                                    executor.spawn_traced_named("http-server-tls-conn-handler", async move {
133                                        if let Err(e) = conn_builder.serve_connection(TokioIo::new(tls_stream), service).await {
134                                            error!(%listen_addr, error = %e, "Failed to serve HTTP connection.");
135                                        }
136                                    });
137                                },
138                                None => {
139                                    executor.spawn_traced_named("http-server-conn-handler", async move {
140                                        if let Err(e) = conn_builder.serve_connection(TokioIo::new(stream), service).await {
141                                            error!(%listen_addr, error = %e, "Failed to serve HTTP connection.");
142                                        }
143                                    });
144                                },
145                            }
146                        },
147                        Err(e) => {
148                            let _ = error_tx.send(e.into());
149                            break;
150                        }
151                    },
152
153                    _ = &mut shutdown => {
154                        debug!(listen_addr = %listener.listen_address(), "Received shutdown signal.");
155                        break;
156                    }
157                }
158            }
159
160            info!(listen_addr = %listener.listen_address(), "HTTP server stopped.");
161        });
162
163        (shutdown_coordinator, ErrorHandle(error_rx))
164    }
165}
166
167/// A future that resolves when [`HttpServer`] encounters an unrecoverable error.
168pub struct ErrorHandle(oneshot::Receiver<GenericError>);
169
170impl Future for ErrorHandle {
171    type Output = Option<GenericError>;
172
173    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174        match ready!(Pin::new(&mut self.0).poll(cx)) {
175            Ok(err) => Poll::Ready(Some(err)),
176            Err(_) => Poll::Ready(None),
177        }
178    }
179}