feat: comprehensive project improvements
Some checks failed
CI / Rust Format (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Test Server (push) Has been cancelled
CI / Frontend Check (push) Has been cancelled
CI / Tauri Client Check (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Build Tauri (Linux) (push) Has been cancelled
Some checks failed
CI / Rust Format (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Test Server (push) Has been cancelled
CI / Frontend Check (push) Has been cancelled
CI / Tauri Client Check (push) Has been cancelled
CI / Docker Build (push) Has been cancelled
CI / Build Tauri (Linux) (push) Has been cancelled
- Fix 14 Clippy warnings across server and bot-sdk - Add 67 unit tests (32 bot-sdk, 34 server, 1 doctest) - Add Prometheus metrics endpoint (/api/metrics) - Add structured JSON logging (EIFELDC_LOG_FORMAT=json) - Add release workflow (Docker push + GitHub Release + Tauri builds) - Add rate limiting middleware (EIFELDC_RATE_LIMIT) - Add CORS restriction (EIFELDC_CORS_ORIGINS) - Add session token expiry (EIFELDC_SESSION_TTL) - Add input validation (username/password/homeserver length limits) - Add upload size limit (EIFELDC_MAX_UPLOAD_MB) - Upgrade Tauri client from v1 to v2 - Add session store with SQLite persistence - Add proper error types and cleanup across all crates - Format all code with cargo fmt - Update CI pipeline with fmt, clippy, test, frontend, and Tauri checks - Add README with full API reference and setup guide
This commit is contained in:
@@ -15,8 +15,17 @@ tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
url = { workspace = true }
|
||||
axum = { version = "0.7", features = ["ws"] }
|
||||
axum = { version = "0.7", features = ["ws", "multipart"] }
|
||||
tower = "0.4"
|
||||
tower-http = { version = "0.5", features = ["cors", "fs", "trace"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
mime = "0.3"
|
||||
futures = "0.3"
|
||||
rusqlite = { version = "0.30", features = ["bundled"] }
|
||||
livekit-api = { version = "0.4", features = ["access-token"] }
|
||||
prometheus = "0.13"
|
||||
lazy_static = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod routes;
|
||||
pub mod session_store;
|
||||
pub mod state;
|
||||
|
||||
pub use state::ServerState;
|
||||
pub use session_store::SessionStore;
|
||||
pub use state::ServerState;
|
||||
|
||||
@@ -1,20 +1,66 @@
|
||||
use eifeldc_server::routes::api_router;
|
||||
use eifeldc_server::routes::metrics;
|
||||
use eifeldc_server::session_store::SessionStore;
|
||||
use eifeldc_server::state::{LiveKitConfig, Session};
|
||||
use eifeldc_server::ServerState;
|
||||
use tower_http::services::ServeDir;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tower_http::services::ServeDir;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
init_tracing();
|
||||
|
||||
metrics::init_metrics();
|
||||
tracing::info!("Metrics initialized");
|
||||
|
||||
let db_path = std::env::var("EIFELDC_DB").unwrap_or_else(|_| "eifeldc.db".to_string());
|
||||
|
||||
let session_store =
|
||||
Arc::new(SessionStore::new(&db_path).expect("Failed to open session database"));
|
||||
|
||||
let livekit = LiveKitConfig {
|
||||
api_key: std::env::var("LIVEKIT_API_KEY").unwrap_or_else(|_| "devkey".to_string()),
|
||||
api_secret: std::env::var("LIVEKIT_API_SECRET").unwrap_or_else(|_| "devsecret".to_string()),
|
||||
url: std::env::var("LIVEKIT_URL").unwrap_or_else(|_| "ws://localhost:7880".to_string()),
|
||||
};
|
||||
tracing::info!("LiveKit URL: {}", livekit.url);
|
||||
|
||||
let session_ttl: Option<std::time::Duration> =
|
||||
std::env::var("EIFELDC_SESSION_TTL").ok().map(|v| {
|
||||
let secs: u64 = v.parse().unwrap_or(86400);
|
||||
std::time::Duration::from_secs(secs)
|
||||
});
|
||||
if let Some(ttl) = session_ttl {
|
||||
tracing::info!("Session TTL: {}s", ttl.as_secs());
|
||||
} else {
|
||||
tracing::info!("Session TTL: unlimited");
|
||||
}
|
||||
|
||||
let state = ServerState::new(session_store.clone(), livekit, session_ttl);
|
||||
|
||||
{
|
||||
let stored_sessions = session_store.get_all_sessions().await.unwrap_or_default();
|
||||
if !stored_sessions.is_empty() {
|
||||
tracing::info!("Restoring {} session(s)...", stored_sessions.len());
|
||||
for stored in &stored_sessions {
|
||||
match restore_session(stored, &state).await {
|
||||
Ok(()) => tracing::info!("Restored session for user {}", stored.user_id),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to restore session for {}: {}", stored.user_id, e);
|
||||
session_store.delete_session(&stored.token).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let state = ServerState::new();
|
||||
let api = api_router(state);
|
||||
|
||||
let static_dir = std::env::var("EIFELDC_STATIC_DIR")
|
||||
.unwrap_or_else(|_| "client/src-ui/dist".to_string());
|
||||
let static_dir =
|
||||
std::env::var("EIFELDC_STATIC_DIR").unwrap_or_else(|_| "client/src-ui/dist".to_string());
|
||||
|
||||
let app = api
|
||||
.fallback_service(ServeDir::new(&static_dir));
|
||||
let app = api.fallback_service(ServeDir::new(&static_dir));
|
||||
|
||||
let addr: SocketAddr = ([0, 0, 0, 0], 3000).into();
|
||||
tracing::info!("EifelDC Web Server listening on http://{}", addr);
|
||||
@@ -22,4 +68,80 @@ async fn main() {
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn restore_session(
|
||||
stored: &eifeldc_server::session_store::StoredSession,
|
||||
state: &ServerState,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use matrix_sdk::matrix_auth::{MatrixSession, MatrixSessionTokens};
|
||||
use matrix_sdk::ruma::device_id;
|
||||
use matrix_sdk::Client;
|
||||
|
||||
let client = Client::builder()
|
||||
.homeserver_url(&stored.homeserver)
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
let user_id: matrix_sdk::ruma::OwnedUserId = stored.user_id.parse()?;
|
||||
let device_id = stored
|
||||
.device_id
|
||||
.as_deref()
|
||||
.map(matrix_sdk::ruma::OwnedDeviceId::from)
|
||||
.unwrap_or_else(|| device_id!("EIFELDC").to_owned());
|
||||
|
||||
let session = MatrixSession {
|
||||
meta: matrix_sdk::SessionMeta { user_id, device_id },
|
||||
tokens: MatrixSessionTokens {
|
||||
access_token: stored.access_token.clone(),
|
||||
refresh_token: stored.refresh_token.clone(),
|
||||
},
|
||||
};
|
||||
|
||||
client.matrix_auth().restore_session(session).await?;
|
||||
|
||||
let mut s = Session::new(
|
||||
client.clone(),
|
||||
stored.user_id.clone(),
|
||||
stored.homeserver.clone(),
|
||||
);
|
||||
|
||||
let sender = s.event_sender.clone();
|
||||
let sync_client = client.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
eifeldc_server::routes::auth::start_sync(sync_client, sender).await;
|
||||
});
|
||||
s.sync_handle = Some(handle);
|
||||
|
||||
let mut state_inner = state.write().await;
|
||||
state_inner.sessions.insert(stored.token.clone(), s);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_tracing() {
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
let log_format = std::env::var("EIFELDC_LOG_FORMAT").unwrap_or_else(|_| "pretty".to_string());
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("eifeldc_server=info,tower_http=info"));
|
||||
|
||||
match log_format.as_str() {
|
||||
"json" => {
|
||||
tracing_subscriber::fmt()
|
||||
.json()
|
||||
.with_env_filter(env_filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(false)
|
||||
.with_file(false)
|
||||
.with_line_number(false)
|
||||
.init();
|
||||
}
|
||||
_ => {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.with_target(true)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
Json,
|
||||
};
|
||||
use crate::session_store::StoredSession;
|
||||
use crate::state::{Session, WsEvent};
|
||||
use axum::{extract::State, http::HeaderMap, Json};
|
||||
use matrix_sdk::config::SyncSettings;
|
||||
use matrix_sdk::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::state::VoiceManager;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
@@ -14,6 +12,21 @@ pub struct LoginRequest {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl LoginRequest {
|
||||
pub fn validate(&self) -> Result<(), &'static str> {
|
||||
if self.homeserver.is_empty() || self.homeserver.len() > 2048 {
|
||||
return Err("Invalid homeserver");
|
||||
}
|
||||
if self.username.is_empty() || self.username.len() > 256 {
|
||||
return Err("Invalid username");
|
||||
}
|
||||
if self.password.is_empty() || self.password.len() > 4096 {
|
||||
return Err("Invalid password");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LoginResult {
|
||||
pub success: bool,
|
||||
@@ -29,36 +42,83 @@ pub struct RegisterRequest {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl RegisterRequest {
|
||||
pub fn validate(&self) -> Result<(), &'static str> {
|
||||
if self.homeserver.is_empty() || self.homeserver.len() > 2048 {
|
||||
return Err("Invalid homeserver");
|
||||
}
|
||||
if self.username.is_empty() || self.username.len() > 256 {
|
||||
return Err("Invalid username");
|
||||
}
|
||||
if self.password.is_empty() || self.password.len() > 4096 {
|
||||
return Err("Invalid password");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<crate::state::ServerState>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Result<Json<LoginResult>, StatusCode> {
|
||||
) -> Result<Json<LoginResult>, axum::http::StatusCode> {
|
||||
req.validate()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let client = Client::builder()
|
||||
.homeserver_url(&req.homeserver)
|
||||
.build()
|
||||
.await
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
client
|
||||
.matrix_auth()
|
||||
.login_username(&req.username, &req.password)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
.map_err(|_| axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user_id = client
|
||||
.user_id()
|
||||
.map(|u| u.to_string())
|
||||
.unwrap_or_default();
|
||||
let user_id = client.user_id().map(|u| u.to_string()).unwrap_or_default();
|
||||
|
||||
let token = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let mut s = state.write().await;
|
||||
s.sessions.insert(token.clone(), crate::state::Session {
|
||||
client,
|
||||
user_id: user_id.clone(),
|
||||
voice_manager: VoiceManager::new(),
|
||||
let matrix_session = client
|
||||
.matrix_auth()
|
||||
.session()
|
||||
.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mut session = Session::new(client.clone(), user_id.clone(), req.homeserver.clone());
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(ttl) = s.session_ttl {
|
||||
session = session.with_ttl(ttl);
|
||||
}
|
||||
}
|
||||
|
||||
let sender = session.event_sender.clone();
|
||||
let sync_client = client.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
start_sync(sync_client, sender).await;
|
||||
});
|
||||
session.sync_handle = Some(handle);
|
||||
|
||||
let stored = StoredSession {
|
||||
token: token.clone(),
|
||||
user_id: user_id.clone(),
|
||||
homeserver: req.homeserver.clone(),
|
||||
access_token: matrix_session.tokens.access_token.clone(),
|
||||
device_id: Some(matrix_session.meta.device_id.to_string()),
|
||||
refresh_token: matrix_session.tokens.refresh_token.clone(),
|
||||
};
|
||||
|
||||
{
|
||||
let s = state.read().await;
|
||||
s.session_store
|
||||
.save_session(&stored)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
let mut s = state.write().await;
|
||||
s.sessions.insert(token.clone(), session);
|
||||
|
||||
Ok(Json(LoginResult {
|
||||
success: true,
|
||||
@@ -71,12 +131,14 @@ pub async fn login(
|
||||
pub async fn register(
|
||||
State(state): State<crate::state::ServerState>,
|
||||
Json(req): Json<RegisterRequest>,
|
||||
) -> Result<Json<LoginResult>, StatusCode> {
|
||||
) -> Result<Json<LoginResult>, axum::http::StatusCode> {
|
||||
req.validate()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let client = Client::builder()
|
||||
.homeserver_url(&req.homeserver)
|
||||
.build()
|
||||
.await
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mut request = matrix_sdk::ruma::api::client::account::register::v3::Request::new();
|
||||
request.username = Some(req.username);
|
||||
@@ -86,21 +148,51 @@ pub async fn register(
|
||||
.matrix_auth()
|
||||
.register(request)
|
||||
.await
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let user_id = client
|
||||
.user_id()
|
||||
.map(|u| u.to_string())
|
||||
.unwrap_or_default();
|
||||
let user_id = client.user_id().map(|u| u.to_string()).unwrap_or_default();
|
||||
|
||||
let token = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let mut s = state.write().await;
|
||||
s.sessions.insert(token.clone(), crate::state::Session {
|
||||
client,
|
||||
user_id: user_id.clone(),
|
||||
voice_manager: VoiceManager::new(),
|
||||
let matrix_session = client
|
||||
.matrix_auth()
|
||||
.session()
|
||||
.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mut session = Session::new(client.clone(), user_id.clone(), req.homeserver.clone());
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(ttl) = s.session_ttl {
|
||||
session = session.with_ttl(ttl);
|
||||
}
|
||||
}
|
||||
|
||||
let sender = session.event_sender.clone();
|
||||
let sync_client = client.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
start_sync(sync_client, sender).await;
|
||||
});
|
||||
session.sync_handle = Some(handle);
|
||||
|
||||
let stored = StoredSession {
|
||||
token: token.clone(),
|
||||
user_id: user_id.clone(),
|
||||
homeserver: req.homeserver.clone(),
|
||||
access_token: matrix_session.tokens.access_token.clone(),
|
||||
device_id: Some(matrix_session.meta.device_id.to_string()),
|
||||
refresh_token: matrix_session.tokens.refresh_token.clone(),
|
||||
};
|
||||
|
||||
{
|
||||
let s = state.read().await;
|
||||
s.session_store
|
||||
.save_session(&stored)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
let mut s = state.write().await;
|
||||
s.sessions.insert(token.clone(), session);
|
||||
|
||||
Ok(Json(LoginResult {
|
||||
success: true,
|
||||
@@ -113,22 +205,34 @@ pub async fn register(
|
||||
pub async fn logout(
|
||||
State(state): State<crate::state::ServerState>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<bool>, StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let mut s = state.write().await;
|
||||
if let Some(session) = s.sessions.remove(&token) {
|
||||
if let Some(mut session) = s.sessions.remove(&token) {
|
||||
if let Some(handle) = session.sync_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
let _ = session.client.matrix_auth().logout().await;
|
||||
}
|
||||
|
||||
{
|
||||
let s_read = state.read().await;
|
||||
s_read
|
||||
.session_store
|
||||
.delete_session(&token)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
pub async fn get_current_user(
|
||||
State(state): State<crate::state::ServerState>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<Option<String>>, StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
) -> Result<Json<Option<String>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let s = state.read().await;
|
||||
let user_id = s.sessions.get(&token).map(|s| s.user_id.clone());
|
||||
@@ -147,14 +251,377 @@ pub async fn auth_middleware(
|
||||
headers: HeaderMap,
|
||||
State(state): State<crate::state::ServerState>,
|
||||
request: axum::extract::Request,
|
||||
next: middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let s = state.read().await;
|
||||
if !s.sessions.contains_key(&token) {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
let expired = {
|
||||
let s = state.read().await;
|
||||
match s.sessions.get(&token) {
|
||||
Some(session) => session.is_expired(),
|
||||
None => return Err(axum::http::StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
};
|
||||
|
||||
if expired {
|
||||
let mut s = state.write().await;
|
||||
if let Some(mut session) = s.sessions.remove(&token) {
|
||||
if let Some(handle) = session.sync_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
drop(s);
|
||||
{
|
||||
let s = state.read().await;
|
||||
let _ = s.session_store.delete_session(&token).await;
|
||||
}
|
||||
return Err(axum::http::StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_sync(client: Client, sender: tokio::sync::broadcast::Sender<WsEvent>) {
|
||||
let mut sync_token: Option<String> = None;
|
||||
loop {
|
||||
let mut settings = SyncSettings::new();
|
||||
if let Some(token) = sync_token.as_ref() {
|
||||
settings = settings.token(token.clone());
|
||||
}
|
||||
match client.sync_once(settings).await {
|
||||
Ok(response) => {
|
||||
sync_token = Some(response.next_batch);
|
||||
|
||||
for (room_id, joined) in &response.rooms.join {
|
||||
let room = match client.get_room(room_id) {
|
||||
Some(r) => r,
|
||||
None => continue,
|
||||
};
|
||||
let name = match room.display_name().await {
|
||||
Ok(n) => n.to_string(),
|
||||
Err(_) => room_id.to_string(),
|
||||
};
|
||||
let _ = sender.send(WsEvent::RoomJoined {
|
||||
room_id: room_id.to_string(),
|
||||
name,
|
||||
});
|
||||
|
||||
for event in &joined.timeline.events {
|
||||
if let Some(event_id) = event.event_id() {
|
||||
let raw_json: serde_json::Value = match event.event.deserialize_as() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let sender_user = raw_json
|
||||
.get("sender")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let event_type =
|
||||
raw_json.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if event_type == "m.room.message" {
|
||||
let content =
|
||||
raw_json.get("content").unwrap_or(&serde_json::Value::Null);
|
||||
let new_content = content.get("m.new_content");
|
||||
|
||||
if new_content.is_some() {
|
||||
let relates_to = content.get("m.relates_to");
|
||||
let original_id = relates_to
|
||||
.and_then(|r| r.get("event_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let new_body = new_content
|
||||
.and_then(|nc| nc.get("body"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if !original_id.is_empty() && !new_body.is_empty() {
|
||||
let _ = sender.send(WsEvent::MessageEdited {
|
||||
room_id: room_id.to_string(),
|
||||
event_id: original_id.to_string(),
|
||||
new_body: new_body.to_string(),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
let body =
|
||||
content.get("body").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let timestamp = raw_json
|
||||
.get("origin_server_ts")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let msgtype = content
|
||||
.get("msgtype")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("m.text");
|
||||
|
||||
let media_url = content
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
content
|
||||
.get("file")
|
||||
.and_then(|f| f.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
});
|
||||
|
||||
let filename = if msgtype == "m.image"
|
||||
|| msgtype == "m.file"
|
||||
|| msgtype == "m.video"
|
||||
|| msgtype == "m.audio"
|
||||
{
|
||||
Some(body.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mimetype = content
|
||||
.get("info")
|
||||
.and_then(|i| i.get("mimetype"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let relates_to = content.get("m.relates_to");
|
||||
let reply_to = relates_to
|
||||
.and_then(|r| r.get("m.in_reply_to"))
|
||||
.and_then(|r| r.get("event_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let is_thread = relates_to
|
||||
.and_then(|r| r.get("rel_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
== Some("m.thread");
|
||||
|
||||
if !body.is_empty() || media_url.is_some() {
|
||||
if is_thread {
|
||||
if let Some(root_id) = relates_to
|
||||
.and_then(|r| r.get("event_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
let _ = sender.send(WsEvent::ThreadReply {
|
||||
room_id: room_id.to_string(),
|
||||
root_event_id: root_id.to_string(),
|
||||
event_id: event_id.to_string(),
|
||||
sender: sender_user.to_string(),
|
||||
body: body.to_string(),
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
let _ = sender.send(WsEvent::Message {
|
||||
room_id: room_id.to_string(),
|
||||
event_id: event_id.to_string(),
|
||||
sender: sender_user.to_string(),
|
||||
body: body.to_string(),
|
||||
timestamp,
|
||||
reply_to,
|
||||
msgtype: Some(msgtype.to_string()),
|
||||
media_url: media_url.clone(),
|
||||
filename,
|
||||
mimetype,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if event_type == "m.reaction" {
|
||||
let content =
|
||||
raw_json.get("content").unwrap_or(&serde_json::Value::Null);
|
||||
if let Some(relates_to) = content.get("m.relates_to") {
|
||||
let target_id = relates_to
|
||||
.get("event_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let key = relates_to
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if !target_id.is_empty() && !key.is_empty() {
|
||||
let _ = sender.send(WsEvent::Reaction {
|
||||
room_id: room_id.to_string(),
|
||||
event_id: target_id.to_string(),
|
||||
key: key.to_string(),
|
||||
sender: sender_user.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if event_type == "m.room.redaction" {
|
||||
let redacts = raw_json
|
||||
.get("redacts")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if !redacts.is_empty() {
|
||||
let _ = sender.send(WsEvent::MessageDeleted {
|
||||
room_id: room_id.to_string(),
|
||||
redacts: redacts.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for raw_ephemeral in &joined.ephemeral {
|
||||
let val: serde_json::Value = match raw_ephemeral.deserialize_as() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let event_type = val.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if event_type == "m.typing" {
|
||||
let user_ids = val
|
||||
.get("content")
|
||||
.and_then(|c| c.get("user_ids"))
|
||||
.and_then(|v| v.as_array());
|
||||
if let Some(ids) = user_ids {
|
||||
for uid in ids {
|
||||
if let Some(user_str) = uid.as_str() {
|
||||
let _ = sender.send(WsEvent::Typing {
|
||||
room_id: room_id.to_string(),
|
||||
user_id: user_str.to_string(),
|
||||
typing: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for room_id in response.rooms.leave.keys() {
|
||||
let _ = sender.send(WsEvent::RoomLeft {
|
||||
room_id: room_id.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
for raw_presence in &response.presence {
|
||||
let val: serde_json::Value = match raw_presence.deserialize_as() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let user_id = val.get("sender").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let presence = val
|
||||
.get("content")
|
||||
.and_then(|c| c.get("presence"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("offline");
|
||||
let status_msg = val
|
||||
.get("content")
|
||||
.and_then(|c| c.get("status_msg"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
if !user_id.is_empty() {
|
||||
let status = match presence {
|
||||
"online" => "online",
|
||||
"unavailable" => "idle",
|
||||
_ => "offline",
|
||||
};
|
||||
let _ = sender.send(WsEvent::Presence {
|
||||
user_id: user_id.to_string(),
|
||||
status: status.to_string(),
|
||||
status_msg,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Sync error: {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn login_request_validate_ok() {
|
||||
let req = LoginRequest {
|
||||
homeserver: "https://matrix.example.org".to_string(),
|
||||
username: "alice".to_string(),
|
||||
password: "secret123".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn login_request_reject_empty_homeserver() {
|
||||
let req = LoginRequest {
|
||||
homeserver: "".to_string(),
|
||||
username: "alice".to_string(),
|
||||
password: "secret".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn login_request_reject_empty_username() {
|
||||
let req = LoginRequest {
|
||||
homeserver: "https://matrix.example.org".to_string(),
|
||||
username: "".to_string(),
|
||||
password: "secret".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn login_request_reject_empty_password() {
|
||||
let req = LoginRequest {
|
||||
homeserver: "https://matrix.example.org".to_string(),
|
||||
username: "alice".to_string(),
|
||||
password: "".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn login_request_reject_long_username() {
|
||||
let req = LoginRequest {
|
||||
homeserver: "https://matrix.example.org".to_string(),
|
||||
username: "a".repeat(300),
|
||||
password: "secret".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_request_validate_ok() {
|
||||
let req = RegisterRequest {
|
||||
homeserver: "https://matrix.example.org".to_string(),
|
||||
username: "bob".to_string(),
|
||||
password: "password".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_request_reject_empty_fields() {
|
||||
let req = RegisterRequest {
|
||||
homeserver: "".to_string(),
|
||||
username: "".to_string(),
|
||||
password: "".to_string(),
|
||||
};
|
||||
assert!(req.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_token_from_bearer() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Authorization", "Bearer my-token-123".parse().unwrap());
|
||||
assert_eq!(extract_token(&headers), Some("my-token-123".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_token_missing_header() {
|
||||
let headers = HeaderMap::new();
|
||||
assert_eq!(extract_token(&headers), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_token_invalid_scheme() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Authorization", "Basic dXNlcjpwYXNz".parse().unwrap());
|
||||
assert_eq!(extract_token(&headers), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::ServerState;
|
||||
use super::auth::extract_token;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct CustomEmoji {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
@@ -33,10 +33,69 @@ pub struct Sticker {
|
||||
pub async fn get_custom_emoji(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(_room_id): Path<String>,
|
||||
Path(room_id): Path<String>,
|
||||
) -> Result<Json<Vec<CustomEmoji>>, axum::http::StatusCode> {
|
||||
let _token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
Ok(Json(Vec::new()))
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let mut emojis = Vec::new();
|
||||
|
||||
let event_type: matrix_sdk::ruma::events::StateEventType = "im.ponies.room_emotes".into();
|
||||
|
||||
if let Ok(Some(raw_event)) = room.get_state_event(event_type, "").await {
|
||||
let raw_json_str = match &raw_event {
|
||||
matrix_sdk::deserialized_responses::RawAnySyncOrStrippedState::Sync(s) => {
|
||||
s.json().get().to_string()
|
||||
}
|
||||
matrix_sdk::deserialized_responses::RawAnySyncOrStrippedState::Stripped(s) => {
|
||||
s.json().get().to_string()
|
||||
}
|
||||
};
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&raw_json_str) {
|
||||
if let Some(content) = value.get("content") {
|
||||
if let Some(images) = content.get("images").and_then(|v| v.as_object()) {
|
||||
for (key, entry) in images {
|
||||
let url = entry
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
if url.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let name = entry
|
||||
.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(key)
|
||||
.to_string();
|
||||
let animated = url.contains(".gif");
|
||||
emojis.push(CustomEmoji {
|
||||
id: key.clone(),
|
||||
name,
|
||||
url,
|
||||
category: "custom".to_string(),
|
||||
animated,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(emojis))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -53,34 +112,50 @@ pub async fn upload_emoji(
|
||||
) -> Result<Json<CustomEmoji>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let _room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let _room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let path = std::path::Path::new(&req.image_path);
|
||||
if !path.exists() {
|
||||
return Err(axum::http::StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let mime_type = match path.extension().and_then(|e| e.to_str()) {
|
||||
let mime_type: mime::Mime = match path.extension().and_then(|e| e.to_str()) {
|
||||
Some("png") => "image/png",
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg",
|
||||
Some("gif") => "image/gif",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
};
|
||||
}
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let data = std::fs::read(path).map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let content_type = mime_type.parse::<matrix_sdk::ruma::mime::Mime>().map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let is_animated = mime_type.subtype() == mime::IMAGE_GIF.subtype();
|
||||
|
||||
let response = session.client.media().upload(&content_type, data).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let response = session
|
||||
.client
|
||||
.media()
|
||||
.upload(&mime_type, data)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(CustomEmoji {
|
||||
id: format!("emoji_{}", chrono::Utc::now().timestamp()),
|
||||
name: req.name,
|
||||
url: response.content_uri.to_string(),
|
||||
category: "custom".to_string(),
|
||||
animated: mime_type == "image/gif",
|
||||
animated: is_animated,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
47
server/src/routes/media.rs
Normal file
47
server/src/routes/media.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, State},
|
||||
http::HeaderMap,
|
||||
response::Response,
|
||||
};
|
||||
use matrix_sdk::ruma::OwnedMxcUri;
|
||||
|
||||
pub async fn get_media(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(mxc_path): Path<String>,
|
||||
) -> Result<Response<Body>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let mxc_uri: OwnedMxcUri = mxc_path.as_str().into();
|
||||
|
||||
let request =
|
||||
matrix_sdk::ruma::api::client::media::get_content::v3::Request::from_url(&mxc_uri)
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let response = session
|
||||
.client
|
||||
.send(request, None)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let content_type = response
|
||||
.content_type
|
||||
.unwrap_or_else(|| "application/octet-stream".to_string());
|
||||
|
||||
let body = response.file;
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(200)
|
||||
.header("Content-Type", content_type)
|
||||
.header("Cache-Control", "public, max-age=86400")
|
||||
.body(Body::from(body))
|
||||
.unwrap())
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::routes::metrics::MESSAGES_SENT_TOTAL;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Path, State, Query},
|
||||
extract::{Path, Query, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::ServerState;
|
||||
use super::auth::extract_token;
|
||||
use matrix_sdk::room::MessagesOptions;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct MessageInfo {
|
||||
@@ -14,6 +16,15 @@ pub struct MessageInfo {
|
||||
pub body: String,
|
||||
pub timestamp: u64,
|
||||
pub reply_to: Option<String>,
|
||||
pub edited: bool,
|
||||
pub reactions: std::collections::HashMap<String, Vec<String>>,
|
||||
pub msgtype: String,
|
||||
pub media_url: Option<String>,
|
||||
pub filename: Option<String>,
|
||||
pub mimetype: Option<String>,
|
||||
pub width: Option<u64>,
|
||||
pub height: Option<u64>,
|
||||
pub size: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -30,29 +41,196 @@ pub async fn get_room_messages(
|
||||
) -> Result<Json<Vec<MessageInfo>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let limit = query.limit.unwrap_or(50);
|
||||
let mut options = matrix_sdk::ruma::api::client::message::get_message_events::v3::Request::new();
|
||||
options.limit = limit.into();
|
||||
options.from = query.from.map(|t| t.into());
|
||||
let mut options = MessagesOptions::backward();
|
||||
options.limit = matrix_sdk::ruma::uint!(50);
|
||||
if let Some(from) = query.from {
|
||||
options.from = Some(from);
|
||||
}
|
||||
|
||||
let messages = room.messages(options).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let messages = room
|
||||
.messages(options)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mut reactions_map: std::collections::HashMap<
|
||||
String,
|
||||
std::collections::HashMap<String, Vec<String>>,
|
||||
> = std::collections::HashMap::new();
|
||||
let mut result = Vec::new();
|
||||
for msg in messages.chunk {
|
||||
if let matrix_sdk::ruma::events::AnySyncMessageLikeEvent::RoomMessage(ev) = msg {
|
||||
result.push(MessageInfo {
|
||||
event_id: ev.event_id().to_string(),
|
||||
sender: ev.sender().to_string(),
|
||||
body: ev.content().body().to_string(),
|
||||
timestamp: ev.origin_server_ts().0,
|
||||
reply_to: ev.content().in_reply_to().map(|r| r.event_id.to_string()),
|
||||
});
|
||||
|
||||
for msg in &messages.chunk {
|
||||
let event_value: serde_json::Value = match msg.event.deserialize_as() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let event_type = event_value
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if event_type == "m.reaction" {
|
||||
if let Some(content) = event_value.get("content") {
|
||||
if let Some(relates_to) = content.get("m.relates_to") {
|
||||
if let (Some(target_id), Some(key)) = (
|
||||
relates_to.get("event_id").and_then(|v| v.as_str()),
|
||||
relates_to.get("key").and_then(|v| v.as_str()),
|
||||
) {
|
||||
let sender = event_value
|
||||
.get("sender")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
reactions_map
|
||||
.entry(target_id.to_string())
|
||||
.or_default()
|
||||
.entry(key.to_string())
|
||||
.or_default()
|
||||
.push(sender.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if event_type == "m.room.redaction" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event_id = match event_value.get("event_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let sender = match event_value.get("sender").and_then(|v| v.as_str()) {
|
||||
Some(s) => s.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let timestamp = event_value
|
||||
.get("origin_server_ts")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let content = event_value
|
||||
.get("content")
|
||||
.unwrap_or(&serde_json::Value::Null);
|
||||
let msgtype = content
|
||||
.get("msgtype")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if msgtype != "m.text"
|
||||
&& msgtype != "m.notice"
|
||||
&& msgtype != "m.emote"
|
||||
&& msgtype != "m.image"
|
||||
&& msgtype != "m.file"
|
||||
&& msgtype != "m.video"
|
||||
&& msgtype != "m.audio"
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let body = content
|
||||
.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let new_content = content.get("m.new_content");
|
||||
let (display_body, edited) = if let Some(nc) = new_content {
|
||||
let new_body = nc.get("body").and_then(|v| v.as_str()).unwrap_or(&body);
|
||||
(new_body.to_string(), true)
|
||||
} else {
|
||||
(body.clone(), false)
|
||||
};
|
||||
|
||||
if body.is_empty() && (msgtype == "m.text" || msgtype == "m.notice" || msgtype == "m.emote")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let reply_to = content
|
||||
.get("m.relates_to")
|
||||
.and_then(|r| r.get("m.in_reply_to"))
|
||||
.and_then(|r| r.get("event_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let reactions = reactions_map.remove(&event_id).unwrap_or_default();
|
||||
|
||||
let media_url = content
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
content
|
||||
.get("file")
|
||||
.and_then(|f| f.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
});
|
||||
|
||||
let filename = if msgtype == "m.image"
|
||||
|| msgtype == "m.file"
|
||||
|| msgtype == "m.video"
|
||||
|| msgtype == "m.audio"
|
||||
{
|
||||
Some(body.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mimetype = content
|
||||
.get("info")
|
||||
.and_then(|i| i.get("mimetype"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let width = content
|
||||
.get("info")
|
||||
.and_then(|i| i.get("w"))
|
||||
.and_then(|v| v.as_u64());
|
||||
|
||||
let height = content
|
||||
.get("info")
|
||||
.and_then(|i| i.get("h"))
|
||||
.and_then(|v| v.as_u64());
|
||||
|
||||
let size = content
|
||||
.get("info")
|
||||
.and_then(|i| i.get("size"))
|
||||
.and_then(|v| v.as_u64());
|
||||
|
||||
result.push(MessageInfo {
|
||||
event_id,
|
||||
sender,
|
||||
body: display_body,
|
||||
timestamp,
|
||||
reply_to,
|
||||
edited,
|
||||
reactions,
|
||||
msgtype: msgtype.to_string(),
|
||||
media_url,
|
||||
filename,
|
||||
mimetype,
|
||||
width,
|
||||
height,
|
||||
size,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Json(result))
|
||||
@@ -71,14 +249,192 @@ pub async fn send_message(
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let content = matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
|
||||
let txn_id = matrix_sdk::ruma::TransactionId::new();
|
||||
let response = room.send(content, Some(&txn_id)).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let content =
|
||||
matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
|
||||
let response = room
|
||||
.send(content)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
MESSAGES_SENT_TOTAL.inc();
|
||||
Ok(Json(response.event_id.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EditMessageRequest {
|
||||
pub event_id: String,
|
||||
pub new_content: String,
|
||||
}
|
||||
|
||||
pub async fn edit_message(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<EditMessageRequest>,
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let original_event_id: matrix_sdk::ruma::OwnedEventId = req
|
||||
.event_id
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let content = matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(
|
||||
&req.new_content,
|
||||
)
|
||||
.make_replacement(
|
||||
matrix_sdk::ruma::events::room::message::ReplacementMetadata::new(original_event_id, None),
|
||||
None,
|
||||
);
|
||||
|
||||
let response = room
|
||||
.send(content)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(response.event_id.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DeleteMessageRequest {
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn delete_message(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path((room_id, event_id)): Path<(String, String)>,
|
||||
Json(req): Json<DeleteMessageRequest>,
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let eid: matrix_sdk::ruma::OwnedEventId = event_id
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let response = room
|
||||
.redact(&eid, req.reason.as_deref(), None)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(response.event_id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ReactRequest {
|
||||
pub event_id: String,
|
||||
pub key: String,
|
||||
}
|
||||
|
||||
pub async fn react_to_message(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<ReactRequest>,
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let event_id: matrix_sdk::ruma::OwnedEventId = req
|
||||
.event_id
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let content = matrix_sdk::ruma::events::reaction::ReactionEventContent::new(
|
||||
matrix_sdk::ruma::events::relation::Annotation::new(event_id, req.key),
|
||||
);
|
||||
|
||||
let response = room
|
||||
.send(content)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(response.event_id.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct TypingRequest {
|
||||
pub typing: bool,
|
||||
}
|
||||
|
||||
pub async fn set_typing(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<TypingRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
room.typing_notice(req.typing)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
82
server/src/routes/metrics.rs
Normal file
82
server/src/routes/metrics.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use axum::{body::Body, extract::State, http::HeaderMap, response::IntoResponse};
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{Encoder, IntCounter, IntGauge, Registry, TextEncoder};
|
||||
|
||||
lazy_static! {
|
||||
pub static ref REGISTRY: Registry = Registry::new();
|
||||
pub static ref HTTP_REQUESTS_TOTAL: IntCounter =
|
||||
IntCounter::new("eifeldc_http_requests_total", "Total HTTP requests")
|
||||
.expect("metric can be created");
|
||||
pub static ref ACTIVE_SESSIONS: IntGauge =
|
||||
IntGauge::new("eifeldc_active_sessions", "Number of active sessions")
|
||||
.expect("metric can be created");
|
||||
pub static ref ACTIVE_WEBSOCKETS: IntGauge = IntGauge::new(
|
||||
"eifeldc_active_websockets",
|
||||
"Number of active WebSocket connections"
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref MESSAGES_SENT_TOTAL: IntCounter =
|
||||
IntCounter::new("eifeldc_messages_sent_total", "Total messages sent")
|
||||
.expect("metric can be created");
|
||||
pub static ref ROOMS_JOINED_TOTAL: IntCounter =
|
||||
IntCounter::new("eifeldc_rooms_joined_total", "Total room join operations")
|
||||
.expect("metric can be created");
|
||||
pub static ref UPLOADS_TOTAL: IntCounter =
|
||||
IntCounter::new("eifeldc_uploads_total", "Total file uploads")
|
||||
.expect("metric can be created");
|
||||
pub static ref VOICE_PARTICIPANTS: IntGauge = IntGauge::new(
|
||||
"eifeldc_voice_participants",
|
||||
"Current voice channel participants"
|
||||
)
|
||||
.expect("metric can be created");
|
||||
}
|
||||
|
||||
pub fn init_metrics() {
|
||||
REGISTRY
|
||||
.register(Box::new(HTTP_REQUESTS_TOTAL.clone()))
|
||||
.expect("metric can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(ACTIVE_SESSIONS.clone()))
|
||||
.expect("metric can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(ACTIVE_WEBSOCKETS.clone()))
|
||||
.expect("metric can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(MESSAGES_SENT_TOTAL.clone()))
|
||||
.expect("metric can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(ROOMS_JOINED_TOTAL.clone()))
|
||||
.expect("metric can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(UPLOADS_TOTAL.clone()))
|
||||
.expect("metric can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(VOICE_PARTICIPANTS.clone()))
|
||||
.expect("metric can be registered");
|
||||
}
|
||||
|
||||
pub async fn get_metrics(
|
||||
State(state): State<crate::ServerState>,
|
||||
_headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
ACTIVE_SESSIONS.set(state.read().await.sessions.len() as i64);
|
||||
|
||||
let mut participants: i64 = 0;
|
||||
let s = state.read().await;
|
||||
for room in s.voice_rooms.rooms.values() {
|
||||
participants += room.participants.len() as i64;
|
||||
}
|
||||
VOICE_PARTICIPANTS.set(participants);
|
||||
drop(s);
|
||||
|
||||
let metric_families = REGISTRY.gather();
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = TextEncoder::new();
|
||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||
|
||||
axum::http::Response::builder()
|
||||
.status(200)
|
||||
.header("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
|
||||
.body(Body::from(buffer))
|
||||
.unwrap()
|
||||
}
|
||||
@@ -1,37 +1,79 @@
|
||||
pub mod auth;
|
||||
pub mod rooms;
|
||||
pub mod messages;
|
||||
pub mod presence;
|
||||
pub mod voice;
|
||||
pub mod emoji;
|
||||
pub mod media;
|
||||
pub mod messages;
|
||||
pub mod metrics;
|
||||
pub mod presence;
|
||||
pub mod profile;
|
||||
pub mod rate_limit;
|
||||
pub mod roles;
|
||||
pub mod rooms;
|
||||
pub mod threads;
|
||||
pub mod upload;
|
||||
pub mod voice;
|
||||
pub mod ws;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{get, post},
|
||||
middleware,
|
||||
http::HeaderMap,
|
||||
};
|
||||
use tower_http::cors::{CorsLayer, Any};
|
||||
use crate::ServerState;
|
||||
use crate::routes::auth::auth_middleware;
|
||||
use crate::routes::rate_limit::{RateLimitConfig, RateLimiter};
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tower_http::cors::{AllowHeaders, AllowMethods, CorsLayer};
|
||||
|
||||
pub fn api_router(state: ServerState) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
let allowed_origins = std::env::var("EIFELDC_CORS_ORIGINS").unwrap_or_else(|_| "*".to_string());
|
||||
|
||||
let cors = if allowed_origins == "*" {
|
||||
CorsLayer::permissive()
|
||||
} else {
|
||||
let origins: Vec<_> = allowed_origins
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
CorsLayer::new()
|
||||
.allow_origin(origins)
|
||||
.allow_methods(AllowMethods::any())
|
||||
.allow_headers(AllowHeaders::any())
|
||||
.max_age(Duration::from_secs(3600))
|
||||
};
|
||||
|
||||
let max_rpm = std::env::var("EIFELDC_RATE_LIMIT")
|
||||
.ok()
|
||||
.map(|v| v.parse().unwrap_or(60))
|
||||
.unwrap_or(60);
|
||||
let rate_limiter = RateLimiter::new(RateLimitConfig {
|
||||
max_requests: max_rpm,
|
||||
window_secs: 60,
|
||||
});
|
||||
|
||||
Router::new()
|
||||
.nest("/api", api_routes(state))
|
||||
.layer(cors)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
rate_limiter,
|
||||
rate_limit::rate_limit_middleware,
|
||||
))
|
||||
}
|
||||
|
||||
fn api_routes(state: ServerState) -> Router {
|
||||
let max_upload_bytes: usize = std::env::var("EIFELDC_MAX_UPLOAD_MB")
|
||||
.ok()
|
||||
.map(|v| v.parse().unwrap_or(50))
|
||||
.unwrap_or(50)
|
||||
* 1024
|
||||
* 1024;
|
||||
|
||||
let default_body_limit = axum::extract::DefaultBodyLimit::max(max_upload_bytes);
|
||||
|
||||
let public = Router::new()
|
||||
.route("/login", post(auth::login))
|
||||
.route("/register", post(auth::register))
|
||||
.route("/current-user", get(auth::get_current_user));
|
||||
.route("/current-user", get(auth::get_current_user))
|
||||
.route("/ws", get(ws::ws_handler))
|
||||
.route("/metrics", get(metrics::get_metrics));
|
||||
|
||||
let protected = Router::new()
|
||||
.route("/logout", post(auth::logout))
|
||||
@@ -40,24 +82,63 @@ fn api_routes(state: ServerState) -> Router {
|
||||
.route("/rooms/join", post(rooms::join_room))
|
||||
.route("/rooms/{room_id}/leave", post(rooms::leave_room))
|
||||
.route("/rooms/{room_id}/members", get(rooms::get_room_members))
|
||||
.route("/rooms/{room_id}/messages", get(messages::get_room_messages))
|
||||
.route("/rooms/{room_id}/name", post(rooms::set_room_name))
|
||||
.route("/rooms/{room_id}/topic", post(rooms::set_room_topic))
|
||||
.route("/rooms/{room_id}/avatar", post(rooms::set_room_avatar))
|
||||
.route(
|
||||
"/rooms/{room_id}/messages",
|
||||
get(messages::get_room_messages),
|
||||
)
|
||||
.route("/rooms/{room_id}/send", post(messages::send_message))
|
||||
.route("/rooms/{room_id}/edit", post(messages::edit_message))
|
||||
.route(
|
||||
"/rooms/{room_id}/delete/{event_id}",
|
||||
post(messages::delete_message),
|
||||
)
|
||||
.route("/rooms/{room_id}/react", post(messages::react_to_message))
|
||||
.route("/rooms/{room_id}/typing", post(messages::set_typing))
|
||||
.route("/presence/set", post(presence::set_presence))
|
||||
.route("/presence/{user_id}", get(presence::get_presence))
|
||||
.route("/voice/join", post(voice::join_voice_channel))
|
||||
.route("/voice/leave", post(voice::leave_voice_channel))
|
||||
.route("/voice/toggle-mute", post(voice::toggle_mute))
|
||||
.route("/voice/toggle-deafen", post(voice::toggle_deafen))
|
||||
.route("/voice/participants", get(voice::get_voice_participants))
|
||||
.route("/rooms/{room_id}/roles", get(roles::get_roles))
|
||||
.route("/rooms/{room_id}/roles/assign", post(roles::assign_role))
|
||||
.route("/rooms/{room_id}/roles/remove", post(roles::remove_role))
|
||||
.route("/rooms/{room_id}/permissions/{user_id}", get(roles::get_permissions))
|
||||
.route(
|
||||
"/rooms/{room_id}/permissions/{user_id}",
|
||||
get(roles::get_permissions),
|
||||
)
|
||||
.route("/rooms/{room_id}/emoji", get(emoji::get_custom_emoji))
|
||||
.route("/rooms/{room_id}/emoji/upload", post(emoji::upload_emoji))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
|
||||
.route("/rooms/{room_id}/threads", get(threads::get_threads))
|
||||
.route(
|
||||
"/rooms/{room_id}/threads/{thread_id}",
|
||||
get(threads::get_thread_messages),
|
||||
)
|
||||
.route(
|
||||
"/rooms/{room_id}/threads/{thread_id}/reply",
|
||||
post(threads::send_thread_reply),
|
||||
)
|
||||
.route("/rooms/{room_id}/reply", post(threads::send_reply))
|
||||
.route("/rooms/{room_id}/upload", post(upload::upload_file))
|
||||
.route("/rooms/unread", get(rooms::get_unread_counts))
|
||||
.route("/rooms/{room_id}/read", post(rooms::mark_room_read))
|
||||
.route("/media/{mxc_path}", get(media::get_media))
|
||||
.route("/profile/me", get(profile::get_own_profile))
|
||||
.route("/profile/{user_id}", get(profile::get_profile))
|
||||
.route("/profile/displayname", post(profile::set_display_name))
|
||||
.route("/profile/avatar", post(profile::upload_avatar))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
Router::new()
|
||||
.merge(public)
|
||||
.merge(protected)
|
||||
.with_state(state)
|
||||
}
|
||||
.layer(default_body_limit)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::ServerState;
|
||||
use super::auth::extract_token;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SetPresenceRequest {
|
||||
@@ -28,21 +28,31 @@ pub async fn set_presence(
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let presence_state = match req.status.as_str() {
|
||||
"online" => matrix_sdk::ruma::presence::PresenceState::Online,
|
||||
"away" => matrix_sdk::ruma::presence::PresenceState::Away,
|
||||
"unavailable" => matrix_sdk::ruma::presence::PresenceState::Unavailable,
|
||||
_ => matrix_sdk::ruma::presence::PresenceState::Online,
|
||||
};
|
||||
|
||||
let user_id = session.client.user_id().ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let mut request = matrix_sdk::ruma::api::client::presence::set_presence::v3::Request::new(user_id.to_owned());
|
||||
request.presence = presence_state;
|
||||
request.status_msg = req.status_msg;
|
||||
let user_id = session
|
||||
.client
|
||||
.user_id()
|
||||
.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let request = matrix_sdk::ruma::api::client::presence::set_presence::v3::Request::new(
|
||||
user_id.to_owned(),
|
||||
presence_state,
|
||||
);
|
||||
|
||||
session.client.send(request, None).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
session
|
||||
.client
|
||||
.send(request, None)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
@@ -54,16 +64,26 @@ pub async fn get_presence(
|
||||
) -> Result<Json<PresenceInfo>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let request = matrix_sdk::ruma::api::client::presence::get_presence::v3::Request::new(uid.to_owned());
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = user_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let request =
|
||||
matrix_sdk::ruma::api::client::presence::get_presence::v3::Request::new(uid.to_owned());
|
||||
|
||||
let response = session.client.send(request, None).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let response = session
|
||||
.client
|
||||
.send(request, None)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let status_str = match response.presence {
|
||||
matrix_sdk::ruma::presence::PresenceState::Online => "online",
|
||||
matrix_sdk::ruma::presence::PresenceState::Away => "away",
|
||||
matrix_sdk::ruma::presence::PresenceState::Unavailable => "unavailable",
|
||||
_ => "offline",
|
||||
};
|
||||
@@ -74,4 +94,4 @@ pub async fn get_presence(
|
||||
status_msg: response.status_msg,
|
||||
last_active: response.last_active_ago.map(|d| d.as_secs()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
156
server/src/routes/profile.rs
Normal file
156
server/src/routes/profile.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Multipart, Path, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct UserProfile {
|
||||
pub user_id: String,
|
||||
pub display_name: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_profile(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(user_id): Path<String>,
|
||||
) -> Result<Json<UserProfile>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = user_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let profile = session
|
||||
.client
|
||||
.get_profile(&uid)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
Ok(Json(UserProfile {
|
||||
user_id,
|
||||
display_name: profile.displayname,
|
||||
avatar_url: profile.avatar_url.map(|u| u.to_string()),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn get_own_profile(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<UserProfile>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user_id = session.user_id.clone();
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = user_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let profile = session
|
||||
.client
|
||||
.get_profile(&uid)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
Ok(Json(UserProfile {
|
||||
user_id,
|
||||
display_name: profile.displayname,
|
||||
avatar_url: profile.avatar_url.map(|u| u.to_string()),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SetDisplayNameRequest {
|
||||
pub display_name: String,
|
||||
}
|
||||
|
||||
pub async fn set_display_name(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<SetDisplayNameRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
session
|
||||
.client
|
||||
.account()
|
||||
.set_display_name(Some(&req.display_name))
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
pub async fn upload_avatar(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<UserProfile>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let field = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let content_type = field
|
||||
.content_type()
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string();
|
||||
let data = field
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mime_type: mime::Mime = content_type.parse().unwrap_or(mime::IMAGE_PNG);
|
||||
|
||||
let upload = session
|
||||
.client
|
||||
.media()
|
||||
.upload(&mime_type, data.to_vec())
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mxc_uri = upload.content_uri;
|
||||
|
||||
session
|
||||
.client
|
||||
.account()
|
||||
.set_avatar_url(Some(&mxc_uri))
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let user_id = session.user_id.clone();
|
||||
|
||||
Ok(Json(UserProfile {
|
||||
user_id,
|
||||
display_name: None,
|
||||
avatar_url: Some(mxc_uri.to_string()),
|
||||
}))
|
||||
}
|
||||
155
server/src/routes/rate_limit.rs
Normal file
155
server/src/routes/rate_limit.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::State,
|
||||
http::{Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
struct RateLimitEntry {
|
||||
count: u32,
|
||||
window_start: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RateLimitConfig {
|
||||
pub max_requests: u32,
|
||||
pub window_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_requests: 60,
|
||||
window_secs: 60,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RateLimiter {
|
||||
entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
|
||||
config: RateLimitConfig,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(config: RateLimitConfig) -> Self {
|
||||
Self {
|
||||
entries: Arc::new(RwLock::new(HashMap::new())),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check(&self, key: &str) -> bool {
|
||||
let mut entries = self.entries.write().await;
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(self.config.window_secs);
|
||||
|
||||
match entries.get_mut(key) {
|
||||
Some(entry) => {
|
||||
if now.duration_since(entry.window_start) > window {
|
||||
entry.count = 1;
|
||||
entry.window_start = now;
|
||||
true
|
||||
} else if entry.count < self.config.max_requests {
|
||||
entry.count += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
None => {
|
||||
entries.insert(
|
||||
key.to_string(),
|
||||
RateLimitEntry {
|
||||
count: 1,
|
||||
window_start: now,
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
State(limiter): State<RateLimiter>,
|
||||
headers: axum::http::HeaderMap,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let key = headers
|
||||
.get("x-forwarded-for")
|
||||
.or_else(|| headers.get("x-real-ip"))
|
||||
.or_else(|| headers.get("authorization"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
if limiter.check(&key).await {
|
||||
Ok(next.run(request).await)
|
||||
} else {
|
||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limit_allows_within_limit() {
|
||||
let limiter = RateLimiter::new(RateLimitConfig {
|
||||
max_requests: 5,
|
||||
window_secs: 60,
|
||||
});
|
||||
for _ in 0..5 {
|
||||
assert!(limiter.check("user1").await);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limit_blocks_over_limit() {
|
||||
let limiter = RateLimiter::new(RateLimitConfig {
|
||||
max_requests: 3,
|
||||
window_secs: 60,
|
||||
});
|
||||
assert!(limiter.check("user1").await);
|
||||
assert!(limiter.check("user1").await);
|
||||
assert!(limiter.check("user1").await);
|
||||
assert!(!limiter.check("user1").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limit_per_key() {
|
||||
let limiter = RateLimiter::new(RateLimitConfig {
|
||||
max_requests: 2,
|
||||
window_secs: 60,
|
||||
});
|
||||
assert!(limiter.check("user1").await);
|
||||
assert!(limiter.check("user1").await);
|
||||
assert!(!limiter.check("user1").await);
|
||||
assert!(limiter.check("user2").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limit_default_config() {
|
||||
let config = RateLimitConfig::default();
|
||||
assert_eq!(config.max_requests, 60);
|
||||
assert_eq!(config.window_secs, 60);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limit_new_key_always_allowed() {
|
||||
let limiter = RateLimiter::new(RateLimitConfig {
|
||||
max_requests: 1,
|
||||
window_secs: 60,
|
||||
});
|
||||
assert!(limiter.check("new_user").await);
|
||||
assert!(!limiter.check("new_user").await);
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use matrix_sdk::ruma::Int;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::ServerState;
|
||||
use super::auth::extract_token;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Role {
|
||||
@@ -52,22 +53,65 @@ pub async fn get_roles(
|
||||
) -> Result<Json<Vec<Role>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let mut roles = Vec::new();
|
||||
if let Ok(power_levels) = room.power_levels().await {
|
||||
for (uid, power_level) in &power_levels.users {
|
||||
roles.push(Role {
|
||||
id: format!("role_{}", uid),
|
||||
name: uid.to_string(),
|
||||
color: if *power_level >= 100 { "#ed4245".to_string() } else if *power_level >= 50 { "#fee75c".to_string() } else { "#5865f2".to_string() },
|
||||
permissions: vec![],
|
||||
position: *power_level as i32,
|
||||
});
|
||||
}
|
||||
let members = room
|
||||
.members(matrix_sdk::RoomMemberships::JOIN)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let power_levels_content: Option<
|
||||
matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent,
|
||||
> = room
|
||||
.get_state_event_static()
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|e| e.deserialize().ok())
|
||||
.and_then(|e| e.as_sync().cloned())
|
||||
.and_then(|e| e.as_original().cloned())
|
||||
.map(|e| e.content);
|
||||
|
||||
for member in members {
|
||||
let pl: i64 = if let Some(ref pl_content) = power_levels_content {
|
||||
pl_content
|
||||
.users
|
||||
.get(member.user_id())
|
||||
.map_or(0, |v| i64::from(*v))
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let color = if pl >= 100 {
|
||||
"#ed4245".to_string()
|
||||
} else if pl >= 50 {
|
||||
"#fee75c".to_string()
|
||||
} else {
|
||||
"#5865f2".to_string()
|
||||
};
|
||||
let display_name = member
|
||||
.display_name()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| member.user_id().to_string());
|
||||
roles.push(Role {
|
||||
id: format!("role_{}", member.user_id()),
|
||||
name: display_name,
|
||||
color,
|
||||
permissions: vec![],
|
||||
position: pl as i32,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Json(roles))
|
||||
@@ -87,22 +131,34 @@ pub async fn assign_role(
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = req.user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = req
|
||||
.user_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let power_level: i64 = match req.role_id.as_str() {
|
||||
let power_level: i32 = match req.role_id.as_str() {
|
||||
"admin" => 100,
|
||||
"moderator" => 50,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
let mut content = matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent::new();
|
||||
content.users.insert(uid.to_owned(), power_level.into());
|
||||
|
||||
room.send_state_event(content).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
room.update_power_levels(vec![(uid.as_ref(), Int::from(power_level))])
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
@@ -121,17 +177,28 @@ pub async fn remove_role(
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = req.user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = req
|
||||
.user_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
if let Ok(mut power_levels) = room.power_levels().await {
|
||||
power_levels.users.remove(&uid);
|
||||
let content = matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent::from(power_levels);
|
||||
room.send_state_event(content).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
room.update_power_levels(vec![(uid.as_ref(), Int::from(0))])
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
@@ -143,17 +210,41 @@ pub async fn get_permissions(
|
||||
) -> Result<Json<Permissions>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = user_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let uid: matrix_sdk::ruma::OwnedUserId = user_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let user_power = if let Ok(power_levels) = room.power_levels().await {
|
||||
power_levels.users.get(&uid).copied().map(|p| p.into()).unwrap_or(power_levels.users_default as i64)
|
||||
let power_levels_content: Option<
|
||||
matrix_sdk::ruma::events::room::power_levels::RoomPowerLevelsEventContent,
|
||||
> = room
|
||||
.get_state_event_static()
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|e| e.deserialize().ok())
|
||||
.and_then(|e| e.as_sync().cloned())
|
||||
.and_then(|e| e.as_original().cloned())
|
||||
.map(|e| e.content);
|
||||
|
||||
let user_power = if let Some(pl) = power_levels_content {
|
||||
pl.users.get(&uid).map_or(0, |v| i64::from(*v))
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Ok(Json(power_level_to_permissions(user_power)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::routes::metrics::ROOMS_JOINED_TOTAL;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
extract::{Multipart, Path, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use crate::ServerState;
|
||||
use super::auth::extract_token;
|
||||
use matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType;
|
||||
use matrix_sdk::ruma::events::receipt::ReceiptThread;
|
||||
use matrix_sdk::RoomMemberships;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct RoomInfo {
|
||||
@@ -15,6 +19,8 @@ pub struct RoomInfo {
|
||||
pub is_encrypted: bool,
|
||||
pub member_count: u64,
|
||||
pub topic: Option<String>,
|
||||
pub unread_notifications: u64,
|
||||
pub unread_messages: u64,
|
||||
}
|
||||
|
||||
pub async fn get_joined_rooms(
|
||||
@@ -23,16 +29,29 @@ pub async fn get_joined_rooms(
|
||||
) -> Result<Json<Vec<RoomInfo>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rooms = session.client.joined_rooms();
|
||||
let mut result = Vec::new();
|
||||
for room in rooms {
|
||||
let name = room.display_name().await.map(|n| n.to_string()).unwrap_or_default();
|
||||
let name = room
|
||||
.display_name()
|
||||
.await
|
||||
.map(|n| n.to_string())
|
||||
.unwrap_or_default();
|
||||
let avatar_url = room.avatar_url().map(|u| u.to_string());
|
||||
let member_count = room.joined_members().len() as u64;
|
||||
let member_count = room
|
||||
.members(RoomMemberships::JOIN)
|
||||
.await
|
||||
.map(|m| m.len() as u64)
|
||||
.unwrap_or(0);
|
||||
let topic = room.topic().map(|t| t.to_string());
|
||||
let is_encrypted = room.is_encrypted().await.unwrap_or(false);
|
||||
let unread_notifications = room.unread_notification_counts().notification_count;
|
||||
let unread_messages = room.num_unread_messages();
|
||||
result.push(RoomInfo {
|
||||
room_id: room.room_id().to_string(),
|
||||
name,
|
||||
@@ -40,6 +59,8 @@ pub async fn get_joined_rooms(
|
||||
is_encrypted,
|
||||
member_count,
|
||||
topic,
|
||||
unread_notifications,
|
||||
unread_messages,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -60,11 +81,14 @@ pub async fn create_room(
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let vis = match req.visibility.as_deref() {
|
||||
Some("public") => matrix_sdk::ruma::Space::Public,
|
||||
_ => matrix_sdk::ruma::Space::Private,
|
||||
Some("public") => matrix_sdk::ruma::api::client::room::Visibility::Public,
|
||||
_ => matrix_sdk::ruma::api::client::room::Visibility::Private,
|
||||
};
|
||||
|
||||
let mut request = matrix_sdk::ruma::api::client::room::create_room::v3::Request::new();
|
||||
@@ -72,8 +96,12 @@ pub async fn create_room(
|
||||
request.topic = req.topic;
|
||||
request.visibility = vis;
|
||||
|
||||
let response = session.client.create_room(request).await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(response.room_id.to_string()))
|
||||
let response = session
|
||||
.client
|
||||
.create_room(request)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(response.room_id().to_string()))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
@@ -88,17 +116,24 @@ pub async fn join_room(
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let room_id = session.client
|
||||
.join_room_by_id_or_alias(
|
||||
&req.room_id_or_alias.try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?,
|
||||
&[],
|
||||
)
|
||||
let alias: matrix_sdk::ruma::OwnedRoomOrAliasId = req
|
||||
.room_id_or_alias
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let room = session
|
||||
.client
|
||||
.join_room_by_id_or_alias(&alias, &[])
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(room_id.to_string()))
|
||||
ROOMS_JOINED_TOTAL.inc();
|
||||
Ok(Json(room.room_id().to_string()))
|
||||
}
|
||||
|
||||
pub async fn leave_room(
|
||||
@@ -108,11 +143,22 @@ pub async fn leave_room(
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
room.leave().await.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
room.leave()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
@@ -124,11 +170,219 @@ pub async fn get_room_members(
|
||||
) -> Result<Json<Vec<String>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s.sessions.get(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id.as_str().try_into().map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session.client.get_room(&rid).ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let members = room.joined_members();
|
||||
Ok(Json(members.iter().map(|m| m.user_id().to_string()).collect()))
|
||||
}
|
||||
let members = room
|
||||
.members(RoomMemberships::JOIN)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(
|
||||
members.iter().map(|m| m.user_id().to_string()).collect(),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SetRoomNameRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
pub async fn set_room_name(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<SetRoomNameRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
room.set_name(req.name)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SetRoomTopicRequest {
|
||||
pub topic: String,
|
||||
}
|
||||
|
||||
pub async fn set_room_topic(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<SetRoomTopicRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
room.set_room_topic(&req.topic)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
pub async fn set_room_avatar(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let field = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let content_type = field
|
||||
.content_type()
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string();
|
||||
let data = field
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mime_type: mime::Mime = content_type.parse().unwrap_or(mime::IMAGE_PNG);
|
||||
|
||||
let upload = session
|
||||
.client
|
||||
.media()
|
||||
.upload(&mime_type, data.to_vec())
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
room.set_avatar_url(&upload.content_uri, None)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct UnreadInfo {
|
||||
pub room_id: String,
|
||||
pub unread_notifications: u64,
|
||||
pub unread_messages: u64,
|
||||
}
|
||||
|
||||
pub async fn get_unread_counts(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<Vec<UnreadInfo>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rooms = session.client.joined_rooms();
|
||||
let mut result = Vec::new();
|
||||
for room in rooms {
|
||||
let unread_notifications = room.unread_notification_counts().notification_count;
|
||||
let unread_messages = room.num_unread_messages();
|
||||
result.push(UnreadInfo {
|
||||
room_id: room.room_id().to_string(),
|
||||
unread_notifications,
|
||||
unread_messages,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct MarkReadRequest {
|
||||
pub event_id: String,
|
||||
}
|
||||
|
||||
pub async fn mark_room_read(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<MarkReadRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let event_id: matrix_sdk::ruma::OwnedEventId = req
|
||||
.event_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
room.send_single_receipt(ReceiptType::Read, ReceiptThread::Main, event_id)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
379
server/src/routes/threads.rs
Normal file
379
server/src/routes/threads.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use matrix_sdk::room::MessagesOptions;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ThreadInfo {
|
||||
pub root_event_id: String,
|
||||
pub root_sender: String,
|
||||
pub root_body: String,
|
||||
pub root_timestamp: u64,
|
||||
pub reply_count: u32,
|
||||
pub last_reply_event_id: Option<String>,
|
||||
pub last_reply_sender: Option<String>,
|
||||
pub last_reply_body: Option<String>,
|
||||
pub last_reply_timestamp: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ThreadMessageInfo {
|
||||
pub event_id: String,
|
||||
pub sender: String,
|
||||
pub body: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ThreadQuery {
|
||||
pub limit: Option<u32>,
|
||||
pub from: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_threads(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
) -> Result<Json<Vec<ThreadInfo>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let mut options = MessagesOptions::backward();
|
||||
options.limit = matrix_sdk::ruma::uint!(100);
|
||||
|
||||
let messages = room
|
||||
.messages(options)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mut thread_roots: std::collections::HashMap<String, ThreadInfo> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for msg in &messages.chunk {
|
||||
let event_value: serde_json::Value = match msg.event.deserialize_as() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let event_type = event_value
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if event_type != "m.room.message" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event_id = match event_value.get("event_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
let sender = event_value
|
||||
.get("sender")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let timestamp = event_value
|
||||
.get("origin_server_ts")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
let content = event_value
|
||||
.get("content")
|
||||
.unwrap_or(&serde_json::Value::Null);
|
||||
let body = content
|
||||
.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let new_content = content.get("m.new_content");
|
||||
let display_body = if let Some(nc) = new_content {
|
||||
nc.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(&body)
|
||||
.to_string()
|
||||
} else {
|
||||
body.clone()
|
||||
};
|
||||
|
||||
let relates_to = content.get("m.relates_to");
|
||||
|
||||
let is_thread = relates_to
|
||||
.and_then(|r| r.get("rel_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
== Some("m.thread");
|
||||
|
||||
if is_thread {
|
||||
if let Some(thread_info) = relates_to {
|
||||
let root_id = thread_info
|
||||
.get("event_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
if root_id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let is_falling_back = thread_info
|
||||
.get("is_falling_back")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_falling_back {
|
||||
let reply_count = thread_roots
|
||||
.get(&root_id)
|
||||
.map(|t: &ThreadInfo| t.reply_count)
|
||||
.unwrap_or(0);
|
||||
thread_roots
|
||||
.entry(root_id.clone())
|
||||
.or_insert_with(|| ThreadInfo {
|
||||
root_event_id: root_id.clone(),
|
||||
root_sender: sender.clone(),
|
||||
root_body: display_body.clone(),
|
||||
root_timestamp: timestamp,
|
||||
reply_count,
|
||||
last_reply_event_id: None,
|
||||
last_reply_sender: None,
|
||||
last_reply_body: None,
|
||||
last_reply_timestamp: None,
|
||||
});
|
||||
} else {
|
||||
let entry = thread_roots
|
||||
.entry(root_id.clone())
|
||||
.or_insert_with(|| ThreadInfo {
|
||||
root_event_id: root_id.clone(),
|
||||
root_sender: String::new(),
|
||||
root_body: String::new(),
|
||||
root_timestamp: 0,
|
||||
reply_count: 0,
|
||||
last_reply_event_id: None,
|
||||
last_reply_sender: None,
|
||||
last_reply_body: None,
|
||||
last_reply_timestamp: None,
|
||||
});
|
||||
entry.reply_count += 1;
|
||||
entry.last_reply_event_id = Some(event_id);
|
||||
entry.last_reply_sender = Some(sender.clone());
|
||||
entry.last_reply_body = Some(display_body);
|
||||
entry.last_reply_timestamp = Some(timestamp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(thread_roots.into_values().collect()))
|
||||
}
|
||||
|
||||
pub async fn get_thread_messages(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path((room_id, thread_id)): Path<(String, String)>,
|
||||
Query(query): Query<ThreadQuery>,
|
||||
) -> Result<Json<Vec<ThreadMessageInfo>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let limit = query.limit.unwrap_or(50);
|
||||
|
||||
let mut options = MessagesOptions::backward();
|
||||
options.limit = matrix_sdk::ruma::uint!(100);
|
||||
|
||||
let messages = room
|
||||
.messages(options)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
for msg in &messages.chunk {
|
||||
let event_value: serde_json::Value = match msg.event.deserialize_as() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let event_type = event_value
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if event_type != "m.room.message" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = event_value
|
||||
.get("content")
|
||||
.unwrap_or(&serde_json::Value::Null);
|
||||
let relates_to = content.get("m.relates_to");
|
||||
|
||||
let is_in_thread = relates_to
|
||||
.and_then(|r| {
|
||||
let rel_type = r.get("rel_type").and_then(|v| v.as_str());
|
||||
let event_id = r.get("event_id").and_then(|v| v.as_str());
|
||||
if rel_type == Some("m.thread") && event_id == Some(thread_id.as_str()) {
|
||||
Some(true)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let is_root =
|
||||
event_value.get("event_id").and_then(|v| v.as_str()) == Some(thread_id.as_str());
|
||||
|
||||
if !is_in_thread && !is_root {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event_id = match event_value.get("event_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
let sender = event_value
|
||||
.get("sender")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let timestamp = event_value
|
||||
.get("origin_server_ts")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
let body = content
|
||||
.get("body")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
if body.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
result.push(ThreadMessageInfo {
|
||||
event_id,
|
||||
sender,
|
||||
body,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
|
||||
result.truncate(limit as usize);
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ThreadReplyRequest {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
pub async fn send_thread_reply(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path((room_id, thread_id)): Path<(String, String)>,
|
||||
Json(req): Json<ThreadReplyRequest>,
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let root_event_id: matrix_sdk::ruma::OwnedEventId = thread_id
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mut content =
|
||||
matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
|
||||
content.relates_to = Some(matrix_sdk::ruma::events::room::message::Relation::Thread(
|
||||
matrix_sdk::ruma::events::relation::Thread::without_fallback(root_event_id),
|
||||
));
|
||||
|
||||
let response = room
|
||||
.send(content)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(response.event_id.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SendReplyRequest {
|
||||
pub message: String,
|
||||
pub reply_to: String,
|
||||
}
|
||||
|
||||
pub async fn send_reply(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
Json(req): Json<SendReplyRequest>,
|
||||
) -> Result<Json<String>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let reply_to_event_id: matrix_sdk::ruma::OwnedEventId = req
|
||||
.reply_to
|
||||
.parse()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mut content =
|
||||
matrix_sdk::ruma::events::room::message::RoomMessageEventContent::text_plain(&req.message);
|
||||
content.relates_to = Some(matrix_sdk::ruma::events::room::message::Relation::Reply {
|
||||
in_reply_to: matrix_sdk::ruma::events::relation::InReplyTo::new(reply_to_event_id),
|
||||
});
|
||||
|
||||
let response = room
|
||||
.send(content)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(response.event_id.to_string()))
|
||||
}
|
||||
97
server/src/routes/upload.rs
Normal file
97
server/src/routes/upload.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use super::auth::extract_token;
|
||||
use crate::routes::metrics::UPLOADS_TOTAL;
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{Multipart, Path, State},
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use matrix_sdk::ruma::events::room::message::{
|
||||
FileMessageEventContent, ImageMessageEventContent, MessageType, RoomMessageEventContent,
|
||||
};
|
||||
use matrix_sdk::ruma::OwnedMxcUri;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct UploadResult {
|
||||
pub event_id: String,
|
||||
pub media_url: Option<String>,
|
||||
pub filename: String,
|
||||
pub mimetype: Option<String>,
|
||||
pub size: Option<u64>,
|
||||
}
|
||||
|
||||
pub async fn upload_file(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Path(room_id): Path<String>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<UploadResult>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let rid: matrix_sdk::ruma::OwnedRoomId = room_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
let room = session
|
||||
.client
|
||||
.get_room(&rid)
|
||||
.ok_or(axum::http::StatusCode::NOT_FOUND)?;
|
||||
|
||||
let field = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let filename = field.file_name().unwrap_or("upload").to_string();
|
||||
let content_type = field
|
||||
.content_type()
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string();
|
||||
let data = field
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mime_type: mime::Mime = content_type
|
||||
.parse()
|
||||
.unwrap_or(mime::APPLICATION_OCTET_STREAM);
|
||||
|
||||
let media_response = session
|
||||
.client
|
||||
.media()
|
||||
.upload(&mime_type, data.to_vec())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Media upload failed: {}", e);
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let media_uri: OwnedMxcUri = media_response.content_uri.clone();
|
||||
|
||||
let content: RoomMessageEventContent = if mime_type.type_() == mime::IMAGE {
|
||||
MessageType::Image(ImageMessageEventContent::plain(filename.clone(), media_uri)).into()
|
||||
} else {
|
||||
MessageType::File(FileMessageEventContent::plain(filename.clone(), media_uri)).into()
|
||||
};
|
||||
|
||||
let response = room
|
||||
.send(content)
|
||||
.await
|
||||
.map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
UPLOADS_TOTAL.inc();
|
||||
Ok(Json(UploadResult {
|
||||
event_id: response.event_id.to_string(),
|
||||
media_url: Some(media_response.content_uri.to_string()),
|
||||
filename,
|
||||
mimetype: Some(content_type),
|
||||
size: Some(data.len() as u64),
|
||||
}))
|
||||
}
|
||||
@@ -1,80 +1,318 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::HeaderMap,
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::ServerState;
|
||||
use super::auth::extract_token;
|
||||
use crate::state::WsEvent;
|
||||
use crate::ServerState;
|
||||
use axum::{extract::State, http::HeaderMap, Json};
|
||||
use livekit_api::access_token::{AccessToken, VideoGrants};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct VoiceRequest {
|
||||
pub struct VoiceJoinRequest {
|
||||
pub room_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct VoiceStateInfo {
|
||||
pub struct VoiceJoinResponse {
|
||||
pub room_id: String,
|
||||
pub livekit_url: String,
|
||||
pub livekit_token: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct VoiceToggleResponse {
|
||||
pub muted: bool,
|
||||
pub deafened: bool,
|
||||
pub streaming: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct VoiceParticipantInfo {
|
||||
pub user_id: String,
|
||||
pub muted: bool,
|
||||
pub deafened: bool,
|
||||
}
|
||||
|
||||
fn sanitize_room_id(room_id: &str) -> String {
|
||||
room_id
|
||||
.replace(':', "-")
|
||||
.replace("!", "")
|
||||
.replace("/", "-")
|
||||
.replace(" ", "_")
|
||||
}
|
||||
|
||||
pub async fn join_voice_channel(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<VoiceRequest>,
|
||||
) -> Result<Json<VoiceStateInfo>, axum::http::StatusCode> {
|
||||
Json(req): Json<VoiceJoinRequest>,
|
||||
) -> Result<Json<VoiceJoinResponse>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let mut s = state.write().await;
|
||||
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
session.voice_manager.join_channel(req.room_id.clone(), session.user_id.clone());
|
||||
let (user_id, livekit_url, livekit_api_key, livekit_api_secret) = {
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
(
|
||||
session.user_id.clone(),
|
||||
s.livekit.url.clone(),
|
||||
s.livekit.api_key.clone(),
|
||||
s.livekit.api_secret.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Json(VoiceStateInfo {
|
||||
let lk_room = sanitize_room_id(&req.room_id);
|
||||
let lk_identity = sanitize_room_id(&user_id);
|
||||
|
||||
let grants = VideoGrants {
|
||||
room_join: true,
|
||||
room: lk_room.clone(),
|
||||
can_publish: true,
|
||||
can_subscribe: true,
|
||||
can_publish_data: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lk_token = AccessToken::with_api_key(&livekit_api_key, &livekit_api_secret)
|
||||
.with_identity(&lk_identity)
|
||||
.with_name(&user_id)
|
||||
.with_grants(grants)
|
||||
.to_jwt()
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to generate LiveKit token: {}", e);
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let old_channel;
|
||||
{
|
||||
let mut s = state.write().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get_mut(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
old_channel = session.voice_manager.active_channel.clone();
|
||||
session.voice_manager.join_channel(&req.room_id);
|
||||
}
|
||||
|
||||
if let Some(ref old) = old_channel {
|
||||
if old != &req.room_id {
|
||||
let mut s = state.write().await;
|
||||
if let Some(room) = s.voice_rooms.rooms.get_mut(old) {
|
||||
room.participants.remove(&user_id);
|
||||
if room.participants.is_empty() {
|
||||
s.voice_rooms.rooms.remove(old);
|
||||
}
|
||||
}
|
||||
drop(s);
|
||||
|
||||
let s = state.read().await;
|
||||
if let Some(session) = s.sessions.get(&token) {
|
||||
let _ = session.event_sender.send(WsEvent::VoiceUserLeft {
|
||||
room_id: old.clone(),
|
||||
user_id: user_id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut s = state.write().await;
|
||||
let room_entry = s
|
||||
.voice_rooms
|
||||
.rooms
|
||||
.entry(req.room_id.clone())
|
||||
.or_insert_with(|| crate::state::VoiceRoom {
|
||||
participants: HashMap::new(),
|
||||
});
|
||||
room_entry.participants.insert(
|
||||
user_id.clone(),
|
||||
crate::state::VoiceParticipant {
|
||||
user_id: user_id.clone(),
|
||||
muted: false,
|
||||
deafened: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(session) = s.sessions.get(&token) {
|
||||
let _ = session.event_sender.send(WsEvent::VoiceUserJoined {
|
||||
room_id: req.room_id.clone(),
|
||||
user_id: user_id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(VoiceJoinResponse {
|
||||
room_id: req.room_id,
|
||||
muted: false,
|
||||
deafened: false,
|
||||
streaming: false,
|
||||
livekit_url,
|
||||
livekit_token: lk_token,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn leave_voice_channel(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<VoiceRequest>,
|
||||
Json(req): Json<VoiceJoinRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let mut s = state.write().await;
|
||||
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
Ok(Json(session.voice_manager.leave_channel(&req.room_id, &session.user_id)))
|
||||
let user_id;
|
||||
{
|
||||
let mut s = state.write().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get_mut(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
user_id = session.user_id.clone();
|
||||
session.voice_manager.leave_channel();
|
||||
|
||||
if let Some(room) = s.voice_rooms.rooms.get_mut(&req.room_id) {
|
||||
room.participants.remove(&user_id);
|
||||
if room.participants.is_empty() {
|
||||
s.voice_rooms.rooms.remove(&req.room_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(session) = s.sessions.get(&token) {
|
||||
let _ = session.event_sender.send(WsEvent::VoiceUserLeft {
|
||||
room_id: req.room_id,
|
||||
user_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(true))
|
||||
}
|
||||
|
||||
pub async fn toggle_mute(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<VoiceRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
Json(_req): Json<VoiceJoinRequest>,
|
||||
) -> Result<Json<VoiceToggleResponse>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let mut s = state.write().await;
|
||||
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
session.voice_manager.toggle_mute(&req.room_id, &session.user_id)
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)
|
||||
.map(Json)
|
||||
let (user_id, room_id, new_muted, deafened);
|
||||
{
|
||||
let mut s = state.write().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get_mut(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
new_muted = session.voice_manager.toggle_mute();
|
||||
deafened = session.voice_manager.deafened;
|
||||
room_id = session
|
||||
.voice_manager
|
||||
.active_channel
|
||||
.clone()
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
|
||||
user_id = session.user_id.clone();
|
||||
|
||||
if let Some(room) = s.voice_rooms.rooms.get_mut(&room_id) {
|
||||
if let Some(participant) = room.participants.get_mut(&user_id) {
|
||||
participant.muted = new_muted;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(session) = s.sessions.get(&token) {
|
||||
let _ = session.event_sender.send(WsEvent::VoiceStateUpdate {
|
||||
room_id,
|
||||
user_id,
|
||||
muted: new_muted,
|
||||
deafened,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(VoiceToggleResponse {
|
||||
muted: new_muted,
|
||||
deafened,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn toggle_deafen(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<VoiceRequest>,
|
||||
) -> Result<Json<bool>, axum::http::StatusCode> {
|
||||
Json(_req): Json<VoiceJoinRequest>,
|
||||
) -> Result<Json<VoiceToggleResponse>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let mut s = state.write().await;
|
||||
let session = s.sessions.get_mut(&token).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
session.voice_manager.toggle_deafen(&req.room_id, &session.user_id)
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)
|
||||
.map(Json)
|
||||
}
|
||||
let (user_id, room_id, muted, new_deafened);
|
||||
{
|
||||
let mut s = state.write().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get_mut(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
new_deafened = session.voice_manager.toggle_deafen();
|
||||
muted = session.voice_manager.muted;
|
||||
room_id = session
|
||||
.voice_manager
|
||||
.active_channel
|
||||
.clone()
|
||||
.ok_or(axum::http::StatusCode::BAD_REQUEST)?;
|
||||
user_id = session.user_id.clone();
|
||||
|
||||
if let Some(room) = s.voice_rooms.rooms.get_mut(&room_id) {
|
||||
if let Some(participant) = room.participants.get_mut(&user_id) {
|
||||
participant.deafened = new_deafened;
|
||||
participant.muted = muted;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let s = state.read().await;
|
||||
if let Some(session) = s.sessions.get(&token) {
|
||||
let _ = session.event_sender.send(WsEvent::VoiceStateUpdate {
|
||||
room_id,
|
||||
user_id,
|
||||
muted,
|
||||
deafened: new_deafened,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(VoiceToggleResponse {
|
||||
muted,
|
||||
deafened: new_deafened,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn get_voice_participants(
|
||||
State(state): State<ServerState>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<Vec<VoiceParticipantInfo>>, axum::http::StatusCode> {
|
||||
let token = extract_token(&headers).ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
let s = state.read().await;
|
||||
let session = s
|
||||
.sessions
|
||||
.get(&token)
|
||||
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let room_id = match &session.voice_manager.active_channel {
|
||||
Some(r) => r.clone(),
|
||||
None => return Ok(Json(Vec::new())),
|
||||
};
|
||||
|
||||
let participants = match s.voice_rooms.rooms.get(&room_id) {
|
||||
Some(room) => room
|
||||
.participants
|
||||
.values()
|
||||
.map(|p| VoiceParticipantInfo {
|
||||
user_id: p.user_id.clone(),
|
||||
muted: p.muted,
|
||||
deafened: p.deafened,
|
||||
})
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
Ok(Json(participants))
|
||||
}
|
||||
|
||||
84
server/src/routes/ws.rs
Normal file
84
server/src/routes/ws.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use crate::ServerState;
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
Query, State, WebSocketUpgrade,
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WsQuery {
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Query(query): Query<WsQuery>,
|
||||
State(state): State<ServerState>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_ws(socket, state, query.token.unwrap_or_default()))
|
||||
}
|
||||
|
||||
async fn handle_ws(socket: WebSocket, state: ServerState, token: String) {
|
||||
let (user_id, mut event_rx) = {
|
||||
let s = state.read().await;
|
||||
match s.sessions.get(&token) {
|
||||
Some(session) => (session.user_id.clone(), session.event_sender.subscribe()),
|
||||
None => {
|
||||
let _ = socket.close().await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let connected_msg = serde_json::json!({"type": "connected", "user_id": user_id}).to_string();
|
||||
let (mut ws_sender, mut ws_receiver) = socket.split();
|
||||
|
||||
if ws_sender.send(Message::Text(connected_msg)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let (out_tx, mut out_rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||
|
||||
let sender_task = tokio::spawn(async move {
|
||||
while let Some(text) = out_rx.recv().await {
|
||||
if ws_sender.send(Message::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let receiver_task = tokio::spawn(async move {
|
||||
while let Some(msg) = ws_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(text)) if text.as_str() == "ping" => {}
|
||||
Ok(Message::Close(_)) | Err(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let forward_task = tokio::spawn(async move {
|
||||
while let Ok(event) = event_rx.recv().await {
|
||||
match serde_json::to_string(&event) {
|
||||
Ok(json) => {
|
||||
if out_tx.send(json).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize WS event: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = sender_task => {},
|
||||
_ = receiver_task => {},
|
||||
_ = forward_task => {},
|
||||
}
|
||||
}
|
||||
181
server/src/session_store.rs
Normal file
181
server/src/session_store.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use rusqlite::Connection;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct SessionStore {
|
||||
conn: Arc<Mutex<Connection>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StoredSession {
|
||||
pub token: String,
|
||||
pub user_id: String,
|
||||
pub homeserver: String,
|
||||
pub access_token: String,
|
||||
pub device_id: Option<String>,
|
||||
pub refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
pub fn new(path: &str) -> Result<Self, rusqlite::Error> {
|
||||
let conn = if Path::new(path).exists() {
|
||||
Connection::open(path)?
|
||||
} else {
|
||||
let conn = Connection::open(path)?;
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS sessions (
|
||||
token TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
homeserver TEXT NOT NULL,
|
||||
access_token TEXT NOT NULL,
|
||||
device_id TEXT,
|
||||
refresh_token TEXT
|
||||
);",
|
||||
)?;
|
||||
conn
|
||||
};
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn save_session(&self, session: &StoredSession) -> Result<(), rusqlite::Error> {
|
||||
let conn = self.conn.lock().await;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO sessions (token, user_id, homeserver, access_token, device_id, refresh_token)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
rusqlite::params![
|
||||
session.token,
|
||||
session.user_id,
|
||||
session.homeserver,
|
||||
session.access_token,
|
||||
session.device_id,
|
||||
session.refresh_token,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_all_sessions(&self) -> Result<Vec<StoredSession>, rusqlite::Error> {
|
||||
let conn = self.conn.lock().await;
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT token, user_id, homeserver, access_token, device_id, refresh_token FROM sessions"
|
||||
)?;
|
||||
let rows = stmt.query_map([], |row| {
|
||||
Ok(StoredSession {
|
||||
token: row.get(0)?,
|
||||
user_id: row.get(1)?,
|
||||
homeserver: row.get(2)?,
|
||||
access_token: row.get(3)?,
|
||||
device_id: row.get(4)?,
|
||||
refresh_token: row.get(5)?,
|
||||
})
|
||||
})?;
|
||||
rows.collect()
|
||||
}
|
||||
|
||||
pub async fn delete_session(&self, token: &str) -> Result<(), rusqlite::Error> {
|
||||
let conn = self.conn.lock().await;
|
||||
conn.execute("DELETE FROM sessions WHERE token = ?1", [token])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
async fn create_test_store() -> (SessionStore, tempfile::TempDir) {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("test.db");
|
||||
let path_str = path.to_str().unwrap().to_string();
|
||||
let store = SessionStore::new(&path_str).unwrap();
|
||||
(store, dir)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_and_retrieve_session() {
|
||||
let (store, _dir) = create_test_store().await;
|
||||
let session = StoredSession {
|
||||
token: "test-token".to_string(),
|
||||
user_id: "@alice:server".to_string(),
|
||||
homeserver: "https://matrix.server".to_string(),
|
||||
access_token: "access-123".to_string(),
|
||||
device_id: Some("DEVICE1".to_string()),
|
||||
refresh_token: Some("refresh-456".to_string()),
|
||||
};
|
||||
|
||||
store.save_session(&session).await.unwrap();
|
||||
let sessions = store.get_all_sessions().await.unwrap();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0].token, "test-token");
|
||||
assert_eq!(sessions[0].user_id, "@alice:server");
|
||||
assert_eq!(sessions[0].access_token, "access-123");
|
||||
assert_eq!(sessions[0].device_id, Some("DEVICE1".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_session_upserts() {
|
||||
let (store, _dir) = create_test_store().await;
|
||||
let session = StoredSession {
|
||||
token: "test-token".to_string(),
|
||||
user_id: "@alice:server".to_string(),
|
||||
homeserver: "https://matrix.server".to_string(),
|
||||
access_token: "access-123".to_string(),
|
||||
device_id: None,
|
||||
refresh_token: None,
|
||||
};
|
||||
|
||||
store.save_session(&session).await.unwrap();
|
||||
let mut updated = session.clone();
|
||||
updated.access_token = "access-456".to_string();
|
||||
store.save_session(&updated).await.unwrap();
|
||||
|
||||
let sessions = store.get_all_sessions().await.unwrap();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0].access_token, "access-456");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_session_removes_it() {
|
||||
let (store, _dir) = create_test_store().await;
|
||||
let session = StoredSession {
|
||||
token: "test-token".to_string(),
|
||||
user_id: "@alice:server".to_string(),
|
||||
homeserver: "https://matrix.server".to_string(),
|
||||
access_token: "access-123".to_string(),
|
||||
device_id: None,
|
||||
refresh_token: None,
|
||||
};
|
||||
|
||||
store.save_session(&session).await.unwrap();
|
||||
store.delete_session("test-token").await.unwrap();
|
||||
let sessions = store.get_all_sessions().await.unwrap();
|
||||
assert!(sessions.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_nonexistent_session_is_ok() {
|
||||
let (store, _dir) = create_test_store().await;
|
||||
store.delete_session("nonexistent").await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_multiple_sessions() {
|
||||
let (store, _dir) = create_test_store().await;
|
||||
for i in 0..3 {
|
||||
let session = StoredSession {
|
||||
token: format!("token-{}", i),
|
||||
user_id: format!("@user{}:server", i),
|
||||
homeserver: "https://matrix.server".to_string(),
|
||||
access_token: format!("access-{}", i),
|
||||
device_id: None,
|
||||
refresh_token: None,
|
||||
};
|
||||
store.save_session(&session).await.unwrap();
|
||||
}
|
||||
let sessions = store.get_all_sessions().await.unwrap();
|
||||
assert_eq!(sessions.len(), 3);
|
||||
}
|
||||
}
|
||||
@@ -1,117 +1,381 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use crate::session_store::SessionStore;
|
||||
use matrix_sdk::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsEvent {
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
room_id: String,
|
||||
event_id: String,
|
||||
sender: String,
|
||||
body: String,
|
||||
timestamp: u64,
|
||||
reply_to: Option<String>,
|
||||
msgtype: Option<String>,
|
||||
media_url: Option<String>,
|
||||
filename: Option<String>,
|
||||
mimetype: Option<String>,
|
||||
},
|
||||
#[serde(rename = "message_edited")]
|
||||
MessageEdited {
|
||||
room_id: String,
|
||||
event_id: String,
|
||||
new_body: String,
|
||||
},
|
||||
#[serde(rename = "message_deleted")]
|
||||
MessageDeleted { room_id: String, redacts: String },
|
||||
#[serde(rename = "reaction")]
|
||||
Reaction {
|
||||
room_id: String,
|
||||
event_id: String,
|
||||
key: String,
|
||||
sender: String,
|
||||
},
|
||||
#[serde(rename = "room_joined")]
|
||||
RoomJoined { room_id: String, name: String },
|
||||
#[serde(rename = "room_left")]
|
||||
RoomLeft { room_id: String },
|
||||
#[serde(rename = "presence")]
|
||||
Presence {
|
||||
user_id: String,
|
||||
status: String,
|
||||
status_msg: Option<String>,
|
||||
},
|
||||
#[serde(rename = "typing")]
|
||||
Typing {
|
||||
room_id: String,
|
||||
user_id: String,
|
||||
typing: bool,
|
||||
},
|
||||
#[serde(rename = "voice_state_update")]
|
||||
VoiceStateUpdate {
|
||||
room_id: String,
|
||||
user_id: String,
|
||||
muted: bool,
|
||||
deafened: bool,
|
||||
},
|
||||
#[serde(rename = "voice_user_joined")]
|
||||
VoiceUserJoined { room_id: String, user_id: String },
|
||||
#[serde(rename = "voice_user_left")]
|
||||
VoiceUserLeft { room_id: String, user_id: String },
|
||||
#[serde(rename = "thread_reply")]
|
||||
ThreadReply {
|
||||
room_id: String,
|
||||
root_event_id: String,
|
||||
event_id: String,
|
||||
sender: String,
|
||||
body: String,
|
||||
timestamp: u64,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct LiveKitConfig {
|
||||
pub api_key: String,
|
||||
pub api_secret: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
pub struct Session {
|
||||
pub client: Client,
|
||||
pub user_id: String,
|
||||
pub homeserver: String,
|
||||
pub voice_manager: VoiceManager,
|
||||
pub event_sender: broadcast::Sender<WsEvent>,
|
||||
pub sync_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
pub created_at: Instant,
|
||||
pub expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(client: Client, user_id: String, homeserver: String) -> Self {
|
||||
let (sender, _) = broadcast::channel(256);
|
||||
Self {
|
||||
client,
|
||||
user_id,
|
||||
homeserver,
|
||||
voice_manager: VoiceManager::new(),
|
||||
event_sender: sender,
|
||||
sync_handle: None,
|
||||
created_at: Instant::now(),
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
|
||||
self.expires_at = Some(self.created_at + ttl);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn is_expired(&self) -> bool {
|
||||
self.expires_at
|
||||
.map(|expires| Instant::now() > expires)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct VoiceManager {
|
||||
channels: HashMap<String, VoiceChannel>,
|
||||
active_channel: Option<String>,
|
||||
}
|
||||
|
||||
pub struct VoiceChannel {
|
||||
pub room_id: String,
|
||||
pub participants: Vec<VoiceParticipant>,
|
||||
}
|
||||
|
||||
pub struct VoiceParticipant {
|
||||
pub user_id: String,
|
||||
pub active_channel: Option<String>,
|
||||
pub muted: bool,
|
||||
pub deafened: bool,
|
||||
pub streaming: bool,
|
||||
}
|
||||
|
||||
impl VoiceManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
channels: HashMap::new(),
|
||||
active_channel: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn join_channel(&mut self, room_id: String, user_id: String) -> bool {
|
||||
if let Some(ref old_channel) = self.active_channel {
|
||||
self.leave_channel_internal(old_channel, &user_id);
|
||||
}
|
||||
let channel = self.channels.entry(room_id.clone()).or_insert_with(|| VoiceChannel {
|
||||
room_id: room_id.clone(),
|
||||
participants: Vec::new(),
|
||||
});
|
||||
if channel.participants.iter().any(|p| p.user_id == user_id) {
|
||||
return false;
|
||||
}
|
||||
channel.participants.push(VoiceParticipant {
|
||||
user_id,
|
||||
muted: false,
|
||||
deafened: false,
|
||||
streaming: false,
|
||||
});
|
||||
self.active_channel = Some(room_id);
|
||||
true
|
||||
}
|
||||
|
||||
fn leave_channel_internal(&mut self, room_id: &str, user_id: &str) {
|
||||
if let Some(channel) = self.channels.get_mut(room_id) {
|
||||
channel.participants.retain(|p| p.user_id != user_id);
|
||||
if channel.participants.is_empty() {
|
||||
self.channels.remove(room_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn leave_channel(&mut self, room_id: &str, user_id: &str) -> bool {
|
||||
self.leave_channel_internal(room_id, user_id);
|
||||
if self.active_channel.as_deref() == Some(room_id) {
|
||||
self.active_channel = None;
|
||||
}
|
||||
true
|
||||
pub fn join_channel(&mut self, room_id: &str) {
|
||||
self.active_channel = Some(room_id.to_string());
|
||||
self.muted = false;
|
||||
self.deafened = false;
|
||||
}
|
||||
|
||||
pub fn toggle_mute(&mut self, room_id: &str, user_id: &str) -> Option<bool> {
|
||||
if let Some(channel) = self.channels.get_mut(room_id) {
|
||||
if let Some(participant) = channel.participants.iter_mut().find(|p| p.user_id == user_id) {
|
||||
participant.muted = !participant.muted;
|
||||
return Some(participant.muted);
|
||||
}
|
||||
}
|
||||
None
|
||||
pub fn leave_channel(&mut self) {
|
||||
self.active_channel = None;
|
||||
self.muted = false;
|
||||
self.deafened = false;
|
||||
}
|
||||
|
||||
pub fn toggle_deafen(&mut self, room_id: &str, user_id: &str) -> Option<bool> {
|
||||
if let Some(channel) = self.channels.get_mut(room_id) {
|
||||
if let Some(participant) = channel.participants.iter_mut().find(|p| p.user_id == user_id) {
|
||||
participant.deafened = !participant.deafened;
|
||||
if participant.deafened {
|
||||
participant.muted = true;
|
||||
}
|
||||
return Some(participant.deafened);
|
||||
}
|
||||
pub fn toggle_mute(&mut self) -> bool {
|
||||
self.muted = !self.muted;
|
||||
self.muted
|
||||
}
|
||||
|
||||
pub fn toggle_deafen(&mut self) -> bool {
|
||||
self.deafened = !self.deafened;
|
||||
if self.deafened {
|
||||
self.muted = true;
|
||||
}
|
||||
None
|
||||
self.deafened
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn voice_manager_new_is_idle() {
|
||||
let vm = VoiceManager::new();
|
||||
assert!(vm.active_channel.is_none());
|
||||
assert!(!vm.muted);
|
||||
assert!(!vm.deafened);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_manager_default_is_idle() {
|
||||
let vm = VoiceManager::default();
|
||||
assert!(vm.active_channel.is_none());
|
||||
assert!(!vm.muted);
|
||||
assert!(!vm.deafened);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_channel_sets_active() {
|
||||
let mut vm = VoiceManager::new();
|
||||
vm.join_channel("!room:server");
|
||||
assert_eq!(vm.active_channel, Some("!room:server".to_string()));
|
||||
assert!(!vm.muted);
|
||||
assert!(!vm.deafened);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_channel_resets_mute_deafen() {
|
||||
let mut vm = VoiceManager::new();
|
||||
vm.muted = true;
|
||||
vm.deafened = true;
|
||||
vm.join_channel("!room:server");
|
||||
assert!(!vm.muted);
|
||||
assert!(!vm.deafened);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leave_channel_clears_state() {
|
||||
let mut vm = VoiceManager::new();
|
||||
vm.join_channel("!room:server");
|
||||
vm.muted = true;
|
||||
vm.leave_channel();
|
||||
assert!(vm.active_channel.is_none());
|
||||
assert!(!vm.muted);
|
||||
assert!(!vm.deafened);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toggle_mute_flips_state() {
|
||||
let mut vm = VoiceManager::new();
|
||||
assert!(!vm.muted);
|
||||
let result = vm.toggle_mute();
|
||||
assert!(result);
|
||||
assert!(vm.muted);
|
||||
let result = vm.toggle_mute();
|
||||
assert!(!result);
|
||||
assert!(!vm.muted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toggle_deafen_sets_muted_when_deafened() {
|
||||
let mut vm = VoiceManager::new();
|
||||
let result = vm.toggle_deafen();
|
||||
assert!(result);
|
||||
assert!(vm.deafened);
|
||||
assert!(vm.muted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toggle_deafen_off_keeps_muted() {
|
||||
let mut vm = VoiceManager::new();
|
||||
vm.muted = true;
|
||||
vm.toggle_deafen();
|
||||
assert!(vm.muted);
|
||||
vm.toggle_deafen();
|
||||
assert!(vm.deafened == false);
|
||||
assert!(vm.muted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ws_event_serialization() {
|
||||
let event = WsEvent::Message {
|
||||
room_id: "!room:server".to_string(),
|
||||
event_id: "$event".to_string(),
|
||||
sender: "@user:server".to_string(),
|
||||
body: "hello".to_string(),
|
||||
timestamp: 123456,
|
||||
reply_to: None,
|
||||
msgtype: Some("m.text".to_string()),
|
||||
media_url: None,
|
||||
filename: None,
|
||||
mimetype: None,
|
||||
};
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("\"type\":\"message\""));
|
||||
assert!(json.contains("\"body\":\"hello\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ws_event_presence_serialization() {
|
||||
let event = WsEvent::Presence {
|
||||
user_id: "@alice:server".to_string(),
|
||||
status: "online".to_string(),
|
||||
status_msg: Some("working".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("\"type\":\"presence\""));
|
||||
assert!(json.contains("\"status\":\"online\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_expired_true_when_past() {
|
||||
let created = Instant::now() - std::time::Duration::from_secs(120);
|
||||
let expires = created + std::time::Duration::from_secs(60);
|
||||
assert!(Instant::now() > expires);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_expired_none_means_not_expired() {
|
||||
let expires_at: Option<Instant> = None;
|
||||
let expired = expires_at
|
||||
.map(|expires| Instant::now() > expires)
|
||||
.unwrap_or(false);
|
||||
assert!(!expired);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_expired_future_means_not_expired() {
|
||||
let expires_at: Option<Instant> =
|
||||
Some(Instant::now() + std::time::Duration::from_secs(3600));
|
||||
let expired = expires_at
|
||||
.map(|expires| Instant::now() > expires)
|
||||
.unwrap_or(false);
|
||||
assert!(!expired);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_expired_past_means_expired() {
|
||||
let expires_at: Option<Instant> = Some(Instant::now() - std::time::Duration::from_secs(1));
|
||||
let expired = expires_at
|
||||
.map(|expires| Instant::now() > expires)
|
||||
.unwrap_or(false);
|
||||
assert!(expired);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerStateInner {
|
||||
pub sessions: HashMap<String, Session>,
|
||||
pub session_store: Arc<SessionStore>,
|
||||
pub livekit: LiveKitConfig,
|
||||
pub voice_rooms: VoiceRooms,
|
||||
pub session_ttl: Option<std::time::Duration>,
|
||||
}
|
||||
|
||||
pub struct VoiceRooms {
|
||||
pub rooms: HashMap<String, VoiceRoom>,
|
||||
}
|
||||
|
||||
pub struct VoiceRoom {
|
||||
pub participants: HashMap<String, VoiceParticipant>,
|
||||
}
|
||||
|
||||
#[derive(Clone, serde::Serialize)]
|
||||
pub struct VoiceParticipant {
|
||||
pub user_id: String,
|
||||
pub muted: bool,
|
||||
pub deafened: bool,
|
||||
}
|
||||
|
||||
impl ServerStateInner {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(
|
||||
session_store: Arc<SessionStore>,
|
||||
livekit: LiveKitConfig,
|
||||
session_ttl: Option<std::time::Duration>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
session_store,
|
||||
livekit,
|
||||
voice_rooms: VoiceRooms {
|
||||
rooms: HashMap::new(),
|
||||
},
|
||||
session_ttl,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ServerState = Arc<RwLock<ServerStateInner>>;
|
||||
#[derive(Clone)]
|
||||
pub struct ServerState {
|
||||
inner: Arc<RwLock<ServerStateInner>>,
|
||||
}
|
||||
|
||||
impl ServerState {
|
||||
pub fn new() -> Self {
|
||||
Arc::new(RwLock::new(ServerStateInner::new()))
|
||||
pub fn new(
|
||||
session_store: Arc<SessionStore>,
|
||||
livekit: LiveKitConfig,
|
||||
session_ttl: Option<std::time::Duration>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(ServerStateInner::new(
|
||||
session_store,
|
||||
livekit,
|
||||
session_ttl,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, ServerStateInner> {
|
||||
self.inner.read().await
|
||||
}
|
||||
|
||||
pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, ServerStateInner> {
|
||||
self.inner.write().await
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user