abridged/src/supervisor.rs

60 lines
1.7 KiB
Rust

use std::{any::Any, time::Duration};
use std::error::Error;
use std::panic::AssertUnwindSafe;
use futures::FutureExt;
use async_trait::async_trait;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use log::{warn, info, error};
use crate::message::{Sender, Receiver, Id};
pub type TaskResult = Result<(), Box<dyn Error>>;
#[async_trait]
pub trait Task {
async fn start(&self, origin: Id, tx: Sender, rc: Receiver) -> TaskResult;
fn restart_timeout(&self) -> Option<Duration> {
Some(Duration::from_secs(15))
}
}
enum ExitStatus {
Success,
Error(Box<dyn Error>),
Panic(Box<dyn Any + 'static>),
}
async fn start_task(id: usize, task: &dyn Task, tx: Sender, timeout: Duration) -> (usize, ExitStatus) {
tokio::time::sleep(timeout).await;
let rx = tx.subscribe();
let future = AssertUnwindSafe(task.start(Id::new(id), tx, rx)).catch_unwind();
let result = match future.await {
Ok(Ok(_)) => ExitStatus::Success,
Ok(Err(e)) => ExitStatus::Error(e),
Err(e) => ExitStatus::Panic(e),
};
(id, result)
}
pub async fn run_tasks(tasks: Vec<Box<dyn Task>>) {
let mut futures = FuturesUnordered::new();
let (tx, _) = tokio::sync::broadcast::channel(64);
for (id, task) in tasks.iter().enumerate() {
futures.push(start_task(id, task.as_ref(), tx.clone(), Duration::ZERO));
}
while let Some((id, result)) = futures.next().await {
let task = &tasks[id];
match &result {
ExitStatus::Success => warn!("task {id:?} exited successfully"),
ExitStatus::Error(e) => warn!("task {id:?}: exited with error: {e}"),
ExitStatus::Panic(_) => error!("task {id:?}: panicked"),
}
if let Some(dur) = task.restart_timeout() {
info!("task {id:?}: retrying in {dur:?}");
futures.push(start_task(id, task.as_ref(), tx.clone(), dur));
}
}
}