Skip to content

Commit

Permalink
Implement custom thread-safe snowflake IDs
Browse files Browse the repository at this point in the history
rs-snowflake generated IDs that count down (in reverse to what was
expected), which was not satisfactory. We then found another create,
idgenerator, which seemed good but wasn't thread-safe. In conclusion we
decided that writing our own simple thread-safe implementation would be
less hassle than dealing with external crates.
  • Loading branch information
Aleksbgbg committed Feb 18, 2024
1 parent 4eb3a8d commit 717a26f
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
7 changes: 0 additions & 7 deletions backend-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion backend-rs/backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ fs-file = { version = "0.0.0", path = "../fs-file", default-features = false }
jsonwebtoken = { version = "9.2.0", default-features = false }
lazy_static = "1.4.0"
ring = "0.17.7"
rs-snowflake = "0.6.0"
sea-orm = { version = "0.12.14", features = ["runtime-tokio-rustls", "sqlx-postgres"] }
sea-orm-migration = "0.12.12"
serde = { version = "1.0.196", features = ["derive"] }
Expand Down
8 changes: 4 additions & 4 deletions backend-rs/backend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ mod config;
mod controllers;
mod models;
mod secure;
mod snowflake;

use crate::config::{Config, ConfigError};
use crate::controllers::user;
use crate::models::migrations::migrator::Migrator;
use crate::models::user::CreateDefaultUsersError;
use crate::snowflake::SnowflakeGenerator;
use axum::{routing, Router};
use cascade::cascade;
use fs::filesystem;
use sea_orm::{Database, DatabaseConnection, DbErr};
use sea_orm_migration::MigratorTrait;
use snowflake::SnowflakeIdBucket;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tracing::{info, Level};

Expand Down Expand Up @@ -56,7 +56,7 @@ pub struct AppState {
}

pub struct Snowflakes {
pub user_snowflake: Mutex<SnowflakeIdBucket>,
pub user_snowflake: SnowflakeGenerator,
}

#[tokio::main]
Expand Down Expand Up @@ -98,7 +98,7 @@ async fn main() -> Result<(), AppError> {
config,
connection,
snowflakes: Arc::new(Snowflakes {
user_snowflake: Mutex::new(SnowflakeIdBucket::new(1, 1)),
user_snowflake: SnowflakeGenerator::new(0),
}),
});
let api = Router::new().nest("/api", app);
Expand Down
2 changes: 1 addition & 1 deletion backend-rs/backend/src/models/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub async fn create(
) -> Result<User, CreateError> {
create_base(
connection,
Id::from(snowflakes.user_snowflake.lock().await.get_id()),
Id::from(snowflakes.user_snowflake.generate_id().await),
Some(user.username),
Some(user.email_address),
Some(user.password),
Expand Down
113 changes: 113 additions & 0 deletions backend-rs/backend/src/snowflake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use chrono::{TimeZone, Utc};
use lazy_static::lazy_static;
use std::sync::atomic::{AtomicI64, Ordering};
use std::time::Duration;
use tokio::time;

const fn mask(start: usize, bits: usize) -> i64 {
assert!((start + bits) <= 63); // Ensure sign bit is empty
((1 << bits) - 1) << start
}

const fn base_mask(bits: usize) -> i64 {
mask(0, bits)
}

const WORKER_BITS: usize = 8;
const SEQUENCE_BITS: usize = 12;
const TIMESTAMP_BITS: usize = 43;

const WORKER_SHIFT: usize = 0;
const SEQUENCE_SHIFT: usize = WORKER_SHIFT + WORKER_BITS;
const TIMESTAMP_SHIFT: usize = SEQUENCE_SHIFT + SEQUENCE_BITS;

const WORKER_MASK: i64 = mask(WORKER_SHIFT, WORKER_BITS);
const SEQUENCE_MASK: i64 = mask(SEQUENCE_SHIFT, SEQUENCE_BITS);
const TIMESTAMP_MASK: i64 = mask(TIMESTAMP_SHIFT, TIMESTAMP_BITS);

const MAX_WORKER: i64 = base_mask(WORKER_BITS);
const MAX_SEQUENCE: i64 = base_mask(SEQUENCE_BITS);

const SEQUENCE_EXHAUSTED_DELAY: Duration = Duration::from_micros(100);

lazy_static! {
static ref EPOCH: i64 = Utc
.with_ymd_and_hms(2020, 3, 12, 21, 33, 54)
.unwrap()
.timestamp();
}

fn timestamp() -> i64 {
Utc::now().timestamp() - *EPOCH
}

fn pack(timestamp: i64, sequence: i64, worker: i64) -> i64 {
(timestamp << TIMESTAMP_SHIFT) | (sequence << SEQUENCE_SHIFT) | (worker << WORKER_SHIFT)
}

fn unpack(id: i64) -> (i64, i64, i64) {
(
(id & TIMESTAMP_MASK) >> TIMESTAMP_SHIFT,
(id & SEQUENCE_MASK) >> SEQUENCE_SHIFT,
(id & WORKER_MASK) >> WORKER_SHIFT,
)
}

/// Thread-safe snowflake ID generator with the following layout:
/// | 1 bit | 43 bits | 12 bits | 8 bits |
/// | sign | timestamp | sequence | worker |
///
/// The epoch is 2020 Mar 12 21:33:54, the first commit of Streamfox. A 43 bit
/// timestamp allows enough space to generate valid timestamps for another ~278
/// years, until the year ~2300.
///
/// The timestamp must use the highest bits to preserve monotonicity, otherwise
/// a previous timestamp with a larger sequence or worker number would create an
/// ID larger than the ID with the current timestamp.
///
/// If we want to avoid generating consecutive IDs for objects generated at the
/// same time, we want to avoid putting the sequence at the lowest bits,
/// therefore we put it in the middle and leave the worker at the lowst bits.
pub struct SnowflakeGenerator {
id: AtomicI64,
}

impl SnowflakeGenerator {
pub fn new(table_index: i64) -> Self {
assert!(table_index <= MAX_WORKER);

Self {
id: AtomicI64::new(pack(timestamp(), 0, table_index)),
}
}

pub async fn generate_id(&self) -> i64 {
let mut last_id = self.id.load(Ordering::Relaxed);
loop {
let (last_timestamp, last_sequence, last_worker) = unpack(last_id);

let timestamp = timestamp();
let sequence = if timestamp == last_timestamp {
last_sequence + 1
} else {
0
};

if sequence > MAX_SEQUENCE {
time::sleep(SEQUENCE_EXHAUSTED_DELAY).await;
last_id = self.id.load(Ordering::Relaxed);
continue;
}

let new_id = pack(timestamp, sequence, last_worker);

match self
.id
.compare_exchange(last_id, new_id, Ordering::Relaxed, Ordering::Relaxed)
{
Ok(_) => break new_id,
Err(id) => last_id = id,
}
}
}
}

0 comments on commit 717a26f

Please sign in to comment.