Add k_proxy unit tests with mocked card and upstream
100 tests covering session management, enrollment CRUD, probe and direct FIDO2 auth routing, UpstreamPool connection handling, and all HTTP endpoints via a live in-process server. Card (FIDO2/CTAP) and k_server are fully mocked so the suite runs locally without hardware or VMs. Also hardens the fido2.features.webauthn_json_mapping import guard to tolerate older python-fido2 versions that lack the attribute. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
86189793b7
commit
e7212b49a0
|
|
@ -55,8 +55,11 @@ from fido2.webauthn import (
|
|||
UserVerificationRequirement,
|
||||
)
|
||||
|
||||
if getattr(fido2.features.webauthn_json_mapping, "_enabled", None) is None:
|
||||
fido2.features.webauthn_json_mapping.enabled = True
|
||||
try:
|
||||
if getattr(fido2.features.webauthn_json_mapping, "_enabled", None) is None:
|
||||
fido2.features.webauthn_json_mapping.enabled = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
HTML = """<!doctype html>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,804 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for k_proxy_app.py.
|
||||
|
||||
Card (FIDO2/CTAP) and k_server (UpstreamPool) are mocked throughout.
|
||||
All tests run locally without any Qubes VMs or attached hardware.
|
||||
"""
|
||||
|
||||
import http.client
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from http.server import ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import k_proxy_app as app
|
||||
from k_proxy_app import (
|
||||
AUTH_MODE_FIDO2_DIRECT,
|
||||
AUTH_MODE_PROBE,
|
||||
Enrollment,
|
||||
Handler,
|
||||
ProxyState,
|
||||
UpstreamPool,
|
||||
b64u_decode,
|
||||
b64u_encode,
|
||||
enrollment_payload,
|
||||
normalize_display_name,
|
||||
normalize_username,
|
||||
)
|
||||
|
||||
|
||||
# ── test helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_state(tmp_path, *, auth_mode=AUTH_MODE_PROBE, session_ttl=300):
|
||||
return ProxyState(
|
||||
session_ttl_s=session_ttl,
|
||||
auth_mode=auth_mode,
|
||||
auth_command="echo ok",
|
||||
server_base_url="http://127.0.0.1:19999",
|
||||
server_ca_file=None,
|
||||
server_max_connections=1,
|
||||
proxy_token="test-token",
|
||||
enrollment_db=tmp_path / "enrollments.json",
|
||||
rp_id="localhost",
|
||||
rp_name="Test RP",
|
||||
origin="https://localhost",
|
||||
direct_device_path="",
|
||||
)
|
||||
|
||||
|
||||
def _enrollment(username="alice", display_name=None, *, credential_data_b64=None):
|
||||
now = int(time.time())
|
||||
return Enrollment(
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
credential_data_b64=credential_data_b64,
|
||||
)
|
||||
|
||||
|
||||
# ── pure function tests ───────────────────────────────────────────────────────
|
||||
|
||||
class TestNormalizeUsername(unittest.TestCase):
|
||||
def test_simple_valid(self):
|
||||
self.assertEqual(normalize_username("alice"), "alice")
|
||||
|
||||
def test_strips_and_lowercases(self):
|
||||
self.assertEqual(normalize_username(" Alice "), "alice")
|
||||
|
||||
def test_valid_with_dots_dashes_underscores(self):
|
||||
for name in ("alice.smith", "alice-smith", "alice_smith", "a1b"):
|
||||
with self.subTest(name=name):
|
||||
self.assertEqual(normalize_username(name), name)
|
||||
|
||||
def test_too_short_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
normalize_username("ab")
|
||||
|
||||
def test_too_long_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
normalize_username("a" * 33)
|
||||
|
||||
def test_invalid_chars_raise(self):
|
||||
for bad in ("Alice!", "al ice", "al@ice", "AB"):
|
||||
with self.subTest(bad=bad):
|
||||
with self.assertRaises(ValueError):
|
||||
normalize_username(bad)
|
||||
|
||||
def test_minimum_length_valid(self):
|
||||
self.assertEqual(normalize_username("abc"), "abc")
|
||||
|
||||
def test_maximum_length_valid(self):
|
||||
self.assertEqual(normalize_username("a" * 32), "a" * 32)
|
||||
|
||||
|
||||
class TestNormalizeDisplayName(unittest.TestCase):
|
||||
def test_none_returns_none(self):
|
||||
self.assertIsNone(normalize_display_name(None))
|
||||
|
||||
def test_whitespace_only_returns_none(self):
|
||||
self.assertIsNone(normalize_display_name(" "))
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
self.assertEqual(normalize_display_name(" Alice Smith "), "Alice Smith")
|
||||
|
||||
def test_max_length_accepted(self):
|
||||
self.assertEqual(normalize_display_name("a" * 64), "a" * 64)
|
||||
|
||||
def test_over_max_length_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
normalize_display_name("a" * 65)
|
||||
|
||||
|
||||
class TestBase64Utils(unittest.TestCase):
|
||||
def test_round_trip(self):
|
||||
original = b"\x00\x01\x02\xffsome\xffbinary"
|
||||
self.assertEqual(b64u_decode(b64u_encode(original)), original)
|
||||
|
||||
def test_no_padding_chars_in_output(self):
|
||||
encoded = b64u_encode(b"x")
|
||||
self.assertNotIn("=", encoded)
|
||||
|
||||
def test_decode_handles_missing_padding(self):
|
||||
encoded = b64u_encode(b"hello")
|
||||
self.assertEqual(b64u_decode(encoded), b"hello")
|
||||
|
||||
|
||||
class TestEnrollmentPayload(unittest.TestCase):
|
||||
def test_basic_fields(self):
|
||||
e = _enrollment("alice", "Alice Smith")
|
||||
payload = enrollment_payload(e)
|
||||
self.assertTrue(payload["ok"])
|
||||
self.assertEqual(payload["username"], "alice")
|
||||
self.assertEqual(payload["display_name"], "Alice Smith")
|
||||
self.assertFalse(payload["has_credential"])
|
||||
|
||||
def test_has_credential_true_when_data_present(self):
|
||||
e = _enrollment(credential_data_b64="abc")
|
||||
self.assertTrue(enrollment_payload(e)["has_credential"])
|
||||
|
||||
def test_created_flag_included_when_given(self):
|
||||
e = _enrollment()
|
||||
self.assertIn("created", enrollment_payload(e, created=True))
|
||||
self.assertNotIn("created", enrollment_payload(e))
|
||||
|
||||
|
||||
# ── session management ────────────────────────────────────────────────────────
|
||||
|
||||
class TestSessionManagement(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.TemporaryDirectory()
|
||||
self.state = _make_state(Path(self._tmpdir.name))
|
||||
|
||||
def tearDown(self):
|
||||
self._tmpdir.cleanup()
|
||||
|
||||
def test_create_returns_token_and_future_expiry(self):
|
||||
token, expires_at = self.state.create_session("alice")
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertGreater(len(token), 16)
|
||||
self.assertGreater(expires_at, time.time())
|
||||
|
||||
def test_get_session_returns_correct_username(self):
|
||||
token, _ = self.state.create_session("alice")
|
||||
session = self.state.get_session(token)
|
||||
self.assertIsNotNone(session)
|
||||
self.assertEqual(session.username, "alice")
|
||||
|
||||
def test_get_session_unknown_token_returns_none(self):
|
||||
self.assertIsNone(self.state.get_session("not-a-real-token"))
|
||||
|
||||
def test_expired_session_returns_none(self):
|
||||
state = _make_state(Path(self._tmpdir.name), session_ttl=-1)
|
||||
token, _ = state.create_session("alice")
|
||||
self.assertIsNone(state.get_session(token))
|
||||
|
||||
def test_invalidate_session_removes_it(self):
|
||||
token, _ = self.state.create_session("alice")
|
||||
self.assertTrue(self.state.invalidate_session(token))
|
||||
self.assertIsNone(self.state.get_session(token))
|
||||
|
||||
def test_invalidate_unknown_token_returns_false(self):
|
||||
self.assertFalse(self.state.invalidate_session("ghost"))
|
||||
|
||||
def test_active_session_count_tracks_correctly(self):
|
||||
self.assertEqual(self.state.active_session_count(), 0)
|
||||
t1, _ = self.state.create_session("alice")
|
||||
t2, _ = self.state.create_session("bob")
|
||||
self.assertEqual(self.state.active_session_count(), 2)
|
||||
self.state.invalidate_session(t1)
|
||||
self.assertEqual(self.state.active_session_count(), 1)
|
||||
|
||||
def test_expired_sessions_garbage_collected(self):
|
||||
state = _make_state(Path(self._tmpdir.name), session_ttl=-1)
|
||||
state.create_session("alice")
|
||||
state.create_session("bob")
|
||||
self.assertEqual(state.active_session_count(), 0)
|
||||
|
||||
def test_tokens_are_unique(self):
|
||||
tokens = {self.state.create_session("alice")[0] for _ in range(20)}
|
||||
self.assertEqual(len(tokens), 20)
|
||||
|
||||
def test_uses_direct_fido2_false_in_probe_mode(self):
|
||||
self.assertFalse(self.state.uses_direct_fido2())
|
||||
|
||||
def test_uses_direct_fido2_true_in_direct_mode(self):
|
||||
state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
||||
self.assertTrue(state.uses_direct_fido2())
|
||||
|
||||
def test_auth_mode_label_probe(self):
|
||||
self.assertEqual(self.state.auth_mode_label(), "card_presence_probe")
|
||||
|
||||
def test_auth_mode_label_direct(self):
|
||||
state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
||||
self.assertEqual(state.auth_mode_label(), "fido2_assertion")
|
||||
|
||||
|
||||
# ── enrollment management ─────────────────────────────────────────────────────
|
||||
|
||||
class TestEnrollmentManagement(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.TemporaryDirectory()
|
||||
self.tmp_path = Path(self._tmpdir.name)
|
||||
self.state = _make_state(self.tmp_path)
|
||||
|
||||
def tearDown(self):
|
||||
self._tmpdir.cleanup()
|
||||
|
||||
def test_register_creates_enrollment(self):
|
||||
e = self.state.register_enrollment("alice", "Alice Smith")
|
||||
self.assertEqual(e.username, "alice")
|
||||
self.assertEqual(e.display_name, "Alice Smith")
|
||||
self.assertTrue(self.state.has_enrollment("alice"))
|
||||
|
||||
def test_register_persists_across_state_reload(self):
|
||||
self.state.register_enrollment("alice", None)
|
||||
state2 = _make_state(self.tmp_path)
|
||||
self.assertTrue(state2.has_enrollment("alice"))
|
||||
|
||||
def test_register_duplicate_raises_file_exists_error(self):
|
||||
self.state.register_enrollment("alice", None)
|
||||
with self.assertRaises(FileExistsError):
|
||||
self.state.register_enrollment("alice", None)
|
||||
|
||||
def test_register_invalid_username_raises_value_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.state.register_enrollment("A!", None)
|
||||
|
||||
def test_register_display_name_too_long_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.state.register_enrollment("alice", "x" * 65)
|
||||
|
||||
def test_update_changes_display_name(self):
|
||||
self.state.register_enrollment("alice", "Old")
|
||||
updated = self.state.update_enrollment("alice", "New")
|
||||
self.assertEqual(updated.display_name, "New")
|
||||
self.assertEqual(self.state.get_enrollment("alice").display_name, "New")
|
||||
|
||||
def test_update_unknown_user_raises_key_error(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.state.update_enrollment("nobody", "Name")
|
||||
|
||||
def test_delete_removes_enrollment(self):
|
||||
self.state.register_enrollment("alice", None)
|
||||
self.state.delete_enrollment("alice")
|
||||
self.assertFalse(self.state.has_enrollment("alice"))
|
||||
|
||||
def test_delete_invalidates_active_sessions(self):
|
||||
self.state.register_enrollment("alice", None)
|
||||
token, _ = self.state.create_session("alice")
|
||||
self.state.delete_enrollment("alice")
|
||||
self.assertIsNone(self.state.get_session(token))
|
||||
|
||||
def test_delete_does_not_affect_other_users_sessions(self):
|
||||
self.state.register_enrollment("alice", None)
|
||||
self.state.register_enrollment("bob", None)
|
||||
bob_token, _ = self.state.create_session("bob")
|
||||
self.state.delete_enrollment("alice")
|
||||
self.assertIsNotNone(self.state.get_session(bob_token))
|
||||
|
||||
def test_delete_unknown_user_raises_key_error(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.state.delete_enrollment("nobody")
|
||||
|
||||
def test_list_enrollments_sorted_alphabetically(self):
|
||||
self.state.register_enrollment("charlie", None)
|
||||
self.state.register_enrollment("alice", None)
|
||||
self.state.register_enrollment("bob", None)
|
||||
names = [e.username for e in self.state.list_enrollments()]
|
||||
self.assertEqual(names, ["alice", "bob", "charlie"])
|
||||
|
||||
def test_get_enrollment_found(self):
|
||||
self.state.register_enrollment("alice", "Alice")
|
||||
e = self.state.get_enrollment("alice")
|
||||
self.assertIsNotNone(e)
|
||||
self.assertEqual(e.username, "alice")
|
||||
|
||||
def test_get_enrollment_not_found_returns_none(self):
|
||||
self.assertIsNone(self.state.get_enrollment("nobody"))
|
||||
|
||||
def test_get_enrollment_invalid_username_returns_none(self):
|
||||
self.assertIsNone(self.state.get_enrollment("!bad!"))
|
||||
|
||||
def test_has_enrollment_true(self):
|
||||
self.state.register_enrollment("alice", None)
|
||||
self.assertTrue(self.state.has_enrollment("alice"))
|
||||
|
||||
def test_has_enrollment_false(self):
|
||||
self.assertFalse(self.state.has_enrollment("nobody"))
|
||||
|
||||
def test_register_direct_mode_delegates_to_direct_method(self):
|
||||
state = _make_state(self.tmp_path, auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
||||
fake = _enrollment("alice", credential_data_b64="cred")
|
||||
with patch.object(state, "_register_direct_fido2", return_value=fake) as mock_direct:
|
||||
result = state.register_enrollment("alice", None)
|
||||
mock_direct.assert_called_once_with("alice", None)
|
||||
self.assertEqual(result.username, "alice")
|
||||
|
||||
|
||||
# ── authentication ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestProbeAuth(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.TemporaryDirectory()
|
||||
self.state = _make_state(Path(self._tmpdir.name))
|
||||
|
||||
def tearDown(self):
|
||||
self._tmpdir.cleanup()
|
||||
|
||||
def _mock_proc(self, returncode, stdout="", stderr=""):
|
||||
proc = MagicMock()
|
||||
proc.returncode = returncode
|
||||
proc.stdout = stdout
|
||||
proc.stderr = stderr
|
||||
return proc
|
||||
|
||||
def test_success_when_subprocess_returns_zero(self):
|
||||
with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(0, '{"ok": true}')):
|
||||
ok, _ = self.state.authenticate_with_card("alice")
|
||||
self.assertTrue(ok)
|
||||
|
||||
def test_failure_when_subprocess_returns_nonzero(self):
|
||||
with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(1, stderr="No CTAP HID devices")):
|
||||
ok, msg = self.state.authenticate_with_card("alice")
|
||||
self.assertFalse(ok)
|
||||
self.assertIn("No CTAP HID devices", msg)
|
||||
|
||||
def test_failure_uses_stdout_when_stderr_empty(self):
|
||||
with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(2, stdout="probe failed")):
|
||||
ok, msg = self.state.authenticate_with_card("alice")
|
||||
self.assertFalse(ok)
|
||||
self.assertIn("probe failed", msg)
|
||||
|
||||
def test_failure_when_subprocess_raises(self):
|
||||
with patch("k_proxy_app.subprocess.run", side_effect=TimeoutError("timed out")):
|
||||
ok, msg = self.state.authenticate_with_card("alice")
|
||||
self.assertFalse(ok)
|
||||
self.assertIn("auth command failed", msg)
|
||||
|
||||
|
||||
class TestDirectFido2Auth(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.TemporaryDirectory()
|
||||
self.state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
||||
|
||||
def tearDown(self):
|
||||
self._tmpdir.cleanup()
|
||||
|
||||
def test_unenrolled_user_returns_false(self):
|
||||
ok, msg = self.state.authenticate_with_card("nobody")
|
||||
self.assertFalse(ok)
|
||||
self.assertEqual(msg, "user not enrolled")
|
||||
|
||||
def test_enrolled_without_credential_returns_false(self):
|
||||
self.state.enrollments["alice"] = _enrollment("alice")
|
||||
ok, msg = self.state.authenticate_with_card("alice")
|
||||
self.assertFalse(ok)
|
||||
self.assertEqual(msg, "user has no registered credential")
|
||||
|
||||
def test_exception_from_ctap_returns_false_with_message(self):
|
||||
self.state.enrollments["alice"] = _enrollment("alice", credential_data_b64="dW5pY29kZQ")
|
||||
with patch("k_proxy_app.AttestedCredentialData", side_effect=Exception("bad cbor")):
|
||||
ok, msg = self.state.authenticate_with_card("alice")
|
||||
self.assertFalse(ok)
|
||||
self.assertIn("assertion verification failed", msg)
|
||||
|
||||
def test_success_path_with_mocked_internals(self):
|
||||
self.state.enrollments["alice"] = _enrollment("alice", credential_data_b64=b64u_encode(b"fake_cred"))
|
||||
|
||||
mock_cred = MagicMock()
|
||||
mock_options = MagicMock()
|
||||
mock_options.public_key.rp_id = "localhost"
|
||||
mock_options.public_key.allow_credentials = []
|
||||
mock_options.public_key.challenge = b"challenge"
|
||||
mock_client_data = MagicMock()
|
||||
mock_client_data.hash = b"hash"
|
||||
mock_assertion = MagicMock()
|
||||
mock_assertion.assertions = None
|
||||
mock_assertion.credential = {"id": b"cred_id"}
|
||||
mock_assertion.auth_data = b"auth"
|
||||
mock_assertion.signature = b"sig"
|
||||
mock_assertion.user = None
|
||||
|
||||
with patch("k_proxy_app.AttestedCredentialData", return_value=mock_cred), \
|
||||
patch("k_proxy_app.AuthenticationResponse", return_value=MagicMock()), \
|
||||
patch("k_proxy_app.AuthenticatorAssertionResponse", return_value=MagicMock()), \
|
||||
patch.object(self.state, "_drop_direct_device"), \
|
||||
patch.object(self.state.fido_server, "authenticate_begin", return_value=(mock_options, {})), \
|
||||
patch.object(self.state, "_collect_client_data", return_value=mock_client_data), \
|
||||
patch.object(self.state, "_with_direct_ctap2", return_value=mock_assertion), \
|
||||
patch.object(self.state.fido_server, "authenticate_complete"):
|
||||
ok, msg = self.state.authenticate_with_card("alice")
|
||||
|
||||
self.assertTrue(ok)
|
||||
self.assertEqual(msg, "assertion verified")
|
||||
|
||||
|
||||
# ── upstream pool ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestUpstreamPool(unittest.TestCase):
|
||||
def _pool(self):
|
||||
return UpstreamPool(
|
||||
server_base_url="http://127.0.0.1:19999",
|
||||
server_ca_file=None,
|
||||
max_connections=2,
|
||||
)
|
||||
|
||||
def _mock_response(self, status, body, will_close=True):
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.read.return_value = body
|
||||
resp.will_close = will_close
|
||||
return resp
|
||||
|
||||
def test_successful_request_returns_status_and_parsed_json(self):
|
||||
pool = self._pool()
|
||||
conn = MagicMock()
|
||||
conn.getresponse.return_value = self._mock_response(200, b'{"ok": true, "value": 7}')
|
||||
with patch.object(pool, "_new_connection", return_value=conn):
|
||||
status, data = pool.request_json("/resource/counter", {"X-Proxy-Token": "tok"}, {})
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertEqual(data["value"], 7)
|
||||
|
||||
def test_non_200_status_is_returned_as_is(self):
|
||||
pool = self._pool()
|
||||
conn = MagicMock()
|
||||
conn.getresponse.return_value = self._mock_response(403, b'{"ok": false, "error": "forbidden"}')
|
||||
with patch.object(pool, "_new_connection", return_value=conn):
|
||||
status, data = pool.request_json("/test", {}, {})
|
||||
self.assertEqual(status, 403)
|
||||
self.assertFalse(data["ok"])
|
||||
|
||||
def test_oserror_returns_502(self):
|
||||
pool = self._pool()
|
||||
conn = MagicMock()
|
||||
conn.request.side_effect = OSError("connection refused")
|
||||
with patch.object(pool, "_new_connection", return_value=conn):
|
||||
status, data = pool.request_json("/test", {}, {})
|
||||
self.assertEqual(status, 502)
|
||||
self.assertIn("server unavailable", data["error"])
|
||||
|
||||
def test_empty_body_returns_empty_dict(self):
|
||||
pool = self._pool()
|
||||
conn = MagicMock()
|
||||
conn.getresponse.return_value = self._mock_response(200, b"")
|
||||
with patch.object(pool, "_new_connection", return_value=conn):
|
||||
status, data = pool.request_json("/test", {}, {})
|
||||
self.assertEqual(data, {})
|
||||
|
||||
def test_connection_reused_when_will_close_false(self):
|
||||
pool = self._pool()
|
||||
conn = MagicMock()
|
||||
conn.getresponse.return_value = self._mock_response(200, b'{"ok": true}', will_close=False)
|
||||
with patch.object(pool, "_new_connection", return_value=conn) as mock_new:
|
||||
pool.request_json("/test", {}, {})
|
||||
pool.request_json("/test", {}, {})
|
||||
self.assertEqual(mock_new.call_count, 1)
|
||||
self.assertEqual(conn.request.call_count, 2)
|
||||
|
||||
def test_connection_not_reused_when_will_close_true(self):
|
||||
pool = self._pool()
|
||||
conn = MagicMock()
|
||||
conn.getresponse.return_value = self._mock_response(200, b'{"ok": true}', will_close=True)
|
||||
with patch.object(pool, "_new_connection", return_value=conn) as mock_new:
|
||||
pool.request_json("/test", {}, {})
|
||||
pool.request_json("/test", {}, {})
|
||||
self.assertEqual(mock_new.call_count, 2)
|
||||
|
||||
|
||||
# ── HTTP handler integration tests ────────────────────────────────────────────
|
||||
|
||||
class ServerFixture(unittest.TestCase):
|
||||
"""Spins up a real ThreadingHTTPServer backed by a ProxyState with mocked
|
||||
card and upstream. Card auth and fetch_counter are patched per-test via
|
||||
patch.object(self.state, ...) or the _login() helper."""
|
||||
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.TemporaryDirectory()
|
||||
self.tmp_path = Path(self._tmpdir.name)
|
||||
self.state = _make_state(self.tmp_path)
|
||||
Handler.state = self.state
|
||||
self.server = ThreadingHTTPServer(("127.0.0.1", 0), Handler)
|
||||
self.port = self.server.server_address[1]
|
||||
self._thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
self._tmpdir.cleanup()
|
||||
|
||||
# ── request helpers ──
|
||||
|
||||
def _conn(self):
|
||||
return http.client.HTTPConnection("127.0.0.1", self.port, timeout=5)
|
||||
|
||||
def _get(self, path):
|
||||
conn = self._conn()
|
||||
try:
|
||||
conn.request("GET", path)
|
||||
resp = conn.getresponse()
|
||||
return resp.status, resp.read()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _get_json(self, path):
|
||||
status, body = self._get(path)
|
||||
return status, json.loads(body)
|
||||
|
||||
def _post(self, path, payload=None, token=None):
|
||||
conn = self._conn()
|
||||
try:
|
||||
body = json.dumps(payload or {}).encode()
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Content-Length": str(len(body)),
|
||||
}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
conn.request("POST", path, body=body, headers=headers)
|
||||
resp = conn.getresponse()
|
||||
return resp.status, json.loads(resp.read())
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _post_raw(self, path, raw_body):
|
||||
conn = self._conn()
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Content-Length": str(len(raw_body)),
|
||||
}
|
||||
conn.request("POST", path, body=raw_body, headers=headers)
|
||||
resp = conn.getresponse()
|
||||
return resp.status, resp.read()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ── state helpers ──
|
||||
|
||||
def _enroll(self, username="alice", display_name=None):
|
||||
self.state.register_enrollment(username, display_name)
|
||||
|
||||
def _login(self, username="alice"):
|
||||
"""Enroll user and obtain a session token with the card mocked to succeed."""
|
||||
self._enroll(username)
|
||||
with patch.object(self.state, "authenticate_with_card", return_value=(True, "ok")):
|
||||
status, data = self._post("/session/login", {"username": username})
|
||||
self.assertEqual(status, 200, f"login setup failed: {data}")
|
||||
return data["session_token"]
|
||||
|
||||
|
||||
class TestHandlerHealth(ServerFixture):
|
||||
def test_get_root_returns_html(self):
|
||||
status, body = self._get("/")
|
||||
self.assertEqual(status, 200)
|
||||
self.assertIn(b"ChromeCard", body)
|
||||
|
||||
def test_health_returns_service_info(self):
|
||||
status, data = self._get_json("/health")
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertEqual(data["service"], "k_proxy")
|
||||
self.assertIn("active_sessions", data)
|
||||
|
||||
def test_health_reflects_active_session_count(self):
|
||||
self.state.create_session("alice")
|
||||
_, data = self._get_json("/health")
|
||||
self.assertEqual(data["active_sessions"], 1)
|
||||
|
||||
def test_unknown_get_returns_404(self):
|
||||
status, _ = self._get("/nonexistent")
|
||||
self.assertEqual(status, 404)
|
||||
|
||||
def test_unknown_post_returns_404(self):
|
||||
status, _ = self._post_raw("/nonexistent", b"{}")
|
||||
self.assertEqual(status, 404)
|
||||
|
||||
|
||||
class TestHandlerEnrollment(ServerFixture):
|
||||
def test_register_new_user_returns_200(self):
|
||||
status, data = self._post("/enroll/register", {"username": "alice", "display_name": "Alice"})
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertEqual(data["username"], "alice")
|
||||
self.assertEqual(data["display_name"], "Alice")
|
||||
|
||||
def test_register_duplicate_returns_409(self):
|
||||
self._enroll("alice")
|
||||
status, data = self._post("/enroll/register", {"username": "alice"})
|
||||
self.assertEqual(status, 409)
|
||||
self.assertFalse(data["ok"])
|
||||
|
||||
def test_register_invalid_username_returns_400(self):
|
||||
status, data = self._post("/enroll/register", {"username": "A!"})
|
||||
self.assertEqual(status, 400)
|
||||
self.assertFalse(data["ok"])
|
||||
|
||||
def test_register_invalid_json_returns_400(self):
|
||||
status, _ = self._post_raw("/enroll/register", b"not-json")
|
||||
self.assertEqual(status, 400)
|
||||
|
||||
def test_enroll_status_found(self):
|
||||
self._enroll("alice", "Alice Smith")
|
||||
status, data = self._get_json("/enroll/status?username=alice")
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertEqual(data["display_name"], "Alice Smith")
|
||||
|
||||
def test_enroll_status_not_found_returns_404(self):
|
||||
status, data = self._get_json("/enroll/status?username=nobody")
|
||||
self.assertEqual(status, 404)
|
||||
|
||||
def test_enroll_status_missing_param_returns_400(self):
|
||||
status, data = self._get_json("/enroll/status")
|
||||
self.assertEqual(status, 400)
|
||||
|
||||
def test_enroll_list_empty(self):
|
||||
status, data = self._get_json("/enroll/list")
|
||||
self.assertEqual(status, 200)
|
||||
self.assertEqual(data["users"], [])
|
||||
|
||||
def test_enroll_list_returns_sorted_users(self):
|
||||
self._enroll("charlie")
|
||||
self._enroll("alice")
|
||||
_, data = self._get_json("/enroll/list")
|
||||
names = [u["username"] for u in data["users"]]
|
||||
self.assertEqual(names, ["alice", "charlie"])
|
||||
|
||||
def test_enroll_update_changes_display_name(self):
|
||||
self._enroll("alice", "Old")
|
||||
status, data = self._post("/enroll/update", {"username": "alice", "display_name": "New"})
|
||||
self.assertEqual(status, 200)
|
||||
self.assertEqual(data["display_name"], "New")
|
||||
|
||||
def test_enroll_update_unknown_returns_404(self):
|
||||
status, _ = self._post("/enroll/update", {"username": "nobody"})
|
||||
self.assertEqual(status, 404)
|
||||
|
||||
def test_enroll_delete_returns_200_and_deleted_true(self):
|
||||
self._enroll("alice")
|
||||
status, data = self._post("/enroll/delete", {"username": "alice"})
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["deleted"])
|
||||
self.assertFalse(self.state.has_enrollment("alice"))
|
||||
|
||||
def test_enroll_delete_unknown_returns_404(self):
|
||||
status, _ = self._post("/enroll/delete", {"username": "nobody"})
|
||||
self.assertEqual(status, 404)
|
||||
|
||||
|
||||
class TestHandlerSession(ServerFixture):
|
||||
def test_login_success_returns_token(self):
|
||||
self._enroll("alice")
|
||||
with patch.object(self.state, "authenticate_with_card", return_value=(True, "ok")):
|
||||
status, data = self._post("/session/login", {"username": "alice"})
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertIn("session_token", data)
|
||||
self.assertIn("expires_at", data)
|
||||
self.assertEqual(data["auth_mode"], "card_presence_probe")
|
||||
|
||||
def test_login_unenrolled_user_returns_403(self):
|
||||
status, data = self._post("/session/login", {"username": "nobody"})
|
||||
self.assertEqual(status, 403)
|
||||
self.assertFalse(data["ok"])
|
||||
self.assertIn("not enrolled", data["error"])
|
||||
|
||||
def test_login_card_failure_returns_401(self):
|
||||
self._enroll("alice")
|
||||
with patch.object(self.state, "authenticate_with_card", return_value=(False, "No CTAP devices")):
|
||||
status, data = self._post("/session/login", {"username": "alice"})
|
||||
self.assertEqual(status, 401)
|
||||
self.assertFalse(data["ok"])
|
||||
self.assertIn("card auth failed", data["error"])
|
||||
self.assertIn("No CTAP devices", data["details"])
|
||||
|
||||
def test_login_invalid_username_returns_400(self):
|
||||
status, data = self._post("/session/login", {"username": "!bad!"})
|
||||
self.assertEqual(status, 400)
|
||||
|
||||
def test_session_status_valid_token(self):
|
||||
token = self._login()
|
||||
status, data = self._post("/session/status", {}, token=token)
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertEqual(data["username"], "alice")
|
||||
self.assertIn("expires_at", data)
|
||||
self.assertGreaterEqual(data["seconds_remaining"], 0)
|
||||
|
||||
def test_session_status_no_token_returns_401(self):
|
||||
status, data = self._post("/session/status", {})
|
||||
self.assertEqual(status, 401)
|
||||
|
||||
def test_session_status_invalid_token_returns_401(self):
|
||||
status, data = self._post("/session/status", {}, token="bad-token")
|
||||
self.assertEqual(status, 401)
|
||||
self.assertIn("invalid or expired", data["error"])
|
||||
|
||||
def test_logout_valid_token(self):
|
||||
token = self._login()
|
||||
status, data = self._post("/session/logout", {}, token=token)
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertTrue(data["invalidated"])
|
||||
self.assertIsNone(self.state.get_session(token))
|
||||
|
||||
def test_logout_invalid_token_returns_200_not_invalidated(self):
|
||||
status, data = self._post("/session/logout", {}, token="ghost")
|
||||
self.assertEqual(status, 200)
|
||||
self.assertFalse(data["invalidated"])
|
||||
|
||||
def test_logout_no_token_returns_401(self):
|
||||
status, data = self._post("/session/logout", {})
|
||||
self.assertEqual(status, 401)
|
||||
|
||||
def test_session_invalid_after_logout(self):
|
||||
token = self._login()
|
||||
self._post("/session/logout", {}, token=token)
|
||||
status, data = self._post("/session/status", {}, token=token)
|
||||
self.assertEqual(status, 401)
|
||||
|
||||
def test_multiple_sessions_independent(self):
|
||||
t1 = self._login("alice")
|
||||
t2 = self._login("bob")
|
||||
# logout alice, bob's session still valid
|
||||
self._post("/session/logout", {}, token=t1)
|
||||
status, data = self._post("/session/status", {}, token=t2)
|
||||
self.assertEqual(status, 200)
|
||||
self.assertEqual(data["username"], "bob")
|
||||
|
||||
|
||||
class TestHandlerResource(ServerFixture):
|
||||
def test_counter_with_valid_session(self):
|
||||
token = self._login()
|
||||
with patch.object(self.state, "fetch_counter", return_value=(200, {"ok": True, "value": 5})):
|
||||
status, data = self._post("/resource/counter", {}, token=token)
|
||||
self.assertEqual(status, 200)
|
||||
self.assertTrue(data["ok"])
|
||||
self.assertEqual(data["upstream"]["value"], 5)
|
||||
self.assertEqual(data["username"], "alice")
|
||||
self.assertTrue(data["session_reused"])
|
||||
|
||||
def test_counter_no_token_returns_401(self):
|
||||
status, data = self._post("/resource/counter", {})
|
||||
self.assertEqual(status, 401)
|
||||
|
||||
def test_counter_invalid_token_returns_401(self):
|
||||
status, data = self._post("/resource/counter", {}, token="garbage")
|
||||
self.assertEqual(status, 401)
|
||||
|
||||
def test_counter_upstream_failure_propagated(self):
|
||||
token = self._login()
|
||||
with patch.object(self.state, "fetch_counter", return_value=(502, {"ok": False, "error": "server unavailable"})):
|
||||
status, data = self._post("/resource/counter", {}, token=token)
|
||||
self.assertEqual(status, 502)
|
||||
self.assertFalse(data["ok"])
|
||||
self.assertIn("upstream failed", data["error"])
|
||||
|
||||
def test_counter_returns_upstream_non_200_as_error(self):
|
||||
token = self._login()
|
||||
with patch.object(self.state, "fetch_counter", return_value=(403, {"ok": False, "error": "forbidden"})):
|
||||
status, data = self._post("/resource/counter", {}, token=token)
|
||||
self.assertEqual(status, 403)
|
||||
self.assertFalse(data["ok"])
|
||||
|
||||
def test_counter_session_still_valid_after_call(self):
|
||||
token = self._login()
|
||||
with patch.object(self.state, "fetch_counter", return_value=(200, {"ok": True, "value": 1})):
|
||||
self._post("/resource/counter", {}, token=token)
|
||||
status, _ = self._post("/session/status", {}, token=token)
|
||||
self.assertEqual(status, 200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Loading…
Reference in New Issue