summary refs log tree commit diff
path: root/crates/messenger_server/src/websocket.rs
blob: db235afa772790818339f487411ae9752ba64fd6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
//! Handles the `WebSocket` connections for the chat server.
//!
//! Contains no public members as it is a top-level route with no dependants.

use std::sync::Arc;

use axum::{
	extract::{
		ws::{Message, WebSocket},
		State, WebSocketUpgrade,
	},
	response::IntoResponse,
};
use futures::{
	stream::{SplitSink, SplitStream},
	StreamExt,
};
use messenger_common::{client::MessageType as ClientMessageType, server::MessageType};

use crate::{
	app::{check_username, State as AppState},
	message::{self, Server},
	session::Session,
};

/// Represents a username choice that was either `Invalid` or `Valid`.
enum UsernameValidity {
	/// Username choice is considered invalid
	Invalid,

	/// Username choice is considered valid
	Valid,
}

/// Handles setting and validating of a username.
/// Returns whether the username choice was valid or invalid.
async fn handle_username_choice(
	state: &Arc<AppState>,
	receiver: &mut SplitStream<WebSocket>,
	sender: &mut SplitSink<WebSocket, Message>,
	username: &mut String,
) -> UsernameValidity {
	// Loop until a text message is found.
	while let Some(Ok(message)) = receiver.next().await {
		if let Message::Text(name) = message {
			if let Ok(inbound_message) = message::deserialize(&name) {
				if let ClientMessageType::SetUsername(name) = inbound_message {
					// If username that is sent by client is not taken, fill username string.
					check_username(state, username, &name).await;

					// If not empty we want to quit the function with a valid username choice.
					if !username.is_empty() {
						return UsernameValidity::Valid;
					}

					// Only send the client that the username is taken
					message::send_error(
						sender,
						messenger_common::server::Error::UsernameNotAvailable {
							chosen_username: name,
						},
					)
					.await;

					return UsernameValidity::Invalid;
				}

				// Client has sent an unexpected message, and should be notified
				message::send_error(
					sender,
					messenger_common::server::Error::UnexpectedMessage {
						expected: ClientMessageType::SetUsername(String::new()),
						received: inbound_message,
					},
				)
				.await;
			} else {
				message::send_error(sender, messenger_common::server::Error::InvalidMessage).await;
			}
		}
	}

	// No items to iterate on, so username is invalid by default
	UsernameValidity::Invalid
}

/// Handle websockets
#[expect(
	clippy::unused_async,
	reason = "axum requires this function to be async, but clippy disallows this due to no \
	          .await's"
)]
pub async fn handler(
	ws: WebSocketUpgrade,
	State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
	ws.on_upgrade(|socket| websocket(socket, state))
}

/// This function deals with a single websocket connection, i.e., a single
/// connected client / user, for which we will spawn two independent tasks (for
/// receiving / sending chat messages).
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
	// By splitting, we can send and receive at the same time.
	let (mut sender, mut receiver) = stream.split();

	// Username gets set in the receive loop, if it's valid.
	let mut username = String::new();

	// Handle username validity
	if matches!(
		handle_username_choice(&state, &mut receiver, &mut sender, &mut username).await,
		UsernameValidity::Invalid
	) {
		return;
	}

	let session = Arc::new(Session::new(
		receiver,
		sender,
		state.tx.subscribe(),
		username,
	));

	// Now send the "joined" message to all subscribers.
	MessageType::UserJoined(session.username().clone()).send(&state.tx);

	// Provide newly added users with the last 100 messages and the currently online
	// users
	if let Err(error) = session.transmit_initial_data(&state).await {
		// Due to sending online users being the only operation that can fail, this
		// error message is correct
		tracing::error!(
			"an error occurred while attempting to send a list of online users: {error}"
		);
	}

	let send_session = Arc::clone(&session);

	// Spawn the first task that will receive broadcast messages and send text
	// messages over the websocket to our client.
	let mut send_task = tokio::spawn(async move {
		send_session.send().await;
	});

	// Clone the state for `recv_task` to prevent consuming state when appending
	// messages
	let recv_state = Arc::clone(&state);

	let recv_session = Arc::clone(&session);

	// Spawn a task that takes messages from the websocket, prepends the user
	// name, and sends them to all broadcast subscribers.
	let mut recv_task = tokio::spawn(async move {
		recv_session.receive(&recv_state).await;
	});

	// If any one of the tasks run to completion, we abort the other.
	tokio::select! {
		_ = (&mut send_task) => recv_task.abort(),
		_ = (&mut recv_task) => send_task.abort(),
	};

	// Send "user left" message (similar to "joined" above).
	MessageType::UserLeft(session.username().clone()).send(&state.tx);

	// Remove username from map so new clients can take it again.
	state.user_set.lock().await.remove(session.username());
}