一、项目架构
整体结构:
rs-counter-study
——|migrations
————|xxxx_create_users.down.sql
————|xxxx_create_users.up.sql
——|src
————|api
——————|jwt.rs
——————|user.rs
————|api.rs
————|main.rs
————|db.rs
——|.env
——|Cargo.lock
——|Cargo.toml
——|data.db
二、核心逻辑
user.rs 用户登陆逻辑
use axum::extract::State;
use axum::Json;
use jsonwebtoken::{encode, EncodingKey, Header};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use crate::api::{ApiError, AuthError};
use crate::api::jwt::Claims;
use crate::db::User;
#[derive(Deserialize)]
pub struct LoginPayload{
code:String
}
// 通过token进行用户的认证
#[derive(Deserialize,Serialize)]
pub struct AuthBody{
access_token:String,
token_type:String,
}
impl AuthBody{
pub fn new(access_token:String)->Self{
Self{
access_token,
token_type:"Bearer".to_string(),
}
}
}
// 微信小程序,通过客户端的code来获取openid和Session
#[derive(Deserialize,Default)]
struct WxUser{
pub openid:String,
pub session_key:String,
}
//wx接口 获取tocken
// todo
pub async fn wx_login(code:String)->Result<WxUser,ApiError>{
Ok(WxUser::default())
}
///登陆API
pub async fn login(State(pool): State<Pool<Sqlite>>, Json(payload): Json<LoginPayload>) -> Result<Json<(AuthBody)>, ApiError>{
// 1. 通过微信code获取用户信息
let wx_user = wx_login(payload.code).await?;
// 2. 通过openid查询数据库是否存在该用户
let user_result = sqlx::query_as::<_,User>("select * from users where openid = ?")
.bind(&wx_user.openid)
.fetch_one(&pool)
.await;
// 3. 用户不存在时创建新用户
let user = match user_result {
Ok(user_result)=>user_result,
Err(sqlx::Error::RowNotFound)=>{ //否则进行新建用户
sqlx::query("insert into users (openid,session_key) values (?,?)")
.bind(&wx_user.openid)
.bind(&wx_user.session_key)
.execute(&pool)
.await?;
// 重新查询获取完整用户信息,返回用户数据
sqlx::query_as::<_,User>("SELECT * FROM users WHERE openid = ?")
.bind(&wx_user.openid)
.fetch_one(&pool)
.await?
},
Err(e)=>return Err(ApiError::from(e)), //兜底其余异常处理
};
//用户如果存在,则进行查询生成小程序这边的token进行后续的访问
let claims = Claims::new(user.id.to_string());
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"secret"),
).map_err(|_| AuthError::TokenCreation)?;
Ok(Json(AuthBody::new(token))) //返回结果Result<Json<(AuthBody)>
}
jwt.rs
use std::time::{Duration, SystemTime};
use serde::{Deserialize, Serialize};
#[derive(Deserialize,Serialize)]
pub struct Claims{
sub:String,
exp:usize,
}
impl Claims{
pub fn new(sub:String)->Claims{
let exp = SystemTime::now()+ Duration::from_secs(15*24*3600);
let exp = exp.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() as usize;
Claims{sub,exp}
}
}
main.rs
use axum::{
http::StatusCode,
routing::{get, post},
Json, Router,
};
use dotenvy::dotenv;
use serde::{Deserialize, Serialize};
use tower_http::trace::TraceLayer;
use tracing::info;
mod db;
mod api;
///这段代码是用 Rust 编写的基于 Axum 框架的 Web 服务器,主要演示了异步处理、路由定义和 JSON 序列化。
#[tokio::main]
async fn main() {
dotenv().ok();
// initialize tracing
tracing_subscriber::fmt::init();
// 获取数据链接-->传入app方便后续function的调用
let pool = db::establish_connection().await;
// build our application with a route
let app = Router::new()
// `GET /` goes to `root`
.route("/", get(root))
.route("/login", post(api::user::login))
.layer(TraceLayer::new_for_http()) //增加每次请求日志的打印
.with_state(pool);
// run our app with hyper, listening globally on port 3000
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap();
// 打印日志 RUST_LOG=debug cargo run
info!("server listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
// basic handler that responds with a static string
async fn root() -> &'static str {
"Hello, World!"
}
api.rs
use axum::http::StatusCode;
use axum::response::IntoResponse;
pub mod user;
pub mod jwt;
pub enum AuthError{
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}
pub enum ApiError {
InternalServerError(anyhow::Error),
Auth(AuthError),
}
// 通用错误转换(如需要)
impl<E> From<E> for ApiError
where
E: Into<anyhow::Error>,
{
fn from(e: E) -> Self {
ApiError::InternalServerError(e.into())
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
impl From<AuthError> for ApiError {
fn from(e:AuthError) -> Self {
ApiError::Auth(e)
}
}
db.rs
use std::env;
use sqlx::{Pool, Sqlite, SqlitePool};
use time::PrimitiveDateTime;
pub async fn establish_connection() ->Pool<Sqlite>{
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = SqlitePool::connect(&database_url)
.await
.expect("can't connect to database");
pool
}
#[derive(sqlx::FromRow)]
pub struct User{
pub id:i32,
pub openid:String,
pub session_id:i32,
pub created_at:PrimitiveDateTime,
pub updated_at:PrimitiveDateTime,
}
三、环境配置
.env
# Postgres
DATABASE_URL = sqlite:data.db
RUST_LOG = debug
Cargo.toml
[package]
name = "rs-counter-study"
version = "0.1.0"
edition = "2024"
[dependencies]
axum = "0.7.9"
tokio = { version = "1.0", features = ["full"]}
tracing = "0.1"
tracing-subscriber = { version = "0.3",features = ["env-filter"]}
serde = { version = "1.0.219", features = ["derive"] }
tower-http = { version = "0.5.0",features = ["trace","request-id","util"]}
tower = "0.4.13"
sqlx = { version = "0.7", features = ["runtime-tokio","tls-rustls","sqlite","time"]}
dotenvy = "0.15"
time = { version = "0.3.31",features = ["serde-human-readable"]}
anyhow = "1.0"
jsonwebtoken = "9"
serde_json = "1.0.140"
create_users.down.sql
-- Add up migration script here
create table users(
id interge primary key not null,
openid text not null,
session_key text not null,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_DATE,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_DATE
);
create unique index users_openid_index on users(openid)
create_users.up.sql
-- sqlx migrate add -r create_users
-- sqlx migrate run
-- sqlx migrate remove
-- Add up migration script here
create table users(
id interge primary key not null,
openid text not null,
session_key text not null,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_DATE,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_DATE
);
create unique index users_openid_index on users(openid)