Implement bulletproof OAuth code reuse prevention

- Add state parameter generation and validation with crypto-secure random values
- Implement used authorization code tracking to prevent replay attacks
- Add automatic redirect after successful auth to prevent refresh issues
- Enhance OAuth callback with comprehensive security checks
- Fix route conflicts between home page and OAuth callback handling
- Add rand dependency for secure state generation
- Update models.rs to handle optional Spotify API fields
- Improve error messages and logging for security violations
This commit is contained in:
Benjamin Slingo 2025-08-30 23:35:20 -04:00
parent 3c37d91bc4
commit e09e8b2d67
3 changed files with 169 additions and 18 deletions

View file

@ -20,4 +20,5 @@ env_logger = "0.11"
anyhow = "1.0" anyhow = "1.0"
toml = "0.8" toml = "0.8"
warp = "0.3" warp = "0.3"
tokio-stream = "0.1" tokio-stream = "0.1"
rand = "0.8"

View file

@ -114,9 +114,9 @@ pub struct CurrentTrack {
pub currently_playing_type: String, pub currently_playing_type: String,
pub actions: Actions, pub actions: Actions,
pub is_playing: bool, pub is_playing: bool,
pub device: Device, pub device: Option<Device>,
pub repeat_state: String, pub repeat_state: Option<String>,
pub shuffle_state: bool, pub shuffle_state: Option<bool>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,14 +1,18 @@
use crate::{ConfigManager, Result, SpotifyClient, SpotifyError}; use crate::{ConfigManager, Result, SpotifyClient, SpotifyError};
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::convert::Infallible; use std::convert::Infallible;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock;
use warp::Filter; use warp::Filter;
pub struct SpotifyServer { pub struct SpotifyServer {
client: Arc<SpotifyClient>, client: Arc<SpotifyClient>,
config_manager: Arc<ConfigManager>, config_manager: Arc<ConfigManager>,
port: u16, port: u16,
// Track used OAuth codes and active states to prevent reuse
used_codes: Arc<RwLock<HashSet<String>>>,
active_states: Arc<RwLock<HashMap<String, std::time::SystemTime>>>,
} }
impl SpotifyServer { impl SpotifyServer {
@ -17,6 +21,8 @@ impl SpotifyServer {
client: Arc::new(client), client: Arc::new(client),
config_manager: Arc::new(config_manager), config_manager: Arc::new(config_manager),
port, port,
used_codes: Arc::new(RwLock::new(HashSet::new())),
active_states: Arc::new(RwLock::new(HashMap::new())),
} }
} }
@ -25,6 +31,8 @@ impl SpotifyServer {
let client = self.client.clone(); let client = self.client.clone();
let config_manager = self.config_manager.clone(); let config_manager = self.config_manager.clone();
let used_codes = self.used_codes.clone();
let active_states = self.active_states.clone();
// Current track endpoint - mimics your existing API // Current track endpoint - mimics your existing API
let current_track = { let current_track = {
@ -75,28 +83,62 @@ impl SpotifyServer {
// OAuth authorization start endpoint // OAuth authorization start endpoint
let auth_start = { let auth_start = {
let client = client.clone(); let client = client.clone();
let active_states = active_states.clone();
warp::path("auth") warp::path("auth")
.and(warp::get()) .and(warp::get())
.map(move || client.clone()) .map(move || (client.clone(), active_states.clone()))
.map(|client: Arc<SpotifyClient>| { .and_then(|(client, active_states): (Arc<SpotifyClient>, Arc<RwLock<HashMap<String, std::time::SystemTime>>>)| async move {
let auth_url = client.get_authorization_url(Some("spotify-tracker")); // Generate unique state parameter
warp::reply::html(get_auth_page(&auth_url)) let state = generate_state();
// Store state with timestamp for validation
{
let mut states = active_states.write().await;
states.insert(state.clone(), std::time::SystemTime::now());
// Clean up old states (older than 10 minutes)
let cutoff = std::time::SystemTime::now() - std::time::Duration::from_secs(600);
states.retain(|_, timestamp| *timestamp > cutoff);
}
let auth_url = client.get_authorization_url(Some(&state));
Ok::<_, Infallible>(warp::reply::html(get_auth_page(&auth_url)))
}) })
}; };
// OAuth callback endpoint // OAuth callback endpoint - only match when we have OAuth parameters
let oauth_callback = { let oauth_callback = {
let client = client.clone(); let client = client.clone();
let config_manager = config_manager.clone(); let config_manager = config_manager.clone();
let used_codes = used_codes.clone();
let active_states = active_states.clone();
warp::path::end() warp::path::end()
.and(warp::get()) .and(warp::get())
.and(warp::query::<HashMap<String, String>>()) .and(warp::query::<HashMap<String, String>>())
.map(move |params: HashMap<String, String>| (client.clone(), config_manager.clone(), params)) .and_then(move |params: HashMap<String, String>| {
.and_then(|(client, config_manager, params): (Arc<SpotifyClient>, Arc<ConfigManager>, HashMap<String, String>)| async move { let client = client.clone();
handle_oauth_callback(client, config_manager, params).await let config_manager = config_manager.clone();
let used_codes = used_codes.clone();
let active_states = active_states.clone();
async move {
// Only handle this as OAuth callback if we have 'code' or 'error' parameters
if params.contains_key("code") || params.contains_key("error") {
handle_oauth_callback_secure(client, config_manager, used_codes, active_states, params).await.map_err(|_| warp::reject())
} else {
// This is just a regular root request, reject and let it fall through to other routes
Err(warp::reject())
}
}
}) })
}; };
// Home page showing auth status or links
let home_page = warp::path::end()
.and(warp::get())
.map(|| {
warp::reply::html(get_home_page())
});
// Health check endpoint // Health check endpoint
let health = warp::path("health") let health = warp::path("health")
.and(warp::get()) .and(warp::get())
@ -118,6 +160,7 @@ impl SpotifyServer {
.or(phantombot) .or(phantombot)
.or(auth_start) .or(auth_start)
.or(oauth_callback) .or(oauth_callback)
.or(home_page)
.or(health) .or(health)
.with(cors) .with(cors)
.with(warp::log("spotify_tracker")); .with(warp::log("spotify_tracker"));
@ -132,7 +175,7 @@ impl SpotifyServer {
println!("Press Ctrl+C to stop"); println!("Press Ctrl+C to stop");
warp::serve(routes) warp::serve(routes)
.run(([127, 0, 0, 1], self.port)) .run(([0, 0, 0, 0], self.port))
.await; .await;
Ok(()) Ok(())
@ -173,12 +216,55 @@ async fn get_current_track_text(client: Arc<SpotifyClient>) -> Result<String> {
} }
} }
async fn handle_oauth_callback( fn generate_state() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
let random: u64 = rand::random();
format!("st_{}_{}_{}", timestamp, random, rand::random::<u32>())
}
async fn handle_oauth_callback_secure(
client: Arc<SpotifyClient>, client: Arc<SpotifyClient>,
config_manager: Arc<ConfigManager>, config_manager: Arc<ConfigManager>,
used_codes: Arc<RwLock<HashSet<String>>>,
active_states: Arc<RwLock<HashMap<String, std::time::SystemTime>>>,
params: HashMap<String, String>, params: HashMap<String, String>,
) -> std::result::Result<impl warp::Reply, std::convert::Infallible> { ) -> std::result::Result<impl warp::Reply, std::convert::Infallible> {
// Validate state parameter first
if let Some(state) = params.get("state") {
let mut states = active_states.write().await;
if states.remove(state).is_none() {
log::error!("Invalid or reused state parameter: {}", state);
return Ok(warp::reply::html(get_error_page("Invalid or expired authentication session. Please try again.")));
}
} else if params.contains_key("code") {
log::error!("Missing state parameter in OAuth callback");
return Ok(warp::reply::html(get_error_page("Invalid authentication request. Missing security parameter.")));
}
if let Some(code) = params.get("code") { if let Some(code) = params.get("code") {
// Check if code was already used
{
let mut codes = used_codes.write().await;
if codes.contains(code) {
log::error!("Authorization code reuse attempt detected: {}", code);
return Ok(warp::reply::html(get_error_page("Authorization code has already been used. Please start the authentication process again.")));
}
// Mark code as used immediately
codes.insert(code.clone());
// Clean up old codes (keep last 100)
if codes.len() > 100 {
let mut codes_vec: Vec<String> = codes.drain().collect();
codes_vec.sort();
codes_vec.truncate(50); // Keep only the first 50
codes.extend(codes_vec);
}
}
// Exchange code for token // Exchange code for token
match client.exchange_code(code).await { match client.exchange_code(code).await {
Ok(token_info) => { Ok(token_info) => {
@ -186,15 +272,20 @@ async fn handle_oauth_callback(
match config_manager.save_token(&token_info) { match config_manager.save_token(&token_info) {
Ok(()) => { Ok(()) => {
log::info!("OAuth authentication successful"); log::info!("OAuth authentication successful");
Ok(warp::reply::html(get_success_page())) // Return a success page that redirects to home after 3 seconds
Ok(warp::reply::html(get_success_page_with_redirect()))
} }
Err(e) => { Err(e) => {
// Remove the code from used set since token save failed
used_codes.write().await.remove(code);
log::error!("Failed to save token: {}", e); log::error!("Failed to save token: {}", e);
Ok(warp::reply::html(get_error_page(&format!("Failed to save token: {}", e)))) Ok(warp::reply::html(get_error_page(&format!("Failed to save token: {}", e))))
} }
} }
} }
Err(e) => { Err(e) => {
// Remove the code from used set since exchange failed
used_codes.write().await.remove(code);
log::error!("OAuth token exchange failed: {}", e); log::error!("OAuth token exchange failed: {}", e);
Ok(warp::reply::html(get_error_page(&format!("Authentication failed: {}", e)))) Ok(warp::reply::html(get_error_page(&format!("Authentication failed: {}", e))))
} }
@ -209,6 +300,7 @@ async fn handle_oauth_callback(
} }
} }
fn get_auth_page(auth_url: &str) -> String { fn get_auth_page(auth_url: &str) -> String {
format!( format!(
r#"<!DOCTYPE html> r#"<!DOCTYPE html>
@ -240,22 +332,41 @@ fn get_auth_page(auth_url: &str) -> String {
) )
} }
fn get_success_page() -> String {
fn get_success_page_with_redirect() -> String {
r#"<!DOCTYPE html> r#"<!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Spotify Tracker - Success</title> <title>Spotify Tracker - Success</title>
<meta http-equiv="refresh" content="3;url=/">
<style> <style>
body { font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; background-color: #f5f5f5; } body { font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; background-color: #f5f5f5; }
.container { background: white; padding: 30px; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); text-align: center; } .container { background: white; padding: 30px; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); text-align: center; }
.success { color: #1DB954; font-size: 48px; margin-bottom: 20px; } .success { color: #1DB954; font-size: 48px; margin-bottom: 20px; }
h1 { color: #333; } h1 { color: #333; }
p { color: #666; line-height: 1.6; } p { color: #666; line-height: 1.6; }
.countdown { color: #1DB954; font-weight: bold; }
.api-endpoints { background: #f8f9fa; padding: 20px; border-radius: 5px; margin: 20px 0; text-align: left; } .api-endpoints { background: #f8f9fa; padding: 20px; border-radius: 5px; margin: 20px 0; text-align: left; }
.endpoint { font-family: monospace; background: #e9ecef; padding: 5px 10px; border-radius: 3px; margin: 5px 0; } .endpoint { font-family: monospace; background: #e9ecef; padding: 5px 10px; border-radius: 3px; margin: 5px 0; }
</style> </style>
<script>
let countdown = 3;
function updateCountdown() {
const element = document.getElementById('countdown');
if (element) {
element.textContent = countdown;
countdown--;
if (countdown >= 0) {
setTimeout(updateCountdown, 1000);
}
}
}
window.onload = function() {
updateCountdown();
};
</script>
</head> </head>
<body> <body>
<div class="container"> <div class="container">
@ -270,7 +381,8 @@ fn get_success_page() -> String {
<div class="endpoint">GET https://spotify.tougie.live/health - Health check</div> <div class="endpoint">GET https://spotify.tougie.live/health - Health check</div>
</div> </div>
<p>You can now use the Spotify Tracker API to get your current playing track!</p> <p>Redirecting to home page in <span id="countdown" class="countdown">3</span> seconds...</p>
<p><small>This prevents refresh issues. <a href="/">Click here</a> if not redirected.</small></p>
</div> </div>
</body> </body>
</html>"#.to_string() </html>"#.to_string()
@ -308,6 +420,44 @@ fn get_error_page(error: &str) -> String {
) )
} }
fn get_home_page() -> String {
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Spotify Tracker</title>
<style>
body { font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; background-color: #f5f5f5; }
.container { background: white; padding: 30px; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); text-align: center; }
h1 { color: #333; }
p { color: #666; line-height: 1.6; }
.auth-btn { background: #1DB954; color: white; padding: 15px 30px; border: none; border-radius: 50px; font-size: 16px; text-decoration: none; display: inline-block; margin: 20px 0; }
.auth-btn:hover { background: #1ed760; }
.api-endpoints { background: #f8f9fa; padding: 20px; border-radius: 5px; margin: 20px 0; text-align: left; }
.endpoint { font-family: monospace; background: #e9ecef; padding: 5px 10px; border-radius: 3px; margin: 5px 0; }
</style>
</head>
<body>
<div class="container">
<h1>🎵 Spotify Tracker</h1>
<p>Track your currently playing Spotify music with a simple API.</p>
<a href="/auth" class="auth-btn">Authenticate with Spotify</a>
<div class="api-endpoints">
<h3>API Endpoints:</h3>
<div class="endpoint">GET /current - Current track (JSON)</div>
<div class="endpoint">GET /phantombot - Current track (text)</div>
<div class="endpoint">GET /health - Health check</div>
</div>
<p><small>You need to authenticate with Spotify first to use the API endpoints.</small></p>
</div>
</body>
</html>"#.to_string()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;