Skip to content

Commit

Permalink
Implement user fetching
Browse files Browse the repository at this point in the history
After a user is authenticated, allow fetching the current user at
/api/user.

All requests can receive the user authenticated for the request by
accepting a user parameter.

Also refactored the Authenticate error into EncodeJwt (to partner with
the new DecodeJwt error).

In addition, implemented Serialize / Deserialize for Id so that it
doesn't have to be manually converted into a string anywhere.
  • Loading branch information
Aleksbgbg committed Feb 10, 2024
1 parent f1be96c commit f3a6a46
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 25 deletions.
16 changes: 12 additions & 4 deletions backend-rs/src/controllers/errors.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::controllers::user::AuthError;
use crate::models::user;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
Expand Down Expand Up @@ -52,13 +51,18 @@ pub enum HandlerError {
UsernameTaken,
#[error("Email Address must not be taken.")]
EmailTaken,
#[error("Authentication token is invalid.")]
DecodeJwt(jsonwebtoken::errors::Error),

#[error("No user was logged in but a user is required.")]
UserRequired,

#[error("Database transaction failed.")]
Database(#[from] DbErr),
#[error("Could not create user.")]
CreateUser(#[from] user::CreateError),
#[error("Could not authenticate: {0}.")]
Authenticate(#[from] AuthError),
#[error("Could not encode JWT.")]
EncodeJwt(jsonwebtoken::errors::Error),
}

impl HandlerError {
Expand All @@ -78,9 +82,13 @@ impl IntoResponse for HandlerError {
match self {
HandlerError::UsernameTaken => self.failed_validation(StatusCode::BAD_REQUEST, "username"),
HandlerError::EmailTaken => self.failed_validation(StatusCode::BAD_REQUEST, "emailAddress"),
HandlerError::DecodeJwt(_) => self.into_generic(StatusCode::BAD_REQUEST),

HandlerError::UserRequired => self.into_generic(StatusCode::UNAUTHORIZED),

HandlerError::Database(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR),
HandlerError::CreateUser(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR),
HandlerError::Authenticate(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR),
HandlerError::EncodeJwt(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR),
}
.into_response()
}
Expand Down
77 changes: 64 additions & 13 deletions backend-rs/src/controllers/user.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,44 @@
use crate::app_state::AppState;
use crate::controllers::errors::HandlerError;
use crate::controllers::validate::ValidatedJson;
use crate::models::id::Id;
use crate::models::user::{self, CreateUser, User};
use axum::extract::State;
use axum::extract::{FromRequestParts, State};
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{async_trait, Json, RequestPartsExt};
use axum_extra::extract::cookie::{Cookie, CookieJar};
use chrono::{Duration, Utc};
use jsonwebtoken::{EncodingKey, Header};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use validator::Validate;

const AUTH_COOKIE_KEY: &str = "Authorization";
const SECRET: &str = "123456";

#[derive(Debug, Serialize, Deserialize)]
struct Claims {
uid: String,
uid: Id,
exp: usize,
}

#[derive(Error, Debug)]
pub enum AuthError {
#[error("could not encode JWT")]
EncodeJwt(#[from] jsonwebtoken::errors::Error),
}

fn authenticate(jar: CookieJar, user: &User, lifespan: Duration) -> Result<CookieJar, AuthError> {
fn authenticate(
jar: CookieJar,
user: &User,
lifespan: Duration,
) -> Result<CookieJar, HandlerError> {
let cookie = Cookie::build((
AUTH_COOKIE_KEY,
jsonwebtoken::encode(
&Header::default(),
&Claims {
uid: user.id.to_string(),
uid: user.id,
exp: (Utc::now() + lifespan).timestamp().try_into().unwrap(),
},
&EncodingKey::from_secret(SECRET.as_ref()),
)?,
)
.map_err(HandlerError::EncodeJwt)?,
))
.max_age(time::Duration::seconds(lifespan.num_seconds()))
.path("/")
Expand Down Expand Up @@ -84,3 +86,52 @@ pub async fn register(

Ok((StatusCode::CREATED, cookies))
}

#[async_trait]
impl FromRequestParts<AppState> for User {
type Rejection = HandlerError;

async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let cookies = parts.extract::<CookieJar>().await.unwrap();
let token = cookies
.get(AUTH_COOKIE_KEY)
.ok_or(HandlerError::UserRequired)?
.value();
let claims = jsonwebtoken::decode::<Claims>(
token,
&DecodingKey::from_secret(SECRET.as_ref()),
&Validation::default(),
)
.map_err(HandlerError::DecodeJwt)?
.claims;

user::find(state, claims.uid)
.await?
.ok_or(HandlerError::UserRequired)
}
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct UserResponse {
id: Id,
username: String,
}

pub async fn get_user(user: Option<User>) -> Response {
if let Some(u) = user {
(
StatusCode::OK,
Json(UserResponse {
id: u.id,
username: u.visible_username().to_string(),
}),
)
.into_response()
} else {
StatusCode::NOT_FOUND.into_response()
}
}
6 changes: 4 additions & 2 deletions backend-rs/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ async fn main() -> Result<(), AppError> {
.map_err(AppError::MigrateDatabase)?;

let app = Router::new()
.route("/api/auth/register", routing::post(user::register))
.route("/auth/register", routing::post(user::register))
.route("/user", routing::get(user::get_user))
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().level(Level::INFO))
Expand All @@ -63,6 +64,7 @@ async fn main() -> Result<(), AppError> {
connection,
user_snowflake: SnowflakeIdBucket::new(1, 1),
});
let api = Router::new().nest("/api", app);

let listener = TcpListener::bind(SocketAddr::from((config.app.host, config.app.port)))
.await
Expand All @@ -75,7 +77,7 @@ async fn main() -> Result<(), AppError> {
.map_err(AppError::GetListenerAddress)?
);

axum::serve(listener, app)
axum::serve(listener, api)
.await
.map_err(AppError::ServeApp)?;

Expand Down
65 changes: 59 additions & 6 deletions backend-rs/src/models/id.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use sea_orm::{DbErr, DeriveValueType, TryFromU64};
use serde::de::{Unexpected, Visitor};
use serde::{Deserialize, Serialize};
use std::fmt::Formatter;

#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, DeriveValueType)]
pub struct Id(i64);

impl Id {
Expand All @@ -21,12 +24,62 @@ impl TryFromU64 for Id {
}
}

fn encode(id: Id) -> Result<String, bs58::encode::Error> {
let mut encoded = String::new();
bs58::encode(id.0.to_be_bytes()).onto(&mut encoded)?;

Ok(encoded)
}

fn decode(str: &str) -> Result<Id, bs58::decode::Error> {
let mut decoded = [0; 8];
bs58::decode(str).onto(&mut decoded)?;

Ok(Id::from(i64::from_be_bytes(decoded)))
}

impl ToString for Id {
fn to_string(&self) -> String {
let mut string = String::new();
bs58::encode(self.0.to_be_bytes())
.onto(&mut string)
.expect("could not encode ID into base 58");
string
encode(*self).expect("could not encode ID into base 58")
}
}

impl Serialize for Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let string = encode(*self).map_err(|err| {
serde::ser::Error::custom(format!("could not encode ID into base58: {}", err))
})?;
serializer.serialize_str(&string)
}
}

impl<'de> Deserialize<'de> for Id {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(IdVisitor)
}
}

struct IdVisitor;

impl<'de> Visitor<'de> for IdVisitor {
type Value = Id;

fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("a base58-encoded string")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
decode(v).map_err(|_| {
serde::de::Error::invalid_value(Unexpected::Str(v), &"a valid base-58 encoded string")
})
}
}
10 changes: 10 additions & 0 deletions backend-rs/src/models/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ pub struct Model {
pub password: Option<String>,
}

impl Model {
pub fn visible_username(&self) -> &str {
self.username.as_ref().map_or("Anonymous", String::as_str)
}
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

Expand Down Expand Up @@ -83,3 +89,7 @@ pub async fn create(state: &mut AppState, user: CreateUser<'_>) -> Result<User,
.await?,
)
}

pub async fn find(state: &AppState, id: Id) -> Result<Option<User>, DbErr> {
Entity::find_by_id(id).one(&state.connection).await
}

0 comments on commit f3a6a46

Please sign in to comment.