1use std::{
8 convert::Infallible,
9 future::Future,
10 net::SocketAddr,
11 panic::{catch_unwind, AssertUnwindSafe},
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15};
16
17use arc_swap::ArcSwap;
18use async_trait::async_trait;
19use axum::{body::Body as AxumBody, Router};
20use http::Response;
21use rcgen::{generate_simple_self_signed, CertifiedKey};
22use rustls::{pki_types::PrivateKeyDer, ServerConfig};
23use rustls_pki_types::PrivatePkcs8KeyDer;
24use saluki_api::{DynamicRoute, EndpointProtocol, EndpointType};
25use saluki_common::collections::FastIndexMap;
26use saluki_core::runtime::{
27 state::{AssertionUpdate, DataspaceRegistry, Identifier, IdentifierFilter, Subscription},
28 InitializationError, ProcessShutdown, Supervisable, SupervisorFuture,
29};
30use saluki_error::{generic_error, GenericError};
31use saluki_io::net::{
32 listener::ConnectionOrientedListener,
33 server::{
34 http::{ErrorHandle, HttpServer, ShutdownHandle},
35 multiplex_service::MultiplexService,
36 },
37 util::hyper::TowerToHyperService,
38 ListenAddress,
39};
40use tokio::{pin, select};
41use tower::Service;
42use tracing::{debug, info, warn};
43
44#[derive(Clone, Debug)]
48pub struct BoundApiAddress(pub SocketAddr);
49
50pub struct DynamicAPIBuilder {
69 endpoint_type: EndpointType,
70 listen_address: ListenAddress,
71 tls_config: Option<ServerConfig>,
72}
73
74impl DynamicAPIBuilder {
75 pub fn new(endpoint_type: EndpointType, listen_address: ListenAddress) -> Self {
77 Self {
78 endpoint_type,
79 listen_address,
80 tls_config: None,
81 }
82 }
83
84 pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
86 self.tls_config = Some(config);
87 self
88 }
89
90 pub fn with_self_signed_tls(self) -> Self {
92 let CertifiedKey { cert, signing_key } = generate_simple_self_signed(["localhost".to_owned()]).unwrap();
93 let cert_chain = vec![cert.der().clone()];
94 let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(signing_key.serialize_der()));
95
96 let config = ServerConfig::builder()
97 .with_no_client_auth()
98 .with_single_cert(cert_chain, key)
99 .unwrap();
100
101 self.with_tls_config(config)
102 }
103}
104
105#[async_trait]
106impl Supervisable for DynamicAPIBuilder {
107 fn name(&self) -> &str {
108 match self.endpoint_type {
109 EndpointType::Unprivileged => "dynamic-unprivileged-api",
110 EndpointType::Privileged => "dynamic-privileged-api",
111 }
112 }
113
114 async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError> {
115 let (inner_http, outer_http) = create_dynamic_router();
117 let (inner_grpc, outer_grpc) = create_dynamic_router();
118
119 let dataspace = DataspaceRegistry::try_current().ok_or_else(|| generic_error!("Dataspace not available."))?;
120
121 let route_assertions = dataspace.subscribe::<DynamicRoute>(IdentifierFilter::All);
123
124 let listener = ConnectionOrientedListener::from_listen_address(self.listen_address.clone())
126 .await
127 .map_err(|e| InitializationError::Failed { source: e.into() })?;
128
129 let bound_addr = listener
131 .local_addr()
132 .map_err(|e| InitializationError::Failed { source: e.into() })?;
133 dataspace.assert(BoundApiAddress(bound_addr), Identifier::named(self.name()));
134
135 let multiplexed_service = TowerToHyperService::new(MultiplexService::new(outer_http, outer_grpc));
136
137 let mut http_server = HttpServer::from_listener(listener, multiplexed_service);
138 if let Some(tls_config) = self.tls_config.clone() {
139 http_server = http_server.with_tls_config(tls_config);
140 }
141 let (shutdown_handle, error_handle) = http_server.listen();
142
143 let endpoint_type = self.endpoint_type;
144 let listen_address = self.listen_address.clone();
145
146 Ok(Box::pin(async move {
147 info!("Serving {} API on {}.", endpoint_type.name(), listen_address);
148
149 run_event_loop(
150 inner_http,
151 inner_grpc,
152 route_assertions,
153 endpoint_type,
154 process_shutdown,
155 shutdown_handle,
156 error_handle,
157 )
158 .await
159 }))
160 }
161}
162
163#[derive(Clone)]
170struct DynamicRouterService {
171 inner_router: Arc<ArcSwap<Router>>,
172}
173
174impl DynamicRouterService {
175 fn from_inner(inner_router: &Arc<ArcSwap<Router>>) -> Self {
176 Self {
177 inner_router: Arc::clone(inner_router),
178 }
179 }
180}
181
182impl Service<http::Request<AxumBody>> for DynamicRouterService {
183 type Response = Response<AxumBody>;
184 type Error = Infallible;
185 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
186
187 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188 Poll::Ready(Ok(()))
189 }
190
191 fn call(&mut self, request: http::Request<AxumBody>) -> Self::Future {
192 let mut router = Arc::unwrap_or_clone(self.inner_router.load_full());
193 Box::pin(async move { router.call(request).await })
194 }
195}
196
197async fn run_event_loop(
199 inner_http: Arc<ArcSwap<Router>>, inner_grpc: Arc<ArcSwap<Router>>,
200 mut route_assertions: Subscription<DynamicRoute>, endpoint_type: EndpointType,
201 mut process_shutdown: ProcessShutdown, shutdown_handle: ShutdownHandle, error_handle: ErrorHandle,
202) -> Result<(), GenericError> {
203 let mut http_handlers = FastIndexMap::default();
204 let mut grpc_handlers = FastIndexMap::default();
205
206 let shutdown = process_shutdown.wait_for_shutdown();
207 pin!(shutdown);
208 pin!(error_handle);
209
210 loop {
211 select! {
212 _ = &mut shutdown => {
213 debug!("Dynamic API shutting down.");
214 shutdown_handle.shutdown();
215 break;
216 }
217
218 maybe_err = &mut error_handle => {
219 if let Some(e) = maybe_err {
220 return Err(GenericError::from(e));
221 }
222 break;
223 }
224
225 maybe_update = route_assertions.recv() => {
226 let Some(update) = maybe_update else {
227 warn!("Route subscription channel closed.");
228 break;
229 };
230
231 let mut rebuild_http = false;
232 let mut rebuild_grpc = false;
233
234 match update {
235 AssertionUpdate::Asserted(id, route) => {
236 if route.endpoint_type() != endpoint_type {
237 continue;
238 }
239
240 match route.endpoint_protocol() {
241 EndpointProtocol::Http => {
242 debug!(?id, "Registering dynamic HTTP handler.");
243 http_handlers.insert(id, route.into_router());
244
245 rebuild_http = true;
246 },
247 EndpointProtocol::Grpc => {
248 debug!(?id, "Registering dynamic gRPC handler.");
249 grpc_handlers.insert(id, route.into_router());
250
251 rebuild_grpc = true;
252 },
253 }
254 }
255 AssertionUpdate::Retracted(id) => {
256 if http_handlers.swap_remove(&id).is_some() {
257 debug!(?id, "Withdrawing dynamic HTTP handler.");
258 rebuild_http = true;
259 }
260
261 if grpc_handlers.swap_remove(&id).is_some() {
262 debug!(?id, "Withdrawing dynamic gRPC handler.");
263 rebuild_grpc = true;
264 }
265 }
266 }
267
268 if rebuild_http {
269 rebuild_router(&inner_http, &http_handlers);
270 }
271
272 if rebuild_grpc {
273 rebuild_router(&inner_grpc, &grpc_handlers);
274 }
275 }
276 }
277 }
278
279 Ok(())
280}
281
282fn create_dynamic_router() -> (Arc<ArcSwap<Router>>, Router) {
284 let inner = Arc::new(ArcSwap::from_pointee(Router::new()));
285 let outer = Router::new().fallback_service(DynamicRouterService::from_inner(&inner));
286 (inner, outer)
287}
288
289fn try_merge_router(base: &Router, id: &Identifier, other: &Router) -> Result<Router, String> {
306 let candidate = base.clone();
307 match catch_unwind(AssertUnwindSafe(|| candidate.merge(other.clone()))) {
308 Ok(merged) => Ok(merged),
309 Err(payload) => {
310 let reason = payload
311 .downcast_ref::<String>()
312 .map(|s| s.as_str())
313 .or_else(|| payload.downcast_ref::<&str>().copied())
314 .unwrap_or("unknown");
315 Err(format!("failed to merge dynamic handler {id:?}: {reason}"))
316 }
317 }
318}
319
320fn rebuild_router(inner_router: &Arc<ArcSwap<Router>>, handlers: &FastIndexMap<Identifier, Router>) {
322 let mut merged = Router::new();
323 let mut skipped = 0usize;
324
325 for (id, router) in handlers.iter() {
326 match try_merge_router(&merged, id, router) {
327 Ok(new_merged) => merged = new_merged,
328 Err(reason) => {
329 warn!(%reason, "Skipping dynamic handler due to overlapping route.");
330 skipped += 1;
331 }
332 }
333 }
334
335 inner_router.store(Arc::new(merged));
336 debug!(handler_count = handlers.len(), skipped, "Rebuilt inner router.");
337}
338
339#[cfg(test)]
340mod tests {
341 use std::{net::SocketAddr, time::Duration};
342
343 use async_trait::async_trait;
344 use axum::Router;
345 use http_body_util::{BodyExt as _, Empty};
346 use hyper::{body::Bytes, StatusCode};
347 use hyper_util::{client::legacy::Client, rt::TokioExecutor};
348 use saluki_api::{APIHandler, DynamicRoute, EndpointType};
349 use saluki_core::runtime::{
350 state::{AssertionUpdate, DataspaceRegistry, Identifier, IdentifierFilter},
351 InitializationError, ProcessShutdown, Supervisable, Supervisor, SupervisorFuture,
352 };
353 use tokio::{
354 sync::{mpsc, oneshot},
355 task::JoinHandle,
356 time::{sleep, timeout, Instant},
357 };
358
359 use super::*;
360
361 struct SimpleHandler {
364 path: &'static str,
365 body: &'static str,
366 }
367
368 impl APIHandler for SimpleHandler {
369 type State = ();
370
371 fn generate_initial_state(&self) -> Self::State {}
372
373 fn generate_routes(&self) -> Router<Self::State> {
374 let body = self.body;
375 Router::new().route(self.path, axum::routing::get(move || async move { body }))
376 }
377 }
378
379 enum RouteCommand {
380 Assert { id: Identifier, route: DynamicRoute },
381 Retract { id: Identifier },
382 }
383
384 struct RouteAsserter {
385 commands_rx: std::sync::Mutex<Option<mpsc::Receiver<RouteCommand>>>,
386 addr_tx: std::sync::Mutex<Option<oneshot::Sender<SocketAddr>>>,
387 endpoint_type: EndpointType,
388 }
389
390 #[async_trait]
391 impl Supervisable for RouteAsserter {
392 fn name(&self) -> &str {
393 "route-asserter"
394 }
395
396 async fn initialize(
397 &self, mut process_shutdown: ProcessShutdown,
398 ) -> Result<SupervisorFuture, InitializationError> {
399 let mut commands_rx =
400 self.commands_rx
401 .lock()
402 .unwrap()
403 .take()
404 .ok_or_else(|| InitializationError::Failed {
405 source: generic_error!("RouteAsserter can only be initialized once"),
406 })?;
407 let addr_tx = self.addr_tx.lock().unwrap().take();
408 let endpoint_type = self.endpoint_type;
409
410 Ok(Box::pin(async move {
411 let dataspace =
412 DataspaceRegistry::try_current().ok_or_else(|| generic_error!("Dataspace not available."))?;
413
414 let bound_addr_name = match endpoint_type {
416 EndpointType::Unprivileged => "dynamic-unprivileged-api",
417 EndpointType::Privileged => "dynamic-privileged-api",
418 };
419 let mut addr_sub =
420 dataspace.subscribe::<BoundApiAddress>(IdentifierFilter::exact(Identifier::named(bound_addr_name)));
421
422 let addr = match addr_sub.recv().await {
423 Some(AssertionUpdate::Asserted(_, BoundApiAddress(mut addr))) => {
424 if addr.ip().is_unspecified() {
426 addr.set_ip(std::net::Ipv4Addr::LOCALHOST.into());
427 }
428 addr
429 }
430 other => return Err(generic_error!("unexpected bound address update: {:?}", other)),
431 };
432
433 if let Some(tx) = addr_tx {
434 let _ = tx.send(addr);
435 }
436
437 let shutdown = process_shutdown.wait_for_shutdown();
439 tokio::pin!(shutdown);
440
441 loop {
442 tokio::select! {
443 _ = &mut shutdown => break,
444 cmd = commands_rx.recv() => {
445 let Some(cmd) = cmd else { break };
446 match cmd {
447 RouteCommand::Assert { id, route } => {
448 dataspace.assert(route, id);
449 }
450 RouteCommand::Retract { id } => {
451 dataspace.retract::<DynamicRoute>(id);
452 }
453 }
454 }
455 }
456 }
457
458 Ok(())
459 }))
460 }
461 }
462
463 struct TestHarness {
464 addr: SocketAddr,
465 commands: mpsc::Sender<RouteCommand>,
466 _shutdown: oneshot::Sender<()>,
467 _handle: JoinHandle<()>,
468 }
469
470 impl TestHarness {
471 async fn assert_route(&self, id: impl Into<Identifier>, route: DynamicRoute) {
472 self.commands
473 .send(RouteCommand::Assert { id: id.into(), route })
474 .await
475 .unwrap();
476 }
477
478 async fn retract_route(&self, id: impl Into<Identifier>) {
479 self.commands
480 .send(RouteCommand::Retract { id: id.into() })
481 .await
482 .unwrap();
483 }
484 }
485
486 async fn setup_test_harness(endpoint_type: EndpointType) -> TestHarness {
487 let (commands_tx, commands_rx) = mpsc::channel(16);
488 let (addr_tx, addr_rx) = oneshot::channel();
489
490 let api_builder = DynamicAPIBuilder::new(endpoint_type, ListenAddress::any_tcp(0));
491 let route_asserter = RouteAsserter {
492 commands_rx: std::sync::Mutex::new(Some(commands_rx)),
493 addr_tx: std::sync::Mutex::new(Some(addr_tx)),
494 endpoint_type,
495 };
496
497 let mut sup = Supervisor::new("test-dynamic-api").unwrap();
498 sup.add_worker(api_builder);
499 sup.add_worker(route_asserter);
500
501 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
502 let handle = tokio::spawn(async move {
503 let _ = sup.run_with_shutdown(shutdown_rx).await;
504 });
505
506 let addr = timeout(Duration::from_secs(5), addr_rx)
507 .await
508 .expect("timed out waiting for bound address")
509 .expect("addr channel closed");
510
511 TestHarness {
512 addr,
513 commands: commands_tx,
514 _shutdown: shutdown_tx,
515 _handle: handle,
516 }
517 }
518
519 async fn http_get(addr: SocketAddr, path: &str) -> (StatusCode, String) {
520 let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build_http();
521 let uri = format!("http://{}{}", addr, path);
522 let resp = client.get(uri.parse().unwrap()).await.unwrap();
523 let status = resp.status();
524 let body = resp.into_body().collect().await.unwrap().to_bytes();
525 let body_str = String::from_utf8_lossy(&body).into_owned();
526 (status, body_str)
527 }
528
529 async fn assert_status_eventually(addr: SocketAddr, path: &str, expected: StatusCode) -> String {
530 let deadline = Instant::now() + Duration::from_secs(2);
531 loop {
532 let (status, body) = http_get(addr, path).await;
533 if status == expected {
534 return body;
535 }
536 if Instant::now() > deadline {
537 panic!("expected {} for {} but got {}", expected, path, status);
538 }
539 sleep(Duration::from_millis(50)).await;
540 }
541 }
542
543 #[tokio::test]
546 async fn serves_asserted_http_route() {
547 let harness = setup_test_harness(EndpointType::Unprivileged).await;
548
549 let route = DynamicRoute::http(
550 EndpointType::Unprivileged,
551 SimpleHandler {
552 path: "/health",
553 body: "ok",
554 },
555 );
556 harness.assert_route("health", route).await;
557
558 let body = assert_status_eventually(harness.addr, "/health", StatusCode::OK).await;
559 assert_eq!(body, "ok");
560 }
561
562 #[tokio::test]
563 async fn returns_404_for_unknown_route() {
564 let harness = setup_test_harness(EndpointType::Unprivileged).await;
565 let (status, _) = http_get(harness.addr, "/nonexistent").await;
566 assert_eq!(status, StatusCode::NOT_FOUND);
567 }
568
569 #[tokio::test]
570 async fn route_retraction_removes_route() {
571 let harness = setup_test_harness(EndpointType::Unprivileged).await;
572
573 let route = DynamicRoute::http(
574 EndpointType::Unprivileged,
575 SimpleHandler {
576 path: "/temp",
577 body: "temporary",
578 },
579 );
580 harness.assert_route("temp", route).await;
581 assert_status_eventually(harness.addr, "/temp", StatusCode::OK).await;
582
583 harness.retract_route("temp").await;
584 assert_status_eventually(harness.addr, "/temp", StatusCode::NOT_FOUND).await;
585 }
586
587 #[tokio::test]
588 async fn multiple_routes_independent_lifecycle() {
589 let harness = setup_test_harness(EndpointType::Unprivileged).await;
590
591 let route_a = DynamicRoute::http(
592 EndpointType::Unprivileged,
593 SimpleHandler {
594 path: "/a",
595 body: "alpha",
596 },
597 );
598 let route_b = DynamicRoute::http(
599 EndpointType::Unprivileged,
600 SimpleHandler {
601 path: "/b",
602 body: "bravo",
603 },
604 );
605 harness.assert_route("a", route_a).await;
606 harness.assert_route("b", route_b).await;
607
608 assert_status_eventually(harness.addr, "/a", StatusCode::OK).await;
609 assert_status_eventually(harness.addr, "/b", StatusCode::OK).await;
610
611 harness.retract_route("a").await;
613 assert_status_eventually(harness.addr, "/a", StatusCode::NOT_FOUND).await;
614
615 let body = assert_status_eventually(harness.addr, "/b", StatusCode::OK).await;
616 assert_eq!(body, "bravo");
617 }
618
619 #[tokio::test]
620 async fn ignores_routes_for_different_endpoint_type() {
621 let harness = setup_test_harness(EndpointType::Unprivileged).await;
622
623 let wrong_route = DynamicRoute::http(
625 EndpointType::Privileged,
626 SimpleHandler {
627 path: "/secret",
628 body: "secret",
629 },
630 );
631 harness.assert_route("secret", wrong_route).await;
632
633 let (status, _) = http_get(harness.addr, "/secret").await;
634 assert_eq!(status, StatusCode::NOT_FOUND);
635
636 let right_route = DynamicRoute::http(
638 EndpointType::Unprivileged,
639 SimpleHandler {
640 path: "/secret",
641 body: "not secret",
642 },
643 );
644 harness.assert_route("secret-unpriv", right_route).await;
645
646 let body = assert_status_eventually(harness.addr, "/secret", StatusCode::OK).await;
647 assert_eq!(body, "not secret");
648 }
649
650 #[tokio::test]
651 async fn overlapping_routes_do_not_crash_server() {
652 let harness = setup_test_harness(EndpointType::Unprivileged).await;
653
654 let route_1 = DynamicRoute::http(
656 EndpointType::Unprivileged,
657 SimpleHandler {
658 path: "/health",
659 body: "health-1",
660 },
661 );
662 harness.assert_route("health-1", route_1).await;
663 let body = assert_status_eventually(harness.addr, "/health", StatusCode::OK).await;
664 assert_eq!(body, "health-1");
665
666 let route_2 = DynamicRoute::http(
669 EndpointType::Unprivileged,
670 SimpleHandler {
671 path: "/health",
672 body: "health-2",
673 },
674 );
675 harness.assert_route("health-2", route_2).await;
676
677 sleep(Duration::from_millis(200)).await;
679
680 let (status, body) = http_get(harness.addr, "/health").await;
682 assert_eq!(status, StatusCode::OK);
683 assert_eq!(body, "health-1");
684
685 let route_info = DynamicRoute::http(
687 EndpointType::Unprivileged,
688 SimpleHandler {
689 path: "/info",
690 body: "info",
691 },
692 );
693 harness.assert_route("info", route_info).await;
694 let body = assert_status_eventually(harness.addr, "/info", StatusCode::OK).await;
695 assert_eq!(body, "info");
696
697 harness.retract_route("health-1").await;
700 let body = assert_status_eventually(harness.addr, "/health", StatusCode::OK).await;
701 assert_eq!(body, "health-2");
702 }
703
704 #[tokio::test]
705 async fn overlapping_route_retraction_then_reassertion() {
706 let harness = setup_test_harness(EndpointType::Unprivileged).await;
707
708 let route_a = DynamicRoute::http(
710 EndpointType::Unprivileged,
711 SimpleHandler {
712 path: "/overlap",
713 body: "a",
714 },
715 );
716 let route_b = DynamicRoute::http(
717 EndpointType::Unprivileged,
718 SimpleHandler {
719 path: "/overlap",
720 body: "b",
721 },
722 );
723 harness.assert_route("ov-a", route_a).await;
724 harness.assert_route("ov-b", route_b).await;
725
726 let body = assert_status_eventually(harness.addr, "/overlap", StatusCode::OK).await;
728 assert_eq!(body, "a");
729
730 harness.retract_route("ov-a").await;
732 harness.retract_route("ov-b").await;
733 assert_status_eventually(harness.addr, "/overlap", StatusCode::NOT_FOUND).await;
734
735 let route_c = DynamicRoute::http(
737 EndpointType::Unprivileged,
738 SimpleHandler {
739 path: "/overlap",
740 body: "c",
741 },
742 );
743 harness.assert_route("ov-c", route_c).await;
744 let body = assert_status_eventually(harness.addr, "/overlap", StatusCode::OK).await;
745 assert_eq!(body, "c");
746 }
747}