diff options
| author | Sophie Forrest <git@sophieforrest.com> | 2024-08-30 23:13:20 +1200 |
|---|---|---|
| committer | Sophie Forrest <git@sophieforrest.com> | 2024-08-30 23:13:44 +1200 |
| commit | e3cb82a3b33bd2a2e49c58ce18d1258fb505869e (patch) | |
| tree | 2375279182fb4f90f5c28560a08cda90591f608b /crates/messenger_server/src/websocket.rs | |
Diffstat (limited to 'crates/messenger_server/src/websocket.rs')
| -rw-r--r-- | crates/messenger_server/src/websocket.rs | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/crates/messenger_server/src/websocket.rs b/crates/messenger_server/src/websocket.rs new file mode 100644 index 0000000..db235af --- /dev/null +++ b/crates/messenger_server/src/websocket.rs @@ -0,0 +1,169 @@ +//! Handles the `WebSocket` connections for the chat server. +//! +//! Contains no public members as it is a top-level route with no dependants. + +use std::sync::Arc; + +use axum::{ + extract::{ + ws::{Message, WebSocket}, + State, WebSocketUpgrade, + }, + response::IntoResponse, +}; +use futures::{ + stream::{SplitSink, SplitStream}, + StreamExt, +}; +use messenger_common::{client::MessageType as ClientMessageType, server::MessageType}; + +use crate::{ + app::{check_username, State as AppState}, + message::{self, Server}, + session::Session, +}; + +/// Represents a username choice that was either `Invalid` or `Valid`. +enum UsernameValidity { + /// Username choice is considered invalid + Invalid, + + /// Username choice is considered valid + Valid, +} + +/// Handles setting and validating of a username. +/// Returns whether the username choice was valid or invalid. +async fn handle_username_choice( + state: &Arc<AppState>, + receiver: &mut SplitStream<WebSocket>, + sender: &mut SplitSink<WebSocket, Message>, + username: &mut String, +) -> UsernameValidity { + // Loop until a text message is found. + while let Some(Ok(message)) = receiver.next().await { + if let Message::Text(name) = message { + if let Ok(inbound_message) = message::deserialize(&name) { + if let ClientMessageType::SetUsername(name) = inbound_message { + // If username that is sent by client is not taken, fill username string. + check_username(state, username, &name).await; + + // If not empty we want to quit the function with a valid username choice. + if !username.is_empty() { + return UsernameValidity::Valid; + } + + // Only send the client that the username is taken + message::send_error( + sender, + messenger_common::server::Error::UsernameNotAvailable { + chosen_username: name, + }, + ) + .await; + + return UsernameValidity::Invalid; + } + + // Client has sent an unexpected message, and should be notified + message::send_error( + sender, + messenger_common::server::Error::UnexpectedMessage { + expected: ClientMessageType::SetUsername(String::new()), + received: inbound_message, + }, + ) + .await; + } else { + message::send_error(sender, messenger_common::server::Error::InvalidMessage).await; + } + } + } + + // No items to iterate on, so username is invalid by default + UsernameValidity::Invalid +} + +/// Handle websockets +#[expect( + clippy::unused_async, + reason = "axum requires this function to be async, but clippy disallows this due to no \ + .await's" +)] +pub async fn handler( + ws: WebSocketUpgrade, + State(state): State<Arc<AppState>>, +) -> impl IntoResponse { + ws.on_upgrade(|socket| websocket(socket, state)) +} + +/// This function deals with a single websocket connection, i.e., a single +/// connected client / user, for which we will spawn two independent tasks (for +/// receiving / sending chat messages). +async fn websocket(stream: WebSocket, state: Arc<AppState>) { + // By splitting, we can send and receive at the same time. + let (mut sender, mut receiver) = stream.split(); + + // Username gets set in the receive loop, if it's valid. + let mut username = String::new(); + + // Handle username validity + if matches!( + handle_username_choice(&state, &mut receiver, &mut sender, &mut username).await, + UsernameValidity::Invalid + ) { + return; + } + + let session = Arc::new(Session::new( + receiver, + sender, + state.tx.subscribe(), + username, + )); + + // Now send the "joined" message to all subscribers. + MessageType::UserJoined(session.username().clone()).send(&state.tx); + + // Provide newly added users with the last 100 messages and the currently online + // users + if let Err(error) = session.transmit_initial_data(&state).await { + // Due to sending online users being the only operation that can fail, this + // error message is correct + tracing::error!( + "an error occurred while attempting to send a list of online users: {error}" + ); + } + + let send_session = Arc::clone(&session); + + // Spawn the first task that will receive broadcast messages and send text + // messages over the websocket to our client. + let mut send_task = tokio::spawn(async move { + send_session.send().await; + }); + + // Clone the state for `recv_task` to prevent consuming state when appending + // messages + let recv_state = Arc::clone(&state); + + let recv_session = Arc::clone(&session); + + // Spawn a task that takes messages from the websocket, prepends the user + // name, and sends them to all broadcast subscribers. + let mut recv_task = tokio::spawn(async move { + recv_session.receive(&recv_state).await; + }); + + // If any one of the tasks run to completion, we abort the other. + tokio::select! { + _ = (&mut send_task) => recv_task.abort(), + _ = (&mut recv_task) => send_task.abort(), + }; + + // Send "user left" message (similar to "joined" above). + MessageType::UserLeft(session.username().clone()).send(&state.tx); + + // Remove username from map so new clients can take it again. + state.user_set.lock().await.remove(session.username()); +} |