【Rust】フレームワーク自作: HttpResponse, Redirectを実装する

利用側: main.rs

use std::collections::HashMap;
use pwhash::bcrypt;
mod ares;
use ares::{Router, HttpResponse, redirect, psql_connect, parse, remove_null_bytes};

fn main() {
    let mut router = Router::new();

    router.get("index", handle_index);
    router.get("hello", handle_hello);
    router.get("signin", handle_signin);
    router.post("signup", handle_signup);
    router.get("login", handle_login);
    router.post("logup", handle_logup);

    router.up("192.168.33.10:8000");
}


fn handle_index() -> Option<HashMap<&'static str, &'static str>> {

    let mut content1 = HashMap::new();
    content1.insert(
        "title",
        "This is index page!",
    );
    return Some(content1);
}

fn handle_hello() -> Option<HashMap<&'static str, &'static str>> {
    return None;
}

fn handle_signin() -> Option<HashMap<&'static str, &'static str>> {
    return None;
}

fn handle_signup(body: String) -> HttpResponse {
    let form =  parse(&body);
    let binding = "<unknown>".to_string();
    let name = form.get("name").unwrap_or(&binding);
    let password = form.get("password").unwrap_or(&binding);

    let pass_hash = bcrypt::hash(password).unwrap();

    let mut client = psql_connect().unwrap();
    let _ = client.execute(
        "INSERT INTO test (username, password) VALUES ($1, $2)",
        &[&remove_null_bytes(&name), &remove_null_bytes(&pass_hash)],
    ).unwrap();
    HttpResponse::new(200, "<h1>Success Signin!</h1>")
}

fn handle_login() -> Option<HashMap<&'static str, &'static str>> {
    return None;
}

fn handle_logup(body: String) -> HttpResponse {
    let form =  parse(&body);
    let binding = "<unknown>".to_string();
    let name = form.get("name").unwrap_or(&binding);
    let password = form.get("password").unwrap_or(&binding);

    let mut client = psql_connect().unwrap();
    let row = client.query(
        "SELECT * from test where username=$1",
        &[&name],
    ).unwrap();
    let value: String = row[0].get(2);
    if bcrypt::verify(password, &value) {
        HttpResponse::new(200, "<h1>Success Login!</h1>")
    } else {
        redirect("/login")
    }
}

フレームワーク側

use std::fs;
use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
use std::collections::HashMap;
use postgres::{Client, NoTls};
use std::env;
use dotenv::dotenv;

pub struct HttpResponse {
    pub status_code: u16,
    pub headers: HashMap<String, String>,
    pub body: String,
}

impl HttpResponse {
    pub fn new(status_code: u16, body: &str) -> Self {
        let mut headers = HashMap::new();
        headers.insert("Content-Type".to_string(), "text/html".to_string());

        HttpResponse {
            status_code,
            headers,
            body: body.to_string(),
        }
    }
}

pub struct Router {
    get_routes: HashMap<String, String>,
    post_routes: HashMap<String, fn(String) -> HttpResponse>,
}

impl Router {

    pub fn new() -> Self {
        Router {
            get_routes: HashMap::new(),
            post_routes: HashMap::new(),
        }
    }

    pub fn get(&mut self, path: &str, f: fn()-> Option<HashMap<&'static str, &'static str>>){
        // let mut content: Option<HashMap<&str, &str>> = Some(HashMap::new());
        let content = f();

        let temp_path = format!("./templates/{}.html", path);
        let mut html = fs::read_to_string(temp_path).unwrap();

        match content {
            Some(data) => {
                for (key, value) in data {
                    let k = format!("{{{{ {} }}}}", key);
                    html = html.replace(&k, value);
                }
            },
            None => {},
        }
        let route_path = format!("/{}", path);
        self.get_routes.insert(route_path, html);
    }

    pub fn post(&mut self, path: &str, handler: fn(String) -> HttpResponse) {
        self.post_routes.insert(format!("/{}", path), handler);
    }

    pub fn up(&self, ip: &str) {
        let listenr = TcpListener::bind(ip).unwrap();
        for stream in listenr.incoming() {
            match stream {
                Ok(stream) => {
                    let _ = handle_connection(
                        stream, 
                        self.get_routes.clone(),
                        self.post_routes.clone(),
                    );
                }
                Err(e) => {
                    println!("Connection failed: {}", e);
                }
            }
        }
    }
}

fn handle_connection(mut stream: TcpStream, get_routes: HashMap<String, String>, post_routes: HashMap<String, fn(String) -> HttpResponse>)  {
    let mut buffer = [0; 1024];
    stream.read(&mut buffer).unwrap();
 
    let request = String::from_utf8_lossy(&buffer);
    let request_line = request.lines().next().unwrap_or("");
    let mut parts = request_line.split_whitespace();
    let method = parts.next().unwrap_or("");
    let path = parts.next().unwrap_or("/");

    // println!("Received {} request for {}", method, path);
 

    let response = match method {
        "GET" => {
            let body = get_routes.get(path).cloned().unwrap_or_else(|| "<h1>404 Not Found</h1>".to_string());
            http_response(200, "text/html", &body)
        }
        "POST" => {
            let body = request.split("\r\n\r\n").nth(1).unwrap_or("").to_string();
            match post_routes.get(path) {
                Some(handler) => {
                    let response = handler(body);
                    http_response_custom(response)
                },
                None => http_response(404, "text/html", "<h1>404 Not Found</h1>"),
            }
        }
        _ => http_response(405, "text/html", "<h1>405 Method Not Allowed</h1>"),
    };
 
    stream.write_all(response.as_bytes()).unwrap();
    stream.flush().unwrap();
 }
 
 fn http_response(status_code: u16, content_type: &str, body: &str) -> String {
    format!(
        "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
        status_code,
        get_status_text(status_code),
        content_type,
        body.len(),
        body
    )
 }
 
 fn http_response_custom(resp: HttpResponse) -> String {
    let mut response = format!(
        "HTTP/1.1 {} {}\r\n",
        resp.status_code,
        get_status_text(resp.status_code)
    );

    let mut has_location = false;
    for (k, v) in &resp.headers {
        if k.to_lowercase() == "location" {
            has_location = true;
        }
        response.push_str(&format!("{}: {}\r\n", k, v));
    }

    if !has_location {
        response.push_str("Content-Type: text/html\r\n");
    }

    response.push_str(&format!("Content-Length: {}\r\n", resp.body.len()));
    response.push_str("\r\n"); 

    response.push_str(&resp.body);

    response
}


 fn get_status_text(code: u16) -> &'static str {
    match code {
        200 => "OK",
        404 => "Not Found",
        _ => "Unknown",
    }
 }

pub fn parse(body: &str) -> HashMap<String, String> {
    let mut data = HashMap::new();

    for pair in body.split('&') {
        let mut iter = pair.splitn(2, '=');
        if let (Some(key), Some(value)) = (iter.next(), iter.next()) {
            let key = url_decode(key);
            let value = url_decode(value);
            data.insert(key, value);
        }
    }
    data
}

fn url_decode(s: &str) -> String {
    s.replace("+", " ")
        .replace("%20", " ") 
}

pub fn psql_connect() -> Result<Client, Box<dyn std::error::Error>> {
    let _ = dotenv();
    let conn_str = format!("host=localhost user=postgres password={}", env::var("PSQL_PASSWORD").unwrap());
    let client = Client::connect(&conn_str, NoTls)?;

    Ok(client)
}

pub fn redirect(location: &str) -> HttpResponse {
    let mut headers = HashMap::new();
    headers.insert("Location".to_string(), location.to_string());

    HttpResponse {
        status_code: 302,
        headers,
        body: String::new(), // もしくは軽いメッセージを入れてもOK
    }
}

pub fn remove_null_bytes(s: &str) -> String {
    s.chars().filter(|&c| c != '\0').collect()
}

これに、cookieの機能をつけたい