feat: comprehensive project improvements
Some checks failed
CI / Rust Format (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Test Server (push) Has been cancelled
CI / Frontend Check (push) Has been cancelled
CI / Tauri Client Check (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Build Tauri (Linux) (push) Has been cancelled

- Fix 14 Clippy warnings across server and bot-sdk
- Add 67 unit tests (32 bot-sdk, 34 server, 1 doctest)
- Add Prometheus metrics endpoint (/api/metrics)
- Add structured JSON logging (EIFELDC_LOG_FORMAT=json)
- Add release workflow (Docker push + GitHub Release + Tauri builds)
- Add rate limiting middleware (EIFELDC_RATE_LIMIT)
- Add CORS restriction (EIFELDC_CORS_ORIGINS)
- Add session token expiry (EIFELDC_SESSION_TTL)
- Add input validation (username/password/homeserver length limits)
- Add upload size limit (EIFELDC_MAX_UPLOAD_MB)
- Upgrade Tauri client from v1 to v2
- Add session store with SQLite persistence
- Add proper error types and cleanup across all crates
- Format all code with cargo fmt
- Update CI pipeline with fmt, clippy, test, frontend, and Tauri checks
- Add README with full API reference and setup guide
This commit is contained in:
root
2026-04-29 13:08:01 +02:00
parent 0978d0c2e9
commit cacd2b04a7
80 changed files with 18307 additions and 1724 deletions

View File

@@ -15,8 +15,17 @@ tracing = { workspace = true }
tracing-subscriber = { workspace = true }
reqwest = { workspace = true }
url = { workspace = true }
axum = { version = "0.7", features = ["ws"] }
axum = { version = "0.7", features = ["ws", "multipart"] }
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "fs", "trace"] }
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1", features = ["v4"] }
uuid = { version = "1", features = ["v4"] }
mime = "0.3"
futures = "0.3"
rusqlite = { version = "0.30", features = ["bundled"] }
livekit-api = { version = "0.4", features = ["access-token"] }
prometheus = "0.13"
lazy_static = "1"
[dev-dependencies]
tempfile = "3"

View File

@@ -1,4 +1,6 @@
pub mod routes;
pub mod session_store;
pub mod state;
pub use state::ServerState;
pub use session_store::SessionStore;
pub use state::ServerState;

View File

@@ -1,20 +1,66 @@
use eifeldc_server::routes::api_router;
use eifeldc_server::routes::metrics;
use eifeldc_server::session_store::SessionStore;
use eifeldc_server::state::{LiveKitConfig, Session};
use eifeldc_server::ServerState;
use tower_http::services::ServeDir;
use std::net::SocketAddr;
use std::sync::Arc;
use tower_http::services::ServeDir;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
init_tracing();
metrics::init_metrics();
tracing::info!("Metrics initialized");
let db_path = std::env::var("EIFELDC_DB").unwrap_or_else(|_| "eifeldc.db".to_string());
let session_store =
Arc::new(SessionStore::new(&db_path).expect("Failed to open session database"));
let livekit = LiveKitConfig {
api_key: std::env::var("LIVEKIT_API_KEY").unwrap_or_else(|_| "devkey".to_string()),
api_secret: std::env::var("LIVEKIT_API_SECRET").unwrap_or_else(|_| "devsecret".to_string()),
url: std::env::var("LIVEKIT_URL").unwrap_or_else(|_| "ws://localhost:7880".to_string()),
};
tracing::info!("LiveKit URL: {}", livekit.url);
let session_ttl: Option<std::time::Duration> =
std::env::var("EIFELDC_SESSION_TTL").ok().map(|v| {
let secs: u64 = v.parse().unwrap_or(86400);
std::time::Duration::from_secs(secs)
});
if let Some(ttl) = session_ttl {
tracing::info!("Session TTL: {}s", ttl.as_secs());
} else {
tracing::info!("Session TTL: unlimited");
}
let state = ServerState::new(session_store.clone(), livekit, session_ttl);
{
let stored_sessions = session_store.get_all_sessions().await.unwrap_or_default();
if !stored_sessions.is_empty() {
tracing::info!("Restoring {} session(s)...", stored_sessions.len());
for stored in &stored_sessions {
match restore_session(stored, &state).await {
Ok(()) => tracing::info!("Restored session for user {}", stored.user_id),
Err(e) => {
tracing::warn!("Failed to restore session for {}: {}", stored.user_id, e);
session_store.delete_session(&stored.token).await.ok();
}
}
}
}
}
let state = ServerState::new();
let api = api_router(state);
let static_dir = std::env::var("EIFELDC_STATIC_DIR")
.unwrap_or_else(|_| "client/src-ui/dist".to_string());
let static_dir =
std::env::var("EIFELDC_STATIC_DIR").unwrap_or_else(|_| "client/src-ui/dist".to_string());
let app = api
.fallback_service(ServeDir::new(&static_dir));
let app = api.fallback_service(ServeDir::new(&static_dir));
let addr: SocketAddr = ([0, 0, 0, 0], 3000).into();
tracing::info!("EifelDC Web Server listening on http://{}", addr);
@@ -22,4 +68,80 @@ async fn main() {
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
}
async fn restore_session(
stored: &eifeldc_server::session_store::StoredSession,
state: &ServerState,
) -> Result<(), Box<dyn std::error::Error>> {
use matrix_sdk::matrix_auth::{MatrixSession, MatrixSessionTokens};
use matrix_sdk::ruma::device_id;
use matrix_sdk::Client;
let client = Client::builder()
.homeserver_url(&stored.homeserver)
.build()
.await?;
let user_id: matrix_sdk::ruma::OwnedUserId = stored.user_id.parse()?;
let device_id = stored
.device_id
.as_deref()
.map(matrix_sdk::ruma::OwnedDeviceId::from)
.unwrap_or_else(|| device_id!("EIFELDC").to_owned());
let session = MatrixSession {
meta: matrix_sdk::SessionMeta { user_id, device_id },
tokens: MatrixSessionTokens {
access_token: stored.access_token.clone(),
refresh_token: stored.refresh_token.clone(),
},
};
client.matrix_auth().restore_session(session).await?;
let mut s = Session::new(
client.clone(),
stored.user_id.clone(),
stored.homeserver.clone(),
);
let sender = s.event_sender.clone();
let sync_client = client.clone();
let handle = tokio::spawn(async move {
eifeldc_server::routes::auth::start_sync(sync_client, sender).await;
});
s.sync_handle = Some(handle);
let mut state_inner = state.write().await;
state_inner.sessions.insert(stored.token.clone(), s);
Ok(())
}
fn init_tracing() {
use tracing_subscriber::EnvFilter;
let log_format = std::env::var("EIFELDC_LOG_FORMAT").unwrap_or_else(|_| "pretty".to_string());
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("eifeldc_server=info,tower_http=info"));
match log_format.as_str() {
"json" => {
tracing_subscriber::fmt()
.json()
.with_env_filter(env_filter)
.with_target(true)
.with_thread_ids(false)
.with_file(false)
.with_line_number(false)
.init();
}
_ => {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_target(true)
.init();
}
}
}

View File

@@ -1,11 +1,9 @@
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use crate::session_store::StoredSession;
use crate::state::{Session, WsEvent};
use axum::{extract::State, http::HeaderMap, Json};
use matrix_sdk::config::SyncSettings;
use matrix_sdk::Client;
use serde::{Deserialize, Serialize};
use crate::state::VoiceManager;
#[derive(Deserialize)]
pub struct LoginRequest {
@@ -14,6 +12,21 @@ pub struct LoginRequest {
pub password: String,
}
impl LoginRequest {
pub fn validate(&self) -> Result<(), &'static str> {
if self.homeserver.is_empty() || self.homeserver.len() > 2048 {
return Err("Invalid homeserver");
}
if self.username.is_empty() || self.username.len() > 256 {
return Err("Invalid username");
}
if self.password.is_empty() || self.password.len() > 4096 {
return Err("Invalid password");
}
Ok(())
}
}
#[derive(Serialize)]
pub struct LoginResult {
pub success: bool,
@@ -29,36 +42,83 @@ pub struct RegisterRequest {
pub password: String,
}
impl RegisterRequest {
pub fn validate(&self) -> Result<(), &'static str> {
if self.homeserver.is_empty() || self.homeserver.len() > 2048 {
return Err("Invalid homeserver");
}
if self.username.is_empty() || self.username.len() > 256 {
return Err("Invalid username");
}
if self.password.is_empty() || self.password.len() > 4096 {
return Err("Invalid password");
}
Ok(())
}
}
pub async fn login(
State(state): State<crate::state::ServerState>,
Json(req): Json<LoginRequest>,
) -> Result<Json<LoginResult>, StatusCode> {
) -> Result<Json<LoginResult>, axum::http::StatusCode> {
req.validate()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let client = Client::builder()
.homeserver_url(&req.homeserver)
.build()
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
client
.matrix_auth()
.login_username(&req.username, &req.password)
.send()
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;
.map_err(|_| axum::http::StatusCode::UNAUTHORIZED)?;
let user_id = client
.user_id()
.map(|u| u.to_string())
.unwrap_or_default();
let user_id = client.user_id().map(|u| u.to_string()).unwrap_or_default();
let token = uuid::Uuid::new_v4().to_string();
let mut s = state.write().await;
s.sessions.insert(token.clone(), crate::state::Session {
client,
user_id: user_id.clone(),
voice_manager: VoiceManager::new(),
let matrix_session = client
.matrix_auth()
.session()
.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mut session = Session::new(client.clone(), user_id.clone(), req.homeserver.clone());
{
let s = state.read().await;
if let Some(ttl) = s.session_ttl {
session = session.with_ttl(ttl);
}
}
let sender = session.event_sender.clone();
let sync_client = client.clone();
let handle = tokio::spawn(async move {
start_sync(sync_client, sender).await;
});
session.sync_handle = Some(handle);
let stored = StoredSession {
token: token.clone(),
user_id: user_id.clone(),
homeserver: req.homeserver.clone(),
access_token: matrix_session.tokens.access_token.clone(),
device_id: Some(matrix_session.meta.device_id.to_string()),
refresh_token: matrix_session.tokens.refresh_token.clone(),
};
{
let s = state.read().await;
s.session_store
.save_session(&stored)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
}
let mut s = state.write().await;
s.sessions.insert(token.clone(), session);
Ok(Json(LoginResult {
success: true,
@@ -71,12 +131,14 @@ pub async fn login(
pub async fn register(
State(state): State<crate::state::ServerState>,
Json(req): Json<RegisterRequest>,
) -> Result<Json<LoginResult>, StatusCode> {
) -> Result<Json<LoginResult>, axum::http::StatusCode> {
req.validate()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let client = Client::builder()
.homeserver_url(&req.homeserver)
.build()
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let mut request = matrix_sdk::ruma::api::client::account::register::v3::Request::new();
request.username = Some(req.username);
@@ -86,21 +148,51 @@ pub async fn register(
.matrix_auth()
.register(request)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let user_id = client
.user_id()
.map(|u| u.to_string())
.unwrap_or_default();
let user_id = client.user_id().map(|u| u.to_string()).unwrap_or_default();
let token = uuid::Uuid::new_v4().to_string();
let mut s = state.write().await;
s.sessions.insert(token.clone(), crate::state::Session {
client,
user_id: user_id.clone(),
voice_manager: VoiceManager::new(),
let matrix_session = client
.matrix_auth()
.session()
.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mut session = Session::new(client.clone(), user_id.clone(), req.homeserver.clone());
{
let s = state.read().await;
if let Some(ttl) = s.session_ttl {
session = session.with_ttl(ttl);
}
}
let sender = session.event_sender.clone();
let sync_client = client.clone();
let handle = tokio::spawn(async move {
start_sync(sync_client, sender).await;
});
session.sync_handle = Some(handle);
let stored = StoredSession {
token: token.clone(),
user_id: user_id.clone(),
homeserver: req.homeserver.clone(),
access_token: matrix_session.tokens.access_token.clone(),
device_id: Some(matrix_session.meta.device_id.to_string()),
refresh_token: matrix_session.tokens.refresh_token.clone(),
};
{
let s = state.read().await;
s.session_store
.save_session(&stored)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
}
let mut s = state.write().await;
s.sessions.insert(token.clone(), session);
Ok(Json(LoginResult {
success: true,
@@ -113,22 +205,34 @@ pub async fn register(
pub async fn logout(
State(state): State<crate::state::ServerState>,
headers: HeaderMap,
) -> Result<Json<bool>, StatusCode> {
let token = extract_token(&headers).ok_or(StatusCode::UNAUTHORIZED)?;
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let mut s = state.write().await;
if let Some(session) = s.sessions.remove(&token) {
if let Some(mut session) = s.sessions.remove(&token) {
if let Some(handle) = session.sync_handle.take() {
handle.abort();
}
let _ = session.client.matrix_auth().logout().await;
}
{
let s_read = state.read().await;
s_read
.session_store
.delete_session(&token)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
}
Ok(Json(true))
}
pub async fn get_current_user(
State(state): State<crate::state::ServerState>,
headers: HeaderMap,
) -> Result<Json<Option<String>>, StatusCode> {
let token = extract_token(&headers).ok_or(StatusCode::UNAUTHORIZED)?;
) -> Result<Json<Option<String>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let user_id = s.sessions.get(&token).map(|s| s.user_id.clone());
@@ -147,14 +251,377 @@ pub async fn auth_middleware(
headers: HeaderMap,
State(state): State<crate::state::ServerState>,
request: axum::extract::Request,
next: middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
let token = extract_token(&headers).ok_or(StatusCode::UNAUTHORIZED)?;
next: axum::middleware::Next,
) -> Result<axum::response::Response, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
if !s.sessions.contains_key(&token) {
return Err(StatusCode::UNAUTHORIZED);
let expired = {
let s = state.read().await;
match s.sessions.get(&token) {
Some(session) => session.is_expired(),
None => return Err(axum::http::StatusCode::UNAUTHORIZED),
}
};
if expired {
let mut s = state.write().await;
if let Some(mut session) = s.sessions.remove(&token) {
if let Some(handle) = session.sync_handle.take() {
handle.abort();
}
}
drop(s);
{
let s = state.read().await;
let _ = s.session_store.delete_session(&token).await;
}
return Err(axum::http::StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}
}
pub async fn start_sync(client: Client, sender: tokio::sync::broadcast::Sender<WsEvent>) {
let mut sync_token: Option<String> = None;
loop {
let mut settings = SyncSettings::new();
if let Some(token) = sync_token.as_ref() {
settings = settings.token(token.clone());
}
match client.sync_once(settings).await {
Ok(response) => {
sync_token = Some(response.next_batch);
for (room_id, joined) in &response.rooms.join {
let room = match client.get_room(room_id) {
Some(r) => r,
None => continue,
};
let name = match room.display_name().await {
Ok(n) => n.to_string(),
Err(_) => room_id.to_string(),
};
let _ = sender.send(WsEvent::RoomJoined {
room_id: room_id.to_string(),
name,
});
for event in &joined.timeline.events {
if let Some(event_id) = event.event_id() {
let raw_json: serde_json::Value = match event.event.deserialize_as() {
Ok(v) => v,
Err(_) => continue,
};
let sender_user = raw_json
.get("sender")
.and_then(|v| v.as_str())
.unwrap_or("");
let event_type =
raw_json.get("type").and_then(|v| v.as_str()).unwrap_or("");
if event_type == "m.room.message" {
let content =
raw_json.get("content").unwrap_or(&serde_json::Value::Null);
let new_content = content.get("m.new_content");
if new_content.is_some() {
let relates_to = content.get("m.relates_to");
let original_id = relates_to
.and_then(|r| r.get("event_id"))
.and_then(|v| v.as_str())
.unwrap_or("");
let new_body = new_content
.and_then(|nc| nc.get("body"))
.and_then(|v| v.as_str())
.unwrap_or("");
if !original_id.is_empty() && !new_body.is_empty() {
let _ = sender.send(WsEvent::MessageEdited {
room_id: room_id.to_string(),
event_id: original_id.to_string(),
new_body: new_body.to_string(),
});
}
} else {
let body =
content.get("body").and_then(|v| v.as_str()).unwrap_or("");
let timestamp = raw_json
.get("origin_server_ts")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let msgtype = content
.get("msgtype")
.and_then(|v| v.as_str())
.unwrap_or("m.text");
let media_url = content
.get("url")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
content
.get("file")
.and_then(|f| f.get("url"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
});
let filename = if msgtype == "m.image"
|| msgtype == "m.file"
|| msgtype == "m.video"
|| msgtype == "m.audio"
{
Some(body.to_string())
} else {
None
};
let mimetype = content
.get("info")
.and_then(|i| i.get("mimetype"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let relates_to = content.get("m.relates_to");
let reply_to = relates_to
.and_then(|r| r.get("m.in_reply_to"))
.and_then(|r| r.get("event_id"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let is_thread = relates_to
.and_then(|r| r.get("rel_type"))
.and_then(|v| v.as_str())
== Some("m.thread");
if !body.is_empty() || media_url.is_some() {
if is_thread {
if let Some(root_id) = relates_to
.and_then(|r| r.get("event_id"))
.and_then(|v| v.as_str())
{
let _ = sender.send(WsEvent::ThreadReply {
room_id: room_id.to_string(),
root_event_id: root_id.to_string(),
event_id: event_id.to_string(),
sender: sender_user.to_string(),
body: body.to_string(),
timestamp,
});
}
} else {
let _ = sender.send(WsEvent::Message {
room_id: room_id.to_string(),
event_id: event_id.to_string(),
sender: sender_user.to_string(),
body: body.to_string(),
timestamp,
reply_to,
msgtype: Some(msgtype.to_string()),
media_url: media_url.clone(),
filename,
mimetype,
});
}
}
}
} else if event_type == "m.reaction" {
let content =
raw_json.get("content").unwrap_or(&serde_json::Value::Null);
if let Some(relates_to) = content.get("m.relates_to") {
let target_id = relates_to
.get("event_id")
.and_then(|v| v.as_str())
.unwrap_or("");
let key = relates_to
.get("key")
.and_then(|v| v.as_str())
.unwrap_or("");
if !target_id.is_empty() && !key.is_empty() {
let _ = sender.send(WsEvent::Reaction {
room_id: room_id.to_string(),
event_id: target_id.to_string(),
key: key.to_string(),
sender: sender_user.to_string(),
});
}
}
} else if event_type == "m.room.redaction" {
let redacts = raw_json
.get("redacts")
.and_then(|v| v.as_str())
.unwrap_or("");
if !redacts.is_empty() {
let _ = sender.send(WsEvent::MessageDeleted {
room_id: room_id.to_string(),
redacts: redacts.to_string(),
});
}
}
}
}
for raw_ephemeral in &joined.ephemeral {
let val: serde_json::Value = match raw_ephemeral.deserialize_as() {
Ok(v) => v,
Err(_) => continue,
};
let event_type = val.get("type").and_then(|v| v.as_str()).unwrap_or("");
if event_type == "m.typing" {
let user_ids = val
.get("content")
.and_then(|c| c.get("user_ids"))
.and_then(|v| v.as_array());
if let Some(ids) = user_ids {
for uid in ids {
if let Some(user_str) = uid.as_str() {
let _ = sender.send(WsEvent::Typing {
room_id: room_id.to_string(),
user_id: user_str.to_string(),
typing: true,
});
}
}
}
}
}
}
for room_id in response.rooms.leave.keys() {
let _ = sender.send(WsEvent::RoomLeft {
room_id: room_id.to_string(),
});
}
for raw_presence in &response.presence {
let val: serde_json::Value = match raw_presence.deserialize_as() {
Ok(v) => v,
Err(_) => continue,
};
let user_id = val.get("sender").and_then(|v| v.as_str()).unwrap_or("");
let presence = val
.get("content")
.and_then(|c| c.get("presence"))
.and_then(|v| v.as_str())
.unwrap_or("offline");
let status_msg = val
.get("content")
.and_then(|c| c.get("status_msg"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if !user_id.is_empty() {
let status = match presence {
"online" => "online",
"unavailable" => "idle",
_ => "offline",
};
let _ = sender.send(WsEvent::Presence {
user_id: user_id.to_string(),
status: status.to_string(),
status_msg,
});
}
}
}
Err(e) => {
tracing::error!("Sync error: {}", e);
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn login_request_validate_ok() {
let req = LoginRequest {
homeserver: "https://matrix.example.org".to_string(),
username: "alice".to_string(),
password: "secret123".to_string(),
};
assert!(req.validate().is_ok());
}
#[test]
fn login_request_reject_empty_homeserver() {
let req = LoginRequest {
homeserver: "".to_string(),
username: "alice".to_string(),
password: "secret".to_string(),
};
assert!(req.validate().is_err());
}
#[test]
fn login_request_reject_empty_username() {
let req = LoginRequest {
homeserver: "https://matrix.example.org".to_string(),
username: "".to_string(),
password: "secret".to_string(),
};
assert!(req.validate().is_err());
}
#[test]
fn login_request_reject_empty_password() {
let req = LoginRequest {
homeserver: "https://matrix.example.org".to_string(),
username: "alice".to_string(),
password: "".to_string(),
};
assert!(req.validate().is_err());
}
#[test]
fn login_request_reject_long_username() {
let req = LoginRequest {
homeserver: "https://matrix.example.org".to_string(),
username: "a".repeat(300),
password: "secret".to_string(),
};
assert!(req.validate().is_err());
}
#[test]
fn register_request_validate_ok() {
let req = RegisterRequest {
homeserver: "https://matrix.example.org".to_string(),
username: "bob".to_string(),
password: "password".to_string(),
};
assert!(req.validate().is_ok());
}
#[test]
fn register_request_reject_empty_fields() {
let req = RegisterRequest {
homeserver: "".to_string(),
username: "".to_string(),
password: "".to_string(),
};
assert!(req.validate().is_err());
}
#[test]
fn extract_token_from_bearer() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Bearer my-token-123".parse().unwrap());
assert_eq!(extract_token(&headers), Some("my-token-123".to_string()));
}
#[test]
fn extract_token_missing_header() {
let headers = HeaderMap::new();
assert_eq!(extract_token(&headers), None);
}
#[test]
fn extract_token_invalid_scheme() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Basic dXNlcjpwYXNz".parse().unwrap());
assert_eq!(extract_token(&headers), None);
}
}

View File

@@ -1,13 +1,13 @@
use super::auth::extract_token;
use crate::ServerState;
use axum::{
extract::{Path, State},
http::HeaderMap,
Json,
};
use serde::{Deserialize, Serialize};
use crate::ServerState;
use super::auth::extract_token;
#[derive(Serialize)]
#[derive(Serialize, Clone)]
pub struct CustomEmoji {
pub id: String,
pub name: String,
@@ -33,10 +33,69 @@ pub struct Sticker {
pub async fn get_custom_emoji(
State(state): State<ServerState>,
headers: HeaderMap,
Path(_room_id): Path<String>,
Path(room_id): Path<String>,
) -> Result<Json<Vec<CustomEmoji>>, axum::http::StatusCode> {
let _token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
Ok(Json(Vec::new()))
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let mut emojis = Vec::new();
let event_type: matrix_sdk::ruma::events::StateEventType = "im.ponies.room_emotes".into();
if let Ok(Some(raw_event)) = room.get_state_event(event_type, "").await {
let raw_json_str = match &raw_event {
matrix_sdk::deserialized_responses::RawAnySyncOrStrippedState::Sync(s) => {
s.json().get().to_string()
}
matrix_sdk::deserialized_responses::RawAnySyncOrStrippedState::Stripped(s) => {
s.json().get().to_string()
}
};
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&raw_json_str) {
if let Some(content) = value.get("content") {
if let Some(images) = content.get("images").and_then(|v| v.as_object()) {
for (key, entry) in images {
let url = entry
.get("url")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if url.is_empty() {
continue;
}
let name = entry
.get("body")
.and_then(|v| v.as_str())
.unwrap_or(key)
.to_string();
let animated = url.contains(".gif");
emojis.push(CustomEmoji {
id: key.clone(),
name,
url,
category: "custom".to_string(),
animated,
});
}
}
}
}
}
Ok(Json(emojis))
}
#[derive(Deserialize)]
@@ -53,34 +112,50 @@ pub async fn upload_emoji(
) -> Result<Json<CustomEmoji>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let _room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let _room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let path = std::path::Path::new(&req.image_path);
if !path.exists() {
return Err(axum::http::StatusCode::BAD_REQUEST);
}
let mime_type = match path.extension().and_then(|e| e.to_str()) {
let mime_type: mime::Mime = match path.extension().and_then(|e| e.to_str()) {
Some("png") => "image/png",
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("gif") => "image/gif",
Some("webp") => "image/webp",
_ => "image/png",
};
}
.parse()
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let data = std::fs::read(path).map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let content_type = mime_type.parse::<matrix_sdk::ruma::mime::Mime>().map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let is_animated = mime_type.subtype() == mime::IMAGE_GIF.subtype();
let response = session.client.media().upload(&content_type, data).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let response = session
.client
.media()
.upload(&mime_type, data)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(CustomEmoji {
id: format!("emoji_{}", chrono::Utc::now().timestamp()),
name: req.name,
url: response.content_uri.to_string(),
category: "custom".to_string(),
animated: mime_type == "image/gif",
animated: is_animated,
}))
}
}

View File

@@ -0,0 +1,47 @@
use super::auth::extract_token;
use crate::ServerState;
use axum::{
body::Body,
extract::{Path, State},
http::HeaderMap,
response::Response,
};
use matrix_sdk::ruma::OwnedMxcUri;
pub async fn get_media(
State(state): State<ServerState>,
headers: HeaderMap,
Path(mxc_path): Path<String>,
) -> Result<Response<Body>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let mxc_uri: OwnedMxcUri = mxc_path.as_str().into();
let request =
matrix_sdk::ruma::api::client::media::get_content::v3::Request::from_url(&mxc_uri)
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let response = session
.client
.send(request, None)
.await
.map_err(|_| axum::http::StatusCode::NOT_FOUND)?;
let content_type = response
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let body = response.file;
Ok(Response::builder()
.status(200)
.header("Content-Type", content_type)
.header("Cache-Control", "public, max-age=86400")
.body(Body::from(body))
.unwrap())
}

View File

@@ -1,11 +1,13 @@
use super::auth::extract_token;
use crate::routes::metrics::MESSAGES_SENT_TOTAL;
use crate::ServerState;
use axum::{
extract::{Path, State, Query},
extract::{Path, Query, State},
http::HeaderMap,
Json,
};
use serde::{Serialize, Deserialize};
use crate::ServerState;
use super::auth::extract_token;
use matrix_sdk::room::MessagesOptions;
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
pub struct MessageInfo {
@@ -14,6 +16,15 @@ pub struct MessageInfo {
pub body: String,
pub timestamp: u64,
pub reply_to: Option<String>,
pub edited: bool,
pub reactions: std::collections::HashMap<String, Vec<String>>,
pub msgtype: String,
pub media_url: Option<String>,
pub filename: Option<String>,
pub mimetype: Option<String>,
pub width: Option<u64>,
pub height: Option<u64>,
pub size: Option<u64>,
}
#[derive(Deserialize)]
@@ -30,29 +41,196 @@ pub async fn get_room_messages(
) -> Result<Json<Vec<MessageInfo>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let limit = query.limit.unwrap_or(50);
let mut options = matrix_sdk::ruma::api::client::message::get_message_events::v3::Request::new();
options.limit = limit.into();
options.from = query.from.map(|t| t.into());
let mut options = MessagesOptions::backward();
options.limit = matrix_sdk::ruma::uint!(50);
if let Some(from) = query.from {
options.from = Some(from);
}
let messages = room.messages(options).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let messages = room
.messages(options)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mut reactions_map: std::collections::HashMap<
String,
std::collections::HashMap<String, Vec<String>>,
> = std::collections::HashMap::new();
let mut result = Vec::new();
for msg in messages.chunk {
if let matrix_sdk::ruma::events::AnySyncMessageLikeEvent::RoomMessage(ev) = msg {
result.push(MessageInfo {
event_id: ev.event_id().to_string(),
sender: ev.sender().to_string(),
body: ev.content().body().to_string(),
timestamp: ev.origin_server_ts().0,
reply_to: ev.content().in_reply_to().map(|r| r.event_id.to_string()),
});
for msg in &messages.chunk {
let event_value: serde_json::Value = match msg.event.deserialize_as() {
Ok(v) => v,
Err(_) => continue,
};
let event_type = event_value
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("");
if event_type == "m.reaction" {
if let Some(content) = event_value.get("content") {
if let Some(relates_to) = content.get("m.relates_to") {
if let (Some(target_id), Some(key)) = (
relates_to.get("event_id").and_then(|v| v.as_str()),
relates_to.get("key").and_then(|v| v.as_str()),
) {
let sender = event_value
.get("sender")
.and_then(|v| v.as_str())
.unwrap_or("");
reactions_map
.entry(target_id.to_string())
.or_default()
.entry(key.to_string())
.or_default()
.push(sender.to_string());
}
}
}
continue;
}
if event_type == "m.room.redaction" {
continue;
}
let event_id = match event_value.get("event_id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => continue,
};
let sender = match event_value.get("sender").and_then(|v| v.as_str()) {
Some(s) => s.to_string(),
None => continue,
};
let timestamp = event_value
.get("origin_server_ts")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let content = event_value
.get("content")
.unwrap_or(&serde_json::Value::Null);
let msgtype = content
.get("msgtype")
.and_then(|v| v.as_str())
.unwrap_or("");
if msgtype != "m.text"
&& msgtype != "m.notice"
&& msgtype != "m.emote"
&& msgtype != "m.image"
&& msgtype != "m.file"
&& msgtype != "m.video"
&& msgtype != "m.audio"
{
continue;
}
let body = content
.get("body")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let new_content = content.get("m.new_content");
let (display_body, edited) = if let Some(nc) = new_content {
let new_body = nc.get("body").and_then(|v| v.as_str()).unwrap_or(&body);
(new_body.to_string(), true)
} else {
(body.clone(), false)
};
if body.is_empty() && (msgtype == "m.text" || msgtype == "m.notice" || msgtype == "m.emote")
{
continue;
}
let reply_to = content
.get("m.relates_to")
.and_then(|r| r.get("m.in_reply_to"))
.and_then(|r| r.get("event_id"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let reactions = reactions_map.remove(&event_id).unwrap_or_default();
let media_url = content
.get("url")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
content
.get("file")
.and_then(|f| f.get("url"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
});
let filename = if msgtype == "m.image"
|| msgtype == "m.file"
|| msgtype == "m.video"
|| msgtype == "m.audio"
{
Some(body.clone())
} else {
None
};
let mimetype = content
.get("info")
.and_then(|i| i.get("mimetype"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let width = content
.get("info")
.and_then(|i| i.get("w"))
.and_then(|v| v.as_u64());
let height = content
.get("info")
.and_then(|i| i.get("h"))
.and_then(|v| v.as_u64());
let size = content
.get("info")
.and_then(|i| i.get("size"))
.and_then(|v| v.as_u64());
result.push(MessageInfo {
event_id,
sender,
body: display_body,
timestamp,
reply_to,
edited,
reactions,
msgtype: msgtype.to_string(),
media_url,
filename,
mimetype,
width,
height,
size,
});
}
Ok(Json(result))
@@ -71,14 +249,192 @@ pub async fn send_message(
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let content = matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
let txn_id = matrix_sdk::ruma::TransactionId::new();
let response = room.send(content, Some(&txn_id)).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let content =
matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
let response = room
.send(content)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
MESSAGES_SENT_TOTAL.inc();
Ok(Json(response.event_id.to_string()))
}
#[derive(Deserialize)]
pub struct EditMessageRequest {
pub event_id: String,
pub new_content: String,
}
pub async fn edit_message(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<EditMessageRequest>,
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let original_event_id: matrix_sdk::ruma::OwnedEventId = req
.event_id
.parse()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let content = matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(
&req.new_content,
)
.make_replacement(
matrix_sdk::ruma::events::room::message::ReplacementMetadata::new(original_event_id, None),
None,
);
let response = room
.send(content)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.event_id.to_string()))
}
#[derive(Deserialize)]
pub struct DeleteMessageRequest {
pub reason: Option<String>,
}
pub async fn delete_message(
State(state): State<ServerState>,
headers: HeaderMap,
Path((room_id, event_id)): Path<(String, String)>,
Json(req): Json<DeleteMessageRequest>,
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let eid: matrix_sdk::ruma::OwnedEventId = event_id
.parse()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let response = room
.redact(&eid, req.reason.as_deref(), None)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.event_id.to_string()))
}
}
#[derive(Deserialize)]
pub struct ReactRequest {
pub event_id: String,
pub key: String,
}
pub async fn react_to_message(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<ReactRequest>,
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let event_id: matrix_sdk::ruma::OwnedEventId = req
.event_id
.parse()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let content = matrix_sdk::ruma::events::reaction::ReactionEventContent::new(
matrix_sdk::ruma::events::relation::Annotation::new(event_id, req.key),
);
let response = room
.send(content)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.event_id.to_string()))
}
#[derive(Deserialize)]
pub struct TypingRequest {
pub typing: bool,
}
pub async fn set_typing(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<TypingRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
room.typing_notice(req.typing)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}

View File

@@ -0,0 +1,82 @@
use axum::{body::Body, extract::State, http::HeaderMap, response::IntoResponse};
use lazy_static::lazy_static;
use prometheus::{Encoder, IntCounter, IntGauge, Registry, TextEncoder};
lazy_static! {
pub static ref REGISTRY: Registry = Registry::new();
pub static ref HTTP_REQUESTS_TOTAL: IntCounter =
IntCounter::new("eifeldc_http_requests_total", "Total HTTP requests")
.expect("metric can be created");
pub static ref ACTIVE_SESSIONS: IntGauge =
IntGauge::new("eifeldc_active_sessions", "Number of active sessions")
.expect("metric can be created");
pub static ref ACTIVE_WEBSOCKETS: IntGauge = IntGauge::new(
"eifeldc_active_websockets",
"Number of active WebSocket connections"
)
.expect("metric can be created");
pub static ref MESSAGES_SENT_TOTAL: IntCounter =
IntCounter::new("eifeldc_messages_sent_total", "Total messages sent")
.expect("metric can be created");
pub static ref ROOMS_JOINED_TOTAL: IntCounter =
IntCounter::new("eifeldc_rooms_joined_total", "Total room join operations")
.expect("metric can be created");
pub static ref UPLOADS_TOTAL: IntCounter =
IntCounter::new("eifeldc_uploads_total", "Total file uploads")
.expect("metric can be created");
pub static ref VOICE_PARTICIPANTS: IntGauge = IntGauge::new(
"eifeldc_voice_participants",
"Current voice channel participants"
)
.expect("metric can be created");
}
pub fn init_metrics() {
REGISTRY
.register(Box::new(HTTP_REQUESTS_TOTAL.clone()))
.expect("metric can be registered");
REGISTRY
.register(Box::new(ACTIVE_SESSIONS.clone()))
.expect("metric can be registered");
REGISTRY
.register(Box::new(ACTIVE_WEBSOCKETS.clone()))
.expect("metric can be registered");
REGISTRY
.register(Box::new(MESSAGES_SENT_TOTAL.clone()))
.expect("metric can be registered");
REGISTRY
.register(Box::new(ROOMS_JOINED_TOTAL.clone()))
.expect("metric can be registered");
REGISTRY
.register(Box::new(UPLOADS_TOTAL.clone()))
.expect("metric can be registered");
REGISTRY
.register(Box::new(VOICE_PARTICIPANTS.clone()))
.expect("metric can be registered");
}
pub async fn get_metrics(
State(state): State<crate::ServerState>,
_headers: HeaderMap,
) -> impl IntoResponse {
ACTIVE_SESSIONS.set(state.read().await.sessions.len() as i64);
let mut participants: i64 = 0;
let s = state.read().await;
for room in s.voice_rooms.rooms.values() {
participants += room.participants.len() as i64;
}
VOICE_PARTICIPANTS.set(participants);
drop(s);
let metric_families = REGISTRY.gather();
let mut buffer = Vec::new();
let encoder = TextEncoder::new();
encoder.encode(&metric_families, &mut buffer).unwrap();
axum::http::Response::builder()
.status(200)
.header("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
.body(Body::from(buffer))
.unwrap()
}

View File

@@ -1,37 +1,79 @@
pub mod auth;
pub mod rooms;
pub mod messages;
pub mod presence;
pub mod voice;
pub mod emoji;
pub mod media;
pub mod messages;
pub mod metrics;
pub mod presence;
pub mod profile;
pub mod rate_limit;
pub mod roles;
pub mod rooms;
pub mod threads;
pub mod upload;
pub mod voice;
pub mod ws;
use axum::{
Router,
routing::{get, post},
middleware,
http::HeaderMap,
};
use tower_http::cors::{CorsLayer, Any};
use crate::ServerState;
use crate::routes::auth::auth_middleware;
use crate::routes::rate_limit::{RateLimitConfig, RateLimiter};
use crate::ServerState;
use axum::{
routing::{get, post},
Router,
};
use std::time::Duration;
use tower_http::cors::{AllowHeaders, AllowMethods, CorsLayer};
pub fn api_router(state: ServerState) -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let allowed_origins = std::env::var("EIFELDC_CORS_ORIGINS").unwrap_or_else(|_| "*".to_string());
let cors = if allowed_origins == "*" {
CorsLayer::permissive()
} else {
let origins: Vec<_> = allowed_origins
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods(AllowMethods::any())
.allow_headers(AllowHeaders::any())
.max_age(Duration::from_secs(3600))
};
let max_rpm = std::env::var("EIFELDC_RATE_LIMIT")
.ok()
.map(|v| v.parse().unwrap_or(60))
.unwrap_or(60);
let rate_limiter = RateLimiter::new(RateLimitConfig {
max_requests: max_rpm,
window_secs: 60,
});
Router::new()
.nest("/api", api_routes(state))
.layer(cors)
.layer(axum::middleware::from_fn_with_state(
rate_limiter,
rate_limit::rate_limit_middleware,
))
}
fn api_routes(state: ServerState) -> Router {
let max_upload_bytes: usize = std::env::var("EIFELDC_MAX_UPLOAD_MB")
.ok()
.map(|v| v.parse().unwrap_or(50))
.unwrap_or(50)
* 1024
* 1024;
let default_body_limit = axum::extract::DefaultBodyLimit::max(max_upload_bytes);
let public = Router::new()
.route("/login", post(auth::login))
.route("/register", post(auth::register))
.route("/current-user", get(auth::get_current_user));
.route("/current-user", get(auth::get_current_user))
.route("/ws", get(ws::ws_handler))
.route("/metrics", get(metrics::get_metrics));
let protected = Router::new()
.route("/logout", post(auth::logout))
@@ -40,24 +82,63 @@ fn api_routes(state: ServerState) -> Router {
.route("/rooms/join", post(rooms::join_room))
.route("/rooms/{room_id}/leave", post(rooms::leave_room))
.route("/rooms/{room_id}/members", get(rooms::get_room_members))
.route("/rooms/{room_id}/messages", get(messages::get_room_messages))
.route("/rooms/{room_id}/name", post(rooms::set_room_name))
.route("/rooms/{room_id}/topic", post(rooms::set_room_topic))
.route("/rooms/{room_id}/avatar", post(rooms::set_room_avatar))
.route(
"/rooms/{room_id}/messages",
get(messages::get_room_messages),
)
.route("/rooms/{room_id}/send", post(messages::send_message))
.route("/rooms/{room_id}/edit", post(messages::edit_message))
.route(
"/rooms/{room_id}/delete/{event_id}",
post(messages::delete_message),
)
.route("/rooms/{room_id}/react", post(messages::react_to_message))
.route("/rooms/{room_id}/typing", post(messages::set_typing))
.route("/presence/set", post(presence::set_presence))
.route("/presence/{user_id}", get(presence::get_presence))
.route("/voice/join", post(voice::join_voice_channel))
.route("/voice/leave", post(voice::leave_voice_channel))
.route("/voice/toggle-mute", post(voice::toggle_mute))
.route("/voice/toggle-deafen", post(voice::toggle_deafen))
.route("/voice/participants", get(voice::get_voice_participants))
.route("/rooms/{room_id}/roles", get(roles::get_roles))
.route("/rooms/{room_id}/roles/assign", post(roles::assign_role))
.route("/rooms/{room_id}/roles/remove", post(roles::remove_role))
.route("/rooms/{room_id}/permissions/{user_id}", get(roles::get_permissions))
.route(
"/rooms/{room_id}/permissions/{user_id}",
get(roles::get_permissions),
)
.route("/rooms/{room_id}/emoji", get(emoji::get_custom_emoji))
.route("/rooms/{room_id}/emoji/upload", post(emoji::upload_emoji))
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
.route("/rooms/{room_id}/threads", get(threads::get_threads))
.route(
"/rooms/{room_id}/threads/{thread_id}",
get(threads::get_thread_messages),
)
.route(
"/rooms/{room_id}/threads/{thread_id}/reply",
post(threads::send_thread_reply),
)
.route("/rooms/{room_id}/reply", post(threads::send_reply))
.route("/rooms/{room_id}/upload", post(upload::upload_file))
.route("/rooms/unread", get(rooms::get_unread_counts))
.route("/rooms/{room_id}/read", post(rooms::mark_room_read))
.route("/media/{mxc_path}", get(media::get_media))
.route("/profile/me", get(profile::get_own_profile))
.route("/profile/{user_id}", get(profile::get_profile))
.route("/profile/displayname", post(profile::set_display_name))
.route("/profile/avatar", post(profile::upload_avatar))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
auth_middleware,
));
Router::new()
.merge(public)
.merge(protected)
.with_state(state)
}
.layer(default_body_limit)
}

View File

@@ -1,11 +1,11 @@
use super::auth::extract_token;
use crate::ServerState;
use axum::{
extract::{Path, State},
http::HeaderMap,
Json,
};
use serde::{Deserialize, Serialize};
use crate::ServerState;
use super::auth::extract_token;
#[derive(Deserialize)]
pub struct SetPresenceRequest {
@@ -28,21 +28,31 @@ pub async fn set_presence(
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let presence_state = match req.status.as_str() {
"online" => matrix_sdk::ruma::presence::PresenceState::Online,
"away" => matrix_sdk::ruma::presence::PresenceState::Away,
"unavailable" => matrix_sdk::ruma::presence::PresenceState::Unavailable,
_ => matrix_sdk::ruma::presence::PresenceState::Online,
};
let user_id = session.client.user_id().ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mut request = matrix_sdk::ruma::api::client::presence::set_presence::v3::Request::new(user_id.to_owned());
request.presence = presence_state;
request.status_msg = req.status_msg;
let user_id = session
.client
.user_id()
.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let request = matrix_sdk::ruma::api::client::presence::set_presence::v3::Request::new(
user_id.to_owned(),
presence_state,
);
session.client.send(request, None).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
session
.client
.send(request, None)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
@@ -54,16 +64,26 @@ pub async fn get_presence(
) -> Result<Json<PresenceInfo>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let uid: matrix_sdk::ruma::OwnedUserId = user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let request = matrix_sdk::ruma::api::client::presence::get_presence::v3::Request::new(uid.to_owned());
let uid: matrix_sdk::ruma::OwnedUserId = user_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let request =
matrix_sdk::ruma::api::client::presence::get_presence::v3::Request::new(uid.to_owned());
let response = session.client.send(request, None).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let response = session
.client
.send(request, None)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let status_str = match response.presence {
matrix_sdk::ruma::presence::PresenceState::Online => "online",
matrix_sdk::ruma::presence::PresenceState::Away => "away",
matrix_sdk::ruma::presence::PresenceState::Unavailable => "unavailable",
_ => "offline",
};
@@ -74,4 +94,4 @@ pub async fn get_presence(
status_msg: response.status_msg,
last_active: response.last_active_ago.map(|d| d.as_secs()),
}))
}
}

View File

@@ -0,0 +1,156 @@
use super::auth::extract_token;
use crate::ServerState;
use axum::{
extract::{Multipart, Path, State},
http::HeaderMap,
Json,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
pub struct UserProfile {
pub user_id: String,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
}
pub async fn get_profile(
State(state): State<ServerState>,
headers: HeaderMap,
Path(user_id): Path<String>,
) -> Result<Json<UserProfile>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let uid: matrix_sdk::ruma::OwnedUserId = user_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let profile = session
.client
.get_profile(&uid)
.await
.map_err(|_| axum::http::StatusCode::NOT_FOUND)?;
Ok(Json(UserProfile {
user_id,
display_name: profile.displayname,
avatar_url: profile.avatar_url.map(|u| u.to_string()),
}))
}
pub async fn get_own_profile(
State(state): State<ServerState>,
headers: HeaderMap,
) -> Result<Json<UserProfile>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let user_id = session.user_id.clone();
let uid: matrix_sdk::ruma::OwnedUserId = user_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let profile = session
.client
.get_profile(&uid)
.await
.map_err(|_| axum::http::StatusCode::NOT_FOUND)?;
Ok(Json(UserProfile {
user_id,
display_name: profile.displayname,
avatar_url: profile.avatar_url.map(|u| u.to_string()),
}))
}
#[derive(Deserialize)]
pub struct SetDisplayNameRequest {
pub display_name: String,
}
pub async fn set_display_name(
State(state): State<ServerState>,
headers: HeaderMap,
Json(req): Json<SetDisplayNameRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
session
.client
.account()
.set_display_name(Some(&req.display_name))
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
pub async fn upload_avatar(
State(state): State<ServerState>,
headers: HeaderMap,
mut multipart: Multipart,
) -> Result<Json<UserProfile>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let field = multipart
.next_field()
.await
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
let content_type = field
.content_type()
.unwrap_or("application/octet-stream")
.to_string();
let data = field
.bytes()
.await
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let mime_type: mime::Mime = content_type.parse().unwrap_or(mime::IMAGE_PNG);
let upload = session
.client
.media()
.upload(&mime_type, data.to_vec())
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mxc_uri = upload.content_uri;
session
.client
.account()
.set_avatar_url(Some(&mxc_uri))
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let user_id = session.user_id.clone();
Ok(Json(UserProfile {
user_id,
display_name: None,
avatar_url: Some(mxc_uri.to_string()),
}))
}

View File

@@ -0,0 +1,155 @@
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
struct RateLimitEntry {
count: u32,
window_start: Instant,
}
#[derive(Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window_secs: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 60,
window_secs: 60,
}
}
}
#[derive(Clone)]
pub struct RateLimiter {
entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
config: RateLimitConfig,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub async fn check(&self, key: &str) -> bool {
let mut entries = self.entries.write().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(self.config.window_secs);
match entries.get_mut(key) {
Some(entry) => {
if now.duration_since(entry.window_start) > window {
entry.count = 1;
entry.window_start = now;
true
} else if entry.count < self.config.max_requests {
entry.count += 1;
true
} else {
false
}
}
None => {
entries.insert(
key.to_string(),
RateLimitEntry {
count: 1,
window_start: now,
},
);
true
}
}
}
}
pub async fn rate_limit_middleware(
State(limiter): State<RateLimiter>,
headers: axum::http::HeaderMap,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let key = headers
.get("x-forwarded-for")
.or_else(|| headers.get("x-real-ip"))
.or_else(|| headers.get("authorization"))
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
if limiter.check(&key).await {
Ok(next.run(request).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn rate_limit_allows_within_limit() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 5,
window_secs: 60,
});
for _ in 0..5 {
assert!(limiter.check("user1").await);
}
}
#[tokio::test]
async fn rate_limit_blocks_over_limit() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 3,
window_secs: 60,
});
assert!(limiter.check("user1").await);
assert!(limiter.check("user1").await);
assert!(limiter.check("user1").await);
assert!(!limiter.check("user1").await);
}
#[tokio::test]
async fn rate_limit_per_key() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 2,
window_secs: 60,
});
assert!(limiter.check("user1").await);
assert!(limiter.check("user1").await);
assert!(!limiter.check("user1").await);
assert!(limiter.check("user2").await);
}
#[tokio::test]
async fn rate_limit_default_config() {
let config = RateLimitConfig::default();
assert_eq!(config.max_requests, 60);
assert_eq!(config.window_secs, 60);
}
#[tokio::test]
async fn rate_limit_new_key_always_allowed() {
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 1,
window_secs: 60,
});
assert!(limiter.check("new_user").await);
assert!(!limiter.check("new_user").await);
}
}

View File

@@ -1,11 +1,12 @@
use super::auth::extract_token;
use crate::ServerState;
use axum::{
extract::{Path, State},
http::HeaderMap,
Json,
};
use matrix_sdk::ruma::Int;
use serde::{Deserialize, Serialize};
use crate::ServerState;
use super::auth::extract_token;
#[derive(Serialize, Deserialize, Clone)]
pub struct Role {
@@ -52,22 +53,65 @@ pub async fn get_roles(
) -> Result<Json<Vec<Role>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let mut roles = Vec::new();
if let Ok(power_levels) = room.power_levels().await {
for (uid, power_level) in &power_levels.users {
roles.push(Role {
id: format!("role_{}", uid),
name: uid.to_string(),
color: if *power_level >= 100 { "#ed4245".to_string() } else if *power_level >= 50 { "#fee75c".to_string() } else { "#5865f2".to_string() },
permissions: vec![],
position: *power_level as i32,
});
}
let members = room
.members(matrix_sdk::RoomMemberships::JOIN)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let power_levels_content: Option<
matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent,
> = room
.get_state_event_static()
.await
.ok()
.flatten()
.and_then(|e| e.deserialize().ok())
.and_then(|e| e.as_sync().cloned())
.and_then(|e| e.as_original().cloned())
.map(|e| e.content);
for member in members {
let pl: i64 = if let Some(ref pl_content) = power_levels_content {
pl_content
.users
.get(member.user_id())
.map_or(0, |v| i64::from(*v))
} else {
0
};
let color = if pl >= 100 {
"#ed4245".to_string()
} else if pl >= 50 {
"#fee75c".to_string()
} else {
"#5865f2".to_string()
};
let display_name = member
.display_name()
.map(|s| s.to_string())
.unwrap_or_else(|| member.user_id().to_string());
roles.push(Role {
id: format!("role_{}", member.user_id()),
name: display_name,
color,
permissions: vec![],
position: pl as i32,
});
}
Ok(Json(roles))
@@ -87,22 +131,34 @@ pub async fn assign_role(
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let uid: matrix_sdk::ruma::OwnedUserId = req.user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let uid: matrix_sdk::ruma::OwnedUserId = req
.user_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let power_level: i64 = match req.role_id.as_str() {
let power_level: i32 = match req.role_id.as_str() {
"admin" => 100,
"moderator" => 50,
_ => 0,
};
let mut content = matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent::new();
content.users.insert(uid.to_owned(), power_level.into());
room.send_state_event(content).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
room.update_power_levels(vec![(uid.as_ref(), Int::from(power_level))])
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
@@ -121,17 +177,28 @@ pub async fn remove_role(
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let uid: matrix_sdk::ruma::OwnedUserId = req.user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let uid: matrix_sdk::ruma::OwnedUserId = req
.user_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
if let Ok(mut power_levels) = room.power_levels().await {
power_levels.users.remove(&uid);
let content = matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent::from(power_levels);
room.send_state_event(content).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
}
room.update_power_levels(vec![(uid.as_ref(), Int::from(0))])
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
@@ -143,17 +210,41 @@ pub async fn get_permissions(
) -> Result<Json<Permissions>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let uid: matrix_sdk::ruma::OwnedUserId = user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let uid: matrix_sdk::ruma::OwnedUserId = user_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let user_power = if let Ok(power_levels) = room.power_levels().await {
power_levels.users.get(&uid).copied().map(|p| p.into()).unwrap_or(power_levels.users_default as i64)
let power_levels_content: Option<
matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent,
> = room
.get_state_event_static()
.await
.ok()
.flatten()
.and_then(|e| e.deserialize().ok())
.and_then(|e| e.as_sync().cloned())
.and_then(|e| e.as_original().cloned())
.map(|e| e.content);
let user_power = if let Some(pl) = power_levels_content {
pl.users.get(&uid).map_or(0, |v| i64::from(*v))
} else {
0
};
Ok(Json(power_level_to_permissions(user_power)))
}
}

View File

@@ -1,11 +1,15 @@
use super::auth::extract_token;
use crate::routes::metrics::ROOMS_JOINED_TOTAL;
use crate::ServerState;
use axum::{
extract::{Path, State},
extract::{Multipart, Path, State},
http::HeaderMap,
Json,
};
use serde::Serialize;
use crate::ServerState;
use super::auth::extract_token;
use matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType;
use matrix_sdk::ruma::events::receipt::ReceiptThread;
use matrix_sdk::RoomMemberships;
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
pub struct RoomInfo {
@@ -15,6 +19,8 @@ pub struct RoomInfo {
pub is_encrypted: bool,
pub member_count: u64,
pub topic: Option<String>,
pub unread_notifications: u64,
pub unread_messages: u64,
}
pub async fn get_joined_rooms(
@@ -23,16 +29,29 @@ pub async fn get_joined_rooms(
) -> Result<Json<Vec<RoomInfo>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rooms = session.client.joined_rooms();
let mut result = Vec::new();
for room in rooms {
let name = room.display_name().await.map(|n| n.to_string()).unwrap_or_default();
let name = room
.display_name()
.await
.map(|n| n.to_string())
.unwrap_or_default();
let avatar_url = room.avatar_url().map(|u| u.to_string());
let member_count = room.joined_members().len() as u64;
let member_count = room
.members(RoomMemberships::JOIN)
.await
.map(|m| m.len() as u64)
.unwrap_or(0);
let topic = room.topic().map(|t| t.to_string());
let is_encrypted = room.is_encrypted().await.unwrap_or(false);
let unread_notifications = room.unread_notification_counts().notification_count;
let unread_messages = room.num_unread_messages();
result.push(RoomInfo {
room_id: room.room_id().to_string(),
name,
@@ -40,6 +59,8 @@ pub async fn get_joined_rooms(
is_encrypted,
member_count,
topic,
unread_notifications,
unread_messages,
});
}
@@ -60,11 +81,14 @@ pub async fn create_room(
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let vis = match req.visibility.as_deref() {
Some("public") => matrix_sdk::ruma::Space::Public,
_ => matrix_sdk::ruma::Space::Private,
Some("public") => matrix_sdk::ruma::api::client::room::Visibility::Public,
_ => matrix_sdk::ruma::api::client::room::Visibility::Private,
};
let mut request = matrix_sdk::ruma::api::client::room::create_room::v3::Request::new();
@@ -72,8 +96,12 @@ pub async fn create_room(
request.topic = req.topic;
request.visibility = vis;
let response = session.client.create_room(request).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.room_id.to_string()))
let response = session
.client
.create_room(request)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.room_id().to_string()))
}
#[derive(serde::Deserialize)]
@@ -88,17 +116,24 @@ pub async fn join_room(
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let room_id = session.client
.join_room_by_id_or_alias(
&req.room_id_or_alias.try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?,
&[],
)
let alias: matrix_sdk::ruma::OwnedRoomOrAliasId = req
.room_id_or_alias
.parse()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.join_room_by_id_or_alias(&alias, &[])
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(room_id.to_string()))
ROOMS_JOINED_TOTAL.inc();
Ok(Json(room.room_id().to_string()))
}
pub async fn leave_room(
@@ -108,11 +143,22 @@ pub async fn leave_room(
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
room.leave().await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
room.leave()
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
@@ -124,11 +170,219 @@ pub async fn get_room_members(
) -> Result<Json<Vec<String>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let members = room.joined_members();
Ok(Json(members.iter().map(|m| m.user_id().to_string()).collect()))
}
let members = room
.members(RoomMemberships::JOIN)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(
members.iter().map(|m| m.user_id().to_string()).collect(),
))
}
#[derive(Deserialize)]
pub struct SetRoomNameRequest {
pub name: String,
}
pub async fn set_room_name(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<SetRoomNameRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
room.set_name(req.name)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
#[derive(Deserialize)]
pub struct SetRoomTopicRequest {
pub topic: String,
}
pub async fn set_room_topic(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<SetRoomTopicRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
room.set_room_topic(&req.topic)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
pub async fn set_room_avatar(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
mut multipart: Multipart,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let field = multipart
.next_field()
.await
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
let content_type = field
.content_type()
.unwrap_or("application/octet-stream")
.to_string();
let data = field
.bytes()
.await
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let mime_type: mime::Mime = content_type.parse().unwrap_or(mime::IMAGE_PNG);
let upload = session
.client
.media()
.upload(&mime_type, data.to_vec())
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
room.set_avatar_url(&upload.content_uri, None)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}
#[derive(Serialize)]
pub struct UnreadInfo {
pub room_id: String,
pub unread_notifications: u64,
pub unread_messages: u64,
}
pub async fn get_unread_counts(
State(state): State<ServerState>,
headers: HeaderMap,
) -> Result<Json<Vec<UnreadInfo>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rooms = session.client.joined_rooms();
let mut result = Vec::new();
for room in rooms {
let unread_notifications = room.unread_notification_counts().notification_count;
let unread_messages = room.num_unread_messages();
result.push(UnreadInfo {
room_id: room.room_id().to_string(),
unread_notifications,
unread_messages,
});
}
Ok(Json(result))
}
#[derive(Deserialize)]
pub struct MarkReadRequest {
pub event_id: String,
}
pub async fn mark_room_read(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<MarkReadRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let event_id: matrix_sdk::ruma::OwnedEventId = req
.event_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
room.send_single_receipt(ReceiptType::Read, ReceiptThread::Main, event_id)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(true))
}

View File

@@ -0,0 +1,379 @@
use super::auth::extract_token;
use crate::ServerState;
use axum::{
extract::{Path, Query, State},
http::HeaderMap,
Json,
};
use matrix_sdk::room::MessagesOptions;
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
pub struct ThreadInfo {
pub root_event_id: String,
pub root_sender: String,
pub root_body: String,
pub root_timestamp: u64,
pub reply_count: u32,
pub last_reply_event_id: Option<String>,
pub last_reply_sender: Option<String>,
pub last_reply_body: Option<String>,
pub last_reply_timestamp: Option<u64>,
}
#[derive(Serialize)]
pub struct ThreadMessageInfo {
pub event_id: String,
pub sender: String,
pub body: String,
pub timestamp: u64,
}
#[derive(Deserialize)]
pub struct ThreadQuery {
pub limit: Option<u32>,
pub from: Option<String>,
}
pub async fn get_threads(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
) -> Result<Json<Vec<ThreadInfo>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let mut options = MessagesOptions::backward();
options.limit = matrix_sdk::ruma::uint!(100);
let messages = room
.messages(options)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mut thread_roots: std::collections::HashMap<String, ThreadInfo> =
std::collections::HashMap::new();
for msg in &messages.chunk {
let event_value: serde_json::Value = match msg.event.deserialize_as() {
Ok(v) => v,
Err(_) => continue,
};
let event_type = event_value
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("");
if event_type != "m.room.message" {
continue;
}
let event_id = match event_value.get("event_id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => continue,
};
let sender = event_value
.get("sender")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let timestamp = event_value
.get("origin_server_ts")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let content = event_value
.get("content")
.unwrap_or(&serde_json::Value::Null);
let body = content
.get("body")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let new_content = content.get("m.new_content");
let display_body = if let Some(nc) = new_content {
nc.get("body")
.and_then(|v| v.as_str())
.unwrap_or(&body)
.to_string()
} else {
body.clone()
};
let relates_to = content.get("m.relates_to");
let is_thread = relates_to
.and_then(|r| r.get("rel_type"))
.and_then(|v| v.as_str())
== Some("m.thread");
if is_thread {
if let Some(thread_info) = relates_to {
let root_id = thread_info
.get("event_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if root_id.is_empty() {
continue;
}
let is_falling_back = thread_info
.get("is_falling_back")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if is_falling_back {
let reply_count = thread_roots
.get(&root_id)
.map(|t: &ThreadInfo| t.reply_count)
.unwrap_or(0);
thread_roots
.entry(root_id.clone())
.or_insert_with(|| ThreadInfo {
root_event_id: root_id.clone(),
root_sender: sender.clone(),
root_body: display_body.clone(),
root_timestamp: timestamp,
reply_count,
last_reply_event_id: None,
last_reply_sender: None,
last_reply_body: None,
last_reply_timestamp: None,
});
} else {
let entry = thread_roots
.entry(root_id.clone())
.or_insert_with(|| ThreadInfo {
root_event_id: root_id.clone(),
root_sender: String::new(),
root_body: String::new(),
root_timestamp: 0,
reply_count: 0,
last_reply_event_id: None,
last_reply_sender: None,
last_reply_body: None,
last_reply_timestamp: None,
});
entry.reply_count += 1;
entry.last_reply_event_id = Some(event_id);
entry.last_reply_sender = Some(sender.clone());
entry.last_reply_body = Some(display_body);
entry.last_reply_timestamp = Some(timestamp);
}
}
}
}
Ok(Json(thread_roots.into_values().collect()))
}
pub async fn get_thread_messages(
State(state): State<ServerState>,
headers: HeaderMap,
Path((room_id, thread_id)): Path<(String, String)>,
Query(query): Query<ThreadQuery>,
) -> Result<Json<Vec<ThreadMessageInfo>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let limit = query.limit.unwrap_or(50);
let mut options = MessagesOptions::backward();
options.limit = matrix_sdk::ruma::uint!(100);
let messages = room
.messages(options)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
let mut result = Vec::new();
for msg in &messages.chunk {
let event_value: serde_json::Value = match msg.event.deserialize_as() {
Ok(v) => v,
Err(_) => continue,
};
let event_type = event_value
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("");
if event_type != "m.room.message" {
continue;
}
let content = event_value
.get("content")
.unwrap_or(&serde_json::Value::Null);
let relates_to = content.get("m.relates_to");
let is_in_thread = relates_to
.and_then(|r| {
let rel_type = r.get("rel_type").and_then(|v| v.as_str());
let event_id = r.get("event_id").and_then(|v| v.as_str());
if rel_type == Some("m.thread") && event_id == Some(thread_id.as_str()) {
Some(true)
} else {
None
}
})
.unwrap_or(false);
let is_root =
event_value.get("event_id").and_then(|v| v.as_str()) == Some(thread_id.as_str());
if !is_in_thread && !is_root {
continue;
}
let event_id = match event_value.get("event_id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => continue,
};
let sender = event_value
.get("sender")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let timestamp = event_value
.get("origin_server_ts")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let body = content
.get("body")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if body.is_empty() {
continue;
}
result.push(ThreadMessageInfo {
event_id,
sender,
body,
timestamp,
});
}
result.truncate(limit as usize);
Ok(Json(result))
}
#[derive(Deserialize)]
pub struct ThreadReplyRequest {
pub message: String,
}
pub async fn send_thread_reply(
State(state): State<ServerState>,
headers: HeaderMap,
Path((room_id, thread_id)): Path<(String, String)>,
Json(req): Json<ThreadReplyRequest>,
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let root_event_id: matrix_sdk::ruma::OwnedEventId = thread_id
.parse()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let mut content =
matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
content.relates_to = Some(matrix_sdk::ruma::events::room::message::Relation::Thread(
matrix_sdk::ruma::events::relation::Thread::without_fallback(root_event_id),
));
let response = room
.send(content)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.event_id.to_string()))
}
#[derive(Deserialize)]
pub struct SendReplyRequest {
pub message: String,
pub reply_to: String,
}
pub async fn send_reply(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
Json(req): Json<SendReplyRequest>,
) -> Result<Json<String>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let reply_to_event_id: matrix_sdk::ruma::OwnedEventId = req
.reply_to
.parse()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let mut content =
matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
content.relates_to = Some(matrix_sdk::ruma::events::room::message::Relation::Reply {
in_reply_to: matrix_sdk::ruma::events::relation::InReplyTo::new(reply_to_event_id),
});
let response = room
.send(content)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(response.event_id.to_string()))
}

View File

@@ -0,0 +1,97 @@
use super::auth::extract_token;
use crate::routes::metrics::UPLOADS_TOTAL;
use crate::ServerState;
use axum::{
extract::{Multipart, Path, State},
http::HeaderMap,
Json,
};
use matrix_sdk::ruma::events::room::message::{
FileMessageEventContent, ImageMessageEventContent, MessageType, RoomMessageEventContent,
};
use matrix_sdk::ruma::OwnedMxcUri;
use serde::Serialize;
#[derive(Serialize)]
pub struct UploadResult {
pub event_id: String,
pub media_url: Option<String>,
pub filename: String,
pub mimetype: Option<String>,
pub size: Option<u64>,
}
pub async fn upload_file(
State(state): State<ServerState>,
headers: HeaderMap,
Path(room_id): Path<String>,
mut multipart: Multipart,
) -> Result<Json<UploadResult>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
.as_str()
.try_into()
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let room = session
.client
.get_room(&rid)
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
let field = multipart
.next_field()
.await
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
let filename = field.file_name().unwrap_or("upload").to_string();
let content_type = field
.content_type()
.unwrap_or("application/octet-stream")
.to_string();
let data = field
.bytes()
.await
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let mime_type: mime::Mime = content_type
.parse()
.unwrap_or(mime::APPLICATION_OCTET_STREAM);
let media_response = session
.client
.media()
.upload(&mime_type, data.to_vec())
.await
.map_err(|e| {
tracing::error!("Media upload failed: {}", e);
axum::http::StatusCode::INTERNAL_SERVER_ERROR
})?;
let media_uri: OwnedMxcUri = media_response.content_uri.clone();
let content: RoomMessageEventContent = if mime_type.type_() == mime::IMAGE {
MessageType::Image(ImageMessageEventContent::plain(filename.clone(), media_uri)).into()
} else {
MessageType::File(FileMessageEventContent::plain(filename.clone(), media_uri)).into()
};
let response = room
.send(content)
.await
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
UPLOADS_TOTAL.inc();
Ok(Json(UploadResult {
event_id: response.event_id.to_string(),
media_url: Some(media_response.content_uri.to_string()),
filename,
mimetype: Some(content_type),
size: Some(data.len() as u64),
}))
}

View File

@@ -1,80 +1,318 @@
use axum::{
extract::State,
http::HeaderMap,
Json,
};
use serde::{Deserialize, Serialize};
use crate::ServerState;
use super::auth::extract_token;
use crate::state::WsEvent;
use crate::ServerState;
use axum::{extract::State, http::HeaderMap, Json};
use livekit_api::access_token::{AccessToken, VideoGrants};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Deserialize)]
pub struct VoiceRequest {
pub struct VoiceJoinRequest {
pub room_id: String,
}
#[derive(Serialize)]
pub struct VoiceStateInfo {
pub struct VoiceJoinResponse {
pub room_id: String,
pub livekit_url: String,
pub livekit_token: String,
}
#[derive(Serialize)]
pub struct VoiceToggleResponse {
pub muted: bool,
pub deafened: bool,
pub streaming: bool,
}
#[derive(Serialize)]
pub struct VoiceParticipantInfo {
pub user_id: String,
pub muted: bool,
pub deafened: bool,
}
fn sanitize_room_id(room_id: &str) -> String {
room_id
.replace(':', "-")
.replace("!", "")
.replace("/", "-")
.replace(" ", "_")
}
pub async fn join_voice_channel(
State(state): State<ServerState>,
headers: HeaderMap,
Json(req): Json<VoiceRequest>,
) -> Result<Json<VoiceStateInfo>, axum::http::StatusCode> {
Json(req): Json<VoiceJoinRequest>,
) -> Result<Json<VoiceJoinResponse>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let mut s = state.write().await;
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
session.voice_manager.join_channel(req.room_id.clone(), session.user_id.clone());
let (user_id, livekit_url, livekit_api_key, livekit_api_secret) = {
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
(
session.user_id.clone(),
s.livekit.url.clone(),
s.livekit.api_key.clone(),
s.livekit.api_secret.clone(),
)
};
Ok(Json(VoiceStateInfo {
let lk_room = sanitize_room_id(&req.room_id);
let lk_identity = sanitize_room_id(&user_id);
let grants = VideoGrants {
room_join: true,
room: lk_room.clone(),
can_publish: true,
can_subscribe: true,
can_publish_data: true,
..Default::default()
};
let lk_token = AccessToken::with_api_key(&livekit_api_key, &livekit_api_secret)
.with_identity(&lk_identity)
.with_name(&user_id)
.with_grants(grants)
.to_jwt()
.map_err(|e| {
tracing::error!("Failed to generate LiveKit token: {}", e);
axum::http::StatusCode::INTERNAL_SERVER_ERROR
})?;
let old_channel;
{
let mut s = state.write().await;
let session = s
.sessions
.get_mut(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
old_channel = session.voice_manager.active_channel.clone();
session.voice_manager.join_channel(&req.room_id);
}
if let Some(ref old) = old_channel {
if old != &req.room_id {
let mut s = state.write().await;
if let Some(room) = s.voice_rooms.rooms.get_mut(old) {
room.participants.remove(&user_id);
if room.participants.is_empty() {
s.voice_rooms.rooms.remove(old);
}
}
drop(s);
let s = state.read().await;
if let Some(session) = s.sessions.get(&token) {
let _ = session.event_sender.send(WsEvent::VoiceUserLeft {
room_id: old.clone(),
user_id: user_id.clone(),
});
}
}
}
{
let mut s = state.write().await;
let room_entry = s
.voice_rooms
.rooms
.entry(req.room_id.clone())
.or_insert_with(|| crate::state::VoiceRoom {
participants: HashMap::new(),
});
room_entry.participants.insert(
user_id.clone(),
crate::state::VoiceParticipant {
user_id: user_id.clone(),
muted: false,
deafened: false,
},
);
}
{
let s = state.read().await;
if let Some(session) = s.sessions.get(&token) {
let _ = session.event_sender.send(WsEvent::VoiceUserJoined {
room_id: req.room_id.clone(),
user_id: user_id.clone(),
});
}
}
Ok(Json(VoiceJoinResponse {
room_id: req.room_id,
muted: false,
deafened: false,
streaming: false,
livekit_url,
livekit_token: lk_token,
}))
}
pub async fn leave_voice_channel(
State(state): State<ServerState>,
headers: HeaderMap,
Json(req): Json<VoiceRequest>,
Json(req): Json<VoiceJoinRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let mut s = state.write().await;
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
Ok(Json(session.voice_manager.leave_channel(&req.room_id, &session.user_id)))
let user_id;
{
let mut s = state.write().await;
let session = s
.sessions
.get_mut(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
user_id = session.user_id.clone();
session.voice_manager.leave_channel();
if let Some(room) = s.voice_rooms.rooms.get_mut(&req.room_id) {
room.participants.remove(&user_id);
if room.participants.is_empty() {
s.voice_rooms.rooms.remove(&req.room_id);
}
}
}
{
let s = state.read().await;
if let Some(session) = s.sessions.get(&token) {
let _ = session.event_sender.send(WsEvent::VoiceUserLeft {
room_id: req.room_id,
user_id,
});
}
}
Ok(Json(true))
}
pub async fn toggle_mute(
State(state): State<ServerState>,
headers: HeaderMap,
Json(req): Json<VoiceRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
Json(_req): Json<VoiceJoinRequest>,
) -> Result<Json<VoiceToggleResponse>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let mut s = state.write().await;
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
session.voice_manager.toggle_mute(&req.room_id, &session.user_id)
.ok_or(axum::http::StatusCode::BAD_REQUEST)
.map(Json)
let (user_id, room_id, new_muted, deafened);
{
let mut s = state.write().await;
let session = s
.sessions
.get_mut(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
new_muted = session.voice_manager.toggle_mute();
deafened = session.voice_manager.deafened;
room_id = session
.voice_manager
.active_channel
.clone()
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
user_id = session.user_id.clone();
if let Some(room) = s.voice_rooms.rooms.get_mut(&room_id) {
if let Some(participant) = room.participants.get_mut(&user_id) {
participant.muted = new_muted;
}
}
}
{
let s = state.read().await;
if let Some(session) = s.sessions.get(&token) {
let _ = session.event_sender.send(WsEvent::VoiceStateUpdate {
room_id,
user_id,
muted: new_muted,
deafened,
});
}
}
Ok(Json(VoiceToggleResponse {
muted: new_muted,
deafened,
}))
}
pub async fn toggle_deafen(
State(state): State<ServerState>,
headers: HeaderMap,
Json(req): Json<VoiceRequest>,
) -> Result<Json<bool>, axum::http::StatusCode> {
Json(_req): Json<VoiceJoinRequest>,
) -> Result<Json<VoiceToggleResponse>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let mut s = state.write().await;
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
session.voice_manager.toggle_deafen(&req.room_id, &session.user_id)
.ok_or(axum::http::StatusCode::BAD_REQUEST)
.map(Json)
}
let (user_id, room_id, muted, new_deafened);
{
let mut s = state.write().await;
let session = s
.sessions
.get_mut(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
new_deafened = session.voice_manager.toggle_deafen();
muted = session.voice_manager.muted;
room_id = session
.voice_manager
.active_channel
.clone()
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
user_id = session.user_id.clone();
if let Some(room) = s.voice_rooms.rooms.get_mut(&room_id) {
if let Some(participant) = room.participants.get_mut(&user_id) {
participant.deafened = new_deafened;
participant.muted = muted;
}
}
}
{
let s = state.read().await;
if let Some(session) = s.sessions.get(&token) {
let _ = session.event_sender.send(WsEvent::VoiceStateUpdate {
room_id,
user_id,
muted,
deafened: new_deafened,
});
}
}
Ok(Json(VoiceToggleResponse {
muted,
deafened: new_deafened,
}))
}
pub async fn get_voice_participants(
State(state): State<ServerState>,
headers: HeaderMap,
) -> Result<Json<Vec<VoiceParticipantInfo>>, axum::http::StatusCode> {
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let s = state.read().await;
let session = s
.sessions
.get(&token)
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
let room_id = match &session.voice_manager.active_channel {
Some(r) => r.clone(),
None => return Ok(Json(Vec::new())),
};
let participants = match s.voice_rooms.rooms.get(&room_id) {
Some(room) => room
.participants
.values()
.map(|p| VoiceParticipantInfo {
user_id: p.user_id.clone(),
muted: p.muted,
deafened: p.deafened,
})
.collect(),
None => Vec::new(),
};
Ok(Json(participants))
}

84
server/src/routes/ws.rs Normal file
View File

@@ -0,0 +1,84 @@
use crate::ServerState;
use axum::{
extract::{
ws::{Message, WebSocket},
Query, State, WebSocketUpgrade,
},
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
#[derive(Deserialize)]
pub struct WsQuery {
pub token: Option<String>,
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(query): Query<WsQuery>,
State(state): State<ServerState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_ws(socket, state, query.token.unwrap_or_default()))
}
async fn handle_ws(socket: WebSocket, state: ServerState, token: String) {
let (user_id, mut event_rx) = {
let s = state.read().await;
match s.sessions.get(&token) {
Some(session) => (session.user_id.clone(), session.event_sender.subscribe()),
None => {
let _ = socket.close().await;
return;
}
}
};
let connected_msg = serde_json::json!({"type": "connected", "user_id": user_id}).to_string();
let (mut ws_sender, mut ws_receiver) = socket.split();
if ws_sender.send(Message::Text(connected_msg)).await.is_err() {
return;
}
let (out_tx, mut out_rx) = tokio::sync::mpsc::channel::<String>(64);
let sender_task = tokio::spawn(async move {
while let Some(text) = out_rx.recv().await {
if ws_sender.send(Message::Text(text)).await.is_err() {
break;
}
}
});
let receiver_task = tokio::spawn(async move {
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Text(text)) if text.as_str() == "ping" => {}
Ok(Message::Close(_)) | Err(_) => break,
_ => {}
}
}
});
let forward_task = tokio::spawn(async move {
while let Ok(event) = event_rx.recv().await {
match serde_json::to_string(&event) {
Ok(json) => {
if out_tx.send(json).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("Failed to serialize WS event: {}", e);
}
}
}
});
tokio::select! {
_ = sender_task => {},
_ = receiver_task => {},
_ = forward_task => {},
}
}

181
server/src/session_store.rs Normal file
View File

@@ -0,0 +1,181 @@
use rusqlite::Connection;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct SessionStore {
conn: Arc<Mutex<Connection>>,
}
#[derive(Clone)]
pub struct StoredSession {
pub token: String,
pub user_id: String,
pub homeserver: String,
pub access_token: String,
pub device_id: Option<String>,
pub refresh_token: Option<String>,
}
impl SessionStore {
pub fn new(path: &str) -> Result<Self, rusqlite::Error> {
let conn = if Path::new(path).exists() {
Connection::open(path)?
} else {
let conn = Connection::open(path)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
homeserver TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT,
refresh_token TEXT
);",
)?;
conn
};
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
pub async fn save_session(&self, session: &StoredSession) -> Result<(), rusqlite::Error> {
let conn = self.conn.lock().await;
conn.execute(
"INSERT OR REPLACE INTO sessions (token, user_id, homeserver, access_token, device_id, refresh_token)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
rusqlite::params![
session.token,
session.user_id,
session.homeserver,
session.access_token,
session.device_id,
session.refresh_token,
],
)?;
Ok(())
}
pub async fn get_all_sessions(&self) -> Result<Vec<StoredSession>, rusqlite::Error> {
let conn = self.conn.lock().await;
let mut stmt = conn.prepare(
"SELECT token, user_id, homeserver, access_token, device_id, refresh_token FROM sessions"
)?;
let rows = stmt.query_map([], |row| {
Ok(StoredSession {
token: row.get(0)?,
user_id: row.get(1)?,
homeserver: row.get(2)?,
access_token: row.get(3)?,
device_id: row.get(4)?,
refresh_token: row.get(5)?,
})
})?;
rows.collect()
}
pub async fn delete_session(&self, token: &str) -> Result<(), rusqlite::Error> {
let conn = self.conn.lock().await;
conn.execute("DELETE FROM sessions WHERE token = ?1", [token])?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn create_test_store() -> (SessionStore, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.db");
let path_str = path.to_str().unwrap().to_string();
let store = SessionStore::new(&path_str).unwrap();
(store, dir)
}
#[tokio::test]
async fn save_and_retrieve_session() {
let (store, _dir) = create_test_store().await;
let session = StoredSession {
token: "test-token".to_string(),
user_id: "@alice:server".to_string(),
homeserver: "https://matrix.server".to_string(),
access_token: "access-123".to_string(),
device_id: Some("DEVICE1".to_string()),
refresh_token: Some("refresh-456".to_string()),
};
store.save_session(&session).await.unwrap();
let sessions = store.get_all_sessions().await.unwrap();
assert_eq!(sessions.len(), 1);
assert_eq!(sessions[0].token, "test-token");
assert_eq!(sessions[0].user_id, "@alice:server");
assert_eq!(sessions[0].access_token, "access-123");
assert_eq!(sessions[0].device_id, Some("DEVICE1".to_string()));
}
#[tokio::test]
async fn save_session_upserts() {
let (store, _dir) = create_test_store().await;
let session = StoredSession {
token: "test-token".to_string(),
user_id: "@alice:server".to_string(),
homeserver: "https://matrix.server".to_string(),
access_token: "access-123".to_string(),
device_id: None,
refresh_token: None,
};
store.save_session(&session).await.unwrap();
let mut updated = session.clone();
updated.access_token = "access-456".to_string();
store.save_session(&updated).await.unwrap();
let sessions = store.get_all_sessions().await.unwrap();
assert_eq!(sessions.len(), 1);
assert_eq!(sessions[0].access_token, "access-456");
}
#[tokio::test]
async fn delete_session_removes_it() {
let (store, _dir) = create_test_store().await;
let session = StoredSession {
token: "test-token".to_string(),
user_id: "@alice:server".to_string(),
homeserver: "https://matrix.server".to_string(),
access_token: "access-123".to_string(),
device_id: None,
refresh_token: None,
};
store.save_session(&session).await.unwrap();
store.delete_session("test-token").await.unwrap();
let sessions = store.get_all_sessions().await.unwrap();
assert!(sessions.is_empty());
}
#[tokio::test]
async fn delete_nonexistent_session_is_ok() {
let (store, _dir) = create_test_store().await;
store.delete_session("nonexistent").await.unwrap();
}
#[tokio::test]
async fn save_multiple_sessions() {
let (store, _dir) = create_test_store().await;
for i in 0..3 {
let session = StoredSession {
token: format!("token-{}", i),
user_id: format!("@user{}:server", i),
homeserver: "https://matrix.server".to_string(),
access_token: format!("access-{}", i),
device_id: None,
refresh_token: None,
};
store.save_session(&session).await.unwrap();
}
let sessions = store.get_all_sessions().await.unwrap();
assert_eq!(sessions.len(), 3);
}
}

View File

@@ -1,117 +1,381 @@
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::session_store::SessionStore;
use matrix_sdk::Client;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::broadcast;
use tokio::sync::RwLock;
#[derive(Clone, Debug, serde::Serialize)]
#[serde(tag = "type")]
pub enum WsEvent {
#[serde(rename = "message")]
Message {
room_id: String,
event_id: String,
sender: String,
body: String,
timestamp: u64,
reply_to: Option<String>,
msgtype: Option<String>,
media_url: Option<String>,
filename: Option<String>,
mimetype: Option<String>,
},
#[serde(rename = "message_edited")]
MessageEdited {
room_id: String,
event_id: String,
new_body: String,
},
#[serde(rename = "message_deleted")]
MessageDeleted { room_id: String, redacts: String },
#[serde(rename = "reaction")]
Reaction {
room_id: String,
event_id: String,
key: String,
sender: String,
},
#[serde(rename = "room_joined")]
RoomJoined { room_id: String, name: String },
#[serde(rename = "room_left")]
RoomLeft { room_id: String },
#[serde(rename = "presence")]
Presence {
user_id: String,
status: String,
status_msg: Option<String>,
},
#[serde(rename = "typing")]
Typing {
room_id: String,
user_id: String,
typing: bool,
},
#[serde(rename = "voice_state_update")]
VoiceStateUpdate {
room_id: String,
user_id: String,
muted: bool,
deafened: bool,
},
#[serde(rename = "voice_user_joined")]
VoiceUserJoined { room_id: String, user_id: String },
#[serde(rename = "voice_user_left")]
VoiceUserLeft { room_id: String, user_id: String },
#[serde(rename = "thread_reply")]
ThreadReply {
room_id: String,
root_event_id: String,
event_id: String,
sender: String,
body: String,
timestamp: u64,
},
}
pub struct LiveKitConfig {
pub api_key: String,
pub api_secret: String,
pub url: String,
}
pub struct Session {
pub client: Client,
pub user_id: String,
pub homeserver: String,
pub voice_manager: VoiceManager,
pub event_sender: broadcast::Sender<WsEvent>,
pub sync_handle: Option<tokio::task::JoinHandle<()>>,
pub created_at: Instant,
pub expires_at: Option<Instant>,
}
impl Session {
pub fn new(client: Client, user_id: String, homeserver: String) -> Self {
let (sender, _) = broadcast::channel(256);
Self {
client,
user_id,
homeserver,
voice_manager: VoiceManager::new(),
event_sender: sender,
sync_handle: None,
created_at: Instant::now(),
expires_at: None,
}
}
pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
self.expires_at = Some(self.created_at + ttl);
self
}
pub fn is_expired(&self) -> bool {
self.expires_at
.map(|expires| Instant::now() > expires)
.unwrap_or(false)
}
}
#[derive(Default)]
pub struct VoiceManager {
channels: HashMap<String, VoiceChannel>,
active_channel: Option<String>,
}
pub struct VoiceChannel {
pub room_id: String,
pub participants: Vec<VoiceParticipant>,
}
pub struct VoiceParticipant {
pub user_id: String,
pub active_channel: Option<String>,
pub muted: bool,
pub deafened: bool,
pub streaming: bool,
}
impl VoiceManager {
pub fn new() -> Self {
Self {
channels: HashMap::new(),
active_channel: None,
}
}
pub fn join_channel(&mut self, room_id: String, user_id: String) -> bool {
if let Some(ref old_channel) = self.active_channel {
self.leave_channel_internal(old_channel, &user_id);
}
let channel = self.channels.entry(room_id.clone()).or_insert_with(|| VoiceChannel {
room_id: room_id.clone(),
participants: Vec::new(),
});
if channel.participants.iter().any(|p| p.user_id == user_id) {
return false;
}
channel.participants.push(VoiceParticipant {
user_id,
muted: false,
deafened: false,
streaming: false,
});
self.active_channel = Some(room_id);
true
}
fn leave_channel_internal(&mut self, room_id: &str, user_id: &str) {
if let Some(channel) = self.channels.get_mut(room_id) {
channel.participants.retain(|p| p.user_id != user_id);
if channel.participants.is_empty() {
self.channels.remove(room_id);
}
}
}
pub fn leave_channel(&mut self, room_id: &str, user_id: &str) -> bool {
self.leave_channel_internal(room_id, user_id);
if self.active_channel.as_deref() == Some(room_id) {
self.active_channel = None;
}
true
pub fn join_channel(&mut self, room_id: &str) {
self.active_channel = Some(room_id.to_string());
self.muted = false;
self.deafened = false;
}
pub fn toggle_mute(&mut self, room_id: &str, user_id: &str) -> Option<bool> {
if let Some(channel) = self.channels.get_mut(room_id) {
if let Some(participant) = channel.participants.iter_mut().find(|p| p.user_id == user_id) {
participant.muted = !participant.muted;
return Some(participant.muted);
}
}
None
pub fn leave_channel(&mut self) {
self.active_channel = None;
self.muted = false;
self.deafened = false;
}
pub fn toggle_deafen(&mut self, room_id: &str, user_id: &str) -> Option<bool> {
if let Some(channel) = self.channels.get_mut(room_id) {
if let Some(participant) = channel.participants.iter_mut().find(|p| p.user_id == user_id) {
participant.deafened = !participant.deafened;
if participant.deafened {
participant.muted = true;
}
return Some(participant.deafened);
}
pub fn toggle_mute(&mut self) -> bool {
self.muted = !self.muted;
self.muted
}
pub fn toggle_deafen(&mut self) -> bool {
self.deafened = !self.deafened;
if self.deafened {
self.muted = true;
}
None
self.deafened
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn voice_manager_new_is_idle() {
let vm = VoiceManager::new();
assert!(vm.active_channel.is_none());
assert!(!vm.muted);
assert!(!vm.deafened);
}
#[test]
fn voice_manager_default_is_idle() {
let vm = VoiceManager::default();
assert!(vm.active_channel.is_none());
assert!(!vm.muted);
assert!(!vm.deafened);
}
#[test]
fn join_channel_sets_active() {
let mut vm = VoiceManager::new();
vm.join_channel("!room:server");
assert_eq!(vm.active_channel, Some("!room:server".to_string()));
assert!(!vm.muted);
assert!(!vm.deafened);
}
#[test]
fn join_channel_resets_mute_deafen() {
let mut vm = VoiceManager::new();
vm.muted = true;
vm.deafened = true;
vm.join_channel("!room:server");
assert!(!vm.muted);
assert!(!vm.deafened);
}
#[test]
fn leave_channel_clears_state() {
let mut vm = VoiceManager::new();
vm.join_channel("!room:server");
vm.muted = true;
vm.leave_channel();
assert!(vm.active_channel.is_none());
assert!(!vm.muted);
assert!(!vm.deafened);
}
#[test]
fn toggle_mute_flips_state() {
let mut vm = VoiceManager::new();
assert!(!vm.muted);
let result = vm.toggle_mute();
assert!(result);
assert!(vm.muted);
let result = vm.toggle_mute();
assert!(!result);
assert!(!vm.muted);
}
#[test]
fn toggle_deafen_sets_muted_when_deafened() {
let mut vm = VoiceManager::new();
let result = vm.toggle_deafen();
assert!(result);
assert!(vm.deafened);
assert!(vm.muted);
}
#[test]
fn toggle_deafen_off_keeps_muted() {
let mut vm = VoiceManager::new();
vm.muted = true;
vm.toggle_deafen();
assert!(vm.muted);
vm.toggle_deafen();
assert!(vm.deafened == false);
assert!(vm.muted);
}
#[test]
fn ws_event_serialization() {
let event = WsEvent::Message {
room_id: "!room:server".to_string(),
event_id: "$event".to_string(),
sender: "@user:server".to_string(),
body: "hello".to_string(),
timestamp: 123456,
reply_to: None,
msgtype: Some("m.text".to_string()),
media_url: None,
filename: None,
mimetype: None,
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"message\""));
assert!(json.contains("\"body\":\"hello\""));
}
#[test]
fn ws_event_presence_serialization() {
let event = WsEvent::Presence {
user_id: "@alice:server".to_string(),
status: "online".to_string(),
status_msg: Some("working".to_string()),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"presence\""));
assert!(json.contains("\"status\":\"online\""));
}
#[test]
fn is_expired_true_when_past() {
let created = Instant::now() - std::time::Duration::from_secs(120);
let expires = created + std::time::Duration::from_secs(60);
assert!(Instant::now() > expires);
}
#[test]
fn is_expired_none_means_not_expired() {
let expires_at: Option<Instant> = None;
let expired = expires_at
.map(|expires| Instant::now() > expires)
.unwrap_or(false);
assert!(!expired);
}
#[test]
fn is_expired_future_means_not_expired() {
let expires_at: Option<Instant> =
Some(Instant::now() + std::time::Duration::from_secs(3600));
let expired = expires_at
.map(|expires| Instant::now() > expires)
.unwrap_or(false);
assert!(!expired);
}
#[test]
fn is_expired_past_means_expired() {
let expires_at: Option<Instant> = Some(Instant::now() - std::time::Duration::from_secs(1));
let expired = expires_at
.map(|expires| Instant::now() > expires)
.unwrap_or(false);
assert!(expired);
}
}
pub struct ServerStateInner {
pub sessions: HashMap<String, Session>,
pub session_store: Arc<SessionStore>,
pub livekit: LiveKitConfig,
pub voice_rooms: VoiceRooms,
pub session_ttl: Option<std::time::Duration>,
}
pub struct VoiceRooms {
pub rooms: HashMap<String, VoiceRoom>,
}
pub struct VoiceRoom {
pub participants: HashMap<String, VoiceParticipant>,
}
#[derive(Clone, serde::Serialize)]
pub struct VoiceParticipant {
pub user_id: String,
pub muted: bool,
pub deafened: bool,
}
impl ServerStateInner {
pub fn new() -> Self {
pub fn new(
session_store: Arc<SessionStore>,
livekit: LiveKitConfig,
session_ttl: Option<std::time::Duration>,
) -> Self {
Self {
sessions: HashMap::new(),
session_store,
livekit,
voice_rooms: VoiceRooms {
rooms: HashMap::new(),
},
session_ttl,
}
}
}
pub type ServerState = Arc<RwLock<ServerStateInner>>;
#[derive(Clone)]
pub struct ServerState {
inner: Arc<RwLock<ServerStateInner>>,
}
impl ServerState {
pub fn new() -> Self {
Arc::new(RwLock::new(ServerStateInner::new()))
pub fn new(
session_store: Arc<SessionStore>,
livekit: LiveKitConfig,
session_ttl: Option<std::time::Duration>,
) -> Self {
Self {
inner: Arc::new(RwLock::new(ServerStateInner::new(
session_store,
livekit,
session_ttl,
))),
}
}
}
pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, ServerStateInner> {
self.inner.read().await
}
pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, ServerStateInner> {
self.inner.write().await
}
}