diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..423aa6c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,41 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - name: Run tests with coverage + run: cargo llvm-cov --lcov --output-path lcov.info + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + files: lcov.info + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/Cargo.lock b/Cargo.lock index c9aa6bc..7e8b38c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "ansi-str" version = "0.9.0" @@ -95,6 +104,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -223,6 +242,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "colored" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "console" version = "0.15.11" @@ -606,6 +634,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + [[package]] name = "getrandom" version = "0.4.2" @@ -614,7 +654,7 @@ checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 6.0.0", "wasip2", "wasip3", ] @@ -673,9 +713,10 @@ dependencies = [ "flate2", "indicatif", "inquire", + "mockito", "nix", "open", - "rand", + "rand 0.8.5", "reqwest", "semver", "serde", @@ -685,6 +726,7 @@ dependencies = [ "sqlformat", "tabled", "tar", + "tempfile", "tiny_http", ] @@ -747,6 +789,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -1119,6 +1162,31 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mockito" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90820618712cab19cfc46b274c6c22546a82affcb3c3bdf0f29e3db8e1bb92c0" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "pin-project-lite", + "rand 0.9.2", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "native-tls" version = "0.2.18" @@ -1383,6 +1451,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "r-efi" version = "6.0.0" @@ -1396,8 +1470,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", ] [[package]] @@ -1407,7 +1491,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", ] [[package]] @@ -1419,6 +1513,15 @@ dependencies = [ "getrandom 0.2.17", ] +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1448,6 +1551,35 @@ dependencies = [ "thiserror", ] +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + [[package]] name = "reqwest" version = "0.12.28" @@ -1750,6 +1882,12 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "slab" version = "0.4.12" @@ -1972,6 +2110,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "windows-sys 0.61.2", diff --git a/Cargo.toml b/Cargo.toml index fe164ff..b2e9737 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,11 +36,16 @@ tar = "0.4" semver = "1" sqlformat = "0.5.0" +[dev-dependencies] +mockito = "1" +tempfile = "3" + [package.metadata.release] pre-release-hook = ["git", "cliff", "-o", "CHANGELOG.md", "--tag", "{{version}}" ] publish = false pre-release-replacements = [ { file = "skills/hotdata-cli/SKILL.md", search = "^version: .+", replace = "version: {{version}}", exactly = 1 }, + { file = "README.md", search = "version-[0-9.]+-blue", replace = "version-{{version}}-blue", exactly = 1 }, ] # The profile that 'dist' will build with diff --git a/README.md b/README.md index 81bcf82..7c7b965 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,16 @@ -# hotdata-cli - -Command line interface for [Hotdata](https://www.hotdata.dev). +

+ Hotdata +
+ Hotdata CLI +
+ Command line interface for Hotdata. +

+ version + build + coverage +

+ +--- ## Install diff --git a/src/auth.rs b/src/auth.rs index a85353e..25483d5 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -16,6 +16,34 @@ pub fn logout(profile: &str) { println!("{}", "Logged out.".green()); } +#[derive(Debug, PartialEq)] +pub enum AuthStatus { + Authenticated, + NotConfigured, + Invalid(u16), + ConnectionError(String), +} + +pub fn check_status(profile_config: &config::ProfileConfig) -> AuthStatus { + let api_key = match &profile_config.api_key { + Some(key) if key != "PLACEHOLDER" => key.clone(), + _ => return AuthStatus::NotConfigured, + }; + + let url = format!("{}/workspaces", profile_config.api_url); + let client = reqwest::blocking::Client::new(); + + match client + .get(&url) + .header("Authorization", format!("Bearer {api_key}")) + .send() + { + Ok(resp) if resp.status().is_success() => AuthStatus::Authenticated, + Ok(resp) => AuthStatus::Invalid(resp.status().as_u16()), + Err(e) => AuthStatus::ConnectionError(e.to_string()), + } +} + pub fn status(profile: &str) { let profile_config = match config::load(profile) { Ok(c) => c, @@ -31,24 +59,12 @@ pub fn status(profile: &str) { "" }; - let api_key = match &profile_config.api_key { - Some(key) if key != "PLACEHOLDER" => key.clone(), - _ => { + match check_status(&profile_config) { + AuthStatus::NotConfigured => { print_row("Authenticated", &"No".red().to_string()); print_row("API Key", &"Not configured".red().to_string()); - return; } - }; - - let url = format!("{}/workspaces", profile_config.api_url); - let client = reqwest::blocking::Client::new(); - - match client - .get(&url) - .header("Authorization", format!("Bearer {api_key}")) - .send() - { - Ok(resp) if resp.status().is_success() => { + AuthStatus::Authenticated => { print_row("API URL", &profile_config.api_url.cyan().to_string()); print_row("Authenticated", &"Yes".green().to_string()); print_row("API Key", &format!("{}{source_label}", "Valid".green())); @@ -60,106 +76,113 @@ pub fn status(profile: &str) { None => print_row("Current Workspace", &"None".dark_grey().to_string()), } } - Ok(resp) => { + AuthStatus::Invalid(code) => { print_row("API URL", &profile_config.api_url.cyan().to_string()); print_row("Authenticated", &"No".red().to_string()); print_row( "API Key", &format!( "{}{source_label}", - format!("Invalid (HTTP {})", resp.status()).red() + format!("Invalid (HTTP {})", code).red() ), ); } - Err(e) => { + AuthStatus::ConnectionError(e) => { eprintln!("error connecting to API: {e}"); std::process::exit(1); } } } -pub fn login() { - let profile_config = config::load("default").unwrap_or_default(); - let api_url = profile_config.api_url.to_string(); - let app_url = profile_config.app_url.to_string(); +#[derive(Debug, PartialEq)] +pub enum LoginResult { + Success { token: String, workspace: Option }, + Forbidden, + Failed(String), + ConnectionError(String), +} - // Check if already authenticated - if let Some(api_key) = &profile_config.api_key { - if api_key != "PLACEHOLDER" { - let client = reqwest::blocking::Client::new(); - if let Ok(resp) = client - .get(format!("{api_url}/workspaces")) - .header("Authorization", format!("Bearer {api_key}")) - .send() - { - if resp.status().is_success() { - println!("{}", "You are already signed in.".green()); - print!("Do you want to log in again? [y/N] "); - use std::io::Write; - std::io::stdout().flush().unwrap(); - let mut input = String::new(); - std::io::stdin().read_line(&mut input).unwrap(); - if !input.trim().eq_ignore_ascii_case("y") { - return; - } - } - } - } - } +#[derive(Deserialize)] +struct TokenResponse { + token: String, +} - let code_verifier = generate_code_verifier(); - let code_challenge = generate_code_challenge(&code_verifier); - let state = generate_random_string(32); +#[derive(Deserialize)] +struct WsListResponse { workspaces: Vec } - // Bind to port 0 so the OS picks an available port - let server = - tiny_http::Server::http("127.0.0.1:0").expect("failed to start local callback server"); - let port = server.server_addr().to_ip().unwrap().port(); +#[derive(Deserialize)] +struct WsItem { public_id: String, name: String } - let login_url = format!( - "{app_url}/auth/cli-login?code_challenge={code_challenge}&code_challenge_method=S256&state={state}&callback_port={port}" - ); +/// Exchange an authorization code + PKCE verifier for an API token, +/// then fetch available workspaces. +fn exchange_and_save_token(api_url: &str, code: &str, code_verifier: &str) -> LoginResult { + let token_url = format!("{api_url}/auth/token"); + let client = reqwest::blocking::Client::new(); - println!("Opening browser to log in..."); - stdout() - .execute(Print("If your browser does not open, visit:\n ")) - .unwrap() - .execute(SetForegroundColor(Color::DarkGrey)) - .unwrap() - .execute(Print(format!("{login_url}\n"))) - .unwrap() - .execute(ResetColor) - .unwrap(); + let resp = match client + .post(&token_url) + .json(&serde_json::json!({ "code": code, "code_verifier": code_verifier })) + .send() + { + Ok(r) => r, + Err(e) => return LoginResult::ConnectionError(e.to_string()), + }; - if let Err(e) = open::that(&login_url) { - eprintln!("failed to open browser: {e}"); + if resp.status() == reqwest::StatusCode::FORBIDDEN { + return LoginResult::Forbidden; } - println!("Waiting for login callback..."); + if !resp.status().is_success() { + return LoginResult::Failed(format!("HTTP {}", resp.status())); + } - let request = server.recv().expect("failed to receive callback request"); + let body: TokenResponse = match resp.json() { + Ok(b) => b, + Err(e) => return LoginResult::Failed(format!("error parsing token response: {e}")), + }; + + // Save the token + if let Err(e) = config::save_api_key("default", &body.token) { + return LoginResult::Failed(format!("error saving token: {e}")); + } + + // Fetch and cache workspaces + let ws_url = format!("{api_url}/workspaces"); + let default_workspace = if let Ok(r) = client.get(&ws_url).header("Authorization", format!("Bearer {}", body.token)).send() { + if r.status().is_success() { + if let Ok(ws) = r.json::() { + let entries: Vec = ws.workspaces.into_iter() + .map(|w| config::WorkspaceEntry { public_id: w.public_id, name: w.name }) + .collect(); + let first = entries.first().cloned(); + let _ = config::save_workspaces("default", entries); + first + } else { None } + } else { None } + } else { None }; + + LoginResult::Success { token: body.token, workspace: default_workspace } +} + +/// Wait for the browser callback, verify state, and extract the authorization code. +fn receive_callback(server: &tiny_http::Server, expected_state: &str) -> Result { + let request = server.recv().map_err(|e| format!("failed to receive callback: {e}"))?; let raw_url = request.url().to_string(); let params = parse_query_params(&raw_url); - // Verify state to prevent CSRF - if params.get("state").map(String::as_str) != Some(state.as_str()) { - let _ = request.respond(tiny_http::Response::from_string( - "Login failed: state mismatch", - )); - eprintln!("error: state mismatch — possible CSRF attack"); - std::process::exit(1); + if params.get("state").map(String::as_str) != Some(expected_state) { + let _ = request.respond(tiny_http::Response::from_string("Login failed: state mismatch")); + return Err("state mismatch — possible CSRF attack".into()); } let code = match params.get("code") { Some(c) => c.clone(), None => { let _ = request.respond(tiny_http::Response::from_string("Login failed: no code")); - eprintln!("error: no authorization code received in callback"); - std::process::exit(1); + return Err("no authorization code received in callback".into()); } }; - // Respond to the browser before making the token exchange request let html = r#" @@ -220,55 +243,71 @@ pub fn login() { ); let _ = request.respond(response); - // Exchange the authorization code + verifier for the real API token - #[derive(Deserialize)] - struct TokenResponse { - token: String, + Ok(code) +} + +fn is_already_signed_in(profile_config: &config::ProfileConfig) -> bool { + check_status(profile_config) == AuthStatus::Authenticated +} + +pub fn login() { + let profile_config = config::load("default").unwrap_or_default(); + let api_url = profile_config.api_url.to_string(); + let app_url = profile_config.app_url.to_string(); + + // Check if already authenticated + if is_already_signed_in(&profile_config) { + println!("{}", "You are already signed in.".green()); + print!("Do you want to log in again? [y/N] "); + use std::io::Write; + std::io::stdout().flush().unwrap(); + let mut input = String::new(); + std::io::stdin().read_line(&mut input).unwrap(); + if !input.trim().eq_ignore_ascii_case("y") { + return; + } } - let token_url = format!("{api_url}/auth/token"); - let client = reqwest::blocking::Client::new(); + let code_verifier = generate_code_verifier(); + let code_challenge = generate_code_challenge(&code_verifier); + let state = generate_random_string(32); - let resp: Result = client - .post(&token_url) - .json(&serde_json::json!({ "code": code, "code_verifier": code_verifier })) - .send(); - - match resp { - Ok(r) if r.status().is_success() => { - let body: TokenResponse = match r.json() { - Ok(b) => b, - Err(e) => { - eprintln!("error parsing token response: {e}"); - std::process::exit(1); - } - }; + // Bind to port 0 so the OS picks an available port + let server = + tiny_http::Server::http("127.0.0.1:0").expect("failed to start local callback server"); + let port = server.server_addr().to_ip().unwrap().port(); - if let Err(e) = config::save_api_key("default", &body.token) { - eprintln!("error saving token: {e}"); - std::process::exit(1); - } + let login_url = format!( + "{app_url}/auth/cli-login?code_challenge={code_challenge}&code_challenge_method=S256&state={state}&callback_port={port}" + ); + + println!("Opening browser to log in..."); + stdout() + .execute(Print("If your browser does not open, visit:\n ")) + .unwrap() + .execute(SetForegroundColor(Color::DarkGrey)) + .unwrap() + .execute(Print(format!("{login_url}\n"))) + .unwrap() + .execute(ResetColor) + .unwrap(); + + if let Err(e) = open::that(&login_url) { + eprintln!("failed to open browser: {e}"); + } - // Fetch and cache workspace IDs for use as default - #[derive(Deserialize)] - struct WsListResponse { workspaces: Vec } - #[derive(Deserialize)] - struct WsItem { public_id: String, name: String } - - let ws_url = format!("{api_url}/workspaces"); - let default_workspace = if let Ok(r) = client.get(&ws_url).header("Authorization", format!("Bearer {}", body.token)).send() { - if r.status().is_success() { - if let Ok(ws) = r.json::() { - let entries: Vec = ws.workspaces.into_iter() - .map(|w| config::WorkspaceEntry { public_id: w.public_id, name: w.name }) - .collect(); - let first = entries.first().cloned(); - let _ = config::save_workspaces("default", entries); - first - } else { None } - } else { None } - } else { None }; + println!("Waiting for login callback..."); + let code = match receive_callback(&server, &state) { + Ok(c) => c, + Err(e) => { + eprintln!("error: {e}"); + std::process::exit(1); + } + }; + + match exchange_and_save_token(&api_url, &code, &code_verifier) { + LoginResult::Success { workspace, .. } => { stdout() .execute(SetForegroundColor(Color::Green)) .unwrap() @@ -277,7 +316,7 @@ pub fn login() { .execute(ResetColor) .unwrap(); - match default_workspace { + match workspace { Some(w) => { print_row("Workspace", &format!("{} {}", w.name.as_str().cyan(), format!("({})", w.public_id).dark_grey())); print_row("", &"use 'hotdata workspaces set' to switch workspaces".dark_grey().to_string()); @@ -285,15 +324,15 @@ pub fn login() { None => print_row("Workspace", &"None".dark_grey().to_string()), } } - Ok(r) if r.status() == reqwest::StatusCode::FORBIDDEN => { + LoginResult::Forbidden => { eprintln!("{}", "You are not authorized to create a new API token.".red()); std::process::exit(1); } - Ok(r) => { - eprintln!("token exchange failed: HTTP {}", r.status()); + LoginResult::Failed(msg) => { + eprintln!("token exchange failed: {msg}"); std::process::exit(1); } - Err(e) => { + LoginResult::ConnectionError(e) => { eprintln!("error connecting to API: {e}"); std::process::exit(1); } @@ -328,6 +367,323 @@ fn parse_query_params(url: &str) -> HashMap { .collect() } +#[cfg(test)] +mod tests { + use super::*; + use config::{ApiUrl, ProfileConfig, test_helpers::with_temp_config_dir}; + + fn mock_profile(api_url: &str, api_key: Option<&str>) -> ProfileConfig { + ProfileConfig { + api_key: api_key.map(String::from), + api_url: ApiUrl(Some(api_url.to_string())), + ..Default::default() + } + } + + // --- check_status tests --- + + #[test] + fn status_not_configured_when_no_key() { + let profile = mock_profile("http://localhost", None); + assert_eq!(check_status(&profile), AuthStatus::NotConfigured); + } + + #[test] + fn status_not_configured_when_placeholder() { + let profile = mock_profile("http://localhost", Some("PLACEHOLDER")); + assert_eq!(check_status(&profile), AuthStatus::NotConfigured); + } + + #[test] + fn status_authenticated_with_valid_key() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/workspaces") + .match_header("Authorization", "Bearer valid-key") + .with_status(200) + .with_body(r#"{"workspaces":[]}"#) + .create(); + + let profile = mock_profile(&server.url(), Some("valid-key")); + assert_eq!(check_status(&profile), AuthStatus::Authenticated); + mock.assert(); + } + + #[test] + fn status_invalid_with_bad_key() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/workspaces") + .with_status(401) + .create(); + + let profile = mock_profile(&server.url(), Some("bad-key")); + assert_eq!(check_status(&profile), AuthStatus::Invalid(401)); + mock.assert(); + } + + #[test] + fn status_invalid_with_forbidden() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/workspaces") + .with_status(403) + .create(); + + let profile = mock_profile(&server.url(), Some("forbidden-key")); + assert_eq!(check_status(&profile), AuthStatus::Invalid(403)); + mock.assert(); + } + + #[test] + fn status_connection_error() { + let profile = mock_profile("http://127.0.0.1:1", Some("key")); + match check_status(&profile) { + AuthStatus::ConnectionError(_) => {} + other => panic!("expected ConnectionError, got {:?}", other), + } + } + + // --- is_already_signed_in tests --- + + #[test] + fn already_signed_in_when_key_valid() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/workspaces") + .match_header("Authorization", "Bearer existing-key") + .with_status(200) + .with_body(r#"{"workspaces":[]}"#) + .create(); + + let profile = mock_profile(&server.url(), Some("existing-key")); + assert!(is_already_signed_in(&profile)); + mock.assert(); + } + + #[test] + fn not_signed_in_when_no_key() { + let profile = mock_profile("http://localhost", None); + assert!(!is_already_signed_in(&profile)); + } + + #[test] + fn not_signed_in_when_key_invalid() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/workspaces") + .with_status(401) + .create(); + + let profile = mock_profile(&server.url(), Some("expired-key")); + assert!(!is_already_signed_in(&profile)); + mock.assert(); + } + + // --- exchange_and_save_token tests --- + + #[test] + fn exchange_and_save_token_success() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + + let token_mock = server + .mock("POST", "/auth/token") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"token":"new-api-token-xyz"}"#) + .create(); + + let ws_mock = server + .mock("GET", "/workspaces") + .match_header("Authorization", "Bearer new-api-token-xyz") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"workspaces":[{"public_id":"ws-123","name":"My Workspace"}]}"#) + .create(); + + let result = exchange_and_save_token(&server.url(), "auth-code", "verifier"); + + token_mock.assert(); + ws_mock.assert(); + + match result { + LoginResult::Success { token, workspace } => { + assert_eq!(token, "new-api-token-xyz"); + let ws = workspace.expect("should have a workspace"); + assert_eq!(ws.public_id, "ws-123"); + assert_eq!(ws.name, "My Workspace"); + } + other => panic!("expected Success, got {:?}", other), + } + + // Verify token was saved to config + let profile = config::load("default").unwrap(); + assert_eq!(profile.api_key, Some("new-api-token-xyz".to_string())); + } + + #[test] + fn exchange_and_save_token_success_no_workspaces() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + + let token_mock = server + .mock("POST", "/auth/token") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"token":"token-no-ws"}"#) + .create(); + + let ws_mock = server + .mock("GET", "/workspaces") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"workspaces":[]}"#) + .create(); + + let result = exchange_and_save_token(&server.url(), "code", "verifier"); + + token_mock.assert(); + ws_mock.assert(); + + match result { + LoginResult::Success { token, workspace } => { + assert_eq!(token, "token-no-ws"); + assert!(workspace.is_none()); + } + other => panic!("expected Success, got {:?}", other), + } + } + + #[test] + fn exchange_and_save_token_forbidden() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + + let mock = server + .mock("POST", "/auth/token") + .with_status(403) + .create(); + + let result = exchange_and_save_token(&server.url(), "code", "verifier"); + mock.assert(); + assert_eq!(result, LoginResult::Forbidden); + } + + #[test] + fn exchange_and_save_token_unauthorized() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + + let mock = server + .mock("POST", "/auth/token") + .with_status(401) + .create(); + + let result = exchange_and_save_token(&server.url(), "code", "verifier"); + mock.assert(); + match result { + LoginResult::Failed(msg) => assert!(msg.contains("401")), + other => panic!("expected Failed, got {:?}", other), + } + } + + #[test] + fn exchange_and_save_token_server_error() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + + let mock = server + .mock("POST", "/auth/token") + .with_status(500) + .create(); + + let result = exchange_and_save_token(&server.url(), "code", "verifier"); + mock.assert(); + match result { + LoginResult::Failed(msg) => assert!(msg.contains("500")), + other => panic!("expected Failed, got {:?}", other), + } + } + + #[test] + fn exchange_and_save_token_connection_error() { + let (_tmp, _guard) = with_temp_config_dir(); + + let result = exchange_and_save_token("http://127.0.0.1:1", "code", "verifier"); + match result { + LoginResult::ConnectionError(_) => {} + other => panic!("expected ConnectionError, got {:?}", other), + } + } + + // --- receive_callback tests --- + + #[test] + fn receive_callback_success() { + let server = tiny_http::Server::http("127.0.0.1:0").unwrap(); + let port = server.server_addr().to_ip().unwrap().port(); + + // Simulate browser redirect in a background thread + let handle = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + client + .get(format!( + "http://127.0.0.1:{port}/callback?code=test-auth-code&state=expected-state" + )) + .send() + .unwrap(); + }); + + let result = receive_callback(&server, "expected-state"); + handle.join().unwrap(); + + assert_eq!(result.unwrap(), "test-auth-code"); + } + + #[test] + fn receive_callback_state_mismatch() { + let server = tiny_http::Server::http("127.0.0.1:0").unwrap(); + let port = server.server_addr().to_ip().unwrap().port(); + + let handle = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let _ = client + .get(format!( + "http://127.0.0.1:{port}/callback?code=code&state=wrong-state" + )) + .send(); + }); + + let result = receive_callback(&server, "expected-state"); + handle.join().unwrap(); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("state mismatch")); + } + + #[test] + fn receive_callback_no_code() { + let server = tiny_http::Server::http("127.0.0.1:0").unwrap(); + let port = server.server_addr().to_ip().unwrap().port(); + + let handle = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let _ = client + .get(format!( + "http://127.0.0.1:{port}/callback?state=expected-state" + )) + .send(); + }); + + let result = receive_callback(&server, "expected-state"); + handle.join().unwrap(); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("no authorization code")); + } +} + fn print_row(label: &str, value: &str) { stdout() .execute(SetForegroundColor(Color::DarkGrey)) diff --git a/src/config.rs b/src/config.rs index ebe4da4..3b80d67 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,11 +5,26 @@ use std::collections::HashMap; use std::env; use std::fs; use std::ops::Deref; +use std::path::PathBuf; + +/// Returns the config directory, defaulting to ~/.hotdata. +/// Override with HOTDATA_CONFIG_DIR env var (useful for testing). +pub fn config_dir() -> Result { + if let Ok(dir) = env::var("HOTDATA_CONFIG_DIR") { + return Ok(PathBuf::from(dir)); + } + let user_dirs = UserDirs::new().ok_or("could not determine home directory")?; + Ok(user_dirs.home_dir().join(".hotdata")) +} + +fn config_path() -> Result { + Ok(config_dir()?.join("config.yml")) +} pub const DEFAULT_API_URL: &str = "https://api.hotdata.dev/v1"; pub const DEFAULT_APP_URL: &str = "https://app.hotdata.dev"; -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct WorkspaceEntry { pub public_id: String, pub name: String, @@ -53,7 +68,7 @@ pub enum ApiKeySource { } #[derive(Debug, Clone, Serialize)] -pub struct ApiUrl(Option); +pub struct ApiUrl(pub(crate) Option); impl Default for ApiUrl { fn default() -> Self { @@ -107,8 +122,7 @@ fn write_config(config_path: &std::path::Path, content: &str) -> Result<(), Stri } pub fn save_api_key(profile: &str, api_key: &str) -> Result<(), String> { - let user_dirs = UserDirs::new().ok_or("could not determine home directory")?; - let config_path = user_dirs.home_dir().join(".hotdata").join("config.yml"); + let config_path = config_path()?; let mut config_file: ConfigFile = if config_path.exists() { let content = fs::read_to_string(&config_path) @@ -133,8 +147,7 @@ pub fn save_api_key(profile: &str, api_key: &str) -> Result<(), String> { } pub fn remove_api_key(profile: &str) -> Result<(), String> { - let user_dirs = UserDirs::new().ok_or("could not determine home directory")?; - let config_path = user_dirs.home_dir().join(".hotdata").join("config.yml"); + let config_path = config_path()?; if !config_path.exists() { return Ok(()); @@ -156,8 +169,7 @@ pub fn remove_api_key(profile: &str) -> Result<(), String> { } pub fn save_workspaces(profile: &str, workspaces: Vec) -> Result<(), String> { - let user_dirs = UserDirs::new().ok_or("could not determine home directory")?; - let config_path = user_dirs.home_dir().join(".hotdata").join("config.yml"); + let config_path = config_path()?; let mut config_file: ConfigFile = if config_path.exists() { let content = fs::read_to_string(&config_path) @@ -182,8 +194,7 @@ pub fn save_workspaces(profile: &str, workspaces: Vec) -> Result } pub fn save_default_workspace(profile: &str, workspace: WorkspaceEntry) -> Result<(), String> { - let user_dirs = UserDirs::new().ok_or("could not determine home directory")?; - let config_path = user_dirs.home_dir().join(".hotdata").join("config.yml"); + let config_path = config_path()?; let mut config_file: ConfigFile = if config_path.exists() { let content = fs::read_to_string(&config_path) @@ -222,8 +233,7 @@ pub fn set_api_key_flag(key: String) { } pub fn load(profile: &str) -> Result { - let user_dirs = UserDirs::new().ok_or("could not determine home directory")?; - let config_file = user_dirs.home_dir().join(".hotdata").join("config.yml"); + let config_file = config_path()?; let mut profile_config = if config_file.exists() { let content = @@ -259,3 +269,177 @@ pub fn load(profile: &str) -> Result { Ok(profile_config) } + +/// Test utilities shared across modules. +#[cfg(test)] +pub mod test_helpers { + use std::sync::Mutex; + + // Serialize all tests that modify HOTDATA_CONFIG_DIR env var. + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + /// Set HOTDATA_CONFIG_DIR to a temp dir and return it with a lock guard. + /// Hold the guard for the duration of the test. + pub fn with_temp_config_dir() -> (tempfile::TempDir, std::sync::MutexGuard<'static, ()>) { + let guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + let tmp = tempfile::tempdir().unwrap(); + // SAFETY: tests are serialized via ENV_LOCK mutex, so no concurrent env mutation. + unsafe { std::env::set_var("HOTDATA_CONFIG_DIR", tmp.path()) }; + (tmp, guard) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use super::test_helpers::with_temp_config_dir; + + #[test] + fn save_and_load_api_key() { + let (_tmp, _guard) = with_temp_config_dir(); + + save_api_key("default", "test-key-123").unwrap(); + let profile = load("default").unwrap(); + assert_eq!(profile.api_key, Some("test-key-123".to_string())); + } + + #[test] + fn save_api_key_creates_config_dir() { + let (_tmp, _guard) = with_temp_config_dir(); + + // Config file shouldn't exist yet + let path = config_path().unwrap(); + assert!(!path.exists()); + + save_api_key("default", "key").unwrap(); + assert!(path.exists()); + } + + #[test] + fn remove_api_key_clears_key_and_workspaces() { + let (_tmp, _guard) = with_temp_config_dir(); + + save_api_key("default", "key-to-remove").unwrap(); + save_workspaces( + "default", + vec![WorkspaceEntry { + public_id: "ws-1".into(), + name: "Test WS".into(), + }], + ) + .unwrap(); + + remove_api_key("default").unwrap(); + + let profile = load("default").unwrap(); + assert_eq!(profile.api_key, None); + assert!(profile.workspaces.is_empty()); + } + + #[test] + fn remove_api_key_noop_when_no_config() { + let (_tmp, _guard) = with_temp_config_dir(); + + // Should not error when config file doesn't exist + assert!(remove_api_key("default").is_ok()); + } + + #[test] + fn save_and_load_workspaces() { + let (_tmp, _guard) = with_temp_config_dir(); + + save_api_key("default", "key").unwrap(); + let workspaces = vec![ + WorkspaceEntry { public_id: "ws-1".into(), name: "First".into() }, + WorkspaceEntry { public_id: "ws-2".into(), name: "Second".into() }, + ]; + save_workspaces("default", workspaces).unwrap(); + + let profile = load("default").unwrap(); + assert_eq!(profile.workspaces.len(), 2); + assert_eq!(profile.workspaces[0].public_id, "ws-1"); + assert_eq!(profile.workspaces[1].name, "Second"); + } + + #[test] + fn save_default_workspace_moves_to_front() { + let (_tmp, _guard) = with_temp_config_dir(); + + save_api_key("default", "key").unwrap(); + let workspaces = vec![ + WorkspaceEntry { public_id: "ws-1".into(), name: "First".into() }, + WorkspaceEntry { public_id: "ws-2".into(), name: "Second".into() }, + ]; + save_workspaces("default", workspaces).unwrap(); + + // Set ws-2 as default — should move to front + save_default_workspace( + "default", + WorkspaceEntry { public_id: "ws-2".into(), name: "Second".into() }, + ) + .unwrap(); + + let profile = load("default").unwrap(); + assert_eq!(profile.workspaces[0].public_id, "ws-2"); + assert_eq!(profile.workspaces[1].public_id, "ws-1"); + } + + #[test] + fn load_missing_profile_returns_default() { + let (_tmp, _guard) = with_temp_config_dir(); + + save_api_key("default", "key").unwrap(); + + let profile = load("nonexistent").unwrap(); + assert_eq!(profile.api_key, None); + assert!(profile.workspaces.is_empty()); + } + + #[test] + fn load_no_config_file_returns_default() { + let (_tmp, _guard) = with_temp_config_dir(); + + let profile = load("default").unwrap(); + assert_eq!(profile.api_key, None); + } + + #[test] + fn multiple_profiles() { + let (_tmp, _guard) = with_temp_config_dir(); + + save_api_key("default", "key-default").unwrap(); + save_api_key("staging", "key-staging").unwrap(); + + let default = load("default").unwrap(); + let staging = load("staging").unwrap(); + assert_eq!(default.api_key, Some("key-default".to_string())); + assert_eq!(staging.api_key, Some("key-staging".to_string())); + } + + #[test] + fn resolve_workspace_id_prefers_provided() { + let profile = ProfileConfig { + workspaces: vec![WorkspaceEntry { public_id: "ws-1".into(), name: "WS".into() }], + ..Default::default() + }; + let result = resolve_workspace_id(Some("explicit-id".into()), &profile).unwrap(); + assert_eq!(result, "explicit-id"); + } + + #[test] + fn resolve_workspace_id_falls_back_to_first() { + let profile = ProfileConfig { + workspaces: vec![WorkspaceEntry { public_id: "ws-1".into(), name: "WS".into() }], + ..Default::default() + }; + let result = resolve_workspace_id(None, &profile).unwrap(); + assert_eq!(result, "ws-1"); + } + + #[test] + fn resolve_workspace_id_errors_when_none() { + let profile = ProfileConfig::default(); + let result = resolve_workspace_id(None, &profile); + assert!(result.is_err()); + } +}