feat/web: Patchset 3 (#455)
Some checks failed
EasyTier Core / pre_job (push) Has been cancelled
EasyTier GUI / pre_job (push) Has been cancelled
EasyTier Mobile / pre_job (push) Has been cancelled
EasyTier Test / pre_job (push) Has been cancelled
EasyTier Core / build (freebsd-13.2-x86_64, 13.2, ubuntu-22.04, x86_64-unknown-freebsd) (push) Has been cancelled
EasyTier Core / build (linux-aarch64, ubuntu-22.04, aarch64-unknown-linux-musl) (push) Has been cancelled
EasyTier Core / build (linux-arm, ubuntu-22.04, arm-unknown-linux-musleabi) (push) Has been cancelled
EasyTier Core / build (linux-armhf, ubuntu-22.04, arm-unknown-linux-musleabihf) (push) Has been cancelled
EasyTier Core / build (linux-armv7, ubuntu-22.04, armv7-unknown-linux-musleabi) (push) Has been cancelled
EasyTier Core / build (linux-armv7hf, ubuntu-22.04, armv7-unknown-linux-musleabihf) (push) Has been cancelled
EasyTier Core / build (linux-mips, ubuntu-22.04, mips-unknown-linux-musl) (push) Has been cancelled
EasyTier Core / build (linux-mipsel, ubuntu-22.04, mipsel-unknown-linux-musl) (push) Has been cancelled
EasyTier Core / build (linux-x86_64, ubuntu-22.04, x86_64-unknown-linux-musl) (push) Has been cancelled
EasyTier Core / build (macos-aarch64, macos-latest, aarch64-apple-darwin) (push) Has been cancelled
EasyTier Core / build (macos-x86_64, macos-latest, x86_64-apple-darwin) (push) Has been cancelled
EasyTier Core / build (windows-x86_64, windows-latest, x86_64-pc-windows-msvc) (push) Has been cancelled
EasyTier Core / core-result (push) Has been cancelled
EasyTier GUI / build-gui (linux-aarch64, aarch64-unknown-linux-gnu, ubuntu-22.04, aarch64-unknown-linux-musl) (push) Has been cancelled
EasyTier GUI / build-gui (linux-x86_64, x86_64-unknown-linux-gnu, ubuntu-22.04, x86_64-unknown-linux-musl) (push) Has been cancelled
EasyTier GUI / build-gui (macos-aarch64, aarch64-apple-darwin, macos-latest, aarch64-apple-darwin) (push) Has been cancelled
EasyTier GUI / build-gui (macos-x86_64, x86_64-apple-darwin, macos-latest, x86_64-apple-darwin) (push) Has been cancelled
EasyTier GUI / build-gui (windows-x86_64, x86_64-pc-windows-msvc, windows-latest, x86_64-pc-windows-msvc) (push) Has been cancelled
EasyTier GUI / gui-result (push) Has been cancelled
EasyTier Mobile / build-mobile (android, ubuntu-22.04, android) (push) Has been cancelled
EasyTier Mobile / mobile-result (push) Has been cancelled
EasyTier Test / test (push) Has been cancelled

https://apifox.com/apidoc/shared-ceda7a60-e817-4ea8-827b-de4e874dc45e

implement all backend API
This commit is contained in:
Sijie.Sun
2024-11-02 15:13:19 +08:00
committed by GitHub
parent 18da94bf33
commit 8aca5851f2
41 changed files with 4621 additions and 217 deletions

1445
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,30 @@ tokio = { version = "1", features = ["full"] }
dashmap = "6.1" dashmap = "6.1"
url = "2.2" url = "2.2"
async-trait = "0.1" async-trait = "0.1"
axum = { version = "0.7", features = ["macros"] } axum = { version = "0.7", features = ["macros"] }
axum-login = { version = "0.16" }
password-auth = { version = "1.0.0" }
axum-messages = "0.7.0"
tower-sessions-sqlx-store = { version = "0.14.1", features = ["sqlite"] }
tower-sessions = { version = "0.13.0", default-features = false, features = [
"signed",
] }
tower-http = { version = "0.6", features = ["cors"] }
sqlx = { version = "0.8", features = ["sqlite"] }
sea-orm = { version = "1.1", features = [ "sqlx-sqlite", "runtime-tokio-rustls", "macros" ] }
sea-orm-migration = { version = "1.1" }
# for captcha
rust-embed = { version = "8.5.0", features = ["debug-embed"] }
base64 = "0.22"
rand = "0.8"
image = {version="0.24", default-features = false, features = ["png"]}
rusttype = "0.9.3"
imageproc = "0.23.0"
clap = { version = "4.4.8", features = [ clap = { version = "4.4.8", features = [
"string", "string",
"unicode", "unicode",

View File

@@ -0,0 +1,85 @@
-- # Entity schema.
-- Create `users` table.
create table if not exists users (
id integer primary key autoincrement,
username text not null unique,
password text not null
);
-- Create `groups` table.
create table if not exists groups (
id integer primary key autoincrement,
name text not null unique
);
-- Create `permissions` table.
create table if not exists permissions (
id integer primary key autoincrement,
name text not null unique
);
-- # Join tables.
-- Create `users_groups` table for many-to-many relationships between users and groups.
create table if not exists users_groups (
user_id integer references users(id),
group_id integer references groups(id),
primary key (user_id, group_id)
);
-- Create `groups_permissions` table for many-to-many relationships between groups and permissions.
create table if not exists groups_permissions (
group_id integer references groups(id),
permission_id integer references permissions(id),
primary key (group_id, permission_id)
);
-- # Fixture hydration.
-- Insert "user" user. password: "user"
insert into users (username, password)
values (
'user',
'$argon2i$v=19$m=16,t=2,p=1$dHJ5dXZkYmZkYXM$UkrNqWz0BbSVBq4ykLSuJw'
);
-- Insert "admin" user. password: "admin"
insert into users (username, password)
values (
'admin',
'$argon2i$v=19$m=16,t=2,p=1$Ymd1Y2FlcnQ$x0q4oZinW9S1ZB9BcaHEpQ'
);
-- Insert "users" and "superusers" groups.
insert into groups (name) values ('users');
insert into groups (name) values ('superusers');
-- Insert individual permissions.
insert into permissions (name) values ('sessions');
insert into permissions (name) values ('devices');
-- Insert group permissions.
insert into groups_permissions (group_id, permission_id)
values (
(select id from groups where name = 'users'),
(select id from permissions where name = 'devices')
), (
(select id from groups where name = 'superusers'),
(select id from permissions where name = 'sessions')
);
-- Insert users into groups.
insert into users_groups (user_id, group_id)
values (
(select id from users where username = 'user'),
(select id from groups where name = 'users')
), (
(select id from users where username = 'admin'),
(select id from groups where name = 'users')
), (
(select id from users where username = 'admin'),
(select id from groups where name = 'superusers')
);

Binary file not shown.

View File

@@ -4,10 +4,14 @@ pub mod storage;
use std::sync::Arc; use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use easytier::{common::scoped_task::ScopedTask, tunnel::TunnelListener}; use easytier::{
common::scoped_task::ScopedTask, proto::web::HeartbeatRequest, tunnel::TunnelListener,
};
use session::Session; use session::Session;
use storage::{Storage, StorageToken}; use storage::{Storage, StorageToken};
use crate::db::Db;
#[derive(Debug)] #[derive(Debug)]
pub struct ClientManager { pub struct ClientManager {
accept_task: Option<ScopedTask<()>>, accept_task: Option<ScopedTask<()>>,
@@ -18,13 +22,13 @@ pub struct ClientManager {
} }
impl ClientManager { impl ClientManager {
pub fn new() -> Self { pub fn new(db: Db) -> Self {
ClientManager { ClientManager {
accept_task: None, accept_task: None,
clear_task: None, clear_task: None,
client_sessions: Arc::new(DashMap::new()), client_sessions: Arc::new(DashMap::new()),
storage: Storage::new(), storage: Storage::new(db),
} }
} }
@@ -41,7 +45,8 @@ impl ClientManager {
let info = tunnel.info().unwrap(); let info = tunnel.info().unwrap();
let client_url: url::Url = info.remote_addr.unwrap().into(); let client_url: url::Url = info.remote_addr.unwrap().into();
println!("New session from {:?}", tunnel.info()); println!("New session from {:?}", tunnel.info());
let session = Session::new(tunnel, storage.clone(), client_url.clone()); let mut session = Session::new(storage.clone(), client_url.clone());
session.serve(tunnel).await;
sessions.insert(client_url, Arc::new(session)); sessions.insert(client_url, Arc::new(session));
} }
}); });
@@ -87,6 +92,19 @@ impl ClientManager {
.get(&c_url) .get(&c_url)
.map(|item| item.value().clone()) .map(|item| item.value().clone())
} }
pub fn list_machine_by_token(&self, token: String) -> Vec<url::Url> {
self.storage.list_token_clients(&token)
}
pub async fn get_heartbeat_requests(&self, client_url: &url::Url) -> Option<HeartbeatRequest> {
let s = self.client_sessions.get(client_url)?.clone();
s.data().read().await.req()
}
pub fn db(&self) -> &Db {
self.storage.db()
}
} }
#[cfg(test)] #[cfg(test)]
@@ -101,12 +119,12 @@ mod tests {
web_client::WebClient, web_client::WebClient,
}; };
use crate::client_manager::ClientManager; use crate::{client_manager::ClientManager, db::Db};
#[tokio::test] #[tokio::test]
async fn test_client() { async fn test_client() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap()); let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap());
let mut mgr = ClientManager::new(); let mut mgr = ClientManager::new(Db::memory_db().await);
mgr.serve(Box::new(listener)).await.unwrap(); mgr.serve(Box::new(listener)).await.unwrap();
let connector = UdpTunnelConnector::new("udp://127.0.0.1:54333".parse().unwrap()); let connector = UdpTunnelConnector::new("udp://127.0.0.1:54333".parse().unwrap());

View File

@@ -1,12 +1,13 @@
use std::{fmt::Debug, sync::Arc}; use std::{fmt::Debug, sync::Arc};
use easytier::{ use easytier::{
common::scoped_task::ScopedTask,
proto::{ proto::{
rpc_impl::bidirect::BidirectRpcManager, rpc_impl::bidirect::BidirectRpcManager,
rpc_types::{self, controller::BaseController}, rpc_types::{self, controller::BaseController},
web::{ web::{
HeartbeatRequest, HeartbeatResponse, WebClientService, WebClientServiceClientFactory, HeartbeatRequest, HeartbeatResponse, RunNetworkInstanceRequest, WebClientService,
WebServerService, WebServerServiceServer, WebClientServiceClientFactory, WebServerService, WebServerServiceServer,
}, },
}, },
tunnel::Tunnel, tunnel::Tunnel,
@@ -98,6 +99,8 @@ pub struct Session {
rpc_mgr: BidirectRpcManager, rpc_mgr: BidirectRpcManager,
data: SharedSessionData, data: SharedSessionData,
run_network_on_start_task: Option<ScopedTask<()>>,
} }
impl Debug for Session { impl Debug for Session {
@@ -106,20 +109,122 @@ impl Debug for Session {
} }
} }
type SessionRpcClient = Box<dyn WebClientService<Controller = BaseController> + Send>;
impl Session { impl Session {
pub fn new(tunnel: Box<dyn Tunnel>, storage: WeakRefStorage, client_url: url::Url) -> Self { pub fn new(storage: WeakRefStorage, client_url: url::Url) -> Self {
let session_data = SessionData::new(storage, client_url);
let data = Arc::new(RwLock::new(session_data));
let rpc_mgr = let rpc_mgr =
BidirectRpcManager::new().set_rx_timeout(Some(std::time::Duration::from_secs(30))); BidirectRpcManager::new().set_rx_timeout(Some(std::time::Duration::from_secs(30)));
rpc_mgr.run_with_tunnel(tunnel);
let data = Arc::new(RwLock::new(SessionData::new(storage, client_url)));
rpc_mgr.rpc_server().registry().register( rpc_mgr.rpc_server().registry().register(
WebServerServiceServer::new(SessionRpcService { data: data.clone() }), WebServerServiceServer::new(SessionRpcService { data: data.clone() }),
"", "",
); );
Session { rpc_mgr, data } Session {
rpc_mgr,
data,
run_network_on_start_task: None,
}
}
pub async fn serve(&mut self, tunnel: Box<dyn Tunnel>) {
self.rpc_mgr.run_with_tunnel(tunnel);
let data = self.data.read().await;
self.run_network_on_start_task.replace(
tokio::spawn(Self::run_network_on_start(
data.heartbeat_waiter(),
data.storage.clone(),
self.scoped_rpc_client(),
))
.into(),
);
}
async fn run_network_on_start(
mut heartbeat_waiter: broadcast::Receiver<HeartbeatRequest>,
storage: WeakRefStorage,
rpc_client: SessionRpcClient,
) {
loop {
heartbeat_waiter = heartbeat_waiter.resubscribe();
let req = heartbeat_waiter.recv().await;
if req.is_err() {
tracing::error!(
"Failed to receive heartbeat request, error: {:?}",
req.err()
);
return;
}
let req = req.unwrap();
let running_inst_ids = req
.running_network_instances
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>();
let Some(storage) = storage.upgrade() else {
tracing::error!("Failed to get storage");
return;
};
let user_id = match storage
.db
.get_user_id_by_token(req.user_token.clone())
.await
{
Ok(Some(user_id)) => user_id,
Ok(None) => {
tracing::info!("User not found by token: {:?}", req.user_token);
return;
}
Err(e) => {
tracing::error!("Failed to get user id by token, error: {:?}", e);
return;
}
};
let local_configs = match storage.db.list_network_configs(user_id, true).await {
Ok(configs) => configs,
Err(e) => {
tracing::error!("Failed to list network configs, error: {:?}", e);
return;
}
};
let mut has_failed = false;
for c in local_configs {
if running_inst_ids.contains(&c.network_instance_id) {
continue;
}
let ret = rpc_client
.run_network_instance(
BaseController::default(),
RunNetworkInstanceRequest {
inst_id: Some(c.network_instance_id.clone().into()),
config: c.network_config,
},
)
.await;
tracing::info!(
?user_id,
"Run network instance: {:?}, user_token: {:?}",
ret,
req.user_token
);
has_failed |= ret.is_err();
}
if !has_failed {
tracing::info!(?req, "All network instances are running");
break;
}
}
} }
pub fn is_running(&self) -> bool { pub fn is_running(&self) -> bool {
@@ -130,9 +235,7 @@ impl Session {
self.data.clone() self.data.clone()
} }
pub fn scoped_rpc_client( pub fn scoped_rpc_client(&self) -> SessionRpcClient {
&self,
) -> Box<dyn WebClientService<Controller = BaseController> + Send> {
self.rpc_mgr self.rpc_mgr
.rpc_client() .rpc_client()
.scoped_client::<WebClientServiceClientFactory<BaseController>>(1, 1, "".to_string()) .scoped_client::<WebClientServiceClientFactory<BaseController>>(1, 1, "".to_string())
@@ -141,4 +244,8 @@ impl Session {
pub async fn get_token(&self) -> Option<StorageToken> { pub async fn get_token(&self) -> Option<StorageToken> {
self.data.read().await.storage_token.clone() self.data.read().await.storage_token.clone()
} }
pub async fn get_heartbeat_req(&self) -> Option<HeartbeatRequest> {
self.data.read().await.req()
}
} }

View File

@@ -2,6 +2,8 @@ use std::sync::{Arc, Weak};
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use crate::db::Db;
// use this to maintain Storage // use this to maintain Storage
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct StorageToken { pub struct StorageToken {
@@ -15,6 +17,7 @@ pub struct StorageInner {
// some map for indexing // some map for indexing
pub token_clients_map: DashMap<String, DashSet<url::Url>>, pub token_clients_map: DashMap<String, DashSet<url::Url>>,
pub machine_client_url_map: DashMap<uuid::Uuid, url::Url>, pub machine_client_url_map: DashMap<uuid::Uuid, url::Url>,
pub db: Db,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -30,10 +33,11 @@ impl TryFrom<WeakRefStorage> for Storage {
} }
impl Storage { impl Storage {
pub fn new() -> Self { pub fn new(db: Db) -> Self {
Storage(Arc::new(StorageInner { Storage(Arc::new(StorageInner {
token_clients_map: DashMap::new(), token_clients_map: DashMap::new(),
machine_client_url_map: DashMap::new(), machine_client_url_map: DashMap::new(),
db,
})) }))
} }
@@ -69,4 +73,16 @@ impl Storage {
.get(&machine_id) .get(&machine_id)
.map(|url| url.clone()) .map(|url| url.clone())
} }
pub fn list_token_clients(&self, token: &str) -> Vec<url::Url> {
self.0
.token_clients_map
.get(token)
.map(|set| set.iter().map(|url| url.clone()).collect())
.unwrap_or_default()
}
pub fn db(&self) -> &Db {
&self.0.db
}
} }

View File

@@ -0,0 +1,35 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "groups")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
#[sea_orm(unique)]
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::groups_permissions::Entity")]
GroupsPermissions,
#[sea_orm(has_many = "super::users_groups::Entity")]
UsersGroups,
}
impl Related<super::groups_permissions::Entity> for Entity {
fn to() -> RelationDef {
Relation::GroupsPermissions.def()
}
}
impl Related<super::users_groups::Entity> for Entity {
fn to() -> RelationDef {
Relation::UsersGroups.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,47 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "groups_permissions")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub group_id: i32,
pub permission_id: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::groups::Entity",
from = "Column::GroupId",
to = "super::groups::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
Groups,
#[sea_orm(
belongs_to = "super::permissions::Entity",
from = "Column::PermissionId",
to = "super::permissions::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
Permissions,
}
impl Related<super::groups::Entity> for Entity {
fn to() -> RelationDef {
Relation::Groups.def()
}
}
impl Related<super::permissions::Entity> for Entity {
fn to() -> RelationDef {
Relation::Permissions.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,11 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
pub mod prelude;
pub mod groups;
pub mod groups_permissions;
pub mod permissions;
pub mod tower_sessions;
pub mod user_running_network_configs;
pub mod users;
pub mod users_groups;

View File

@@ -0,0 +1,27 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "permissions")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
#[sea_orm(unique)]
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::groups_permissions::Entity")]
GroupsPermissions,
}
impl Related<super::groups_permissions::Entity> for Entity {
fn to() -> RelationDef {
Relation::GroupsPermissions.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,9 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
pub use super::groups::Entity as Groups;
pub use super::groups_permissions::Entity as GroupsPermissions;
pub use super::permissions::Entity as Permissions;
pub use super::tower_sessions::Entity as TowerSessions;
pub use super::user_running_network_configs::Entity as UserRunningNetworkConfigs;
pub use super::users::Entity as Users;
pub use super::users_groups::Entity as UsersGroups;

View File

@@ -0,0 +1,19 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "tower_sessions")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false, column_type = "Text")]
pub id: String,
#[sea_orm(column_type = "Blob")]
pub data: Vec<u8>,
pub expiry_date: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,39 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "user_running_network_configs")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub user_id: i32,
#[sea_orm(column_type = "Text", unique)]
pub network_instance_id: String,
#[sea_orm(column_type = "Text")]
pub network_config: String,
pub disabled: bool,
pub create_time: DateTimeWithTimeZone,
pub update_time: DateTimeWithTimeZone,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,36 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "users")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
#[sea_orm(unique)]
pub username: String,
pub password: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::user_running_network_configs::Entity")]
UserRunningNetworkConfigs,
#[sea_orm(has_many = "super::users_groups::Entity")]
UsersGroups,
}
impl Related<super::user_running_network_configs::Entity> for Entity {
fn to() -> RelationDef {
Relation::UserRunningNetworkConfigs.def()
}
}
impl Related<super::users_groups::Entity> for Entity {
fn to() -> RelationDef {
Relation::UsersGroups.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,47 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "users_groups")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub user_id: i32,
pub group_id: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::groups::Entity",
from = "Column::GroupId",
to = "super::groups::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
Groups,
#[sea_orm(
belongs_to = "super::users::Entity",
from = "Column::UserId",
to = "super::users::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
Users,
}
impl Related<super::groups::Entity> for Entity {
fn to() -> RelationDef {
Relation::Groups.def()
}
}
impl Related<super::users::Entity> for Entity {
fn to() -> RelationDef {
Relation::Users.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

215
easytier-web/src/db/mod.rs Normal file
View File

@@ -0,0 +1,215 @@
// sea-orm-cli generate entity -u sqlite:./et.db -o easytier-web/src/db/entity/ --with-serde both --with-copy-enums
#[allow(unused_imports)]
pub mod entity;
use entity::user_running_network_configs;
use sea_orm::{
sea_query::OnConflict, ColumnTrait as _, DatabaseConnection, DbErr, EntityTrait as _,
QueryFilter as _, SqlxSqliteConnector, TransactionTrait as _,
};
use sea_orm_migration::MigratorTrait as _;
use sqlx::{migrate::MigrateDatabase as _, types::chrono, Sqlite, SqlitePool};
use crate::migrator;
type UserIdInDb = i32;
#[derive(Debug, Clone)]
pub struct Db {
db_path: String,
db: SqlitePool,
orm_db: DatabaseConnection,
}
impl Db {
pub async fn new<T: ToString>(db_path: T) -> anyhow::Result<Self> {
let db = Self::prepare_db(db_path.to_string().as_str()).await?;
let orm_db = SqlxSqliteConnector::from_sqlx_sqlite_pool(db.clone());
migrator::Migrator::up(&orm_db, None).await?;
Ok(Self {
db_path: db_path.to_string(),
db,
orm_db,
})
}
pub async fn memory_db() -> Self {
Self::new(":memory:").await.unwrap()
}
#[tracing::instrument(ret)]
async fn prepare_db(db_path: &str) -> anyhow::Result<SqlitePool> {
if !Sqlite::database_exists(db_path).await.unwrap_or(false) {
tracing::info!("Database not found, creating a new one");
Sqlite::create_database(db_path).await?;
}
let db = sqlx::pool::PoolOptions::new()
.max_lifetime(None)
.idle_timeout(None)
.connect(db_path)
.await?;
Ok(db)
}
pub fn inner(&self) -> SqlitePool {
self.db.clone()
}
pub fn orm_db(&self) -> &DatabaseConnection {
&self.orm_db
}
pub async fn insert_or_update_user_network_config<T: ToString>(
&self,
user_id: UserIdInDb,
network_inst_id: uuid::Uuid,
network_config: T,
) -> Result<(), DbErr> {
let txn = self.orm_db().begin().await?;
use entity::user_running_network_configs as urnc;
let on_conflict = OnConflict::column(urnc::Column::NetworkInstanceId)
.update_columns([
urnc::Column::NetworkConfig,
urnc::Column::Disabled,
urnc::Column::UpdateTime,
])
.to_owned();
let insert_m = urnc::ActiveModel {
user_id: sea_orm::Set(user_id),
network_instance_id: sea_orm::Set(network_inst_id.to_string()),
network_config: sea_orm::Set(network_config.to_string()),
disabled: sea_orm::Set(false),
create_time: sea_orm::Set(chrono::Local::now().fixed_offset()),
update_time: sea_orm::Set(chrono::Local::now().fixed_offset()),
..Default::default()
};
urnc::Entity::insert(insert_m)
.on_conflict(on_conflict)
.do_nothing()
.exec(&txn)
.await?;
txn.commit().await
}
pub async fn delete_network_config(
&self,
user_id: UserIdInDb,
network_inst_id: uuid::Uuid,
) -> Result<(), DbErr> {
use entity::user_running_network_configs as urnc;
urnc::Entity::delete_many()
.filter(urnc::Column::UserId.eq(user_id))
.filter(urnc::Column::NetworkInstanceId.eq(network_inst_id.to_string()))
.exec(self.orm_db())
.await?;
Ok(())
}
pub async fn list_network_configs(
&self,
user_id: UserIdInDb,
only_enabled: bool,
) -> Result<Vec<user_running_network_configs::Model>, DbErr> {
use entity::user_running_network_configs as urnc;
let configs = urnc::Entity::find().filter(urnc::Column::UserId.eq(user_id));
let configs = if only_enabled {
configs.filter(urnc::Column::Disabled.eq(false))
} else {
configs
};
let configs = configs.all(self.orm_db()).await?;
Ok(configs)
}
pub async fn get_user_id<T: ToString>(
&self,
user_name: T,
) -> Result<Option<UserIdInDb>, DbErr> {
use entity::users as u;
let user = u::Entity::find()
.filter(u::Column::Username.eq(user_name.to_string()))
.one(self.orm_db())
.await?;
Ok(user.map(|u| u.id))
}
// TODO: currently we don't have a token system, so we just use the user name as token
pub async fn get_user_id_by_token<T: ToString>(
&self,
token: T,
) -> Result<Option<UserIdInDb>, DbErr> {
self.get_user_id(token).await
}
}
#[cfg(test)]
mod tests {
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _};
use crate::db::{entity::user_running_network_configs, Db};
#[tokio::test]
async fn test_user_network_config_management() {
let db = Db::memory_db().await;
let user_id = 1;
let network_config = "test_config";
let inst_id = uuid::Uuid::new_v4();
db.insert_or_update_user_network_config(user_id, inst_id, network_config)
.await
.unwrap();
let result = user_running_network_configs::Entity::find()
.filter(user_running_network_configs::Column::UserId.eq(user_id))
.one(db.orm_db())
.await
.unwrap()
.unwrap();
println!("{:?}", result);
assert_eq!(result.network_config, network_config);
// overwrite the config
let network_config = "test_config2";
db.insert_or_update_user_network_config(user_id, inst_id, network_config)
.await
.unwrap();
let result2 = user_running_network_configs::Entity::find()
.filter(user_running_network_configs::Column::UserId.eq(user_id))
.one(db.orm_db())
.await
.unwrap()
.unwrap();
println!("{:?}", result2);
assert_eq!(result2.network_config, network_config);
assert_eq!(result.create_time, result2.create_time);
assert_ne!(result.update_time, result2.update_time);
assert_eq!(
db.list_network_configs(user_id, true).await.unwrap().len(),
1
);
db.delete_network_config(user_id, inst_id).await.unwrap();
let result3 = user_running_network_configs::Entity::find()
.filter(user_running_network_configs::Column::UserId.eq(user_id))
.one(db.orm_db())
.await
.unwrap();
assert!(result3.is_none());
}
}

View File

@@ -2,21 +2,37 @@
use std::sync::Arc; use std::sync::Arc;
use easytier::tunnel::udp::UdpTunnelListener; use easytier::{
common::config::{ConfigLoader, ConsoleLoggerConfig, TomlConfigLoader},
tunnel::udp::UdpTunnelListener,
utils::init_logger,
};
mod client_manager; mod client_manager;
mod db;
mod migrator;
mod restful; mod restful;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let config = TomlConfigLoader::default();
config.set_console_logger_config(ConsoleLoggerConfig {
level: Some("trace".to_string()),
});
init_logger(config, false).unwrap();
// let db = db::Db::new(":memory:").await.unwrap();
let db = db::Db::new("et.db").await.unwrap();
let listener = UdpTunnelListener::new("udp://0.0.0.0:22020".parse().unwrap()); let listener = UdpTunnelListener::new("udp://0.0.0.0:22020".parse().unwrap());
let mut mgr = client_manager::ClientManager::new(); let mut mgr = client_manager::ClientManager::new(db.clone());
mgr.serve(listener).await.unwrap(); mgr.serve(listener).await.unwrap();
let mgr = Arc::new(mgr); let mgr = Arc::new(mgr);
let mut restful_server = let mut restful_server =
restful::RestfulServer::new("0.0.0.0:11211".parse().unwrap(), mgr.clone()); restful::RestfulServer::new("0.0.0.0:11211".parse().unwrap(), mgr.clone(), db)
.await
.unwrap();
restful_server.start().await.unwrap(); restful_server.start().await.unwrap();
tokio::signal::ctrl_c().await.unwrap(); tokio::signal::ctrl_c().await.unwrap();
} }

View File

@@ -0,0 +1,450 @@
// src/migrator/m20220602_000001_create_bakery_table.rs (create new file)
use sea_orm_migration::{prelude::*, schema::*};
pub struct Migration;
/*
-- # Entity schema.
-- Create `users` table.
create table if not exists users (
id integer primary key autoincrement,
username text not null unique,
password text not null
);
-- Create `groups` table.
create table if not exists groups (
id integer primary key autoincrement,
name text not null unique
);
-- Create `permissions` table.
create table if not exists permissions (
id integer primary key autoincrement,
name text not null unique
);
-- # Join tables.
-- Create `users_groups` table for many-to-many relationships between users and groups.
create table if not exists users_groups (
user_id integer references users(id),
group_id integer references groups(id),
primary key (user_id, group_id)
);
-- Create `groups_permissions` table for many-to-many relationships between groups and permissions.
create table if not exists groups_permissions (
group_id integer references groups(id),
permission_id integer references permissions(id),
primary key (group_id, permission_id)
);
-- # Fixture hydration.
-- Insert "user" user. password: "user"
insert into users (username, password)
values (
'user',
'$argon2i$v=19$m=16,t=2,p=1$dHJ5dXZkYmZkYXM$UkrNqWz0BbSVBq4ykLSuJw'
);
-- Insert "admin" user. password: "admin"
insert into users (username, password)
values (
'admin',
'$argon2i$v=19$m=16,t=2,p=1$Ymd1Y2FlcnQ$x0q4oZinW9S1ZB9BcaHEpQ'
);
-- Insert "users" and "superusers" groups.
insert into groups (name) values ('users');
insert into groups (name) values ('superusers');
-- Insert individual permissions.
insert into permissions (name) values ('sessions');
insert into permissions (name) values ('devices');
-- Insert group permissions.
insert into groups_permissions (group_id, permission_id)
values (
(select id from groups where name = 'users'),
(select id from permissions where name = 'devices')
), (
(select id from groups where name = 'superusers'),
(select id from permissions where name = 'sessions')
);
-- Insert users into groups.
insert into users_groups (user_id, group_id)
values (
(select id from users where username = 'user'),
(select id from groups where name = 'users')
), (
(select id from users where username = 'admin'),
(select id from groups where name = 'users')
), (
(select id from users where username = 'admin'),
(select id from groups where name = 'superusers')
);
*/
impl MigrationName for Migration {
fn name(&self) -> &str {
"m20241029_000001_init"
}
}
#[derive(DeriveIden)]
pub enum Users {
Table,
Id,
Username,
Password,
}
#[derive(DeriveIden)]
enum Groups {
Table,
Id,
Name,
}
#[derive(DeriveIden)]
enum Permissions {
Table,
Id,
Name,
}
#[derive(DeriveIden)]
enum UsersGroups {
Table,
Id,
UserId,
GroupId,
}
#[derive(DeriveIden)]
enum GroupsPermissions {
Table,
Id,
GroupId,
PermissionId,
}
#[derive(DeriveIden)]
enum UserRunningNetworkConfigs {
Table,
Id,
UserId,
NetworkInstanceId,
NetworkConfig,
Disabled,
CreateTime,
UpdateTime,
}
#[async_trait::async_trait]
impl MigrationTrait for Migration {
// Define how to apply this migration: Create the Bakery table.
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
// Create the `users` table.
manager
.create_table(
Table::create()
.if_not_exists()
.table(Users::Table)
.col(pk_auto(Users::Id).not_null())
.col(string(Users::Username).not_null().unique_key())
.col(string(Users::Password).not_null())
.to_owned(),
)
.await?;
manager
.create_index(
Index::create()
.name("idx_users_username")
.table(Users::Table)
.col(Users::Username)
.to_owned(),
)
.await?;
// Create the `groups` table.
manager
.create_table(
Table::create()
.if_not_exists()
.table(Groups::Table)
.col(pk_auto(Groups::Id).not_null())
.col(string(Groups::Name).not_null().unique_key())
.to_owned(),
)
.await?;
manager
.create_index(
Index::create()
.name("idx_groups_name")
.table(Groups::Table)
.col(Groups::Name)
.to_owned(),
)
.await?;
// Create the `permissions` table.
manager
.create_table(
Table::create()
.if_not_exists()
.table(Permissions::Table)
.col(pk_auto(Permissions::Id).not_null())
.col(string(Permissions::Name).not_null().unique_key())
.to_owned(),
)
.await?;
// Create the `users_groups` table.
manager
.create_table(
Table::create()
.if_not_exists()
.table(UsersGroups::Table)
.col(pk_auto(UsersGroups::Id).not_null())
.col(integer(UsersGroups::UserId).not_null())
.col(integer(UsersGroups::GroupId).not_null())
.foreign_key(
ForeignKey::create()
.name("fk_users_groups_user_id_to_users_id")
.from(UsersGroups::Table, UsersGroups::UserId)
.to(Users::Table, Users::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.foreign_key(
ForeignKey::create()
.name("fk_users_groups_group_id_to_groups_id")
.from(UsersGroups::Table, UsersGroups::GroupId)
.to(Groups::Table, Groups::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_owned(),
)
.await?;
// Create the `groups_permissions` table.
manager
.create_table(
Table::create()
.if_not_exists()
.table(GroupsPermissions::Table)
.col(pk_auto(GroupsPermissions::Id).not_null())
.col(integer(GroupsPermissions::GroupId).not_null())
.col(integer(GroupsPermissions::PermissionId).not_null())
.foreign_key(
ForeignKey::create()
.name("fk_groups_permissions_group_id_to_groups_id")
.from(GroupsPermissions::Table, GroupsPermissions::GroupId)
.to(Groups::Table, Groups::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.foreign_key(
ForeignKey::create()
.name("fk_groups_permissions_permission_id_to_permissions_id")
.from(GroupsPermissions::Table, GroupsPermissions::PermissionId)
.to(Permissions::Table, Permissions::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_owned(),
)
.await?;
// create user running network configs table
manager
.create_table(
Table::create()
.if_not_exists()
.table(UserRunningNetworkConfigs::Table)
.col(pk_auto(UserRunningNetworkConfigs::Id).not_null())
.col(integer(UserRunningNetworkConfigs::UserId).not_null())
.col(
text(UserRunningNetworkConfigs::NetworkInstanceId)
.unique_key()
.not_null(),
)
.col(text(UserRunningNetworkConfigs::NetworkConfig).not_null())
.col(
boolean(UserRunningNetworkConfigs::Disabled)
.not_null()
.default(false),
)
.col(timestamp_with_time_zone(UserRunningNetworkConfigs::CreateTime).not_null())
.col(timestamp_with_time_zone(UserRunningNetworkConfigs::UpdateTime).not_null())
.foreign_key(
ForeignKey::create()
.name("fk_user_running_network_configs_user_id_to_users_id")
.from(
UserRunningNetworkConfigs::Table,
UserRunningNetworkConfigs::UserId,
)
.to(Users::Table, Users::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_owned(),
)
.await?;
manager
.create_index(
Index::create()
.name("idx_user_running_network_configs_user_id")
.table(UserRunningNetworkConfigs::Table)
.col(UserRunningNetworkConfigs::UserId)
.to_owned(),
)
.await?;
// prepare data
let user = Query::insert()
.into_table(Users::Table)
.columns(vec![Users::Username, Users::Password])
.values_panic(vec![
"user".into(),
"$argon2i$v=19$m=16,t=2,p=1$dHJ5dXZkYmZkYXM$UkrNqWz0BbSVBq4ykLSuJw".into(),
])
.to_owned();
manager.exec_stmt(user).await?;
let admin = Query::insert()
.into_table(Users::Table)
.columns(vec![Users::Username, Users::Password])
.values_panic(vec![
"admin".into(),
"$argon2i$v=19$m=16,t=2,p=1$Ymd1Y2FlcnQ$x0q4oZinW9S1ZB9BcaHEpQ".into(),
])
.to_owned();
manager.exec_stmt(admin).await?;
let users = Query::insert()
.into_table(Groups::Table)
.columns(vec![Groups::Name])
.values_panic(vec!["users".into()])
.to_owned();
manager.exec_stmt(users).await?;
let admins = Query::insert()
.into_table(Groups::Table)
.columns(vec![Groups::Name])
.values_panic(vec!["admins".into()])
.to_owned();
manager.exec_stmt(admins).await?;
let sessions = Query::insert()
.into_table(Permissions::Table)
.columns(vec![Permissions::Name])
.values_panic(vec!["sessions".into()])
.to_owned();
manager.exec_stmt(sessions).await?;
let devices = Query::insert()
.into_table(Permissions::Table)
.columns(vec![Permissions::Name])
.values_panic(vec!["devices".into()])
.to_owned();
manager.exec_stmt(devices).await?;
let users_devices = Query::insert()
.into_table(GroupsPermissions::Table)
.columns(vec![
GroupsPermissions::GroupId,
GroupsPermissions::PermissionId,
])
.select_from(
Query::select()
.column((Groups::Table, Groups::Id))
.column((Permissions::Table, Permissions::Id))
.from(Groups::Table)
.full_outer_join(Permissions::Table, all![])
.cond_where(any![
// users have devices permission
Expr::col((Groups::Table, Groups::Name))
.eq("users")
.and(Expr::col((Permissions::Table, Permissions::Name)).eq("devices")),
// admins have all permissions
Expr::col((Groups::Table, Groups::Name)).eq("admins"),
])
.to_owned(),
)
.unwrap()
.to_owned();
manager.exec_stmt(users_devices).await?;
let add_user_to_users = Query::insert()
.into_table(UsersGroups::Table)
.columns(vec![UsersGroups::UserId, UsersGroups::GroupId])
.select_from(
Query::select()
.column((Users::Table, Users::Id))
.column((Groups::Table, Groups::Id))
.from(Users::Table)
.full_outer_join(Groups::Table, all![])
.cond_where(
Expr::col(Users::Username)
.eq("user")
.and(Expr::col(Groups::Name).eq("users")),
)
.to_owned(),
)
.unwrap()
.to_owned();
manager.exec_stmt(add_user_to_users).await?;
let add_admin_to_admins = Query::insert()
.into_table(UsersGroups::Table)
.columns(vec![UsersGroups::UserId, UsersGroups::GroupId])
.select_from(
Query::select()
.column((Users::Table, Users::Id))
.column((Groups::Table, Groups::Id))
.from(Users::Table)
.full_outer_join(Groups::Table, all![])
.cond_where(
Expr::col(Users::Username)
.eq("admin")
.and(Expr::col(Groups::Name).eq("admins")),
)
.to_owned(),
)
.unwrap()
.to_owned();
manager.exec_stmt(add_admin_to_admins).await?;
Ok(())
}
// Define how to rollback this migration: Drop the Bakery table.
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(Users::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(Groups::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(Permissions::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(UsersGroups::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(GroupsPermissions::Table).to_owned())
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,12 @@
use sea_orm_migration::prelude::*;
mod m20241029_000001_init;
pub struct Migrator;
#[async_trait::async_trait]
impl MigratorTrait for Migrator {
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
vec![Box::new(m20241029_000001_init::Migration)]
}
}

View File

@@ -0,0 +1,171 @@
use axum::{
http::StatusCode,
routing::{get, post, put},
Router,
};
use axum_login::login_required;
use axum_messages::Message;
use serde::{Deserialize, Serialize};
use crate::restful::users::Backend;
use super::{
users::{AuthSession, Credentials},
AppStateInner,
};
#[derive(Debug, Deserialize, Serialize)]
pub struct LoginResult {
messages: Vec<Message>,
}
pub fn router() -> Router<AppStateInner> {
let r = Router::new()
.route("/api/v1/auth/password", put(self::put::change_password))
.route_layer(login_required!(Backend));
Router::new()
.merge(r)
.route("/api/v1/auth/login", post(self::post::login))
.route("/api/v1/auth/logout", get(self::get::logout))
.route("/api/v1/auth/captcha", get(self::get::get_captcha))
.route("/api/v1/auth/register", post(self::post::register))
}
mod put {
use axum::Json;
use axum_login::AuthUser;
use easytier::proto::common::Void;
use crate::restful::{other_error, users::ChangePassword, HttpHandleError};
use super::*;
pub async fn change_password(
mut auth_session: AuthSession,
Json(req): Json<ChangePassword>,
) -> Result<Json<Void>, HttpHandleError> {
if let Err(e) = auth_session
.backend
.change_password(auth_session.user.as_ref().unwrap().id(), &req)
.await
{
tracing::error!("Failed to change password: {:?}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
));
}
let _ = auth_session.logout().await;
Ok(Void::default().into())
}
}
mod post {
use axum::Json;
use easytier::proto::common::Void;
use crate::restful::{
captcha::extension::{axum_tower_sessions::CaptchaAxumTowerSessionStaticExt, CaptchaUtil},
other_error,
users::RegisterNewUser,
HttpHandleError,
};
use super::*;
pub async fn login(
mut auth_session: AuthSession,
Json(creds): Json<Credentials>,
) -> Result<Json<Void>, HttpHandleError> {
let user = match auth_session.authenticate(creds.clone()).await {
Ok(Some(user)) => user,
Ok(None) => {
return Err((
StatusCode::UNAUTHORIZED,
Json::from(other_error("Invalid credentials")),
));
}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
))
}
};
if let Err(e) = auth_session.login(&user).await {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
));
}
Ok(Void::default().into())
}
pub async fn register(
auth_session: AuthSession,
captcha_session: tower_sessions::Session,
Json(req): Json<RegisterNewUser>,
) -> Result<Json<Void>, HttpHandleError> {
// 调用CaptchaUtil的静态方法验证验证码是否正确
if !CaptchaUtil::ver(&req.captcha, &captcha_session).await {
return Err((
StatusCode::BAD_REQUEST,
other_error(format!("captcha verify error, input: {}", req.captcha)).into(),
));
}
if let Err(e) = auth_session.backend.register_new_user(&req).await {
tracing::error!("Failed to register new user: {:?}", e);
return Err((
StatusCode::BAD_REQUEST,
other_error(format!("{:?}", e)).into(),
));
}
Ok(Void::default().into())
}
}
mod get {
use crate::restful::{
captcha::{
captcha::spec::SpecCaptcha,
extension::{axum_tower_sessions::CaptchaAxumTowerSessionExt as _, CaptchaUtil},
NewCaptcha as _,
},
other_error, HttpHandleError,
};
use axum::{response::Response, Json};
use easytier::proto::common::Void;
use tower_sessions::Session;
use super::*;
pub async fn logout(mut auth_session: AuthSession) -> Result<Json<Void>, HttpHandleError> {
match auth_session.logout().await {
Ok(_) => Ok(Json(Void::default())),
Err(e) => {
tracing::error!("Failed to logout: {:?}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
))
}
}
}
pub async fn get_captcha(session: Session) -> Result<Response, HttpHandleError> {
let mut captcha: CaptchaUtil<SpecCaptcha> = CaptchaUtil::with_size_and_len(127, 48, 4);
match captcha.out(&session).await {
Ok(response) => Ok(response),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
)),
}
}
}

View File

@@ -0,0 +1,308 @@
use super::super::base::randoms::Randoms;
use super::super::utils::color::Color;
use super::super::utils::font;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use rusttype::Font;
use std::fmt::Debug;
use std::io::Write;
use std::sync::Arc;
/// 验证码抽象类
pub(crate) struct Captcha {
/// 随机数工具类
pub(crate) randoms: Randoms,
/// 常用颜色
color: Vec<Color>,
/// 字体名称
font_names: [&'static str; 1],
/// 验证码的字体
font_name: String,
/// 验证码的字体大小
font_size: f32,
/// 验证码随机字符长度
pub len: usize,
/// 验证码显示宽度
pub width: i32,
/// 验证码显示高度
pub height: i32,
/// 验证码类型
char_type: CaptchaType,
/// 当前验证码
pub(crate) chars: Option<String>,
}
/// 验证码文本类型 The character type of the captcha
pub enum CaptchaType {
/// 字母数字混合
TypeDefault = 1,
/// 纯数字
TypeOnlyNumber,
/// 纯字母
TypeOnlyChar,
/// 纯大写字母
TypeOnlyUpper,
/// 纯小写字母
TypeOnlyLower,
/// 数字大写字母
TypeNumAndUpper,
}
/// 内置字体 Fonts shipped with the library
pub enum CaptchaFont {
/// actionj
Font1,
/// epilog
Font2,
/// fresnel
Font3,
/// headache
Font4,
/// lexo
Font5,
/// prefix
Font6,
/// progbot
Font7,
/// ransom
Font8,
/// robot
Font9,
/// scandal
Font10,
}
impl Captcha {
/// 生成随机验证码
pub fn alphas(&mut self) -> Vec<char> {
let mut cs = vec!['\0'; self.len];
for i in 0..self.len {
match self.char_type {
CaptchaType::TypeDefault => cs[i] = self.randoms.alpha(),
CaptchaType::TypeOnlyNumber => {
cs[i] = self.randoms.alpha_under(self.randoms.num_max_index)
}
CaptchaType::TypeOnlyChar => {
cs[i] = self
.randoms
.alpha_between(self.randoms.char_min_index, self.randoms.char_max_index)
}
CaptchaType::TypeOnlyUpper => {
cs[i] = self
.randoms
.alpha_between(self.randoms.upper_min_index, self.randoms.upper_max_index)
}
CaptchaType::TypeOnlyLower => {
cs[i] = self
.randoms
.alpha_between(self.randoms.lower_min_index, self.randoms.lower_max_index)
}
CaptchaType::TypeNumAndUpper => {
cs[i] = self.randoms.alpha_under(self.randoms.upper_max_index)
}
}
}
self.chars = Some(cs.iter().collect());
cs
}
/// 获取当前的验证码
pub fn text(&mut self) -> String {
self.check_alpha();
self.chars.clone().unwrap()
}
/// 获取当前验证码的字符数组
pub fn text_char(&mut self) -> Vec<char> {
self.check_alpha();
self.chars.clone().unwrap().chars().collect()
}
/// 检查验证码是否生成,没有则立即生成
pub fn check_alpha(&mut self) {
if self.chars.is_none() {
self.alphas();
}
}
pub fn get_font(&mut self) -> Arc<Font> {
if let Some(font) = font::get_font(&self.font_name) {
font
} else {
font::get_font(self.font_names[0]).unwrap()
}
}
pub fn get_font_size(&mut self) -> f32 {
self.font_size
}
pub fn set_font_by_enum(&mut self, font: CaptchaFont, size: Option<f32>) {
let font_name = self.font_names[font as usize];
self.font_name = font_name.into();
self.font_size = size.unwrap_or(32.);
}
}
/// 初始化验证码的抽象方法 Traits for initialize a Captcha instance.
pub trait NewCaptcha
where
Self: Sized,
{
/// 用默认参数初始化
///
/// Initialize the Captcha with the default properties.
fn new() -> Self;
/// 使用输出图像大小初始化
///
/// Initialize the Captcha with the size of output image.
fn with_size(width: i32, height: i32) -> Self;
/// 使用输出图像大小和验证码字符长度初始化
///
/// Initialize the Captcha with the size of output image and the character length of the Captcha.
///
/// <br/>
///
/// 特别地/In particular:
///
/// - 对算术验证码[ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha)而言,这里的`len`是验证码中数字的数量。
/// For [ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha), the `len` presents the count of the digits
/// in the Captcha.
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self;
/// 使用完整的参数来初始化,包括输出图像大小、验证码字符长度和输出字体及其大小
///
/// Initialize the Captcha with full properties, including the size of output image, the character length of the Captcha,
/// and the font used in Captcha with the font size.
///
/// 关于`len`字段的注意事项,请参见[with_size_and_len](Self::with_size_and_len)中的说明。Refer to the document of
/// [with_size_and_len](Self::with_size_and_len) for the precautions of the `len` property.
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self;
}
impl NewCaptcha for Captcha {
fn new() -> Self {
let color = [
(0, 135, 255),
(51, 153, 51),
(255, 102, 102),
(255, 153, 0),
(153, 102, 0),
(153, 102, 153),
(51, 153, 153),
(102, 102, 255),
(0, 102, 204),
(204, 51, 51),
(0, 153, 204),
(0, 51, 102),
]
.iter()
.map(|v| (*v).into())
.collect();
let font_names = ["robot.ttf"];
let font_name = font_names[0].into();
let font_size = 32.;
let len = 5;
let width = 130;
let height = 48;
let char_type = CaptchaType::TypeDefault;
let chars = None;
Self {
randoms: Randoms::new(),
color,
font_names,
font_name,
font_size,
len,
width,
height,
char_type,
chars,
}
}
fn with_size(width: i32, height: i32) -> Self {
let mut _self = Self::new();
_self.width = width;
_self.height = height;
_self
}
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self {
let mut _self = Self::new();
_self.width = width;
_self.height = height;
_self.len = len;
_self
}
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self {
let mut _self = Self::new();
_self.width = width;
_self.height = height;
_self.len = len;
_self.set_font_by_enum(font, None);
_self.font_size = font_size;
_self
}
}
/// 验证码的抽象方法 Traits which a Captcha must implements.
pub trait AbstractCaptcha: NewCaptcha {
/// 错误类型
type Error: std::error::Error + Debug + Send + Sync + 'static;
/// 输出验证码到指定位置
///
/// Write the Captcha image to the specified place.
fn out(&mut self, out: impl Write) -> Result<(), Self::Error>;
/// 获取验证码中的字符(即正确答案)
///
/// Get the characters (i.e. the correct answer) of the Captcha
fn get_chars(&mut self) -> Vec<char>;
/// 输出Base64编码。注意返回值会带编码头例如`data:image/png;base64,`),可以直接在浏览器中显示;如不需要编码头,
/// 请使用[base64_with_head](Self::base64_with_head)方法并传入空参数以去除编码头。
///
/// Get the Base64 encoded image. Reminds: the returned Base64 strings will begin with an encoding head like
/// `data:image/png;base64,`, which make it possible to display in browsers directly. If you don't need it, you may
/// use [base64_with_head](Self::base64_with_head) and pass a null string.
fn base64(&mut self) -> Result<String, Self::Error>;
/// 获取验证码的MIME类型
///
/// Get the MIME Content type of the Captcha.
fn get_content_type(&mut self) -> String;
/// 输出Base64编码指定编码头
///
/// Get the Base64 encoded image, with specified encoding head.
fn base64_with_head(&mut self, head: &str) -> Result<String, Self::Error> {
let mut output_stream = Vec::new();
self.out(&mut output_stream)?;
Ok(String::from(head) + &BASE64_STANDARD.encode(&output_stream))
}
}

View File

@@ -0,0 +1,4 @@
//! Base traits
pub(crate) mod captcha;
pub(crate) mod randoms;

View File

@@ -0,0 +1,86 @@
use rand::{random};
/// 随机数工具类
pub(crate) struct Randoms {
/// 定义验证码字符.去除了0、O、I、L等容易混淆的字母
pub alpha: [char; 54],
/// 数字的最大索引,不包括最大值
pub num_max_index: usize,
/// 字符的最小索引,包括最小值
pub char_min_index: usize,
/// 字符的最大索引,不包括最大值
pub char_max_index: usize,
/// 大写字符最小索引
pub upper_min_index: usize,
/// 大写字符最大索引
pub upper_max_index: usize,
/// 小写字母最小索引
pub lower_min_index: usize,
/// 小写字母最大索引
pub lower_max_index: usize,
}
impl Randoms {
pub fn new() -> Self {
// Defines the Captcha characters, removing characters like 0, O, I, l, etc.
let alpha = [
'2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
'd', 'e', 'f', 'g', 'h', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
'x', 'y', 'z',
];
let num_max_index = 8;
let char_min_index = num_max_index;
let char_max_index = alpha.len();
let upper_min_index = char_min_index;
let upper_max_index = upper_min_index + 23;
let lower_min_index = upper_max_index;
let lower_max_index = char_max_index;
Self {
alpha,
num_max_index,
char_min_index,
char_max_index,
upper_min_index,
upper_max_index,
lower_min_index,
lower_max_index,
}
}
/// 产生两个数之间的随机数
pub fn num_between(&mut self, min: i32, max: i32) -> i32 {
min + (random::<usize>() % (max - min) as usize) as i32
}
/// 产生0-num的随机数,不包括num
pub fn num(&mut self, num: usize) -> usize {
random::<usize>() % num
}
/// 返回ALPHA中的随机字符
pub fn alpha(&mut self) -> char {
self.alpha[self.num(self.alpha.len())]
}
/// 返回ALPHA中第0位到第num位的随机字符
pub fn alpha_under(&mut self, num: usize) -> char {
self.alpha[self.num(num)]
}
/// 返回ALPHA中第min位到第max位的随机字符
pub fn alpha_between(&mut self, min: usize, max: usize) -> char {
self.alpha[self.num_between(min as i32, max as i32) as usize]
}
}

View File

@@ -0,0 +1 @@
pub mod spec;

View File

@@ -0,0 +1,318 @@
//! Static alphabetical PNG Captcha
//!
//! PNG格式验证码
//!
use super::super::base::captcha::{AbstractCaptcha, Captcha};
use super::super::{CaptchaFont, NewCaptcha};
use image::{ImageBuffer, Rgba};
use imageproc::drawing;
use rand::{rngs::ThreadRng, Rng};
use rusttype::{Font, Scale};
use std::io::{Cursor, Write};
use std::sync::Arc;
mod color {
use image::Rgba;
use rand::{rngs::ThreadRng, Rng};
pub fn gen_background_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(200..=255);
let green = rng.gen_range(200..=255);
let blue = rng.gen_range(200..=255);
//let a=rng.gen_range(0..255);
Rgba([red, green, blue, 255])
}
pub fn gen_text_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(0..=150);
let green = rng.gen_range(0..=150);
let blue = rng.gen_range(0..=150);
Rgba([red, green, blue, 255])
}
pub fn gen_line_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(100..=255);
let green = rng.gen_range(100..=255);
let blue = rng.gen_range(100..=255);
Rgba([red, green, blue, 255])
}
}
///the builder of captcha
pub struct CaptchaBuilder<'a, 'b> {
///captcha image width
pub width: u32,
///captcha image height
pub height: u32,
///random string length.
pub length: u32,
///source is a unicode which is the rand string from.
pub source: String,
///image background color (optional)
pub background_color: Option<Rgba<u8>>,
///fonts collection for text
pub fonts: &'b [Arc<Font<'a>>],
///The maximum number of lines to draw behind of the image
pub max_behind_lines: Option<u32>,
///The maximum number of lines to draw in front of the image
pub max_front_lines: Option<u32>,
///The maximum number of ellipse lines to draw in front of the image
pub max_ellipse_lines: Option<u32>,
}
impl<'a, 'b> Default for CaptchaBuilder<'a, 'b> {
fn default() -> Self {
Self {
width: 150,
height: 40,
length: 5,
source: String::from("1234567890qwertyuioplkjhgfdsazxcvbnm"),
background_color: None,
fonts: &[],
max_behind_lines: None,
max_front_lines: None,
max_ellipse_lines: None,
}
}
}
impl<'a, 'b> CaptchaBuilder<'a, 'b> {
fn write_phrase(
&self,
image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>,
rng: &mut ThreadRng,
phrase: &str,
) {
//println!("phrase={}", phrase);
//println!("width={}, height={}", self.width, self.height);
let font_size = (self.width as f32) / (self.length as f32) - rng.gen_range(1.0..=4.0);
let scale = Scale::uniform(font_size);
if self.fonts.is_empty() {
panic!("no fonts loaded");
}
let font_index = rng.gen_range(0..self.fonts.len());
let font = &self.fonts[font_index];
let glyphs: Vec<_> = font
.layout(phrase, scale, rusttype::point(0.0, 0.0))
.collect();
let text_height = {
let v_metrics = font.v_metrics(scale);
(v_metrics.ascent - v_metrics.descent).ceil() as u32
};
let text_width = {
let min_x = glyphs.first().unwrap().pixel_bounding_box().unwrap().min.x;
let max_x = glyphs.last().unwrap().pixel_bounding_box().unwrap().max.x;
let last_x_pos = glyphs.last().unwrap().position().x as i32;
(max_x + last_x_pos - min_x) as u32
};
let node_width = text_width / self.length;
//println!("text_width={}, text_height={}", text_width, text_height);
let mut x = ((self.width as i32) - (text_width as i32)) / 2;
let y = ((self.height as i32) - (text_height as i32)) / 2;
//
for s in phrase.chars() {
let text_color = color::gen_text_color(rng);
let offset = rng.gen_range(-5..=5);
//println!("x={}, y={}", x, y);
drawing::draw_text_mut(
image,
text_color,
x,
y + offset,
scale,
font,
&s.to_string(),
);
x += node_width as i32;
}
}
fn draw_line(&self, image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>, rng: &mut ThreadRng) {
let line_color = color::gen_line_color(rng);
let is_h = rng.gen();
let (start, end) = if is_h {
let xa = rng.gen_range(0.0..(self.width as f32) / 2.0);
let ya = rng.gen_range(0.0..(self.height as f32));
let xb = rng.gen_range((self.width as f32) / 2.0..(self.width as f32));
let yb = rng.gen_range(0.0..(self.height as f32));
((xa, ya), (xb, yb))
} else {
let xa = rng.gen_range(0.0..(self.width as f32));
let ya = rng.gen_range(0.0..(self.height as f32) / 2.0);
let xb = rng.gen_range(0.0..(self.width as f32));
let yb = rng.gen_range((self.height as f32) / 2.0..(self.height as f32));
((xa, ya), (xb, yb))
};
let thickness = rng.gen_range(2..4);
for i in 0..thickness {
let offset = i as f32;
if is_h {
drawing::draw_line_segment_mut(
image,
(start.0, start.1 + offset),
(end.0, end.1 + offset),
line_color,
);
} else {
drawing::draw_line_segment_mut(
image,
(start.0 + offset, start.1),
(end.0 + offset, end.1),
line_color,
);
}
}
}
fn draw_ellipse(&self, image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>, rng: &mut ThreadRng) {
let line_color = color::gen_line_color(rng);
let thickness = rng.gen_range(2..4);
for i in 0..thickness {
let center = (
rng.gen_range(-(self.width as i32) / 4..(self.width as i32) * 5 / 4),
rng.gen_range(-(self.height as i32) / 4..(self.height as i32) * 5 / 4),
);
drawing::draw_hollow_ellipse_mut(
image,
(center.0, center.1 + i),
(self.width * 6 / 7) as i32,
(self.height * 5 / 8) as i32,
line_color,
);
}
}
fn build_image(&self, phrase: String) -> ImageBuffer<Rgba<u8>, Vec<u8>> {
let mut rng = rand::thread_rng();
let bgc = match self.background_color {
Some(v) => v,
None => color::gen_background_color(&mut rng),
};
let mut image = ImageBuffer::from_fn(self.width, self.height, |_, _| bgc);
//draw behind line
let square = self.width * self.height;
let effects = match self.max_behind_lines {
Some(s) => {
if s > 0 {
rng.gen_range(square / 3000..square / 2000).min(s)
} else {
0
}
}
None => rng.gen_range(square / 3000..square / 2000),
};
for _ in 0..effects {
self.draw_line(&mut image, &mut rng);
}
//write phrase
self.write_phrase(&mut image, &mut rng, &phrase);
//draw front line
let effects = match self.max_front_lines {
Some(s) => {
if s > 0 {
rng.gen_range(square / 3000..=square / 2000).min(s)
} else {
0
}
}
None => rng.gen_range(square / 3000..=square / 2000),
};
for _ in 0..effects {
self.draw_line(&mut image, &mut rng);
}
//draw ellipse
let effects = match self.max_front_lines {
Some(s) => {
if s > 0 {
rng.gen_range(square / 4000..=square / 3000).min(s)
} else {
0
}
}
None => rng.gen_range(square / 4000..=square / 3000),
};
for _ in 0..effects {
self.draw_ellipse(&mut image, &mut rng);
}
image
}
}
/// PNG格式验证码
pub struct SpecCaptcha {
pub(crate) captcha: Captcha,
}
impl NewCaptcha for SpecCaptcha {
fn new() -> Self {
Self {
captcha: Captcha::new(),
}
}
fn with_size(width: i32, height: i32) -> Self {
Self {
captcha: Captcha::with_size(width, height),
}
}
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self {
Self {
captcha: Captcha::with_size_and_len(width, height, len),
}
}
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self {
Self {
captcha: Captcha::with_all(width, height, len, font, font_size),
}
}
}
impl AbstractCaptcha for SpecCaptcha {
type Error = image::ImageError;
fn out(&mut self, mut out: impl Write) -> Result<(), Self::Error> {
let phrase = self.captcha.text_char();
let builder = CaptchaBuilder {
width: self.captcha.width as u32,
height: self.captcha.height as u32,
length: self.captcha.len as u32,
background_color: None,
fonts: &[self.captcha.get_font()],
max_behind_lines: Some(0),
max_front_lines: Some(0),
max_ellipse_lines: Some(0),
..Default::default()
};
let image = builder.build_image(phrase.iter().collect());
let format = image::ImageOutputFormat::Png;
let mut raw_data: Vec<u8> = Vec::new();
image.write_to(&mut Cursor::new(&mut raw_data), format)?;
out.write_all(&raw_data)?;
Ok(())
}
fn get_chars(&mut self) -> Vec<char> {
self.captcha.text_char()
}
fn base64(&mut self) -> Result<String, Self::Error> {
self.base64_with_head("data:image/png;base64,")
}
fn get_content_type(&mut self) -> String {
"image/png".into()
}
}
#[cfg(test)]
mod test {
#[test]
fn it_works() {}
}

View File

@@ -0,0 +1,69 @@
//! Axum & Tower_sessions 组合
//!
//! - Axum: [axum](https://docs.rs/axum)
//! - Tower Sessions: [axum](https://docs.rs/tower-sessions)
use super::AbstractCaptcha;
use super::CaptchaUtil;
use async_trait::async_trait;
use axum::response::Response;
use std::fmt::Debug;
use tower_sessions::Session;
const CAPTCHA_KEY: &'static str = "ez-captcha";
/// Axum & Tower_Sessions
#[async_trait]
pub trait CaptchaAxumTowerSessionExt {
/// 错误类型
type Error: Debug + Send + Sync + 'static;
/// 将验证码图片写入响应并将用户的验证码信息保存至Session中
///
/// Write the Captcha Image into the response and save the Captcha information into the user's Session.
async fn out(&mut self, session: &Session) -> Result<Response, Self::Error>;
}
/// Axum & Tower_Sessions - 静态方法
#[async_trait]
pub trait CaptchaAxumTowerSessionStaticExt {
/// 验证验证码,返回的布尔值代表验证码是否正确
///
/// Verify the Captcha code, and return whether user's code is correct.
async fn ver(code: &str, session: &Session) -> bool {
match session.get::<String>(CAPTCHA_KEY).await {
Ok(Some(ans)) => ans.to_ascii_lowercase() == code.to_ascii_lowercase(),
_ => false,
}
}
/// 清除Session中的验证码
///
/// Clear the Captcha in the session.
async fn clear(session: &Session) {
if session.remove::<String>(CAPTCHA_KEY).await.is_err() {
tracing::warn!("Exception occurs during clearing the session.")
}
}
}
#[async_trait]
impl<T: AbstractCaptcha + Send> CaptchaAxumTowerSessionExt for CaptchaUtil<T> {
type Error = anyhow::Error;
async fn out(&mut self, session: &Session) -> Result<Response, Self::Error> {
let mut data = vec![];
self.captcha_instance.out(&mut data)?;
let ans: String = self.captcha_instance.get_chars().iter().collect();
session.insert(CAPTCHA_KEY, ans).await?;
let resp = Response::builder()
.header("Content-Type", self.captcha_instance.get_content_type())
.body(data.into())?;
Ok(resp)
}
}
#[async_trait]
impl CaptchaAxumTowerSessionStaticExt for CaptchaUtil {}

View File

@@ -0,0 +1,41 @@
pub mod axum_tower_sessions;
use super::base::captcha::AbstractCaptcha;
use super::captcha::spec::SpecCaptcha;
use super::{CaptchaFont, NewCaptcha};
/// 验证码工具类 - Captcha Utils
///
/// 默认使用[SpecCaptcha]静态PNG字母验证码作为验证码实现用户也可以指定其他实现了[AbstractCaptcha]的类型。
///
/// Use [SpecCaptcha] (static PNG-format alphabetical Captcha) as the default implement of the Captcha service. Users may use other implementation of [AbstractCaptcha] they prefer.
///
pub struct CaptchaUtil<T: AbstractCaptcha = SpecCaptcha> {
captcha_instance: T,
}
impl<T: AbstractCaptcha> NewCaptcha for CaptchaUtil<T> {
fn new() -> Self {
Self {
captcha_instance: T::new(),
}
}
fn with_size(width: i32, height: i32) -> Self {
Self {
captcha_instance: T::with_size(width, height),
}
}
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self {
Self {
captcha_instance: T::with_size_and_len(width, height, len),
}
}
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self {
Self {
captcha_instance: T::with_all(width, height, len, font, font_size),
}
}
}

View File

@@ -0,0 +1,134 @@
//! Rust图形验证码由Java同名开源库[whvcse/EasyCaptcha](https://github.com/ele-admin/EasyCaptcha)移植而来👏100%纯Rust实现支持gif、算术等类型。
//!
//! Rust Captcha library, which is ported from Java's same-name library [whvcse/EasyCaptcha](https://github.com/ele-admin/EasyCaptcha),
//! implemented in 100% pure Rust, supporting GIF and arithmetic problems.
//!
//! <br/>
//!
//! 目前已适配框架 / Frameworks which is adapted now:
//!
//! - `axum` + `tower-sessions`
//!
//! 更多框架欢迎您提交PR参与适配🙏 PR for new frameworks are welcomed
//!
//! <br/>
//!
//! ## 安装 Install
//!
//! 请参考Github README为Linux系统安装依赖。
//!
//! If you are compiling this project in linux, please refer to README in repository to install
//! dependencies into you system.
//!
//! ## 使用 Usage
//!
//! 若您正在使用的框架已适配,您可直接通过[CaptchaUtil](extension::CaptchaUtil)类并导入相应框架的trait来使用验证码
//!
//! If your framework is adapted, you can just use [CaptchaUtil](extension::CaptchaUtil) and importing traits of your
//! framework to use the Captcha:
//!
//! ```
//! use std::collections::HashMap;
//! use axum::extract::Query;
//! use axum::response::IntoResponse;
//! use easy_captcha::captcha::gif::GifCaptcha;
//! use easy_captcha::extension::axum_tower_sessions::{
//! CaptchaAxumTowerSessionExt, CaptchaAxumTowerSessionStaticExt,
//! };
//! use easy_captcha::extension::CaptchaUtil;
//! use easy_captcha::NewCaptcha;
//!
//! /// 接口:获取验证码
//! /// Handler: Get a captcha
//! async fn get_captcha(session: tower_sessions::Session) -> Result<axum::response::Response, axum::http::StatusCode> {
//! let mut captcha: CaptchaUtil<GifCaptcha> = CaptchaUtil::new();
//! match captcha.out(&session).await {
//! Ok(response) => Ok(response),
//! Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
//! }
//! }
//!
//! /// 接口:验证验证码
//! /// Handler: Verify captcha codes
//! async fn verify_captcha(
//! session: tower_sessions::Session,
//! Query(query): Query<HashMap<String, String>>,
//! ) -> axum::response::Response {
//! // 从请求中获取验证码 Getting code from the request.
//! if let Some(code) = query.get("code") {
//! // 调用CaptchaUtil的静态方法验证验证码是否正确 Use a static method in CaptchaUtil to verify.
//! if CaptchaUtil::ver(code, &session).await {
//! CaptchaUtil::clear(&session).await; // 如果愿意的话你可以从Session中清理掉验证码 You may clear the Captcha from the Session if you want
//! "Your code is valid, thank you.".into_response()
//! } else {
//! "Your code is not valid, I'm sorry.".into_response()
//! }
//! } else {
//! "You didn't provide the code.".into_response()
//! }
//! }
//! ```
//!
//! 您也可以自定义验证码的各项属性
//!
//! You can also specify properties of the Captcha.
//!
//! ```rust
//! use easy_captcha::captcha::gif::GifCaptcha;
//! use easy_captcha::extension::axum_tower_sessions::CaptchaAxumTowerSessionExt;
//! use easy_captcha::extension::CaptchaUtil;
//! use easy_captcha::NewCaptcha;
//!
//! async fn get_captcha(session: tower_sessions::Session) -> Result<axum::response::Response, axum::http::StatusCode> {
//! let mut captcha: CaptchaUtil<GifCaptcha> = CaptchaUtil::with_size_and_len(127, 48, 4);
//! match captcha.out(&session).await {
//! Ok(response) => Ok(response),
//! Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
//! }
//! }
//! ```
//!
//! 项目当前提供了三种验证码实现:[SpecCaptcha](captcha::spec::SpecCaptcha)静态PNG、[GifCaptcha](captcha::gif::GifCaptcha)动态GIF
//! 、[ArithmeticCaptcha](captcha::arithmetic::ArithmeticCaptcha)算术PNG您可按需使用。
//!
//! There is three implementation of Captcha currently, which are [SpecCaptcha](captcha::spec::SpecCaptcha)(static PNG),
//! [GifCaptcha](captcha::gif::GifCaptcha)(GIF), [ArithmeticCaptcha](captcha::arithmetic::ArithmeticCaptcha)(Arithmetic problems),
//! you can use them according to your need.
//!
//! <br/>
//!
//! 自带字体效果 / Fonts shipped
//!
//! | 字体/Fonts | 效果/Preview |
//! |---------------------|------------------------------------------------|
//! | CaptchaFont::Font1 | ![](https://s2.ax1x.com/2019/08/23/msMe6U.png) |
//! | CaptchaFont::Font2 | ![](https://s2.ax1x.com/2019/08/23/msMAf0.png) |
//! | CaptchaFont::Font3 | ![](https://s2.ax1x.com/2019/08/23/msMCwj.png) |
//! | CaptchaFont::Font4 | ![](https://s2.ax1x.com/2019/08/23/msM9mQ.png) |
//! | CaptchaFont::Font5 | ![](https://s2.ax1x.com/2019/08/23/msKz6S.png) |
//! | CaptchaFont::Font6 | ![](https://s2.ax1x.com/2019/08/23/msKxl8.png) |
//! | CaptchaFont::Font7 | ![](https://s2.ax1x.com/2019/08/23/msMPTs.png) |
//! | CaptchaFont::Font8 | ![](https://s2.ax1x.com/2019/08/23/msMmXF.png) |
//! | CaptchaFont::Font9 | ![](https://s2.ax1x.com/2019/08/23/msMVpV.png) |
//! | CaptchaFont::Font10 | ![](https://s2.ax1x.com/2019/08/23/msMZlT.png) |
//!
#![warn(missing_docs)]
#![allow(dead_code)]
pub(crate) mod base;
pub mod captcha;
pub mod extension;
mod utils;
pub use base::captcha::*;
// #[cfg(test)]
// mod tests {
// use super::*;
//
// #[test]
// fn it_works() {
//
// }
// }

View File

@@ -0,0 +1,53 @@
//! RGBA颜色
use std::fmt::{Debug, Formatter};
#[derive(Clone)]
pub struct Color(f64, f64, f64, f64);
impl Color {
pub fn set_alpha(&mut self, a: f64) {
self.3 = a;
}
}
impl Debug for Color {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Color")
.field("r", &self.0)
.field("g", &self.1)
.field("b", &self.2)
.field("a", &self.3)
.finish()
}
}
impl From<(u8, u8, u8)> for Color {
fn from(value: (u8, u8, u8)) -> Self {
Self(
value.0 as f64 / 255.0,
value.1 as f64 / 255.0,
value.2 as f64 / 255.0,
1.0,
)
}
}
impl Into<(u8, u8, u8, u8)> for Color {
fn into(self) -> (u8, u8, u8, u8) {
(
(self.0 * 255.0) as u8,
(self.1 * 255.0) as u8,
(self.2 * 255.0) as u8,
(self.3 * 255.0) as u8,
)
}
}
impl Into<u32> for Color {
fn into(self) -> u32 {
let color: (u8, u8, u8, u8) = self.into();
(color.0 as u32) << 24 + (color.1 as u32) << 16 + (color.2 as u32) << 8 + (color.3 as u32)
}
}
impl Color {}

View File

@@ -0,0 +1,45 @@
use rust_embed::RustEmbed;
use rusttype::Font;
use std::error::Error;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "resources/"]
struct FontAssets;
// lazy_static! {
// pub(crate) static ref FONTS: RwLock<HashMap<String, Arc<Font>>> = Default::default();
// }
pub fn get_font(font_name: &str) -> Option<Arc<Font>> {
// let fonts_cell = FONTS.get_or_init(|| Default::default());
// let guard = fonts_cell.read();
//
// if guard.contains_key(font_name) {
// Some(guard.get(font_name).unwrap().clone())
// } else {
// drop(guard);
if let Ok(Some(font)) = load_font(font_name) {
// let mut guard = fonts_cell.write();
let font = Arc::new(font);
// guard.insert(String::from(font_name), font.clone());
Some(font)
} else {
None
}
// }
}
pub fn load_font(font_name: &str) -> Result<Option<Font>, Box<dyn Error>> {
match FontAssets::get(font_name) {
Some(assets) => {
let font = Font::try_from_vec(Vec::from(assets.data)).unwrap();
Ok(Some(font))
}
None => {
tracing::error!("Unable to find the specified font.");
Ok(None)
}
}
}

View File

@@ -0,0 +1,4 @@
//! Utilities
pub(crate) mod color;
pub(crate) mod font;

View File

@@ -1,23 +1,40 @@
use std::vec; mod auth;
pub(crate) mod captcha;
mod network;
mod users;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use axum::extract::{Path, Query};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::post;
use axum::{extract::State, routing::get, Json, Router}; use axum::{extract::State, routing::get, Json, Router};
use easytier::proto::{self, rpc_types, web::*}; use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer};
use easytier::{common::scoped_task::ScopedTask, proto::rpc_types::controller::BaseController}; use axum_login::{login_required, AuthManagerLayerBuilder, AuthzBackend};
use axum_messages::MessagesManagerLayer;
use easytier::common::scoped_task::ScopedTask;
use easytier::proto::{self, rpc_types};
use network::NetworkApi;
use sea_orm::DbErr;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_sessions::cookie::time::Duration;
use tower_sessions::cookie::Key;
use tower_sessions::Expiry;
use tower_sessions_sqlx_store::SqliteStore;
use users::{AuthSession, Backend};
use crate::client_manager::session::Session; use crate::client_manager::session::Session;
use crate::client_manager::storage::StorageToken; use crate::client_manager::storage::StorageToken;
use crate::client_manager::ClientManager; use crate::client_manager::ClientManager;
use crate::db::Db;
pub struct RestfulServer { pub struct RestfulServer {
bind_addr: SocketAddr, bind_addr: SocketAddr,
client_mgr: Arc<ClientManager>, client_mgr: Arc<ClientManager>,
db: Db,
serve_task: Option<ScopedTask<()>>, serve_task: Option<ScopedTask<()>>,
delete_task: Option<ScopedTask<tower_sessions::session_store::Result<()>>>,
network_api: NetworkApi,
} }
type AppStateInner = Arc<ClientManager>; type AppStateInner = Arc<ClientManager>;
@@ -26,52 +43,44 @@ type AppState = State<AppStateInner>;
#[derive(Debug, serde::Deserialize, serde::Serialize)] #[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListSessionJsonResp(Vec<StorageToken>); struct ListSessionJsonResp(Vec<StorageToken>);
#[derive(Debug, serde::Deserialize, serde::Serialize)] pub type Error = proto::error::Error;
struct ValidateConfigJsonReq { pub type ErrorKind = proto::error::error::ErrorKind;
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RunNetworkJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ColletNetworkInfoJsonReq {
inst_ids: Option<Vec<uuid::Uuid>>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RemoveNetworkJsonReq {
inst_ids: Vec<uuid::Uuid>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListNetworkInstanceIdsJsonResp(Vec<uuid::Uuid>);
type Error = proto::error::Error;
type ErrorKind = proto::error::error::ErrorKind;
type RpcError = rpc_types::error::Error; type RpcError = rpc_types::error::Error;
type HttpHandleError = (StatusCode, Json<Error>); type HttpHandleError = (StatusCode, Json<Error>);
fn convert_rpc_error(e: RpcError) -> (StatusCode, Json<Error>) { pub fn other_error<T: ToString>(error_message: T) -> Error {
let status_code = match &e { Error {
RpcError::ExecutionError(_) => StatusCode::BAD_REQUEST, error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
RpcError::Timeout(_) => StatusCode::GATEWAY_TIMEOUT, error_message: error_message.to_string(),
_ => StatusCode::BAD_GATEWAY, })),
}; }
let error = Error::from(&e); }
(status_code, Json(error))
pub fn convert_db_error(e: DbErr) -> HttpHandleError {
(
StatusCode::INTERNAL_SERVER_ERROR,
other_error(format!("DB Error: {:#}", e)).into(),
)
} }
impl RestfulServer { impl RestfulServer {
pub fn new(bind_addr: SocketAddr, client_mgr: Arc<ClientManager>) -> Self { pub async fn new(
bind_addr: SocketAddr,
client_mgr: Arc<ClientManager>,
db: Db,
) -> anyhow::Result<Self> {
assert!(client_mgr.is_running()); assert!(client_mgr.is_running());
RestfulServer {
let network_api = NetworkApi::new();
Ok(RestfulServer {
bind_addr, bind_addr,
client_mgr, client_mgr,
db,
serve_task: None, serve_task: None,
} delete_task: None,
network_api,
})
} }
async fn get_session_by_machine_id( async fn get_session_by_machine_id(
@@ -79,162 +88,69 @@ impl RestfulServer {
machine_id: &uuid::Uuid, machine_id: &uuid::Uuid,
) -> Result<Arc<Session>, HttpHandleError> { ) -> Result<Arc<Session>, HttpHandleError> {
let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else { let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else {
return Err(( return Err((StatusCode::NOT_FOUND, other_error("No such session").into()));
StatusCode::NOT_FOUND,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: "No such session".to_string(),
})),
}
.into(),
));
}; };
Ok(result) Ok(result)
} }
async fn handle_list_all_sessions( async fn handle_list_all_sessions(
auth_session: AuthSession,
State(client_mgr): AppState, State(client_mgr): AppState,
) -> Result<Json<ListSessionJsonResp>, HttpHandleError> { ) -> Result<Json<ListSessionJsonResp>, HttpHandleError> {
let pers = auth_session
.backend
.get_group_permissions(auth_session.user.as_ref().unwrap())
.await
.unwrap();
println!("{:?}", pers);
let ret = client_mgr.list_sessions().await; let ret = client_mgr.list_sessions().await;
Ok(ListSessionJsonResp(ret).into()) Ok(ListSessionJsonResp(ret).into())
} }
async fn handle_validate_config(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<ValidateConfigJsonReq>,
) -> Result<(), HttpHandleError> {
let config = payload.config;
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.validate_config(BaseController::default(), ValidateConfigRequest { config })
.await
.map_err(convert_rpc_error)?;
Ok(())
}
async fn handle_run_network_instance(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<RunNetworkJsonReq>,
) -> Result<(), HttpHandleError> {
let config = payload.config;
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.run_network_instance(
BaseController::default(),
RunNetworkInstanceRequest { config },
)
.await
.map_err(convert_rpc_error)?;
Ok(())
}
async fn handle_collect_one_network_info(
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_collect_network_info(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Query(payload): Query<ColletNetworkInfoJsonReq>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: payload
.inst_ids
.unwrap_or_default()
.into_iter()
.map(Into::into)
.collect(),
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_list_network_instance_ids(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
) -> Result<Json<ListNetworkInstanceIdsJsonResp>, HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.list_network_instance(BaseController::default(), ListNetworkInstanceRequest {})
.await
.map_err(convert_rpc_error)?;
Ok(
ListNetworkInstanceIdsJsonResp(ret.inst_ids.into_iter().map(Into::into).collect())
.into(),
)
}
async fn handle_remove_network_instance(
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<(), HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.delete_network_instance(
BaseController::default(),
DeleteNetworkInstanceRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(())
}
pub async fn start(&mut self) -> Result<(), anyhow::Error> { pub async fn start(&mut self) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(self.bind_addr).await.unwrap(); let listener = TcpListener::bind(self.bind_addr).await?;
// Session layer.
//
// This uses `tower-sessions` to establish a layer that will provide the session
// as a request extension.
let session_store = SqliteStore::new(self.db.inner());
session_store.migrate().await?;
self.delete_task.replace(
tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
)
.into(),
);
// Generate a cryptographic key to sign the session cookie.
let key = Key::generate();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1)))
.with_signed(key);
// Auth service.
//
// This combines the session layer with our backend to establish the auth
// service which will provide the auth session as a request extension.
let backend = Backend::new(self.db.clone());
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
let app = Router::new() let app = Router::new()
.route("/api/v1/sessions", get(Self::handle_list_all_sessions)) .route("/api/v1/sessions", get(Self::handle_list_all_sessions))
.route( .merge(self.network_api.build_route())
"/api/v1/network/:machine-id/validate-config", .route_layer(login_required!(Backend))
post(Self::handle_validate_config), .merge(auth::router())
) .with_state(self.client_mgr.clone())
.route( .layer(MessagesManagerLayer)
"/api/v1/network/:machine-id", .layer(auth_layer)
post(Self::handle_run_network_instance).get(Self::handle_list_network_instance_ids), .layer(tower_http::cors::CorsLayer::very_permissive());
)
.route(
"/api/v1/network/:machine-id/info",
get(Self::handle_collect_network_info),
)
.route(
"/api/v1/network/:machine-id/:inst-id",
get(Self::handle_collect_one_network_info)
.delete(Self::handle_remove_network_instance),
)
.with_state(self.client_mgr.clone());
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();

View File

@@ -0,0 +1,321 @@
use std::sync::Arc;
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::routing::{delete, post};
use axum::{extract::State, routing::get, Json, Router};
use axum_login::AuthUser;
use dashmap::DashSet;
use easytier::proto::common::Void;
use easytier::proto::rpc_types::controller::BaseController;
use easytier::proto::{self, web::*};
use crate::client_manager::session::Session;
use crate::client_manager::ClientManager;
use super::users::AuthSession;
use super::{
convert_db_error, AppState, AppStateInner, Error, ErrorKind, HttpHandleError, RpcError,
};
fn convert_rpc_error(e: RpcError) -> (StatusCode, Json<Error>) {
let status_code = match &e {
RpcError::ExecutionError(_) => StatusCode::BAD_REQUEST,
RpcError::Timeout(_) => StatusCode::GATEWAY_TIMEOUT,
_ => StatusCode::BAD_GATEWAY,
};
let error = Error::from(&e);
(status_code, Json(error))
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ValidateConfigJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RunNetworkJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ColletNetworkInfoJsonReq {
inst_ids: Option<Vec<uuid::Uuid>>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RemoveNetworkJsonReq {
inst_ids: Vec<uuid::Uuid>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListNetworkInstanceIdsJsonResp(Vec<uuid::Uuid>);
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListMachineItem {
client_url: Option<url::Url>,
info: Option<HeartbeatRequest>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListMachineJsonResp {
machines: Vec<ListMachineItem>,
}
pub struct NetworkApi {}
impl NetworkApi {
pub fn new() -> Self {
Self {}
}
async fn get_session_by_machine_id(
auth_session: &AuthSession,
client_mgr: &ClientManager,
machine_id: &uuid::Uuid,
) -> Result<Arc<Session>, HttpHandleError> {
let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else {
return Err((
StatusCode::NOT_FOUND,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: format!("No such session: {}", machine_id),
})),
}
.into(),
));
};
let Some(token) = result.get_token().await else {
return Err((
StatusCode::UNAUTHORIZED,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: "No token reported".to_string(),
})),
}
.into(),
));
};
if !auth_session
.user
.as_ref()
.map(|x| x.tokens.contains(&token.token))
.unwrap_or(false)
{
return Err((
StatusCode::FORBIDDEN,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: "Token mismatch".to_string(),
})),
}
.into(),
));
}
Ok(result)
}
async fn handle_validate_config(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<ValidateConfigJsonReq>,
) -> Result<Json<Void>, HttpHandleError> {
let config = payload.config;
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.validate_config(BaseController::default(), ValidateConfigRequest { config })
.await
.map_err(convert_rpc_error)?;
Ok(Void::default().into())
}
async fn handle_run_network_instance(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<RunNetworkJsonReq>,
) -> Result<Json<Void>, HttpHandleError> {
let config = payload.config;
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let resp = c
.run_network_instance(
BaseController::default(),
RunNetworkInstanceRequest {
inst_id: None,
config: config.clone(),
},
)
.await
.map_err(convert_rpc_error)?;
client_mgr
.db()
.insert_or_update_user_network_config(
auth_session.user.as_ref().unwrap().id(),
resp.inst_id.clone().unwrap_or_default().into(),
config,
)
.await
.map_err(convert_db_error)?;
Ok(Void::default().into())
}
async fn handle_collect_one_network_info(
auth_session: AuthSession,
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_collect_network_info(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Query(payload): Query<ColletNetworkInfoJsonReq>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: payload
.inst_ids
.unwrap_or_default()
.into_iter()
.map(Into::into)
.collect(),
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_list_network_instance_ids(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
) -> Result<Json<ListNetworkInstanceIdsJsonResp>, HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.list_network_instance(BaseController::default(), ListNetworkInstanceRequest {})
.await
.map_err(convert_rpc_error)?;
Ok(
ListNetworkInstanceIdsJsonResp(ret.inst_ids.into_iter().map(Into::into).collect())
.into(),
)
}
async fn handle_remove_network_instance(
auth_session: AuthSession,
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<(), HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
client_mgr
.db()
.delete_network_config(auth_session.user.as_ref().unwrap().id(), inst_id)
.await
.map_err(convert_db_error)?;
let c = result.scoped_rpc_client();
c.delete_network_instance(
BaseController::default(),
DeleteNetworkInstanceRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(())
}
async fn handle_list_machines(
auth_session: AuthSession,
State(client_mgr): AppState,
) -> Result<Json<ListMachineJsonResp>, HttpHandleError> {
let tokens = auth_session
.user
.as_ref()
.map(|x| x.tokens.clone())
.unwrap_or_default();
let client_urls = DashSet::new();
for token in tokens {
let urls = client_mgr.list_machine_by_token(token);
for url in urls {
client_urls.insert(url);
}
}
let mut machines = vec![];
for item in client_urls.iter() {
let client_url = item.key().clone();
let session = client_mgr.get_heartbeat_requests(&client_url).await;
machines.push(ListMachineItem {
client_url: Some(client_url),
info: session,
});
}
Ok(Json(ListMachineJsonResp { machines }))
}
pub fn build_route(&mut self) -> Router<AppStateInner> {
Router::new()
.route("/api/v1/machines", get(Self::handle_list_machines))
.route(
"/api/v1/machines/:machine-id/validate-config",
post(Self::handle_validate_config),
)
.route(
"/api/v1/machines/:machine-id/networks",
post(Self::handle_run_network_instance).get(Self::handle_list_network_instance_ids),
)
.route(
"/api/v1/machines/:machine-id/networks/:inst-id",
delete(Self::handle_remove_network_instance),
)
.route(
"/api/v1/machines/:machine-id/networks/info",
get(Self::handle_collect_network_info),
)
.route(
"/api/v1/machines/:machine-id/networks/info/:inst-id",
get(Self::handle_collect_one_network_info),
)
}
}

View File

@@ -0,0 +1,241 @@
use std::collections::HashSet;
use async_trait::async_trait;
use axum_login::{AuthUser, AuthnBackend, AuthzBackend, UserId};
use password_auth::verify_password;
use sea_orm::{
ActiveModelTrait as _, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel, JoinType,
QueryFilter, QuerySelect as _, RelationTrait, Set, TransactionTrait,
};
use serde::{Deserialize, Serialize};
use tokio::task;
use crate::db::{self, entity};
#[derive(Clone, Serialize, Deserialize)]
pub struct User {
db_user: entity::users::Model,
pub tokens: Vec<String>,
}
// Here we've implemented `Debug` manually to avoid accidentally logging the
// password hash.
impl std::fmt::Debug for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("User")
.field("id", &self.db_user.id)
.field("username", &self.db_user.username)
.field("password", &"[redacted]")
.finish()
}
}
impl AuthUser for User {
type Id = i32;
fn id(&self) -> Self::Id {
self.db_user.id
}
fn session_auth_hash(&self) -> &[u8] {
self.db_user.password.as_bytes() // We use the password hash as the auth
// hash--what this means
// is when the user changes their password the
// auth session becomes invalid.
}
}
// This allows us to extract the authentication fields from forms. We use this
// to authenticate requests with the backend.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Credentials {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RegisterNewUser {
pub credentials: Credentials,
pub captcha: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChangePassword {
pub new_password: String,
}
#[derive(Debug, Clone)]
pub struct Backend {
db: db::Db,
}
impl Backend {
pub fn new(db: db::Db) -> Self {
Self { db }
}
pub async fn register_new_user(&self, new_user: &RegisterNewUser) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(new_user.credentials.password.as_str());
let mut txn = self.db.orm_db().begin().await?;
entity::users::ActiveModel {
username: Set(new_user.credentials.username.clone()),
password: Set(hashed_password.clone()),
..Default::default()
}
.save(&mut txn)
.await?;
entity::users_groups::ActiveModel {
user_id: Set(entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(new_user.credentials.username.as_str()))
.one(&mut txn)
.await?
.unwrap()
.id),
group_id: Set(entity::groups::Entity::find()
.filter(entity::groups::Column::Name.eq("users"))
.one(&mut txn)
.await?
.unwrap()
.id),
..Default::default()
}
.save(&mut txn)
.await?;
txn.commit().await?;
Ok(())
}
pub async fn change_password(
&self,
id: <User as AuthUser>::Id,
req: &ChangePassword,
) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(req.new_password.as_str());
use entity::users;
let mut user = users::Entity::find_by_id(id)
.one(self.db.orm_db())
.await?
.ok_or(anyhow::anyhow!("User not found"))?
.into_active_model();
user.password = Set(hashed_password.clone());
entity::users::Entity::update(user)
.exec(self.db.orm_db())
.await?;
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Sqlx(#[from] sea_orm::DbErr),
#[error(transparent)]
TaskJoin(#[from] task::JoinError),
}
#[async_trait]
impl AuthnBackend for Backend {
type User = User;
type Credentials = Credentials;
type Error = Error;
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
let user = entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(creds.username))
.one(self.db.orm_db())
.await?;
task::spawn_blocking(|| {
// We're using password-based authentication--this works by comparing our form
// input with an argon2 password hash.
Ok(user
.filter(|user| verify_password(creds.password, &user.password).is_ok())
.map(|user| User {
db_user: user.clone(),
tokens: vec![user.username.clone()],
}))
})
.await?
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
let mut user = entity::users::Entity::find()
.filter(entity::users::Column::Id.eq(*user_id))
.one(self.db.orm_db())
.await?;
if let Some(u) = &mut user {
let mut user = User {
db_user: u.clone(),
tokens: vec![],
};
// username is a token
user.tokens.push(u.username.clone());
Ok(Some(user))
} else {
Ok(None)
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash, FromQueryResult)]
pub struct Permission {
pub name: String,
}
impl From<&str> for Permission {
fn from(name: &str) -> Self {
Permission {
name: name.to_string(),
}
}
}
#[async_trait]
impl AuthzBackend for Backend {
type Permission = Permission;
async fn get_group_permissions(
&self,
_user: &Self::User,
) -> Result<HashSet<Self::Permission>, Self::Error> {
let permissions = entity::users::Entity::find()
.column_as(entity::permissions::Column::Name, "name")
.join(
JoinType::LeftJoin,
entity::users::Relation::UsersGroups.def(),
)
.join(
JoinType::LeftJoin,
entity::users_groups::Relation::Groups.def(),
)
.join(
JoinType::LeftJoin,
entity::groups::Relation::GroupsPermissions.def(),
)
.join(
JoinType::LeftJoin,
entity::groups_permissions::Relation::Permissions.def(),
)
.into_model::<Self::Permission>()
.all(self.db.orm_db())
.await?;
Ok(permissions.into_iter().collect())
}
}
// We use a type alias for convenience.
//
// Note that we've supplied our concrete backend here.
pub type AuthSession = axum_login::AuthSession<Backend>;

View File

@@ -55,7 +55,6 @@ pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
} }
future::poll_fn(|cx| { future::poll_fn(|cx| {
tracing::debug!("try join joinset tasks");
let Some(js) = js.upgrade() else { let Some(js) = js.upgrade() else {
return std::task::Poll::Ready(()); return std::task::Poll::Ready(());
}; };

View File

@@ -17,6 +17,12 @@ impl From<Uuid> for uuid::Uuid {
} }
} }
impl From<String> for Uuid {
fn from(value: String) -> Self {
uuid::Uuid::parse_str(&value).unwrap().into()
}
}
impl Display for Uuid { impl Display for Uuid {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", uuid::Uuid::from(self.clone())) write!(f, "{}", uuid::Uuid::from(self.clone()))

View File

@@ -36,6 +36,12 @@ message HeartbeatRequest {
common.UUID machine_id = 1; common.UUID machine_id = 1;
common.UUID inst_id = 2; common.UUID inst_id = 2;
string user_token = 3; string user_token = 3;
string easytier_version = 4;
string report_time = 5;
string hostname = 6;
repeated common.UUID running_network_instances = 7;
} }
message HeartbeatResponse { message HeartbeatResponse {
@@ -53,10 +59,12 @@ message ValidateConfigResponse {
} }
message RunNetworkInstanceRequest { message RunNetworkInstanceRequest {
string config = 1; common.UUID inst_id = 1;
string config = 2;
} }
message RunNetworkInstanceResponse { message RunNetworkInstanceResponse {
common.UUID inst_id = 1;
} }
message RetainNetworkInstanceRequest { message RetainNetworkInstanceRequest {

View File

@@ -547,7 +547,7 @@ impl ZCPacket {
ZCPacketType::NIC => unreachable!(), ZCPacketType::NIC => unreachable!(),
}; };
tracing::debug!(?self.packet_type, ?target_packet_type, ?new_offset, "convert zc packet type"); tracing::trace!(?self.packet_type, ?target_packet_type, ?new_offset, "convert zc packet type");
if new_offset == INVALID_OFFSET { if new_offset == INVALID_OFFSET {
// copy peer manager header and payload to new buffer // copy peer manager header and payload to new buffer

View File

@@ -101,8 +101,14 @@ impl WebClientService for Controller {
req: RunNetworkInstanceRequest, req: RunNetworkInstanceRequest,
) -> Result<RunNetworkInstanceResponse, rpc_types::error::Error> { ) -> Result<RunNetworkInstanceResponse, rpc_types::error::Error> {
let cfg = TomlConfigLoader::new_from_str(&req.config)?; let cfg = TomlConfigLoader::new_from_str(&req.config)?;
let id = cfg.get_id();
if let Some(inst_id) = req.inst_id {
cfg.set_id(inst_id.into());
}
self.run_network_instance(cfg)?; self.run_network_instance(cfg)?;
Ok(RunNetworkInstanceResponse {}) Ok(RunNetworkInstanceResponse {
inst_id: Some(id.into()),
})
} }
async fn retain_network_instance( async fn retain_network_instance(

View File

@@ -1,4 +1,4 @@
use std::sync::Arc; use std::sync::{Arc, Weak};
use tokio::{ use tokio::{
sync::{broadcast, Mutex}, sync::{broadcast, Mutex},
@@ -7,7 +7,7 @@ use tokio::{
}; };
use crate::{ use crate::{
common::get_machine_id, common::{constants::EASYTIER_VERSION, get_machine_id},
proto::{ proto::{
rpc_impl::bidirect::BidirectRpcManager, rpc_impl::bidirect::BidirectRpcManager,
rpc_types::controller::BaseController, rpc_types::controller::BaseController,
@@ -47,7 +47,8 @@ impl Session {
.register(WebClientServiceServer::new(controller.clone()), ""); .register(WebClientServiceServer::new(controller.clone()), "");
let mut tasks: JoinSet<()> = JoinSet::new(); let mut tasks: JoinSet<()> = JoinSet::new();
let heartbeat_ctx = Self::heartbeat_routine(&rpc_mgr, controller.token(), &mut tasks); let heartbeat_ctx =
Self::heartbeat_routine(&rpc_mgr, Arc::downgrade(&controller), &mut tasks);
Session { Session {
rpc_mgr, rpc_mgr,
@@ -59,7 +60,7 @@ impl Session {
fn heartbeat_routine( fn heartbeat_routine(
rpc_mgr: &BidirectRpcManager, rpc_mgr: &BidirectRpcManager,
token: String, controller: Weak<Controller>,
tasks: &mut JoinSet<()>, tasks: &mut JoinSet<()>,
) -> HeartbeatCtx { ) -> HeartbeatCtx {
let (tx, _rx1) = broadcast::channel(2); let (tx, _rx1) = broadcast::channel(2);
@@ -71,7 +72,8 @@ impl Session {
let mid = get_machine_id(); let mid = get_machine_id();
let inst_id = uuid::Uuid::new_v4(); let inst_id = uuid::Uuid::new_v4();
let token = token; let token = controller.upgrade().unwrap().token();
let hostname = gethostname::gethostname().to_string_lossy().to_string();
let ctx_clone = ctx.clone(); let ctx_clone = ctx.clone();
let mut tick = interval(std::time::Duration::from_secs(1)); let mut tick = interval(std::time::Duration::from_secs(1));
@@ -79,13 +81,29 @@ impl Session {
.rpc_client() .rpc_client()
.scoped_client::<WebServerServiceClientFactory<BaseController>>(1, 1, "".to_string()); .scoped_client::<WebServerServiceClientFactory<BaseController>>(1, 1, "".to_string());
tasks.spawn(async move { tasks.spawn(async move {
loop {
tick.tick().await;
let Some(controller) = controller.upgrade() else {
break;
};
let req = HeartbeatRequest { let req = HeartbeatRequest {
machine_id: Some(mid.into()), machine_id: Some(mid.into()),
inst_id: Some(inst_id.into()), inst_id: Some(inst_id.into()),
user_token: token.to_string(), user_token: token.to_string(),
easytier_version: EASYTIER_VERSION.to_string(),
hostname: hostname.clone(),
report_time: chrono::Local::now().to_string(),
running_network_instances: controller
.list_network_instance_ids()
.into_iter()
.map(Into::into)
.collect(),
}; };
loop {
tick.tick().await;
match client match client
.heartbeat(BaseController::default(), req.clone()) .heartbeat(BaseController::default(), req.clone())
.await .await