diff --git a/src/storage_mapping.rs b/src/storage_mapping.rs index c094e56..9eb9c72 100644 --- a/src/storage_mapping.rs +++ b/src/storage_mapping.rs @@ -13,7 +13,7 @@ use log::debug; use rand::{thread_rng, Rng}; use sqlx::any::AnyConnectOptions; use sqlx::{query_as, Any, AnyPool, FromRow}; -use std::time::Instant; +use std::{sync::RwLock, time::Instant}; use tokio::time::Duration; #[derive(Debug, Clone, FromRow)] @@ -27,6 +27,7 @@ pub struct UserStorageAccess { struct CachedAccess { access: Vec, valid_till: Instant, + updating: RwLock, } impl CachedAccess { @@ -36,11 +37,18 @@ impl CachedAccess { access, valid_till: Instant::now() + Duration::from_millis(rng.gen_range((4 * 60 * 1000)..(5 * 60 * 1000))), + updating: RwLock::new(false), } } pub fn is_valid(&self) -> bool { - self.valid_till > Instant::now() + self.valid_till > Instant::now() || self.updating.try_read().is_ok_and(|value| *value) + } + + pub fn prepare_update(&self, value: bool) { + if let Ok(mut updating) = self.updating.try_write() { + *updating = value + } } } @@ -71,14 +79,27 @@ impl StorageMapping { &self, storage: u32, ) -> Result, DatabaseError> { - if let Some(cached) = self.cache.get(&storage).filter(|cached| cached.is_valid()) { - Ok(cached) - } else { - let users = self.load_storage_mapping(storage).await?; + if let Some(cached) = self.cache.get(&storage) { + if cached.is_valid() { + return Ok(cached); + } - self.cache.insert(storage, CachedAccess::new(users)); - Ok(self.cache.get(&storage).unwrap()) + cached.prepare_update(true); + let users = self + .load_storage_mapping(storage) + .await + .inspect_err(|_| cached.prepare_update(false))?; + + drop(cached); + let cached = CachedAccess::new(users); + self.cache.insert(storage, cached); + return Ok(self.cache.get(&storage).unwrap()); } + + let users = self.load_storage_mapping(storage).await?; + + self.cache.insert(storage, CachedAccess::new(users)); + Ok(self.cache.get(&storage).unwrap()) } pub async fn get_users_for_storage_path(