summary refs log tree commit diff
path: root/crates/messenger_server/src/websocket.rs
diff options
context:
space:
mode:
authorSophie Forrest <git@sophieforrest.com>2024-08-30 23:13:20 +1200
committerSophie Forrest <git@sophieforrest.com>2024-08-30 23:13:44 +1200
commite3cb82a3b33bd2a2e49c58ce18d1258fb505869e (patch)
tree2375279182fb4f90f5c28560a08cda90591f608b /crates/messenger_server/src/websocket.rs
chore: initial commit (codeberg upload) HEAD main
Diffstat (limited to 'crates/messenger_server/src/websocket.rs')
-rw-r--r--crates/messenger_server/src/websocket.rs169
1 files changed, 169 insertions, 0 deletions
diff --git a/crates/messenger_server/src/websocket.rs b/crates/messenger_server/src/websocket.rs
new file mode 100644
index 0000000..db235af
--- /dev/null
+++ b/crates/messenger_server/src/websocket.rs
@@ -0,0 +1,169 @@
+//! Handles the `WebSocket` connections for the chat server.
+//!
+//! Contains no public members as it is a top-level route with no dependants.
+
+use std::sync::Arc;
+
+use axum::{
+	extract::{
+		ws::{Message, WebSocket},
+		State, WebSocketUpgrade,
+	},
+	response::IntoResponse,
+};
+use futures::{
+	stream::{SplitSink, SplitStream},
+	StreamExt,
+};
+use messenger_common::{client::MessageType as ClientMessageType, server::MessageType};
+
+use crate::{
+	app::{check_username, State as AppState},
+	message::{self, Server},
+	session::Session,
+};
+
+/// Represents a username choice that was either `Invalid` or `Valid`.
+enum UsernameValidity {
+	/// Username choice is considered invalid
+	Invalid,
+
+	/// Username choice is considered valid
+	Valid,
+}
+
+/// Handles setting and validating of a username.
+/// Returns whether the username choice was valid or invalid.
+async fn handle_username_choice(
+	state: &Arc<AppState>,
+	receiver: &mut SplitStream<WebSocket>,
+	sender: &mut SplitSink<WebSocket, Message>,
+	username: &mut String,
+) -> UsernameValidity {
+	// Loop until a text message is found.
+	while let Some(Ok(message)) = receiver.next().await {
+		if let Message::Text(name) = message {
+			if let Ok(inbound_message) = message::deserialize(&name) {
+				if let ClientMessageType::SetUsername(name) = inbound_message {
+					// If username that is sent by client is not taken, fill username string.
+					check_username(state, username, &name).await;
+
+					// If not empty we want to quit the function with a valid username choice.
+					if !username.is_empty() {
+						return UsernameValidity::Valid;
+					}
+
+					// Only send the client that the username is taken
+					message::send_error(
+						sender,
+						messenger_common::server::Error::UsernameNotAvailable {
+							chosen_username: name,
+						},
+					)
+					.await;
+
+					return UsernameValidity::Invalid;
+				}
+
+				// Client has sent an unexpected message, and should be notified
+				message::send_error(
+					sender,
+					messenger_common::server::Error::UnexpectedMessage {
+						expected: ClientMessageType::SetUsername(String::new()),
+						received: inbound_message,
+					},
+				)
+				.await;
+			} else {
+				message::send_error(sender, messenger_common::server::Error::InvalidMessage).await;
+			}
+		}
+	}
+
+	// No items to iterate on, so username is invalid by default
+	UsernameValidity::Invalid
+}
+
+/// Handle websockets
+#[expect(
+	clippy::unused_async,
+	reason = "axum requires this function to be async, but clippy disallows this due to no \
+	          .await's"
+)]
+pub async fn handler(
+	ws: WebSocketUpgrade,
+	State(state): State<Arc<AppState>>,
+) -> impl IntoResponse {
+	ws.on_upgrade(|socket| websocket(socket, state))
+}
+
+/// This function deals with a single websocket connection, i.e., a single
+/// connected client / user, for which we will spawn two independent tasks (for
+/// receiving / sending chat messages).
+async fn websocket(stream: WebSocket, state: Arc<AppState>) {
+	// By splitting, we can send and receive at the same time.
+	let (mut sender, mut receiver) = stream.split();
+
+	// Username gets set in the receive loop, if it's valid.
+	let mut username = String::new();
+
+	// Handle username validity
+	if matches!(
+		handle_username_choice(&state, &mut receiver, &mut sender, &mut username).await,
+		UsernameValidity::Invalid
+	) {
+		return;
+	}
+
+	let session = Arc::new(Session::new(
+		receiver,
+		sender,
+		state.tx.subscribe(),
+		username,
+	));
+
+	// Now send the "joined" message to all subscribers.
+	MessageType::UserJoined(session.username().clone()).send(&state.tx);
+
+	// Provide newly added users with the last 100 messages and the currently online
+	// users
+	if let Err(error) = session.transmit_initial_data(&state).await {
+		// Due to sending online users being the only operation that can fail, this
+		// error message is correct
+		tracing::error!(
+			"an error occurred while attempting to send a list of online users: {error}"
+		);
+	}
+
+	let send_session = Arc::clone(&session);
+
+	// Spawn the first task that will receive broadcast messages and send text
+	// messages over the websocket to our client.
+	let mut send_task = tokio::spawn(async move {
+		send_session.send().await;
+	});
+
+	// Clone the state for `recv_task` to prevent consuming state when appending
+	// messages
+	let recv_state = Arc::clone(&state);
+
+	let recv_session = Arc::clone(&session);
+
+	// Spawn a task that takes messages from the websocket, prepends the user
+	// name, and sends them to all broadcast subscribers.
+	let mut recv_task = tokio::spawn(async move {
+		recv_session.receive(&recv_state).await;
+	});
+
+	// If any one of the tasks run to completion, we abort the other.
+	tokio::select! {
+		_ = (&mut send_task) => recv_task.abort(),
+		_ = (&mut recv_task) => send_task.abort(),
+	};
+
+	// Send "user left" message (similar to "joined" above).
+	MessageType::UserLeft(session.username().clone()).send(&state.tx);
+
+	// Remove username from map so new clients can take it again.
+	state.user_set.lock().await.remove(session.username());
+}