1use std::{convert::Infallible, error::Error, future::Future, io::BufReader};
4
5use axum::Router;
6use http::{Request, Response};
7use rcgen::{generate_simple_self_signed, CertifiedKey};
8use rustls::ServerConfig;
9use rustls_pemfile::{certs, pkcs8_private_keys};
10use saluki_api::APIHandler;
11use saluki_error::GenericError;
12use saluki_io::net::{
13 listener::ConnectionOrientedListener,
14 server::{http::HttpServer, multiplex_service::MultiplexService},
15 util::hyper::TowerToHyperService,
16 ListenAddress,
17};
18use tokio::select;
19use tonic::{body::Body, server::NamedService, service::RoutesBuilder};
20use tower::Service;
21
22#[derive(Default)]
33pub struct APIBuilder {
34 http_router: Router,
35 grpc_router: RoutesBuilder,
36 tls_config: Option<ServerConfig>,
37}
38
39impl APIBuilder {
40 pub fn new() -> Self {
44 Self {
45 http_router: Router::new(),
46 grpc_router: RoutesBuilder::default(),
47 tls_config: None,
48 }
49 }
50
51 pub fn with_handler<H>(mut self, handler: H) -> Self
55 where
56 H: APIHandler,
57 {
58 let handler_router = handler.generate_routes();
59 let handler_state = handler.generate_initial_state();
60 self.http_router = self.http_router.merge(handler_router.with_state(handler_state));
61
62 self
63 }
64
65 pub fn with_optional_handler<H>(self, handler: Option<H>) -> Self
70 where
71 H: APIHandler,
72 {
73 if let Some(handler) = handler {
74 self.with_handler(handler)
75 } else {
76 self
77 }
78 }
79
80 pub fn with_grpc_service<S>(mut self, svc: S) -> Self
82 where
83 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
84 + NamedService
85 + Clone
86 + Send
87 + Sync
88 + 'static,
89 S::Future: Send + 'static,
90 S::Error: Into<Box<dyn Error + Send + Sync>> + Send,
91 {
92 self.grpc_router.add_service(svc);
93 self
94 }
95
96 pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
102 self.tls_config = Some(config);
103 self
104 }
105
106 pub fn with_self_signed_tls(self) -> Self {
110 let CertifiedKey { cert, key_pair } = generate_simple_self_signed(["localhost".to_owned()]).unwrap();
111 let cert_file = cert.pem();
112 let key_file = key_pair.serialize_pem();
113
114 let cert_file = &mut BufReader::new(cert_file.as_bytes());
115 let key_file = &mut BufReader::new(key_file.as_bytes());
116
117 let cert_chain = certs(cert_file).collect::<Result<Vec<_>, _>>().unwrap();
118 let mut keys = pkcs8_private_keys(key_file).collect::<Result<Vec<_>, _>>().unwrap();
119
120 let config = ServerConfig::builder()
121 .with_no_client_auth()
122 .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
123 .unwrap();
124
125 self.with_tls_config(config)
126 }
127
128 pub async fn serve<F>(self, listen_address: ListenAddress, shutdown: F) -> Result<(), GenericError>
137 where
138 F: Future<Output = ()> + Send + 'static,
139 {
140 let listener = ConnectionOrientedListener::from_listen_address(listen_address).await?;
141
142 let multiplexed_service = TowerToHyperService::new(MultiplexService::new(
145 self.http_router,
146 self.grpc_router.routes().into_axum_router(),
147 ));
148
149 let mut http_server = HttpServer::from_listener(listener, multiplexed_service);
151 if let Some(tls_config) = self.tls_config {
152 http_server = http_server.with_tls_config(tls_config);
153 }
154 let (shutdown_handle, error_handle) = http_server.listen();
155
156 select! {
159 _ = shutdown => shutdown_handle.shutdown(),
160 maybe_err = error_handle => if let Some(e) = maybe_err {
161 return Err(GenericError::from(e))
162 },
163 }
164
165 Ok(())
166 }
167}