diff --git a/k_proxy_app.py b/k_proxy_app.py index f762927..da8a9cf 100644 --- a/k_proxy_app.py +++ b/k_proxy_app.py @@ -55,8 +55,11 @@ from fido2.webauthn import ( UserVerificationRequirement, ) -if getattr(fido2.features.webauthn_json_mapping, "_enabled", None) is None: - fido2.features.webauthn_json_mapping.enabled = True +try: + if getattr(fido2.features.webauthn_json_mapping, "_enabled", None) is None: + fido2.features.webauthn_json_mapping.enabled = True +except AttributeError: + pass HTML = """ diff --git a/tests/test_k_proxy.py b/tests/test_k_proxy.py new file mode 100644 index 0000000..91f9cf0 --- /dev/null +++ b/tests/test_k_proxy.py @@ -0,0 +1,804 @@ +#!/usr/bin/env python3 +""" +Unit tests for k_proxy_app.py. + +Card (FIDO2/CTAP) and k_server (UpstreamPool) are mocked throughout. +All tests run locally without any Qubes VMs or attached hardware. +""" + +import http.client +import json +import sys +import tempfile +import threading +import time +import unittest +from http.server import ThreadingHTTPServer +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import k_proxy_app as app +from k_proxy_app import ( + AUTH_MODE_FIDO2_DIRECT, + AUTH_MODE_PROBE, + Enrollment, + Handler, + ProxyState, + UpstreamPool, + b64u_decode, + b64u_encode, + enrollment_payload, + normalize_display_name, + normalize_username, +) + + +# ── test helpers ────────────────────────────────────────────────────────────── + +def _make_state(tmp_path, *, auth_mode=AUTH_MODE_PROBE, session_ttl=300): + return ProxyState( + session_ttl_s=session_ttl, + auth_mode=auth_mode, + auth_command="echo ok", + server_base_url="http://127.0.0.1:19999", + server_ca_file=None, + server_max_connections=1, + proxy_token="test-token", + enrollment_db=tmp_path / "enrollments.json", + rp_id="localhost", + rp_name="Test RP", + origin="https://localhost", + direct_device_path="", + ) + + +def _enrollment(username="alice", display_name=None, *, credential_data_b64=None): + now = int(time.time()) + return Enrollment( + username=username, + display_name=display_name, + created_at=now, + updated_at=now, + credential_data_b64=credential_data_b64, + ) + + +# ── pure function tests ─────────────────────────────────────────────────────── + +class TestNormalizeUsername(unittest.TestCase): + def test_simple_valid(self): + self.assertEqual(normalize_username("alice"), "alice") + + def test_strips_and_lowercases(self): + self.assertEqual(normalize_username(" Alice "), "alice") + + def test_valid_with_dots_dashes_underscores(self): + for name in ("alice.smith", "alice-smith", "alice_smith", "a1b"): + with self.subTest(name=name): + self.assertEqual(normalize_username(name), name) + + def test_too_short_raises(self): + with self.assertRaises(ValueError): + normalize_username("ab") + + def test_too_long_raises(self): + with self.assertRaises(ValueError): + normalize_username("a" * 33) + + def test_invalid_chars_raise(self): + for bad in ("Alice!", "al ice", "al@ice", "AB"): + with self.subTest(bad=bad): + with self.assertRaises(ValueError): + normalize_username(bad) + + def test_minimum_length_valid(self): + self.assertEqual(normalize_username("abc"), "abc") + + def test_maximum_length_valid(self): + self.assertEqual(normalize_username("a" * 32), "a" * 32) + + +class TestNormalizeDisplayName(unittest.TestCase): + def test_none_returns_none(self): + self.assertIsNone(normalize_display_name(None)) + + def test_whitespace_only_returns_none(self): + self.assertIsNone(normalize_display_name(" ")) + + def test_strips_whitespace(self): + self.assertEqual(normalize_display_name(" Alice Smith "), "Alice Smith") + + def test_max_length_accepted(self): + self.assertEqual(normalize_display_name("a" * 64), "a" * 64) + + def test_over_max_length_raises(self): + with self.assertRaises(ValueError): + normalize_display_name("a" * 65) + + +class TestBase64Utils(unittest.TestCase): + def test_round_trip(self): + original = b"\x00\x01\x02\xffsome\xffbinary" + self.assertEqual(b64u_decode(b64u_encode(original)), original) + + def test_no_padding_chars_in_output(self): + encoded = b64u_encode(b"x") + self.assertNotIn("=", encoded) + + def test_decode_handles_missing_padding(self): + encoded = b64u_encode(b"hello") + self.assertEqual(b64u_decode(encoded), b"hello") + + +class TestEnrollmentPayload(unittest.TestCase): + def test_basic_fields(self): + e = _enrollment("alice", "Alice Smith") + payload = enrollment_payload(e) + self.assertTrue(payload["ok"]) + self.assertEqual(payload["username"], "alice") + self.assertEqual(payload["display_name"], "Alice Smith") + self.assertFalse(payload["has_credential"]) + + def test_has_credential_true_when_data_present(self): + e = _enrollment(credential_data_b64="abc") + self.assertTrue(enrollment_payload(e)["has_credential"]) + + def test_created_flag_included_when_given(self): + e = _enrollment() + self.assertIn("created", enrollment_payload(e, created=True)) + self.assertNotIn("created", enrollment_payload(e)) + + +# ── session management ──────────────────────────────────────────────────────── + +class TestSessionManagement(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.state = _make_state(Path(self._tmpdir.name)) + + def tearDown(self): + self._tmpdir.cleanup() + + def test_create_returns_token_and_future_expiry(self): + token, expires_at = self.state.create_session("alice") + self.assertIsInstance(token, str) + self.assertGreater(len(token), 16) + self.assertGreater(expires_at, time.time()) + + def test_get_session_returns_correct_username(self): + token, _ = self.state.create_session("alice") + session = self.state.get_session(token) + self.assertIsNotNone(session) + self.assertEqual(session.username, "alice") + + def test_get_session_unknown_token_returns_none(self): + self.assertIsNone(self.state.get_session("not-a-real-token")) + + def test_expired_session_returns_none(self): + state = _make_state(Path(self._tmpdir.name), session_ttl=-1) + token, _ = state.create_session("alice") + self.assertIsNone(state.get_session(token)) + + def test_invalidate_session_removes_it(self): + token, _ = self.state.create_session("alice") + self.assertTrue(self.state.invalidate_session(token)) + self.assertIsNone(self.state.get_session(token)) + + def test_invalidate_unknown_token_returns_false(self): + self.assertFalse(self.state.invalidate_session("ghost")) + + def test_active_session_count_tracks_correctly(self): + self.assertEqual(self.state.active_session_count(), 0) + t1, _ = self.state.create_session("alice") + t2, _ = self.state.create_session("bob") + self.assertEqual(self.state.active_session_count(), 2) + self.state.invalidate_session(t1) + self.assertEqual(self.state.active_session_count(), 1) + + def test_expired_sessions_garbage_collected(self): + state = _make_state(Path(self._tmpdir.name), session_ttl=-1) + state.create_session("alice") + state.create_session("bob") + self.assertEqual(state.active_session_count(), 0) + + def test_tokens_are_unique(self): + tokens = {self.state.create_session("alice")[0] for _ in range(20)} + self.assertEqual(len(tokens), 20) + + def test_uses_direct_fido2_false_in_probe_mode(self): + self.assertFalse(self.state.uses_direct_fido2()) + + def test_uses_direct_fido2_true_in_direct_mode(self): + state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT) + self.assertTrue(state.uses_direct_fido2()) + + def test_auth_mode_label_probe(self): + self.assertEqual(self.state.auth_mode_label(), "card_presence_probe") + + def test_auth_mode_label_direct(self): + state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT) + self.assertEqual(state.auth_mode_label(), "fido2_assertion") + + +# ── enrollment management ───────────────────────────────────────────────────── + +class TestEnrollmentManagement(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.tmp_path = Path(self._tmpdir.name) + self.state = _make_state(self.tmp_path) + + def tearDown(self): + self._tmpdir.cleanup() + + def test_register_creates_enrollment(self): + e = self.state.register_enrollment("alice", "Alice Smith") + self.assertEqual(e.username, "alice") + self.assertEqual(e.display_name, "Alice Smith") + self.assertTrue(self.state.has_enrollment("alice")) + + def test_register_persists_across_state_reload(self): + self.state.register_enrollment("alice", None) + state2 = _make_state(self.tmp_path) + self.assertTrue(state2.has_enrollment("alice")) + + def test_register_duplicate_raises_file_exists_error(self): + self.state.register_enrollment("alice", None) + with self.assertRaises(FileExistsError): + self.state.register_enrollment("alice", None) + + def test_register_invalid_username_raises_value_error(self): + with self.assertRaises(ValueError): + self.state.register_enrollment("A!", None) + + def test_register_display_name_too_long_raises(self): + with self.assertRaises(ValueError): + self.state.register_enrollment("alice", "x" * 65) + + def test_update_changes_display_name(self): + self.state.register_enrollment("alice", "Old") + updated = self.state.update_enrollment("alice", "New") + self.assertEqual(updated.display_name, "New") + self.assertEqual(self.state.get_enrollment("alice").display_name, "New") + + def test_update_unknown_user_raises_key_error(self): + with self.assertRaises(KeyError): + self.state.update_enrollment("nobody", "Name") + + def test_delete_removes_enrollment(self): + self.state.register_enrollment("alice", None) + self.state.delete_enrollment("alice") + self.assertFalse(self.state.has_enrollment("alice")) + + def test_delete_invalidates_active_sessions(self): + self.state.register_enrollment("alice", None) + token, _ = self.state.create_session("alice") + self.state.delete_enrollment("alice") + self.assertIsNone(self.state.get_session(token)) + + def test_delete_does_not_affect_other_users_sessions(self): + self.state.register_enrollment("alice", None) + self.state.register_enrollment("bob", None) + bob_token, _ = self.state.create_session("bob") + self.state.delete_enrollment("alice") + self.assertIsNotNone(self.state.get_session(bob_token)) + + def test_delete_unknown_user_raises_key_error(self): + with self.assertRaises(KeyError): + self.state.delete_enrollment("nobody") + + def test_list_enrollments_sorted_alphabetically(self): + self.state.register_enrollment("charlie", None) + self.state.register_enrollment("alice", None) + self.state.register_enrollment("bob", None) + names = [e.username for e in self.state.list_enrollments()] + self.assertEqual(names, ["alice", "bob", "charlie"]) + + def test_get_enrollment_found(self): + self.state.register_enrollment("alice", "Alice") + e = self.state.get_enrollment("alice") + self.assertIsNotNone(e) + self.assertEqual(e.username, "alice") + + def test_get_enrollment_not_found_returns_none(self): + self.assertIsNone(self.state.get_enrollment("nobody")) + + def test_get_enrollment_invalid_username_returns_none(self): + self.assertIsNone(self.state.get_enrollment("!bad!")) + + def test_has_enrollment_true(self): + self.state.register_enrollment("alice", None) + self.assertTrue(self.state.has_enrollment("alice")) + + def test_has_enrollment_false(self): + self.assertFalse(self.state.has_enrollment("nobody")) + + def test_register_direct_mode_delegates_to_direct_method(self): + state = _make_state(self.tmp_path, auth_mode=AUTH_MODE_FIDO2_DIRECT) + fake = _enrollment("alice", credential_data_b64="cred") + with patch.object(state, "_register_direct_fido2", return_value=fake) as mock_direct: + result = state.register_enrollment("alice", None) + mock_direct.assert_called_once_with("alice", None) + self.assertEqual(result.username, "alice") + + +# ── authentication ──────────────────────────────────────────────────────────── + +class TestProbeAuth(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.state = _make_state(Path(self._tmpdir.name)) + + def tearDown(self): + self._tmpdir.cleanup() + + def _mock_proc(self, returncode, stdout="", stderr=""): + proc = MagicMock() + proc.returncode = returncode + proc.stdout = stdout + proc.stderr = stderr + return proc + + def test_success_when_subprocess_returns_zero(self): + with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(0, '{"ok": true}')): + ok, _ = self.state.authenticate_with_card("alice") + self.assertTrue(ok) + + def test_failure_when_subprocess_returns_nonzero(self): + with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(1, stderr="No CTAP HID devices")): + ok, msg = self.state.authenticate_with_card("alice") + self.assertFalse(ok) + self.assertIn("No CTAP HID devices", msg) + + def test_failure_uses_stdout_when_stderr_empty(self): + with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(2, stdout="probe failed")): + ok, msg = self.state.authenticate_with_card("alice") + self.assertFalse(ok) + self.assertIn("probe failed", msg) + + def test_failure_when_subprocess_raises(self): + with patch("k_proxy_app.subprocess.run", side_effect=TimeoutError("timed out")): + ok, msg = self.state.authenticate_with_card("alice") + self.assertFalse(ok) + self.assertIn("auth command failed", msg) + + +class TestDirectFido2Auth(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT) + + def tearDown(self): + self._tmpdir.cleanup() + + def test_unenrolled_user_returns_false(self): + ok, msg = self.state.authenticate_with_card("nobody") + self.assertFalse(ok) + self.assertEqual(msg, "user not enrolled") + + def test_enrolled_without_credential_returns_false(self): + self.state.enrollments["alice"] = _enrollment("alice") + ok, msg = self.state.authenticate_with_card("alice") + self.assertFalse(ok) + self.assertEqual(msg, "user has no registered credential") + + def test_exception_from_ctap_returns_false_with_message(self): + self.state.enrollments["alice"] = _enrollment("alice", credential_data_b64="dW5pY29kZQ") + with patch("k_proxy_app.AttestedCredentialData", side_effect=Exception("bad cbor")): + ok, msg = self.state.authenticate_with_card("alice") + self.assertFalse(ok) + self.assertIn("assertion verification failed", msg) + + def test_success_path_with_mocked_internals(self): + self.state.enrollments["alice"] = _enrollment("alice", credential_data_b64=b64u_encode(b"fake_cred")) + + mock_cred = MagicMock() + mock_options = MagicMock() + mock_options.public_key.rp_id = "localhost" + mock_options.public_key.allow_credentials = [] + mock_options.public_key.challenge = b"challenge" + mock_client_data = MagicMock() + mock_client_data.hash = b"hash" + mock_assertion = MagicMock() + mock_assertion.assertions = None + mock_assertion.credential = {"id": b"cred_id"} + mock_assertion.auth_data = b"auth" + mock_assertion.signature = b"sig" + mock_assertion.user = None + + with patch("k_proxy_app.AttestedCredentialData", return_value=mock_cred), \ + patch("k_proxy_app.AuthenticationResponse", return_value=MagicMock()), \ + patch("k_proxy_app.AuthenticatorAssertionResponse", return_value=MagicMock()), \ + patch.object(self.state, "_drop_direct_device"), \ + patch.object(self.state.fido_server, "authenticate_begin", return_value=(mock_options, {})), \ + patch.object(self.state, "_collect_client_data", return_value=mock_client_data), \ + patch.object(self.state, "_with_direct_ctap2", return_value=mock_assertion), \ + patch.object(self.state.fido_server, "authenticate_complete"): + ok, msg = self.state.authenticate_with_card("alice") + + self.assertTrue(ok) + self.assertEqual(msg, "assertion verified") + + +# ── upstream pool ───────────────────────────────────────────────────────────── + +class TestUpstreamPool(unittest.TestCase): + def _pool(self): + return UpstreamPool( + server_base_url="http://127.0.0.1:19999", + server_ca_file=None, + max_connections=2, + ) + + def _mock_response(self, status, body, will_close=True): + resp = MagicMock() + resp.status = status + resp.read.return_value = body + resp.will_close = will_close + return resp + + def test_successful_request_returns_status_and_parsed_json(self): + pool = self._pool() + conn = MagicMock() + conn.getresponse.return_value = self._mock_response(200, b'{"ok": true, "value": 7}') + with patch.object(pool, "_new_connection", return_value=conn): + status, data = pool.request_json("/resource/counter", {"X-Proxy-Token": "tok"}, {}) + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertEqual(data["value"], 7) + + def test_non_200_status_is_returned_as_is(self): + pool = self._pool() + conn = MagicMock() + conn.getresponse.return_value = self._mock_response(403, b'{"ok": false, "error": "forbidden"}') + with patch.object(pool, "_new_connection", return_value=conn): + status, data = pool.request_json("/test", {}, {}) + self.assertEqual(status, 403) + self.assertFalse(data["ok"]) + + def test_oserror_returns_502(self): + pool = self._pool() + conn = MagicMock() + conn.request.side_effect = OSError("connection refused") + with patch.object(pool, "_new_connection", return_value=conn): + status, data = pool.request_json("/test", {}, {}) + self.assertEqual(status, 502) + self.assertIn("server unavailable", data["error"]) + + def test_empty_body_returns_empty_dict(self): + pool = self._pool() + conn = MagicMock() + conn.getresponse.return_value = self._mock_response(200, b"") + with patch.object(pool, "_new_connection", return_value=conn): + status, data = pool.request_json("/test", {}, {}) + self.assertEqual(data, {}) + + def test_connection_reused_when_will_close_false(self): + pool = self._pool() + conn = MagicMock() + conn.getresponse.return_value = self._mock_response(200, b'{"ok": true}', will_close=False) + with patch.object(pool, "_new_connection", return_value=conn) as mock_new: + pool.request_json("/test", {}, {}) + pool.request_json("/test", {}, {}) + self.assertEqual(mock_new.call_count, 1) + self.assertEqual(conn.request.call_count, 2) + + def test_connection_not_reused_when_will_close_true(self): + pool = self._pool() + conn = MagicMock() + conn.getresponse.return_value = self._mock_response(200, b'{"ok": true}', will_close=True) + with patch.object(pool, "_new_connection", return_value=conn) as mock_new: + pool.request_json("/test", {}, {}) + pool.request_json("/test", {}, {}) + self.assertEqual(mock_new.call_count, 2) + + +# ── HTTP handler integration tests ──────────────────────────────────────────── + +class ServerFixture(unittest.TestCase): + """Spins up a real ThreadingHTTPServer backed by a ProxyState with mocked + card and upstream. Card auth and fetch_counter are patched per-test via + patch.object(self.state, ...) or the _login() helper.""" + + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.tmp_path = Path(self._tmpdir.name) + self.state = _make_state(self.tmp_path) + Handler.state = self.state + self.server = ThreadingHTTPServer(("127.0.0.1", 0), Handler) + self.port = self.server.server_address[1] + self._thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self._thread.start() + + def tearDown(self): + self.server.shutdown() + self.server.server_close() + self._tmpdir.cleanup() + + # ── request helpers ── + + def _conn(self): + return http.client.HTTPConnection("127.0.0.1", self.port, timeout=5) + + def _get(self, path): + conn = self._conn() + try: + conn.request("GET", path) + resp = conn.getresponse() + return resp.status, resp.read() + finally: + conn.close() + + def _get_json(self, path): + status, body = self._get(path) + return status, json.loads(body) + + def _post(self, path, payload=None, token=None): + conn = self._conn() + try: + body = json.dumps(payload or {}).encode() + headers = { + "Content-Type": "application/json", + "Content-Length": str(len(body)), + } + if token: + headers["Authorization"] = f"Bearer {token}" + conn.request("POST", path, body=body, headers=headers) + resp = conn.getresponse() + return resp.status, json.loads(resp.read()) + finally: + conn.close() + + def _post_raw(self, path, raw_body): + conn = self._conn() + try: + headers = { + "Content-Type": "application/json", + "Content-Length": str(len(raw_body)), + } + conn.request("POST", path, body=raw_body, headers=headers) + resp = conn.getresponse() + return resp.status, resp.read() + finally: + conn.close() + + # ── state helpers ── + + def _enroll(self, username="alice", display_name=None): + self.state.register_enrollment(username, display_name) + + def _login(self, username="alice"): + """Enroll user and obtain a session token with the card mocked to succeed.""" + self._enroll(username) + with patch.object(self.state, "authenticate_with_card", return_value=(True, "ok")): + status, data = self._post("/session/login", {"username": username}) + self.assertEqual(status, 200, f"login setup failed: {data}") + return data["session_token"] + + +class TestHandlerHealth(ServerFixture): + def test_get_root_returns_html(self): + status, body = self._get("/") + self.assertEqual(status, 200) + self.assertIn(b"ChromeCard", body) + + def test_health_returns_service_info(self): + status, data = self._get_json("/health") + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertEqual(data["service"], "k_proxy") + self.assertIn("active_sessions", data) + + def test_health_reflects_active_session_count(self): + self.state.create_session("alice") + _, data = self._get_json("/health") + self.assertEqual(data["active_sessions"], 1) + + def test_unknown_get_returns_404(self): + status, _ = self._get("/nonexistent") + self.assertEqual(status, 404) + + def test_unknown_post_returns_404(self): + status, _ = self._post_raw("/nonexistent", b"{}") + self.assertEqual(status, 404) + + +class TestHandlerEnrollment(ServerFixture): + def test_register_new_user_returns_200(self): + status, data = self._post("/enroll/register", {"username": "alice", "display_name": "Alice"}) + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertEqual(data["username"], "alice") + self.assertEqual(data["display_name"], "Alice") + + def test_register_duplicate_returns_409(self): + self._enroll("alice") + status, data = self._post("/enroll/register", {"username": "alice"}) + self.assertEqual(status, 409) + self.assertFalse(data["ok"]) + + def test_register_invalid_username_returns_400(self): + status, data = self._post("/enroll/register", {"username": "A!"}) + self.assertEqual(status, 400) + self.assertFalse(data["ok"]) + + def test_register_invalid_json_returns_400(self): + status, _ = self._post_raw("/enroll/register", b"not-json") + self.assertEqual(status, 400) + + def test_enroll_status_found(self): + self._enroll("alice", "Alice Smith") + status, data = self._get_json("/enroll/status?username=alice") + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertEqual(data["display_name"], "Alice Smith") + + def test_enroll_status_not_found_returns_404(self): + status, data = self._get_json("/enroll/status?username=nobody") + self.assertEqual(status, 404) + + def test_enroll_status_missing_param_returns_400(self): + status, data = self._get_json("/enroll/status") + self.assertEqual(status, 400) + + def test_enroll_list_empty(self): + status, data = self._get_json("/enroll/list") + self.assertEqual(status, 200) + self.assertEqual(data["users"], []) + + def test_enroll_list_returns_sorted_users(self): + self._enroll("charlie") + self._enroll("alice") + _, data = self._get_json("/enroll/list") + names = [u["username"] for u in data["users"]] + self.assertEqual(names, ["alice", "charlie"]) + + def test_enroll_update_changes_display_name(self): + self._enroll("alice", "Old") + status, data = self._post("/enroll/update", {"username": "alice", "display_name": "New"}) + self.assertEqual(status, 200) + self.assertEqual(data["display_name"], "New") + + def test_enroll_update_unknown_returns_404(self): + status, _ = self._post("/enroll/update", {"username": "nobody"}) + self.assertEqual(status, 404) + + def test_enroll_delete_returns_200_and_deleted_true(self): + self._enroll("alice") + status, data = self._post("/enroll/delete", {"username": "alice"}) + self.assertEqual(status, 200) + self.assertTrue(data["deleted"]) + self.assertFalse(self.state.has_enrollment("alice")) + + def test_enroll_delete_unknown_returns_404(self): + status, _ = self._post("/enroll/delete", {"username": "nobody"}) + self.assertEqual(status, 404) + + +class TestHandlerSession(ServerFixture): + def test_login_success_returns_token(self): + self._enroll("alice") + with patch.object(self.state, "authenticate_with_card", return_value=(True, "ok")): + status, data = self._post("/session/login", {"username": "alice"}) + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertIn("session_token", data) + self.assertIn("expires_at", data) + self.assertEqual(data["auth_mode"], "card_presence_probe") + + def test_login_unenrolled_user_returns_403(self): + status, data = self._post("/session/login", {"username": "nobody"}) + self.assertEqual(status, 403) + self.assertFalse(data["ok"]) + self.assertIn("not enrolled", data["error"]) + + def test_login_card_failure_returns_401(self): + self._enroll("alice") + with patch.object(self.state, "authenticate_with_card", return_value=(False, "No CTAP devices")): + status, data = self._post("/session/login", {"username": "alice"}) + self.assertEqual(status, 401) + self.assertFalse(data["ok"]) + self.assertIn("card auth failed", data["error"]) + self.assertIn("No CTAP devices", data["details"]) + + def test_login_invalid_username_returns_400(self): + status, data = self._post("/session/login", {"username": "!bad!"}) + self.assertEqual(status, 400) + + def test_session_status_valid_token(self): + token = self._login() + status, data = self._post("/session/status", {}, token=token) + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertEqual(data["username"], "alice") + self.assertIn("expires_at", data) + self.assertGreaterEqual(data["seconds_remaining"], 0) + + def test_session_status_no_token_returns_401(self): + status, data = self._post("/session/status", {}) + self.assertEqual(status, 401) + + def test_session_status_invalid_token_returns_401(self): + status, data = self._post("/session/status", {}, token="bad-token") + self.assertEqual(status, 401) + self.assertIn("invalid or expired", data["error"]) + + def test_logout_valid_token(self): + token = self._login() + status, data = self._post("/session/logout", {}, token=token) + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertTrue(data["invalidated"]) + self.assertIsNone(self.state.get_session(token)) + + def test_logout_invalid_token_returns_200_not_invalidated(self): + status, data = self._post("/session/logout", {}, token="ghost") + self.assertEqual(status, 200) + self.assertFalse(data["invalidated"]) + + def test_logout_no_token_returns_401(self): + status, data = self._post("/session/logout", {}) + self.assertEqual(status, 401) + + def test_session_invalid_after_logout(self): + token = self._login() + self._post("/session/logout", {}, token=token) + status, data = self._post("/session/status", {}, token=token) + self.assertEqual(status, 401) + + def test_multiple_sessions_independent(self): + t1 = self._login("alice") + t2 = self._login("bob") + # logout alice, bob's session still valid + self._post("/session/logout", {}, token=t1) + status, data = self._post("/session/status", {}, token=t2) + self.assertEqual(status, 200) + self.assertEqual(data["username"], "bob") + + +class TestHandlerResource(ServerFixture): + def test_counter_with_valid_session(self): + token = self._login() + with patch.object(self.state, "fetch_counter", return_value=(200, {"ok": True, "value": 5})): + status, data = self._post("/resource/counter", {}, token=token) + self.assertEqual(status, 200) + self.assertTrue(data["ok"]) + self.assertEqual(data["upstream"]["value"], 5) + self.assertEqual(data["username"], "alice") + self.assertTrue(data["session_reused"]) + + def test_counter_no_token_returns_401(self): + status, data = self._post("/resource/counter", {}) + self.assertEqual(status, 401) + + def test_counter_invalid_token_returns_401(self): + status, data = self._post("/resource/counter", {}, token="garbage") + self.assertEqual(status, 401) + + def test_counter_upstream_failure_propagated(self): + token = self._login() + with patch.object(self.state, "fetch_counter", return_value=(502, {"ok": False, "error": "server unavailable"})): + status, data = self._post("/resource/counter", {}, token=token) + self.assertEqual(status, 502) + self.assertFalse(data["ok"]) + self.assertIn("upstream failed", data["error"]) + + def test_counter_returns_upstream_non_200_as_error(self): + token = self._login() + with patch.object(self.state, "fetch_counter", return_value=(403, {"ok": False, "error": "forbidden"})): + status, data = self._post("/resource/counter", {}, token=token) + self.assertEqual(status, 403) + self.assertFalse(data["ok"]) + + def test_counter_session_still_valid_after_call(self): + token = self._login() + with patch.object(self.state, "fetch_counter", return_value=(200, {"ok": True, "value": 1})): + self._post("/resource/counter", {}, token=token) + status, _ = self._post("/session/status", {}, token=token) + self.assertEqual(status, 200) + + +if __name__ == "__main__": + unittest.main(verbosity=2)