1use std::{convert::Infallible, error::Error, future::Future};
4
5use axum::Router;
6use http::{Request, Response};
7use rcgen::{generate_simple_self_signed, CertifiedKey};
8use rustls::{pki_types::PrivateKeyDer, ServerConfig};
9use rustls_pki_types::PrivatePkcs8KeyDer;
10use saluki_api::APIHandler;
11use saluki_error::GenericError;
12use saluki_io::net::{
13 listener::ConnectionOrientedListener,
14 server::{http::HttpServer, multiplex_service::MultiplexService},
15 util::hyper::TowerToHyperService,
16 ListenAddress,
17};
18use tokio::select;
19use tonic::{body::Body, server::NamedService, service::RoutesBuilder};
20use tower::Service;
21
22#[derive(Default)]
33pub struct APIBuilder {
34 http_router: Router,
35 grpc_router: RoutesBuilder,
36 tls_config: Option<ServerConfig>,
37}
38
39impl APIBuilder {
40 pub fn new() -> Self {
44 Self {
45 http_router: Router::new(),
46 grpc_router: RoutesBuilder::default(),
47 tls_config: None,
48 }
49 }
50
51 pub fn with_handler<H>(mut self, handler: H) -> Self
55 where
56 H: APIHandler,
57 {
58 let handler_router = handler.generate_routes();
59 let handler_state = handler.generate_initial_state();
60 self.http_router = self.http_router.merge(handler_router.with_state(handler_state));
61
62 self
63 }
64
65 pub fn with_optional_handler<H>(self, handler: Option<H>) -> Self
70 where
71 H: APIHandler,
72 {
73 if let Some(handler) = handler {
74 self.with_handler(handler)
75 } else {
76 self
77 }
78 }
79
80 pub fn with_grpc_service<S>(mut self, svc: S) -> Self
82 where
83 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
84 + NamedService
85 + Clone
86 + Send
87 + Sync
88 + 'static,
89 S::Future: Send + 'static,
90 S::Error: Into<Box<dyn Error + Send + Sync>> + Send,
91 {
92 self.grpc_router.add_service(svc);
93 self
94 }
95
96 pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
102 self.tls_config = Some(config);
103 self
104 }
105
106 pub fn with_self_signed_tls(self) -> Self {
110 let CertifiedKey { cert, key_pair } = generate_simple_self_signed(["localhost".to_owned()]).unwrap();
111 let cert_chain = vec![cert.der().clone()];
112 let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
113
114 let config = ServerConfig::builder()
115 .with_no_client_auth()
116 .with_single_cert(cert_chain, key)
117 .unwrap();
118
119 self.with_tls_config(config)
120 }
121
122 pub async fn serve<F>(self, listen_address: ListenAddress, shutdown: F) -> Result<(), GenericError>
131 where
132 F: Future<Output = ()> + Send + 'static,
133 {
134 let listener = ConnectionOrientedListener::from_listen_address(listen_address).await?;
135
136 let multiplexed_service = TowerToHyperService::new(MultiplexService::new(
139 self.http_router,
140 self.grpc_router.routes().into_axum_router(),
141 ));
142
143 let mut http_server = HttpServer::from_listener(listener, multiplexed_service);
145 if let Some(tls_config) = self.tls_config {
146 http_server = http_server.with_tls_config(tls_config);
147 }
148 let (shutdown_handle, error_handle) = http_server.listen();
149
150 select! {
153 _ = shutdown => shutdown_handle.shutdown(),
154 maybe_err = error_handle => if let Some(e) = maybe_err {
155 return Err(GenericError::from(e))
156 },
157 }
158
159 Ok(())
160 }
161}