summary refs log tree commit diff
path: root/crates/messenger_server/src
diff options
context:
space:
mode:
authorSophie Forrest <git@sophieforrest.com>2024-08-30 23:13:20 +1200
committerSophie Forrest <git@sophieforrest.com>2024-08-30 23:13:44 +1200
commite3cb82a3b33bd2a2e49c58ce18d1258fb505869e (patch)
tree2375279182fb4f90f5c28560a08cda90591f608b /crates/messenger_server/src
chore: initial commit (codeberg upload) HEAD main
Diffstat (limited to 'crates/messenger_server/src')
-rw-r--r--crates/messenger_server/src/app.rs34
-rw-r--r--crates/messenger_server/src/main.rs142
-rw-r--r--crates/messenger_server/src/message.rs103
-rw-r--r--crates/messenger_server/src/session.rs193
-rw-r--r--crates/messenger_server/src/websocket.rs169
5 files changed, 641 insertions, 0 deletions
diff --git a/crates/messenger_server/src/app.rs b/crates/messenger_server/src/app.rs
new file mode 100644
index 0000000..2d91919
--- /dev/null
+++ b/crates/messenger_server/src/app.rs
@@ -0,0 +1,34 @@
+//! Contains functions and structures useful to the general web server
+
+use std::collections::{HashSet, VecDeque};
+
+use tokio::sync::{broadcast, Mutex};
+
+/// Our shared state
+#[derive(Debug)]
+pub struct State {
+	/// Contains the history of the last 100 messages sent
+	pub message_history: Mutex<VecDeque<String>>,
+
+	/// Channel used to send messages to all connected clients.
+	pub tx: broadcast::Sender<String>,
+
+	/// We require unique usernames. This tracks which usernames have been
+	/// taken.
+	pub user_set: Mutex<HashSet<String>>,
+}
+
+/// Doc
+pub async fn check_username(state: &State, string: &mut String, name: &str) {
+	let mut user_set = state.user_set.lock().await;
+
+	let name = name.trim();
+
+	if !name.is_empty() && !user_set.contains(name) {
+		user_set.insert(name.to_owned());
+
+		drop(user_set);
+
+		string.push_str(name);
+	}
+}
diff --git a/crates/messenger_server/src/main.rs b/crates/messenger_server/src/main.rs
new file mode 100644
index 0000000..76ce81d
--- /dev/null
+++ b/crates/messenger_server/src/main.rs
@@ -0,0 +1,142 @@
+#![feature(async_fn_in_trait)]
+#![feature(custom_inner_attributes)]
+#![feature(lint_reasons)]
+#![feature(never_type)]
+#![feature(lazy_cell)]
+#![clippy::msrv = "1.69.0"]
+#![deny(clippy::nursery)]
+#![deny(clippy::pedantic)]
+#![deny(clippy::alloc_instead_of_core)]
+#![deny(clippy::as_underscore)]
+#![deny(clippy::clone_on_ref_ptr)]
+#![deny(clippy::create_dir)]
+#![warn(clippy::dbg_macro)]
+#![deny(clippy::default_numeric_fallback)]
+#![deny(clippy::default_union_representation)]
+#![deny(clippy::deref_by_slicing)]
+#![deny(clippy::else_if_without_else)]
+#![deny(clippy::empty_structs_with_brackets)]
+#![deny(clippy::exit)]
+#![deny(clippy::expect_used)]
+#![deny(clippy::filetype_is_file)]
+#![deny(clippy::fn_to_numeric_cast)]
+#![deny(clippy::format_push_string)]
+#![deny(clippy::get_unwrap)]
+#![deny(clippy::if_then_some_else_none)]
+#![allow(
+	clippy::implicit_return,
+	reason = "returns should be done implicitly, not explicitly"
+)]
+#![deny(clippy::indexing_slicing)]
+#![deny(clippy::large_include_file)]
+#![deny(clippy::let_underscore_must_use)]
+#![deny(clippy::lossy_float_literal)]
+#![deny(clippy::map_err_ignore)]
+#![deny(clippy::mem_forget)]
+#![deny(clippy::missing_docs_in_private_items)]
+#![deny(clippy::missing_trait_methods)]
+#![deny(clippy::multiple_inherent_impl)]
+#![deny(clippy::needless_return)]
+#![deny(clippy::non_ascii_literal)]
+#![deny(clippy::panic_in_result_fn)]
+#![deny(clippy::pattern_type_mismatch)]
+#![deny(clippy::rc_buffer)]
+#![deny(clippy::rc_mutex)]
+#![deny(clippy::rest_pat_in_fully_bound_structs)]
+#![deny(clippy::same_name_method)]
+#![deny(clippy::separated_literal_suffix)]
+#![deny(clippy::str_to_string)]
+#![deny(clippy::string_add)]
+#![deny(clippy::string_slice)]
+#![deny(clippy::string_to_string)]
+#![allow(
+	clippy::tabs_in_doc_comments,
+	reason = "tabs are preferred for this project"
+)]
+#![deny(clippy::try_err)]
+#![deny(clippy::undocumented_unsafe_blocks)]
+#![deny(clippy::unnecessary_self_imports)]
+#![deny(clippy::unneeded_field_pattern)]
+#![deny(clippy::unwrap_in_result)]
+#![deny(clippy::unwrap_used)]
+#![warn(clippy::use_debug)]
+#![deny(clippy::verbose_file_reads)]
+#![deny(clippy::wildcard_dependencies)]
+#![deny(clippy::wildcard_enum_match_arm)]
+#![deny(clippy::missing_panics_doc)]
+#![deny(missing_copy_implementations)]
+#![deny(missing_debug_implementations)]
+#![deny(missing_docs)]
+#![deny(single_use_lifetimes)]
+#![deny(unsafe_code)]
+#![deny(unused)]
+// Server-specific lint disables
+#![allow(clippy::redundant_pub_crate)]
+
+//! # Messenger Server
+//!
+//! Provides a server-side implementation of the messenger protocol
+
+mod app;
+mod message;
+mod session;
+mod websocket;
+
+use std::{
+	collections::{HashSet, VecDeque},
+	net::{Ipv4Addr, SocketAddr, SocketAddrV4},
+	sync::Arc,
+};
+
+use app::State as AppState;
+use axum::{response::Html, routing::get, Router};
+use tokio::sync::{broadcast, Mutex};
+use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+
+/// Socket Address the server is bound to when ran.
+const ADDR: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000));
+
+#[tokio::main]
+async fn main() {
+	tracing_subscriber::registry()
+		.with(
+			tracing_subscriber::EnvFilter::try_from_default_env()
+				.unwrap_or_else(|_| "chat=trace".into()),
+		)
+		.with(tracing_subscriber::fmt::layer())
+		.init();
+
+	// Set up application state for use with with_state().
+	let user_set: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
+	let (tx, _rx) = broadcast::channel::<String>(100);
+	let message_history = Mutex::new(VecDeque::new());
+
+	let app_state = Arc::new(AppState {
+		message_history,
+		tx,
+		user_set,
+	});
+
+	let app = Router::new()
+		.route("/", get(index))
+		.route("/websocket", get(websocket::handler))
+		.with_state(app_state);
+
+	tracing::debug!("listening on {ADDR}");
+
+	if let Err(error) = axum::Server::bind(&ADDR)
+		.serve(app.into_make_service())
+		.await
+	{
+		tracing::error!("server failed to start: {error}");
+	}
+}
+
+/// Include utf-8 file at **compile** time.
+#[expect(
+	clippy::unused_async,
+	reason = "axum requires this function to by async, but clippy disallows this"
+)]
+async fn index() -> Html<&'static str> {
+	Html(std::include_str!("../chat.html"))
+}
diff --git a/crates/messenger_server/src/message.rs b/crates/messenger_server/src/message.rs
new file mode 100644
index 0000000..2682c1b
--- /dev/null
+++ b/crates/messenger_server/src/message.rs
@@ -0,0 +1,103 @@
+//! Abstraction for the messaging system to increase overall type safety and
+//! code readability.
+
+use std::sync::Arc;
+
+use axum::extract::ws::{Message, WebSocket};
+use futures::{stream::SplitSink, SinkExt};
+use messenger_common::{client::MessageType as ClientMessageType, server::MessageType};
+use tokio::sync::broadcast::Sender;
+use tracing::error;
+
+use crate::app::State;
+
+/// Represents messages which can be sent through a [`WebSocket`].
+pub trait Server
+where
+	Self: std::marker::Sized + serde::Serialize,
+{
+	/// Adds a message to the message history of the app state.
+	async fn append_to_history(&self, state: &Arc<State>) -> serde_json::Result<()>;
+
+	/// Performs serialization of a message.
+	/// Semantic alternative to `serde_json::to_string()`.
+	fn serialize(&self) -> serde_json::Result<String> {
+		serde_json::to_string(&self)
+	}
+
+	/// Sends a message through the provided sender.
+	///
+	/// The main purpose of this function is to enforce type-safety when sending
+	/// a message. This prevents accidentally sending non-`MessageType` messages
+	/// through the server.
+	fn send(&self, tx: &Sender<String>);
+}
+
+impl Server for MessageType {
+	/// Adds a message to the message history of the app state.
+	async fn append_to_history(&self, state: &Arc<State>) -> serde_json::Result<()> {
+		// Join and leave messages shouldn't be saved to history
+		// We ignore the inner message here as we want to serialize the entire message
+		if let Self::UserMessage(..) = *self {
+			let mut history_guard = state.message_history.lock().await;
+
+			// Only save the last 100 messages
+			if history_guard.len() > 99 {
+				history_guard.pop_front();
+			}
+
+			// Add a new message to the back of the queue
+			history_guard.push_back(serde_json::to_string(self)?);
+		}
+
+		Ok(())
+	}
+
+	fn serialize(&self) -> serde_json::Result<String> {
+		serde_json::to_string(&self)
+	}
+
+	fn send(&self, tx: &Sender<String>) {
+		tracing::debug!("sending message, content: {:?}", &self);
+
+		match self.serialize() {
+			Ok(json_message) => {
+				if let Err(error) = tx.send(json_message) {
+					error!("error occurred while sending a message through a channel: {error}");
+				}
+			}
+			Err(error) => {
+				error!("error occurred while converting message to json: {error}");
+			}
+		};
+	}
+}
+
+/// Performs deserialization of a message.
+/// Semantic alternative to `serde_json::from_str::<MessageType>()`.
+pub fn deserialize(message: &str) -> serde_json::Result<ClientMessageType> {
+	serde_json::from_str(message)
+}
+
+/// Sends an error through a websocket to the client.
+/// Contains error handling to reduce overall code bloat.
+pub async fn send_error(
+	sender: &mut SplitSink<WebSocket, Message>,
+	error: messenger_common::server::Error,
+) {
+	// Log error through tracing to show invalid client behaviour
+	error!("received message from client that is considered an error: {error}");
+
+	// Handle deserialization errors correctly to avoid a panic
+	match MessageType::Error(error).serialize() {
+		Ok(outbound_error) => {
+			if let Err(error) = sender.send(Message::Text(outbound_error)).await {
+				error!("unable to send error message through a websocket: {error}");
+			}
+		}
+		Err(error) => {
+			// Errors can also occur during serialization, so these should be covered
+			error!("unable to serialize outbound error message: {error}");
+		}
+	};
+}
diff --git a/crates/messenger_server/src/session.rs b/crates/messenger_server/src/session.rs
new file mode 100644
index 0000000..1b15d5a
--- /dev/null
+++ b/crates/messenger_server/src/session.rs
@@ -0,0 +1,193 @@
+//! Code for running websocket related functions on the web server.
+//!
+//! This includes the messaging system as a whole.
+
+use std::sync::Arc;
+
+use axum::extract::ws::{Message, WebSocket};
+use futures::{
+	stream::{SplitSink, SplitStream},
+	SinkExt, StreamExt,
+};
+use messenger_common::{
+	client::MessageType as ClientMessageType,
+	server::{MessageType, UserMessage},
+};
+use thiserror::Error;
+use tokio::sync::{broadcast::Receiver, Mutex};
+
+use crate::{
+	app::State as AppState,
+	message::{self, deserialize, Server},
+};
+
+/// Represents an error that can occur during a session.
+#[derive(Debug, Error)]
+pub enum Error {
+	/// An error with a Axum
+	#[error("an error occurred while attempting to interact with Axum")]
+	Axum(#[from] axum::Error),
+
+	/// An error with serde_json
+	#[error("an error occurred while ser/deserializing data with serde_json")]
+	SerdeJson(#[from] serde_json::Error),
+}
+
+/// Represents a singular session for a user. Handles sending and receiving
+/// messages for this session.
+#[derive(Debug)]
+pub struct Session {
+	/// Receiving component of the WebSocket split.
+	receiver: Mutex<SplitStream<WebSocket>>,
+
+	/// Sending component of the WebSocket split.
+	sender: Mutex<SplitSink<WebSocket, Message>>,
+
+	/// Receiver from the apps state.
+	state_rx: Mutex<Receiver<String>>,
+
+	/// The username associated with this session.
+	username: String,
+}
+
+impl Session {
+	/// Constructs a new instance of [`Session`].
+	#[must_use]
+	pub fn new(
+		receiver: SplitStream<WebSocket>,
+		sender: SplitSink<WebSocket, Message>,
+		state_rx: Receiver<String>,
+		username: String,
+	) -> Self {
+		Self {
+			receiver: Mutex::new(receiver),
+			sender: Mutex::new(sender),
+			state_rx: Mutex::new(state_rx),
+			username,
+		}
+	}
+
+	/// Feeds the last 100 messages to a singular sender, differentiated by the
+	/// `WebSocket` sender.
+	pub async fn feed_message_history(&self, state: &Arc<crate::app::State>) {
+		// Iterate through the message history and feed it to the sender
+		for message in &*state.message_history.lock().await {
+			if self
+				.sender
+				.lock()
+				.await
+				.feed(Message::Text(message.clone()))
+				.await
+				.is_err()
+			{
+				break;
+			}
+		}
+	}
+
+	/// Feeds a list of online users to a singular sender, differentiated by the
+	/// `WebSocket` sender.
+	pub async fn feed_online_users(&self, state: &Arc<crate::app::State>) -> Result<(), Error> {
+		self.sender
+			.lock()
+			.await
+			.feed(Message::Text(serde_json::to_string(
+				&MessageType::OnlineUsers(state.user_set.lock().await.clone()),
+			)?))
+			.await?;
+
+		Ok(())
+	}
+
+	/// Receives messages for this session. For use inside a tokio select
+	/// receive task.
+	pub async fn receive(&self, state: &Arc<AppState>) {
+		while let Some(Ok(Message::Text(text))) = self.receiver.lock().await.next().await {
+			let result = match deserialize(&text) {
+				Ok(ClientMessageType::UserMessage(content)) => {
+					// Handle the possibility of a timestamp being unable to be created when a
+					// message is bring processed
+					let Ok(timestamp) = time::OffsetDateTime::now_local() else {
+						tracing::error!("could not create an OffsetDateTime for received message");
+
+						let mut lock = self.sender.lock().await;
+
+						// Report to the client that their message could not be processed
+						message::send_error(
+							&mut lock,
+							messenger_common::server::Error::CannotProcess,
+						)
+						.await;
+
+						// Drop the lock early
+						drop(lock);
+
+						// Skip to the next message.
+						continue;
+					};
+
+					let message: MessageType =
+						UserMessage::new(content, self.username.clone(), timestamp).into();
+
+					Some(message)
+				}
+				Ok(_) | Err(_) => {
+					let mut lock = self.sender.lock().await;
+
+					// Messages outside of type UserMessage are unexpected, and should be
+					// reported
+					message::send_error(&mut lock, messenger_common::server::Error::InvalidMessage)
+						.await;
+
+					drop(lock);
+
+					None
+				}
+			};
+
+			if let Some(message) = result {
+				message.send(&state.tx);
+				if let Err(error) = message.append_to_history(state).await {
+					tracing::error!(
+						"error encountered when appending a message to history: {error}"
+					);
+				}
+			}
+		}
+	}
+
+	/// Processes a send operation for this session.
+	pub async fn send(&self) {
+		let mut lock = self.state_rx.lock().await;
+
+		while let Ok(msg) = lock.recv().await {
+			// In any websocket error, break loop.
+			if self
+				.sender
+				.lock()
+				.await
+				.send(Message::Text(msg))
+				.await
+				.is_err()
+			{
+				break;
+			}
+		}
+	}
+
+	/// Transmits a copy of the last 100 messages and a list of online users to
+	/// the user.
+	pub async fn transmit_initial_data(&self, state: &Arc<crate::app::State>) -> Result<(), Error> {
+		self.feed_message_history(state).await;
+		self.feed_online_users(state).await?;
+
+		self.sender.lock().await.flush().await?;
+
+		Ok(())
+	}
+
+	/// Retrieves a copy to the username associated with this session.
+	pub const fn username(&self) -> &String {
+		&self.username
+	}
+}
diff --git a/crates/messenger_server/src/websocket.rs b/crates/messenger_server/src/websocket.rs
new file mode 100644
index 0000000..db235af
--- /dev/null
+++ b/crates/messenger_server/src/websocket.rs
@@ -0,0 +1,169 @@
+//! Handles the `WebSocket` connections for the chat server.
+//!
+//! Contains no public members as it is a top-level route with no dependants.
+
+use std::sync::Arc;
+
+use axum::{
+	extract::{
+		ws::{Message, WebSocket},
+		State, WebSocketUpgrade,
+	},
+	response::IntoResponse,
+};
+use futures::{
+	stream::{SplitSink, SplitStream},
+	StreamExt,
+};
+use messenger_common::{client::MessageType as ClientMessageType, server::MessageType};
+
+use crate::{
+	app::{check_username, State as AppState},
+	message::{self, Server},
+	session::Session,
+};
+
+/// Represents a username choice that was either `Invalid` or `Valid`.
+enum UsernameValidity {
+	/// Username choice is considered invalid
+	Invalid,
+
+	/// Username choice is considered valid
+	Valid,
+}
+
+/// Handles setting and validating of a username.
+/// Returns whether the username choice was valid or invalid.
+async fn handle_username_choice(
+	state: &Arc<AppState>,
+	receiver: &mut SplitStream<WebSocket>,
+	sender: &mut SplitSink<WebSocket, Message>,
+	username: &mut String,
+) -> UsernameValidity {
+	// Loop until a text message is found.
+	while let Some(Ok(message)) = receiver.next().await {
+		if let Message::Text(name) = message {
+			if let Ok(inbound_message) = message::deserialize(&name) {
+				if let ClientMessageType::SetUsername(name) = inbound_message {
+					// If username that is sent by client is not taken, fill username string.
+					check_username(state, username, &name).await;
+
+					// If not empty we want to quit the function with a valid username choice.
+					if !username.is_empty() {
+						return UsernameValidity::Valid;
+					}
+
+					// Only send the client that the username is taken
+					message::send_error(
+						sender,
+						messenger_common::server::Error::UsernameNotAvailable {
+							chosen_username: name,
+						},
+					)
+					.await;
+
+					return UsernameValidity::Invalid;
+				}
+
+				// Client has sent an unexpected message, and should be notified
+				message::send_error(
+					sender,
+					messenger_common::server::Error::UnexpectedMessage {
+						expected: ClientMessageType::SetUsername(String::new()),
+						received: inbound_message,
+					},
+				)
+				.await;
+			} else {
+				message::send_error(sender, messenger_common::server::Error::InvalidMessage).await;
+			}
+		}
+	}
+
+	// No items to iterate on, so username is invalid by default
+	UsernameValidity::Invalid
+}
+
+/// Handle websockets
+#[expect(
+	clippy::unused_async,
+	reason = "axum requires this function to be async, but clippy disallows this due to no \
+	          .await's"
+)]
+pub async fn handler(
+	ws: WebSocketUpgrade,
+	State(state): State<Arc<AppState>>,
+) -> impl IntoResponse {
+	ws.on_upgrade(|socket| websocket(socket, state))
+}
+
+/// This function deals with a single websocket connection, i.e., a single
+/// connected client / user, for which we will spawn two independent tasks (for
+/// receiving / sending chat messages).
+async fn websocket(stream: WebSocket, state: Arc<AppState>) {
+	// By splitting, we can send and receive at the same time.
+	let (mut sender, mut receiver) = stream.split();
+
+	// Username gets set in the receive loop, if it's valid.
+	let mut username = String::new();
+
+	// Handle username validity
+	if matches!(
+		handle_username_choice(&state, &mut receiver, &mut sender, &mut username).await,
+		UsernameValidity::Invalid
+	) {
+		return;
+	}
+
+	let session = Arc::new(Session::new(
+		receiver,
+		sender,
+		state.tx.subscribe(),
+		username,
+	));
+
+	// Now send the "joined" message to all subscribers.
+	MessageType::UserJoined(session.username().clone()).send(&state.tx);
+
+	// Provide newly added users with the last 100 messages and the currently online
+	// users
+	if let Err(error) = session.transmit_initial_data(&state).await {
+		// Due to sending online users being the only operation that can fail, this
+		// error message is correct
+		tracing::error!(
+			"an error occurred while attempting to send a list of online users: {error}"
+		);
+	}
+
+	let send_session = Arc::clone(&session);
+
+	// Spawn the first task that will receive broadcast messages and send text
+	// messages over the websocket to our client.
+	let mut send_task = tokio::spawn(async move {
+		send_session.send().await;
+	});
+
+	// Clone the state for `recv_task` to prevent consuming state when appending
+	// messages
+	let recv_state = Arc::clone(&state);
+
+	let recv_session = Arc::clone(&session);
+
+	// Spawn a task that takes messages from the websocket, prepends the user
+	// name, and sends them to all broadcast subscribers.
+	let mut recv_task = tokio::spawn(async move {
+		recv_session.receive(&recv_state).await;
+	});
+
+	// If any one of the tasks run to completion, we abort the other.
+	tokio::select! {
+		_ = (&mut send_task) => recv_task.abort(),
+		_ = (&mut recv_task) => send_task.abort(),
+	};
+
+	// Send "user left" message (similar to "joined" above).
+	MessageType::UserLeft(session.username().clone()).send(&state.tx);
+
+	// Remove username from map so new clients can take it again.
+	state.user_set.lock().await.remove(session.username());
+}