Rust axumでcrudを実装する

$ cargo new axum_crud_api

Cargo.toml

[package]
name = "axum_crud_api"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = "0.5.9"
tokio = { version="1.0", features = ["full"] }
serde = "1.0.137"
tracing = "0.1"
tracing-subscriber = { version="0.3", features = ["env-filter"] }

sqlx = { version = "0.5", features = ["runtime-tokio-tls", "json", "postgres"] }
anyhow = "1.0.58"
serde_json = "1.0.57"
tower-http = { version="0.3.4", features = ["trace"] }

main.rs

use axum::{
	routing::{get},
	Router,
};

use std::net::SocketAddr;

#[tokio::main]
async fn main() {
	let app = Router::new()
		.route("/", get(root));

	let addr = SocketAddr::from(([192,168,56,10], 8000));
	println!("listening on {}", addr);
	axum::Server::bind(&addr)
		.serve(app.into_make_service())
		.await
		.unwrap();
}

async fn root() -> &'static str {
	"Hello, World"
}

task.rs

use serde::{Deserialize, Serialize};

pub struct Task {
	pub id: i32,
	pub task: String,
}

pub struct NewTask {
	pub task: String,
}

DATABASE_URL = postgresql://user:password@locahost:host/database

$ cargo install sqlx-cli
$ cargo install sqlx-cli –no-default-features –features native-tls,postgres
$ cargo install sqlx-cli –features openssl-vendored
$ cargo install sqlx-cli –no-default-features –features rustls

sqlx database create
sqlx migrate add task

CREATE TABLE task {
id SERIAL PRIMARY KEY,
task varch(255) NOT NULL
}

sqlx migrate run

main.rs

use axum::{
	extract::{Extension},
	routing::{get, post},
	Router,
};

use sqlx::postgres::PgPoolOptions;
use std::net::SocketAddr;
use std::fs;
use anyhow::Context

#[tokio::main]
async fn main() -> anyhow::Result<()> {

	let env = fs::read_to_string(".env").unwrap();
	let (key, database_url) = env.split_once('=').unwrap();

	assert_eq!(key, "DATABASE_URL");

	tracing_subscriber::fmt::init();

	let pool = PgPoolOptions::new()
	.max_connections(50)
	.connect(&dtabase_url)
	.await
	.context("could not connect to database_url")?;

	let app = Router::new()
		.route("/hello", get(root));

	let addr = SocketAddr::from(([192,168,56,10], 8000));
	println!("listening on {}", addr);
	axum::Server::bind(&addr)
		.serve(app.into_make_service())
		.await?;

		Ok(())
}

controller/task.rs

use axum::response::IntoResponse;
use axum::http::StatusCode;

use axum::{Extension, Json}
use sqlx::PgPool;

use crate::{
	models::task
};

pub async fn all_tasks(Extension(pool): Extension<PgPool>) -> impl IntoResponse {
	let sql = "SELECT * FROM task".to_string();

	let task = sqlx::query_as::<_, task::Task>(&sql).fetch_all(&pool).await.unwrap();

	(StatusCode::OK, Json(task))
}

error.rs

use axum::{http::StatusCode, response::IntoResponse, Json};
use serde_json::json;

pub enum CustomError {
	BadRequest,
	TaskNotFound,
	InternalServerError,
}

impl IntoResponse for CustomError {
	fn into_response(self) -> axum::response::Response {
		let (status, error_message) = match self {
			Self::InternalServerError => (
				StatusCode::INTERNAL_SERVER_ERROR,
				"Internal Server Error",
			),
			Self::BadRequest=> (StatusCode::BAD_REQUEST, "Bad Request"),
			Self::TaskNotFound => (StatusCode::NOT_FOUND, "Task Not Found"),
		};
		(status, Json(json!({"error": error_message}))).into_response()
	}
}

GET

// GET
pub async fn task(Path(id):Path<i32>,
	Extension(pool): Extension<PgPool>)-> Result<Json<task::Task>, CustomError>{

	let sql = "SELECT * FROM task where id=$1".to_string();

	let task: task::Task = sqlx::query_as(&sql)
		.bind(id)
		.fetch_one(&pool)
		.await
		.map_err(|_| {
			CustomError::TaskNotFound
		})?;

	Ok(Json(task))
}

// POST
pub async fn new_task(Json(task): Json<task::NewTask>,
	Extension(pool): Extension<PgPool>) -> Result<(StatusCode,
	Json<task::NewTask>), CustomError> {
		if task.task.is_empty() {
			return Err(CustomError::BadRequest)
		}
		let sql = "INSERT INTO task (task) values ($1)";

		let _ = sql::query(&sql)
			.bind(&task.task)
			.execute(&pool)
			.await
			.map_err(|_| {
				CustomError::InternalServerError
			})?;

		Ok((StatusCode::CREATED, Json(task)))
	}


// PUT
pub async fn update_task(Path(id): Path<i32>,
	Json(task): Json<task::UpdateTask>, Extension(pool): Extension<PgPool>)
	-> Result <(StatusCode, Json<task::UpdateTask>), CustomError> {

		let sql = "SELECT * FROM task where id=$1".to_string();

		let _find: task::Task = sqlx::query_as(&sql)
			.bind(id)
			.fetch_one(&pool)
			.await
			.map_err(|_| {
				CustomError::TaskNotFound
			})?;

		sqlx::query("UPDATE task SET task=$1 WHERE id=$2")
			.bind(&task.task)
			.bind(id)
			.execute(&pool)
			.await;

		Ok((StatusCode::OK, Json(task)))
	}

pub async fn delete_task(Path(id): Path<i32>, Extension(pool):Extension<PgPool>)
	-> Result<(StatusCode, Json<Value>), CustomError> {

		let _find: task::Task = sqlx::query_as("SELECT * FROM task where id=$1")
			.bind(id)
			.fetch_one(&pool)
			.await
			.map_err(|_| {
				CustomError::TaskNotFound
			})?;

		sqlx::query("DELETE FROM task WHERE id=$1")
			.bind(id)
			.execute(&pool)
			.await
			.map_err(|_|{
				CustomError::TaskNotFound
			})?;

			Ok((StatusCode::OK, Json(json!({"msg": "TaskDeleted"}))))
	}

main.rs

use axum::{
	extract::{Extension},
	routing::{get, post, put, delete},
	Router,
};

use sqlx::postgres::PgPoolOptions;
use std::net::SocketAddr;
use std::fs;
use anyhow::Context
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SuscriberInitExt};

#[tokio::main]
async fn main() -> anyhow::Result<()> {

	let env = fs::read_to_string(".env").unwrap();
	let (key, database_url) = env.split_once('=').unwrap();

	assert_eq!(key, "DATABASE_URL");

	tracing_subscriber::registery()
		.with(tracing_subscriber::EnvFilter::new(
			std::env::var("tower_http=trace")
				.unwrap_or_else(|_| "example_tracing_aka_logging=debug,tower_http=debug".into()),
				))
				.with(tracing_subscriber::fmt::layer())
				.init();

	let pool = PgPoolOptions::new()
	.max_connections(50)
	.connect(&dtabase_url)
	.await
	.context("could not connect to database_url")?;

	let app = Router::new()
		.route("/hello", get(root))
		.route("/tasks", get(controllers::task::all_tasks))
		.route("/task", post(controllers::task::new_task))
		.route("/task/:id", get(controllers::task::task))
		.route("/task/:id", put(controllers::task::update_task))
		.route("/task/:id", delete(controllers::task::delete_task))
		.layer(Extension(pool))
		.layer(TraceLayer::new_for_http());

	let addr = SocketAddr::from(([192,168,56,10], 8000));
	println!("listening on {}", addr);
	axum::Server::bind(&addr)
		.serve(app.into_make_service())
		.await?;

		Ok(())
}

async fn root() -> &'static str {
	"Hello, World"
}