diff --git a/backend-rs/Cargo.lock b/backend-rs/Cargo.lock index 4921621..3450277 100644 --- a/backend-rs/Cargo.lock +++ b/backend-rs/Cargo.lock @@ -1840,12 +1840,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "rs-snowflake" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e60ef3b82994702bbe4e134d98aadca4b49ed04440148985678d415c68127666" - [[package]] name = "rsa" version = "0.9.6" @@ -2574,7 +2568,6 @@ dependencies = [ "jsonwebtoken", "lazy_static", "ring", - "rs-snowflake", "sea-orm", "sea-orm-migration", "serde", diff --git a/backend-rs/backend/Cargo.toml b/backend-rs/backend/Cargo.toml index f38084d..19afe5b 100644 --- a/backend-rs/backend/Cargo.toml +++ b/backend-rs/backend/Cargo.toml @@ -17,7 +17,6 @@ fs-file = { version = "0.0.0", path = "../fs-file", default-features = false } jsonwebtoken = { version = "9.2.0", default-features = false } lazy_static = "1.4.0" ring = "0.17.7" -rs-snowflake = "0.6.0" sea-orm = { version = "0.12.14", features = ["runtime-tokio-rustls", "sqlx-postgres"] } sea-orm-migration = "0.12.12" serde = { version = "1.0.196", features = ["derive"] } diff --git a/backend-rs/backend/src/main.rs b/backend-rs/backend/src/main.rs index f6b3bce..8decb19 100644 --- a/backend-rs/backend/src/main.rs +++ b/backend-rs/backend/src/main.rs @@ -2,23 +2,23 @@ mod config; mod controllers; mod models; mod secure; +mod snowflake; use crate::config::{Config, ConfigError}; use crate::controllers::user; use crate::models::migrations::migrator::Migrator; use crate::models::user::CreateDefaultUsersError; +use crate::snowflake::SnowflakeGenerator; use axum::{routing, Router}; use cascade::cascade; use fs::filesystem; use sea_orm::{Database, DatabaseConnection, DbErr}; use sea_orm_migration::MigratorTrait; -use snowflake::SnowflakeIdBucket; use std::io; use std::net::SocketAddr; use std::sync::Arc; use thiserror::Error; use tokio::net::TcpListener; -use tokio::sync::Mutex; use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}; use tracing::{info, Level}; @@ -56,7 +56,7 @@ pub struct AppState { } pub struct Snowflakes { - pub user_snowflake: Mutex, + pub user_snowflake: SnowflakeGenerator, } #[tokio::main] @@ -98,7 +98,7 @@ async fn main() -> Result<(), AppError> { config, connection, snowflakes: Arc::new(Snowflakes { - user_snowflake: Mutex::new(SnowflakeIdBucket::new(1, 1)), + user_snowflake: SnowflakeGenerator::new(0), }), }); let api = Router::new().nest("/api", app); diff --git a/backend-rs/backend/src/models/user.rs b/backend-rs/backend/src/models/user.rs index 859f6b1..48303d1 100644 --- a/backend-rs/backend/src/models/user.rs +++ b/backend-rs/backend/src/models/user.rs @@ -128,7 +128,7 @@ pub async fn create( ) -> Result { create_base( connection, - Id::from(snowflakes.user_snowflake.lock().await.get_id()), + Id::from(snowflakes.user_snowflake.generate_id().await), Some(user.username), Some(user.email_address), Some(user.password), diff --git a/backend-rs/backend/src/snowflake.rs b/backend-rs/backend/src/snowflake.rs new file mode 100644 index 0000000..5e8b33e --- /dev/null +++ b/backend-rs/backend/src/snowflake.rs @@ -0,0 +1,113 @@ +use chrono::{TimeZone, Utc}; +use lazy_static::lazy_static; +use std::sync::atomic::{AtomicI64, Ordering}; +use std::time::Duration; +use tokio::time; + +const fn mask(start: usize, bits: usize) -> i64 { + assert!((start + bits) <= 63); // Ensure sign bit is empty + ((1 << bits) - 1) << start +} + +const fn base_mask(bits: usize) -> i64 { + mask(0, bits) +} + +const WORKER_BITS: usize = 8; +const SEQUENCE_BITS: usize = 12; +const TIMESTAMP_BITS: usize = 43; + +const WORKER_SHIFT: usize = 0; +const SEQUENCE_SHIFT: usize = WORKER_SHIFT + WORKER_BITS; +const TIMESTAMP_SHIFT: usize = SEQUENCE_SHIFT + SEQUENCE_BITS; + +const WORKER_MASK: i64 = mask(WORKER_SHIFT, WORKER_BITS); +const SEQUENCE_MASK: i64 = mask(SEQUENCE_SHIFT, SEQUENCE_BITS); +const TIMESTAMP_MASK: i64 = mask(TIMESTAMP_SHIFT, TIMESTAMP_BITS); + +const MAX_WORKER: i64 = base_mask(WORKER_BITS); +const MAX_SEQUENCE: i64 = base_mask(SEQUENCE_BITS); + +const SEQUENCE_EXHAUSTED_DELAY: Duration = Duration::from_micros(100); + +lazy_static! { + static ref EPOCH: i64 = Utc + .with_ymd_and_hms(2020, 3, 12, 21, 33, 54) + .unwrap() + .timestamp(); +} + +fn timestamp() -> i64 { + Utc::now().timestamp() - *EPOCH +} + +fn pack(timestamp: i64, sequence: i64, worker: i64) -> i64 { + (timestamp << TIMESTAMP_SHIFT) | (sequence << SEQUENCE_SHIFT) | (worker << WORKER_SHIFT) +} + +fn unpack(id: i64) -> (i64, i64, i64) { + ( + (id & TIMESTAMP_MASK) >> TIMESTAMP_SHIFT, + (id & SEQUENCE_MASK) >> SEQUENCE_SHIFT, + (id & WORKER_MASK) >> WORKER_SHIFT, + ) +} + +/// Thread-safe snowflake ID generator with the following layout: +/// | 1 bit | 43 bits | 12 bits | 8 bits | +/// | sign | timestamp | sequence | worker | +/// +/// The epoch is 2020 Mar 12 21:33:54, the first commit of Streamfox. A 43 bit +/// timestamp allows enough space to generate valid timestamps for another ~278 +/// years, until the year ~2300. +/// +/// The timestamp must use the highest bits to preserve monotonicity, otherwise +/// a previous timestamp with a larger sequence or worker number would create an +/// ID larger than the ID with the current timestamp. +/// +/// If we want to avoid generating consecutive IDs for objects generated at the +/// same time, we want to avoid putting the sequence at the lowest bits, +/// therefore we put it in the middle and leave the worker at the lowst bits. +pub struct SnowflakeGenerator { + id: AtomicI64, +} + +impl SnowflakeGenerator { + pub fn new(table_index: i64) -> Self { + assert!(table_index <= MAX_WORKER); + + Self { + id: AtomicI64::new(pack(timestamp(), 0, table_index)), + } + } + + pub async fn generate_id(&self) -> i64 { + let mut last_id = self.id.load(Ordering::Relaxed); + loop { + let (last_timestamp, last_sequence, last_worker) = unpack(last_id); + + let timestamp = timestamp(); + let sequence = if timestamp == last_timestamp { + last_sequence + 1 + } else { + 0 + }; + + if sequence > MAX_SEQUENCE { + time::sleep(SEQUENCE_EXHAUSTED_DELAY).await; + last_id = self.id.load(Ordering::Relaxed); + continue; + } + + let new_id = pack(timestamp, sequence, last_worker); + + match self + .id + .compare_exchange(last_id, new_id, Ordering::Relaxed, Ordering::Relaxed) + { + Ok(_) => break new_id, + Err(id) => last_id = id, + } + } + } +}