【Rust】axumのauthとmiddleware

use std::{collections::HashMap, sync::Arc};

use axum::{
    async_trait,
    extract::{FromRequestParts, Request, State},
    middleware::Next,
    response::Response,
    RequestExt as _,
};
use axum_extra::{
    headers::{authorization::Bearer, Authorization},
    TypedHeader,
};
use http::{request::Parts, StatusCode};

pub type Token = String;
pub type UserMap = HashMap<Token, User>;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
    pub username: String,
}

pub fn build_user_map() -> UserMap {
    let mut user_map = HashMap::new();
    user_map.insert(
        "aaa".to_string(),
        User {
            username: "Andy".to_string(),
        },
    );
    user_map.insert(
        "bbb".to_string(),
        User {
            username: "Bella".to_string(),
        },
    );
    user_map.insert(
        "ccc".to_string(),
        User {
            username: "Callie".to_string(),
        },
    );
    user_map.insert(
        "ddd".to_string(),
        User {
            username: "Daren".to_string(),
        },
    );
    user_map
}

#[async_trait]
impl<S> FromRequestParts<S> for User
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
        let user = parts
            .extensions
            .get::<Self>()
            .expect("User not found. Did you add auth_middleware?");
        Ok(user.clone())
    }
}

pub async fn auth_middleware(
    State(user_map): State<Arc<UserMap>>,
    mut request: Request,
    next: Next,
) -> axum::response::Result<Response> {
    let bearer = request
        .extract_parts::<TypedHeader<Authorization<Bearer>>>()
        .await
        .map_err(|_| StatusCode::BAD_REQUEST)?;
    let token = bearer.token();

    let user = user_map.get(token).ok_or(StatusCode::UNAUTHORIZED)?;
    request.extensions_mut().insert(user.clone());

    Ok(next.run(request).await)
}
mod auth;

use std::sync::Arc;
use axum::{middleware::from_fn_with_state, routing::get, Router};
use auth::{auth_middleware, build_user_map, User, UserMap};


#[tokio::main]
async fn main() {
    let user_map = build_user_map();
    let app = build_app(Arc::new(user_map));
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

#[rustfmt::skip]
fn build_app(user_map: Arc<UserMap>) -> Router {
    let public_router = Router::new()
        .route("/public", get(public));

    let private_router = Router::new()
        .route("/private", get(private))
        .route("/your-name", get(your_name))
        .route_layer(from_fn_with_state(user_map.clone(), auth_middleware));

    Router::new()
        .nest("/", public_router)
        .nest("/", private_router)
        .with_state(user_map)
}

async fn public() -> &'static str {
    "This is public."
}

async fn private() -> &'static str {
    "This is public."
}

async fn your_name(user: User) -> String {
    format!("Your name is {}.", user.username)
}

curl http://192.168.33.10:3000/public
curl -I http://192.168.33.10:3000/private
curl -H ‘Authorization: Bearer aaa’ http://192.168.33.10:3000/your-name

やりたいことはギリわかるが、Bearer認証およびtraitのextensionとauth middlewareのinputとその先の内容がイマイチ理解できん…