saluki_app/
api.rs

1//! API server.
2
3use std::{convert::Infallible, error::Error, future::Future, io::BufReader};
4
5use axum::Router;
6use http::{Request, Response};
7use rcgen::{generate_simple_self_signed, CertifiedKey};
8use rustls::ServerConfig;
9use rustls_pemfile::{certs, pkcs8_private_keys};
10use saluki_api::APIHandler;
11use saluki_error::GenericError;
12use saluki_io::net::{
13    listener::ConnectionOrientedListener,
14    server::{http::HttpServer, multiplex_service::MultiplexService},
15    util::hyper::TowerToHyperService,
16    ListenAddress,
17};
18use tokio::select;
19use tonic::{body::Body, server::NamedService, service::RoutesBuilder};
20use tower::Service;
21
22/// An API builder.
23///
24/// `APIBuilder` provides a simple and ergonomic builder pattern for constructing an API server from multiple handlers.
25/// This allows composing portions of an API from individual building blocks.
26///
27/// ## Missing
28///
29/// - TLS support
30/// - API-wide authentication support (can be added at the per-handler level)
31/// - graceful shutdown (shutdown stops new connections, but does not wait for existing connections to close)
32#[derive(Default)]
33pub struct APIBuilder {
34    http_router: Router,
35    grpc_router: RoutesBuilder,
36    tls_config: Option<ServerConfig>,
37}
38
39impl APIBuilder {
40    /// Create a new `APIBuilder` with an empty router.
41    ///
42    /// A fallback route will be provided that returns a 404 Not Found response for any route that isn't explicitly handled.
43    pub fn new() -> Self {
44        Self {
45            http_router: Router::new(),
46            grpc_router: RoutesBuilder::default(),
47            tls_config: None,
48        }
49    }
50
51    /// Adds the given handler to this builder.
52    ///
53    /// The initial state and routes provided by the handler will be merged into this builder.
54    pub fn with_handler<H>(mut self, handler: H) -> Self
55    where
56        H: APIHandler,
57    {
58        let handler_router = handler.generate_routes();
59        let handler_state = handler.generate_initial_state();
60        self.http_router = self.http_router.merge(handler_router.with_state(handler_state));
61
62        self
63    }
64
65    /// Adds the given optional handler to this builder.
66    ///
67    /// If the handler is `Some`, the initial state and routes provided by the handler will be merged into this builder.
68    /// Otherwise, this builder will be returned unchanged.
69    pub fn with_optional_handler<H>(self, handler: Option<H>) -> Self
70    where
71        H: APIHandler,
72    {
73        if let Some(handler) = handler {
74            self.with_handler(handler)
75        } else {
76            self
77        }
78    }
79
80    /// Add the given gRPC service to this builder.
81    pub fn with_grpc_service<S>(mut self, svc: S) -> Self
82    where
83        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
84            + NamedService
85            + Clone
86            + Send
87            + Sync
88            + 'static,
89        S::Future: Send + 'static,
90        S::Error: Into<Box<dyn Error + Send + Sync>> + Send,
91    {
92        self.grpc_router.add_service(svc);
93        self
94    }
95
96    /// Sets the TLS configuration for the server.
97    ///
98    /// This will enable TLS for the server, and the server will only accept connections that are encrypted with TLS.
99    ///
100    /// Defaults to TLS being disabled.
101    pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
102        self.tls_config = Some(config);
103        self
104    }
105
106    /// Sets the TLS configuration for the server based on a dynamically generated self-signed certificate.
107    ///
108    /// This will enable TLS for the server, and the server will only accept connections that are encrypted with TLS.
109    pub fn with_self_signed_tls(self) -> Self {
110        let CertifiedKey { cert, key_pair } = generate_simple_self_signed(["localhost".to_owned()]).unwrap();
111        let cert_file = cert.pem();
112        let key_file = key_pair.serialize_pem();
113
114        let cert_file = &mut BufReader::new(cert_file.as_bytes());
115        let key_file = &mut BufReader::new(key_file.as_bytes());
116
117        let cert_chain = certs(cert_file).collect::<Result<Vec<_>, _>>().unwrap();
118        let mut keys = pkcs8_private_keys(key_file).collect::<Result<Vec<_>, _>>().unwrap();
119
120        let config = ServerConfig::builder()
121            .with_no_client_auth()
122            .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
123            .unwrap();
124
125        self.with_tls_config(config)
126    }
127
128    /// Serves the API on the given listen address until `shutdown` resolves.
129    ///
130    /// The listen address must be a connection-oriented address (TCP or Unix domain socket in SOCK_STREAM mode).
131    ///
132    /// # Errors
133    ///
134    /// If the given listen address is not connection-oriented, or if the server fails to bind to the address, or if
135    /// there is an error while accepting for new connections, an error will be returned.
136    pub async fn serve<F>(self, listen_address: ListenAddress, shutdown: F) -> Result<(), GenericError>
137    where
138        F: Future<Output = ()> + Send + 'static,
139    {
140        let listener = ConnectionOrientedListener::from_listen_address(listen_address).await?;
141
142        // Wrap up our HTTP and gRPC routers in a multiplexed service, allowing us to handle both types of requests on
143        // the same port. Additionally, we have to wrap the service to translate from `tower::Service` to `hyper::Service`.
144        let multiplexed_service = TowerToHyperService::new(MultiplexService::new(
145            self.http_router,
146            self.grpc_router.routes().into_axum_router(),
147        ));
148
149        // Create and spawn the HTTP server.
150        let mut http_server = HttpServer::from_listener(listener, multiplexed_service);
151        if let Some(tls_config) = self.tls_config {
152            http_server = http_server.with_tls_config(tls_config);
153        }
154        let (shutdown_handle, error_handle) = http_server.listen();
155
156        // Wait for our shutdown signal, which we'll forward to the listener to stop accepting new connections... or
157        // capture any errors thrown by the listener itself.
158        select! {
159            _ = shutdown =>  shutdown_handle.shutdown(),
160            maybe_err = error_handle => if let Some(e) = maybe_err {
161                return Err(GenericError::from(e))
162            },
163        }
164
165        Ok(())
166    }
167}