1049 lines
42 KiB
Python
1049 lines
42 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unit tests for k_proxy_app.py.
|
|
|
|
Card (FIDO2/CTAP) and k_server (UpstreamPool) are mocked throughout.
|
|
All tests run locally without any Qubes VMs or attached hardware.
|
|
"""
|
|
|
|
import http.client
|
|
import json
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
import unittest
|
|
from http.server import ThreadingHTTPServer
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
import k_proxy_app as app
|
|
from k_proxy_app import (
|
|
AUTH_MODE_FIDO2_DIRECT,
|
|
AUTH_MODE_PROBE,
|
|
Enrollment,
|
|
Handler,
|
|
ProxyState,
|
|
UpstreamPool,
|
|
b64u_decode,
|
|
b64u_encode,
|
|
enrollment_payload,
|
|
normalize_display_name,
|
|
normalize_username,
|
|
)
|
|
|
|
|
|
# ── test helpers ──────────────────────────────────────────────────────────────
|
|
|
|
def _make_state(tmp_path, *, auth_mode=AUTH_MODE_PROBE, session_ttl=300):
|
|
return ProxyState(
|
|
session_ttl_s=session_ttl,
|
|
auth_mode=auth_mode,
|
|
auth_command="echo ok",
|
|
server_base_url="http://127.0.0.1:19999",
|
|
server_ca_file=None,
|
|
server_max_connections=1,
|
|
proxy_token="test-token",
|
|
enrollment_db=tmp_path / "enrollments.json",
|
|
rp_id="localhost",
|
|
rp_name="Test RP",
|
|
origin="https://localhost",
|
|
direct_device_path="",
|
|
)
|
|
|
|
|
|
def _enrollment(username="alice", display_name=None, *, credential_data_b64=None):
|
|
now = int(time.time())
|
|
return Enrollment(
|
|
username=username,
|
|
display_name=display_name,
|
|
created_at=now,
|
|
updated_at=now,
|
|
credential_data_b64=credential_data_b64,
|
|
)
|
|
|
|
|
|
# ── pure function tests ───────────────────────────────────────────────────────
|
|
|
|
class TestNormalizeUsername(unittest.TestCase):
|
|
def test_simple_valid(self):
|
|
self.assertEqual(normalize_username("alice"), "alice")
|
|
|
|
def test_strips_and_lowercases(self):
|
|
self.assertEqual(normalize_username(" Alice "), "alice")
|
|
|
|
def test_valid_with_dots_dashes_underscores(self):
|
|
for name in ("alice.smith", "alice-smith", "alice_smith", "a1b"):
|
|
with self.subTest(name=name):
|
|
self.assertEqual(normalize_username(name), name)
|
|
|
|
def test_too_short_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
normalize_username("ab")
|
|
|
|
def test_too_long_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
normalize_username("a" * 33)
|
|
|
|
def test_invalid_chars_raise(self):
|
|
for bad in ("Alice!", "al ice", "al@ice", "AB"):
|
|
with self.subTest(bad=bad):
|
|
with self.assertRaises(ValueError):
|
|
normalize_username(bad)
|
|
|
|
def test_minimum_length_valid(self):
|
|
self.assertEqual(normalize_username("abc"), "abc")
|
|
|
|
def test_maximum_length_valid(self):
|
|
self.assertEqual(normalize_username("a" * 32), "a" * 32)
|
|
|
|
|
|
class TestNormalizeDisplayName(unittest.TestCase):
|
|
def test_none_returns_none(self):
|
|
self.assertIsNone(normalize_display_name(None))
|
|
|
|
def test_whitespace_only_returns_none(self):
|
|
self.assertIsNone(normalize_display_name(" "))
|
|
|
|
def test_strips_whitespace(self):
|
|
self.assertEqual(normalize_display_name(" Alice Smith "), "Alice Smith")
|
|
|
|
def test_max_length_accepted(self):
|
|
self.assertEqual(normalize_display_name("a" * 64), "a" * 64)
|
|
|
|
def test_over_max_length_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
normalize_display_name("a" * 65)
|
|
|
|
|
|
class TestBase64Utils(unittest.TestCase):
|
|
def test_round_trip(self):
|
|
original = b"\x00\x01\x02\xffsome\xffbinary"
|
|
self.assertEqual(b64u_decode(b64u_encode(original)), original)
|
|
|
|
def test_no_padding_chars_in_output(self):
|
|
encoded = b64u_encode(b"x")
|
|
self.assertNotIn("=", encoded)
|
|
|
|
def test_decode_handles_missing_padding(self):
|
|
encoded = b64u_encode(b"hello")
|
|
self.assertEqual(b64u_decode(encoded), b"hello")
|
|
|
|
|
|
class TestEnrollmentPayload(unittest.TestCase):
|
|
def test_basic_fields(self):
|
|
e = _enrollment("alice", "Alice Smith")
|
|
payload = enrollment_payload(e)
|
|
self.assertTrue(payload["ok"])
|
|
self.assertEqual(payload["username"], "alice")
|
|
self.assertEqual(payload["display_name"], "Alice Smith")
|
|
self.assertFalse(payload["has_credential"])
|
|
|
|
def test_has_credential_true_when_data_present(self):
|
|
e = _enrollment(credential_data_b64="abc")
|
|
self.assertTrue(enrollment_payload(e)["has_credential"])
|
|
|
|
def test_created_flag_included_when_given(self):
|
|
e = _enrollment()
|
|
self.assertIn("created", enrollment_payload(e, created=True))
|
|
self.assertNotIn("created", enrollment_payload(e))
|
|
|
|
|
|
# ── session management ────────────────────────────────────────────────────────
|
|
|
|
class TestSessionManagement(unittest.TestCase):
|
|
def setUp(self):
|
|
self._tmpdir = tempfile.TemporaryDirectory()
|
|
self.state = _make_state(Path(self._tmpdir.name))
|
|
|
|
def tearDown(self):
|
|
self._tmpdir.cleanup()
|
|
|
|
def test_create_returns_token_and_future_expiry(self):
|
|
token, expires_at = self.state.create_session("alice")
|
|
self.assertIsInstance(token, str)
|
|
self.assertGreater(len(token), 16)
|
|
self.assertGreater(expires_at, time.time())
|
|
|
|
def test_get_session_returns_correct_username(self):
|
|
token, _ = self.state.create_session("alice")
|
|
session = self.state.get_session(token)
|
|
self.assertIsNotNone(session)
|
|
self.assertEqual(session.username, "alice")
|
|
|
|
def test_get_session_unknown_token_returns_none(self):
|
|
self.assertIsNone(self.state.get_session("not-a-real-token"))
|
|
|
|
def test_expired_session_returns_none(self):
|
|
state = _make_state(Path(self._tmpdir.name), session_ttl=-1)
|
|
token, _ = state.create_session("alice")
|
|
self.assertIsNone(state.get_session(token))
|
|
|
|
def test_invalidate_session_removes_it(self):
|
|
token, _ = self.state.create_session("alice")
|
|
self.assertTrue(self.state.invalidate_session(token))
|
|
self.assertIsNone(self.state.get_session(token))
|
|
|
|
def test_invalidate_unknown_token_returns_false(self):
|
|
self.assertFalse(self.state.invalidate_session("ghost"))
|
|
|
|
def test_active_session_count_tracks_correctly(self):
|
|
self.assertEqual(self.state.active_session_count(), 0)
|
|
t1, _ = self.state.create_session("alice")
|
|
t2, _ = self.state.create_session("bob")
|
|
self.assertEqual(self.state.active_session_count(), 2)
|
|
self.state.invalidate_session(t1)
|
|
self.assertEqual(self.state.active_session_count(), 1)
|
|
|
|
def test_expired_sessions_garbage_collected(self):
|
|
state = _make_state(Path(self._tmpdir.name), session_ttl=-1)
|
|
state.create_session("alice")
|
|
state.create_session("bob")
|
|
self.assertEqual(state.active_session_count(), 0)
|
|
|
|
def test_tokens_are_unique(self):
|
|
tokens = {self.state.create_session("alice")[0] for _ in range(20)}
|
|
self.assertEqual(len(tokens), 20)
|
|
|
|
def test_uses_direct_fido2_false_in_probe_mode(self):
|
|
self.assertFalse(self.state.uses_direct_fido2())
|
|
|
|
def test_uses_direct_fido2_true_in_direct_mode(self):
|
|
state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
|
self.assertTrue(state.uses_direct_fido2())
|
|
|
|
def test_auth_mode_label_probe(self):
|
|
self.assertEqual(self.state.auth_mode_label(), "card_presence_probe")
|
|
|
|
def test_auth_mode_label_direct(self):
|
|
state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
|
self.assertEqual(state.auth_mode_label(), "fido2_assertion")
|
|
|
|
|
|
# ── enrollment management ─────────────────────────────────────────────────────
|
|
|
|
class TestEnrollmentManagement(unittest.TestCase):
|
|
def setUp(self):
|
|
self._tmpdir = tempfile.TemporaryDirectory()
|
|
self.tmp_path = Path(self._tmpdir.name)
|
|
self.state = _make_state(self.tmp_path)
|
|
|
|
def tearDown(self):
|
|
self._tmpdir.cleanup()
|
|
|
|
def test_register_creates_enrollment(self):
|
|
e = self.state.register_enrollment("alice", "Alice Smith")
|
|
self.assertEqual(e.username, "alice")
|
|
self.assertEqual(e.display_name, "Alice Smith")
|
|
self.assertTrue(self.state.has_enrollment("alice"))
|
|
|
|
def test_register_persists_across_state_reload(self):
|
|
self.state.register_enrollment("alice", None)
|
|
state2 = _make_state(self.tmp_path)
|
|
self.assertTrue(state2.has_enrollment("alice"))
|
|
|
|
def test_register_duplicate_raises_file_exists_error(self):
|
|
self.state.register_enrollment("alice", None)
|
|
with self.assertRaises(FileExistsError):
|
|
self.state.register_enrollment("alice", None)
|
|
|
|
def test_register_invalid_username_raises_value_error(self):
|
|
with self.assertRaises(ValueError):
|
|
self.state.register_enrollment("A!", None)
|
|
|
|
def test_register_display_name_too_long_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
self.state.register_enrollment("alice", "x" * 65)
|
|
|
|
def test_update_changes_display_name(self):
|
|
self.state.register_enrollment("alice", "Old")
|
|
updated = self.state.update_enrollment("alice", "New")
|
|
self.assertEqual(updated.display_name, "New")
|
|
self.assertEqual(self.state.get_enrollment("alice").display_name, "New")
|
|
|
|
def test_update_unknown_user_raises_key_error(self):
|
|
with self.assertRaises(KeyError):
|
|
self.state.update_enrollment("nobody", "Name")
|
|
|
|
def test_delete_removes_enrollment(self):
|
|
self.state.register_enrollment("alice", None)
|
|
self.state.delete_enrollment("alice")
|
|
self.assertFalse(self.state.has_enrollment("alice"))
|
|
|
|
def test_delete_invalidates_active_sessions(self):
|
|
self.state.register_enrollment("alice", None)
|
|
token, _ = self.state.create_session("alice")
|
|
self.state.delete_enrollment("alice")
|
|
self.assertIsNone(self.state.get_session(token))
|
|
|
|
def test_delete_does_not_affect_other_users_sessions(self):
|
|
self.state.register_enrollment("alice", None)
|
|
self.state.register_enrollment("bob", None)
|
|
bob_token, _ = self.state.create_session("bob")
|
|
self.state.delete_enrollment("alice")
|
|
self.assertIsNotNone(self.state.get_session(bob_token))
|
|
|
|
def test_delete_unknown_user_raises_key_error(self):
|
|
with self.assertRaises(KeyError):
|
|
self.state.delete_enrollment("nobody")
|
|
|
|
def test_list_enrollments_sorted_alphabetically(self):
|
|
self.state.register_enrollment("charlie", None)
|
|
self.state.register_enrollment("alice", None)
|
|
self.state.register_enrollment("bob", None)
|
|
names = [e.username for e in self.state.list_enrollments()]
|
|
self.assertEqual(names, ["alice", "bob", "charlie"])
|
|
|
|
def test_get_enrollment_found(self):
|
|
self.state.register_enrollment("alice", "Alice")
|
|
e = self.state.get_enrollment("alice")
|
|
self.assertIsNotNone(e)
|
|
self.assertEqual(e.username, "alice")
|
|
|
|
def test_get_enrollment_not_found_returns_none(self):
|
|
self.assertIsNone(self.state.get_enrollment("nobody"))
|
|
|
|
def test_get_enrollment_invalid_username_returns_none(self):
|
|
self.assertIsNone(self.state.get_enrollment("!bad!"))
|
|
|
|
def test_has_enrollment_true(self):
|
|
self.state.register_enrollment("alice", None)
|
|
self.assertTrue(self.state.has_enrollment("alice"))
|
|
|
|
def test_has_enrollment_false(self):
|
|
self.assertFalse(self.state.has_enrollment("nobody"))
|
|
|
|
def test_register_direct_mode_delegates_to_direct_method(self):
|
|
state = _make_state(self.tmp_path, auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
|
fake = _enrollment("alice", credential_data_b64="cred")
|
|
with patch.object(state, "_register_direct_fido2", return_value=fake) as mock_direct:
|
|
result = state.register_enrollment("alice", None)
|
|
mock_direct.assert_called_once_with("alice", None)
|
|
self.assertEqual(result.username, "alice")
|
|
|
|
|
|
# ── authentication ────────────────────────────────────────────────────────────
|
|
|
|
class TestProbeAuth(unittest.TestCase):
|
|
def setUp(self):
|
|
self._tmpdir = tempfile.TemporaryDirectory()
|
|
self.state = _make_state(Path(self._tmpdir.name))
|
|
|
|
def tearDown(self):
|
|
self._tmpdir.cleanup()
|
|
|
|
def _mock_proc(self, returncode, stdout="", stderr=""):
|
|
proc = MagicMock()
|
|
proc.returncode = returncode
|
|
proc.stdout = stdout
|
|
proc.stderr = stderr
|
|
return proc
|
|
|
|
def test_success_when_subprocess_returns_zero(self):
|
|
with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(0, '{"ok": true}')):
|
|
ok, _ = self.state.authenticate_with_card("alice")
|
|
self.assertTrue(ok)
|
|
|
|
def test_failure_when_subprocess_returns_nonzero(self):
|
|
with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(1, stderr="No CTAP HID devices")):
|
|
ok, msg = self.state.authenticate_with_card("alice")
|
|
self.assertFalse(ok)
|
|
self.assertIn("No CTAP HID devices", msg)
|
|
|
|
def test_failure_uses_stdout_when_stderr_empty(self):
|
|
with patch("k_proxy_app.subprocess.run", return_value=self._mock_proc(2, stdout="probe failed")):
|
|
ok, msg = self.state.authenticate_with_card("alice")
|
|
self.assertFalse(ok)
|
|
self.assertIn("probe failed", msg)
|
|
|
|
def test_failure_when_subprocess_raises(self):
|
|
with patch("k_proxy_app.subprocess.run", side_effect=TimeoutError("timed out")):
|
|
ok, msg = self.state.authenticate_with_card("alice")
|
|
self.assertFalse(ok)
|
|
self.assertIn("auth command failed", msg)
|
|
|
|
|
|
class TestDirectFido2Auth(unittest.TestCase):
|
|
def setUp(self):
|
|
self._tmpdir = tempfile.TemporaryDirectory()
|
|
self.state = _make_state(Path(self._tmpdir.name), auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
|
|
|
def tearDown(self):
|
|
self._tmpdir.cleanup()
|
|
|
|
def test_unenrolled_user_returns_false(self):
|
|
ok, msg = self.state.authenticate_with_card("nobody")
|
|
self.assertFalse(ok)
|
|
self.assertEqual(msg, "user not enrolled")
|
|
|
|
def test_enrolled_without_credential_returns_false(self):
|
|
self.state.enrollments["alice"] = _enrollment("alice")
|
|
ok, msg = self.state.authenticate_with_card("alice")
|
|
self.assertFalse(ok)
|
|
self.assertEqual(msg, "user has no registered credential")
|
|
|
|
def test_exception_from_ctap_returns_false_with_message(self):
|
|
self.state.enrollments["alice"] = _enrollment("alice", credential_data_b64="dW5pY29kZQ")
|
|
with patch("k_proxy_app.AttestedCredentialData", side_effect=Exception("bad cbor")):
|
|
ok, msg = self.state.authenticate_with_card("alice")
|
|
self.assertFalse(ok)
|
|
self.assertIn("assertion verification failed", msg)
|
|
|
|
def test_success_path_with_mocked_internals(self):
|
|
self.state.enrollments["alice"] = _enrollment("alice", credential_data_b64=b64u_encode(b"fake_cred"))
|
|
|
|
mock_cred = MagicMock()
|
|
mock_options = MagicMock()
|
|
mock_options.public_key.rp_id = "localhost"
|
|
mock_options.public_key.allow_credentials = []
|
|
mock_options.public_key.challenge = b"challenge"
|
|
mock_client_data = MagicMock()
|
|
mock_client_data.hash = b"hash"
|
|
mock_assertion = MagicMock()
|
|
mock_assertion.assertions = None
|
|
mock_assertion.credential = {"id": b"cred_id"}
|
|
mock_assertion.auth_data = b"auth"
|
|
mock_assertion.signature = b"sig"
|
|
mock_assertion.user = None
|
|
|
|
with patch("k_proxy_app.AttestedCredentialData", return_value=mock_cred), \
|
|
patch("k_proxy_app.AuthenticationResponse", return_value=MagicMock()), \
|
|
patch("k_proxy_app.AuthenticatorAssertionResponse", return_value=MagicMock()), \
|
|
patch.object(self.state, "_drop_direct_device"), \
|
|
patch.object(self.state.fido_server, "authenticate_begin", return_value=(mock_options, {})), \
|
|
patch.object(self.state, "_collect_client_data", return_value=mock_client_data), \
|
|
patch.object(self.state, "_with_direct_ctap2", return_value=mock_assertion), \
|
|
patch.object(self.state.fido_server, "authenticate_complete"):
|
|
ok, msg = self.state.authenticate_with_card("alice")
|
|
|
|
self.assertTrue(ok)
|
|
self.assertEqual(msg, "assertion verified")
|
|
|
|
|
|
# ── upstream pool ─────────────────────────────────────────────────────────────
|
|
|
|
class TestUpstreamPool(unittest.TestCase):
|
|
def _pool(self):
|
|
return UpstreamPool(
|
|
server_base_url="http://127.0.0.1:19999",
|
|
server_ca_file=None,
|
|
max_connections=2,
|
|
)
|
|
|
|
def _mock_response(self, status, body, will_close=True):
|
|
resp = MagicMock()
|
|
resp.status = status
|
|
resp.read.return_value = body
|
|
resp.will_close = will_close
|
|
return resp
|
|
|
|
def test_successful_request_returns_status_and_parsed_json(self):
|
|
pool = self._pool()
|
|
conn = MagicMock()
|
|
conn.getresponse.return_value = self._mock_response(200, b'{"ok": true, "value": 7}')
|
|
with patch.object(pool, "_new_connection", return_value=conn):
|
|
status, data = pool.request_json("/resource/counter", {"X-Proxy-Token": "tok"}, {})
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["value"], 7)
|
|
|
|
def test_non_200_status_is_returned_as_is(self):
|
|
pool = self._pool()
|
|
conn = MagicMock()
|
|
conn.getresponse.return_value = self._mock_response(403, b'{"ok": false, "error": "forbidden"}')
|
|
with patch.object(pool, "_new_connection", return_value=conn):
|
|
status, data = pool.request_json("/test", {}, {})
|
|
self.assertEqual(status, 403)
|
|
self.assertFalse(data["ok"])
|
|
|
|
def test_oserror_returns_502(self):
|
|
pool = self._pool()
|
|
conn = MagicMock()
|
|
conn.request.side_effect = OSError("connection refused")
|
|
with patch.object(pool, "_new_connection", return_value=conn):
|
|
status, data = pool.request_json("/test", {}, {})
|
|
self.assertEqual(status, 502)
|
|
self.assertIn("server unavailable", data["error"])
|
|
|
|
def test_empty_body_returns_empty_dict(self):
|
|
pool = self._pool()
|
|
conn = MagicMock()
|
|
conn.getresponse.return_value = self._mock_response(200, b"")
|
|
with patch.object(pool, "_new_connection", return_value=conn):
|
|
status, data = pool.request_json("/test", {}, {})
|
|
self.assertEqual(data, {})
|
|
|
|
def test_connection_reused_when_will_close_false(self):
|
|
pool = self._pool()
|
|
conn = MagicMock()
|
|
conn.getresponse.return_value = self._mock_response(200, b'{"ok": true}', will_close=False)
|
|
with patch.object(pool, "_new_connection", return_value=conn) as mock_new:
|
|
pool.request_json("/test", {}, {})
|
|
pool.request_json("/test", {}, {})
|
|
self.assertEqual(mock_new.call_count, 1)
|
|
self.assertEqual(conn.request.call_count, 2)
|
|
|
|
def test_connection_not_reused_when_will_close_true(self):
|
|
pool = self._pool()
|
|
conn = MagicMock()
|
|
conn.getresponse.return_value = self._mock_response(200, b'{"ok": true}', will_close=True)
|
|
with patch.object(pool, "_new_connection", return_value=conn) as mock_new:
|
|
pool.request_json("/test", {}, {})
|
|
pool.request_json("/test", {}, {})
|
|
self.assertEqual(mock_new.call_count, 2)
|
|
|
|
|
|
# ── HTTP handler integration tests ────────────────────────────────────────────
|
|
|
|
class ServerFixture(unittest.TestCase):
|
|
"""Spins up a real ThreadingHTTPServer backed by a ProxyState with mocked
|
|
card and upstream. Card auth and fetch_counter are patched per-test via
|
|
patch.object(self.state, ...) or the _login() helper."""
|
|
|
|
def setUp(self):
|
|
self._tmpdir = tempfile.TemporaryDirectory()
|
|
self.tmp_path = Path(self._tmpdir.name)
|
|
self.state = _make_state(self.tmp_path)
|
|
Handler.state = self.state
|
|
self.server = ThreadingHTTPServer(("127.0.0.1", 0), Handler)
|
|
self.port = self.server.server_address[1]
|
|
self._thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
|
self._thread.start()
|
|
|
|
def tearDown(self):
|
|
self.server.shutdown()
|
|
self.server.server_close()
|
|
self._tmpdir.cleanup()
|
|
|
|
# ── request helpers ──
|
|
|
|
def _conn(self):
|
|
return http.client.HTTPConnection("127.0.0.1", self.port, timeout=5)
|
|
|
|
def _get(self, path):
|
|
conn = self._conn()
|
|
try:
|
|
conn.request("GET", path)
|
|
resp = conn.getresponse()
|
|
return resp.status, resp.read()
|
|
finally:
|
|
conn.close()
|
|
|
|
def _get_json(self, path):
|
|
status, body = self._get(path)
|
|
return status, json.loads(body)
|
|
|
|
def _post(self, path, payload=None, token=None):
|
|
conn = self._conn()
|
|
try:
|
|
body = json.dumps(payload or {}).encode()
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Content-Length": str(len(body)),
|
|
}
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
conn.request("POST", path, body=body, headers=headers)
|
|
resp = conn.getresponse()
|
|
return resp.status, json.loads(resp.read())
|
|
finally:
|
|
conn.close()
|
|
|
|
def _post_raw(self, path, raw_body):
|
|
conn = self._conn()
|
|
try:
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Content-Length": str(len(raw_body)),
|
|
}
|
|
conn.request("POST", path, body=raw_body, headers=headers)
|
|
resp = conn.getresponse()
|
|
return resp.status, resp.read()
|
|
finally:
|
|
conn.close()
|
|
|
|
# ── state helpers ──
|
|
|
|
def _enroll(self, username="alice", display_name=None):
|
|
self.state.register_enrollment(username, display_name)
|
|
|
|
def _login(self, username="alice"):
|
|
"""Enroll user and obtain a session token with the card mocked to succeed."""
|
|
self._enroll(username)
|
|
with patch.object(self.state, "authenticate_with_card", return_value=(True, "ok")):
|
|
status, data = self._post("/session/login", {"username": username})
|
|
self.assertEqual(status, 200, f"login setup failed: {data}")
|
|
return data["session_token"]
|
|
|
|
|
|
class TestHandlerHealth(ServerFixture):
|
|
def test_get_root_returns_html(self):
|
|
status, body = self._get("/")
|
|
self.assertEqual(status, 200)
|
|
self.assertIn(b"ChromeCard", body)
|
|
|
|
def test_health_returns_service_info(self):
|
|
status, data = self._get_json("/health")
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["service"], "k_proxy")
|
|
self.assertIn("active_sessions", data)
|
|
|
|
def test_health_reflects_active_session_count(self):
|
|
self.state.create_session("alice")
|
|
_, data = self._get_json("/health")
|
|
self.assertEqual(data["active_sessions"], 1)
|
|
|
|
def test_unknown_get_returns_404(self):
|
|
status, _ = self._get("/nonexistent")
|
|
self.assertEqual(status, 404)
|
|
|
|
def test_unknown_post_returns_404(self):
|
|
status, _ = self._post_raw("/nonexistent", b"{}")
|
|
self.assertEqual(status, 404)
|
|
|
|
|
|
class TestHandlerEnrollment(ServerFixture):
|
|
def test_register_new_user_returns_200(self):
|
|
status, data = self._post("/enroll/register", {"username": "alice", "display_name": "Alice"})
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["username"], "alice")
|
|
self.assertEqual(data["display_name"], "Alice")
|
|
|
|
def test_register_duplicate_returns_409(self):
|
|
self._enroll("alice")
|
|
status, data = self._post("/enroll/register", {"username": "alice"})
|
|
self.assertEqual(status, 409)
|
|
self.assertFalse(data["ok"])
|
|
|
|
def test_register_invalid_username_returns_400(self):
|
|
status, data = self._post("/enroll/register", {"username": "A!"})
|
|
self.assertEqual(status, 400)
|
|
self.assertFalse(data["ok"])
|
|
|
|
def test_register_invalid_json_returns_400(self):
|
|
status, _ = self._post_raw("/enroll/register", b"not-json")
|
|
self.assertEqual(status, 400)
|
|
|
|
def test_enroll_status_found(self):
|
|
self._enroll("alice", "Alice Smith")
|
|
status, data = self._get_json("/enroll/status?username=alice")
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["display_name"], "Alice Smith")
|
|
|
|
def test_enroll_status_not_found_returns_404(self):
|
|
status, data = self._get_json("/enroll/status?username=nobody")
|
|
self.assertEqual(status, 404)
|
|
|
|
def test_enroll_status_missing_param_returns_400(self):
|
|
status, data = self._get_json("/enroll/status")
|
|
self.assertEqual(status, 400)
|
|
|
|
def test_enroll_list_empty(self):
|
|
status, data = self._get_json("/enroll/list")
|
|
self.assertEqual(status, 200)
|
|
self.assertEqual(data["users"], [])
|
|
|
|
def test_enroll_list_returns_sorted_users(self):
|
|
self._enroll("charlie")
|
|
self._enroll("alice")
|
|
_, data = self._get_json("/enroll/list")
|
|
names = [u["username"] for u in data["users"]]
|
|
self.assertEqual(names, ["alice", "charlie"])
|
|
|
|
def test_enroll_update_changes_display_name(self):
|
|
self._enroll("alice", "Old")
|
|
status, data = self._post("/enroll/update", {"username": "alice", "display_name": "New"})
|
|
self.assertEqual(status, 200)
|
|
self.assertEqual(data["display_name"], "New")
|
|
|
|
def test_enroll_update_unknown_returns_404(self):
|
|
status, _ = self._post("/enroll/update", {"username": "nobody"})
|
|
self.assertEqual(status, 404)
|
|
|
|
def test_enroll_delete_returns_200_and_deleted_true(self):
|
|
self._enroll("alice")
|
|
status, data = self._post("/enroll/delete", {"username": "alice"})
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["deleted"])
|
|
self.assertFalse(self.state.has_enrollment("alice"))
|
|
|
|
def test_enroll_delete_unknown_returns_404(self):
|
|
status, _ = self._post("/enroll/delete", {"username": "nobody"})
|
|
self.assertEqual(status, 404)
|
|
|
|
|
|
class TestHandlerSession(ServerFixture):
|
|
def test_login_success_returns_token(self):
|
|
self._enroll("alice")
|
|
with patch.object(self.state, "authenticate_with_card", return_value=(True, "ok")):
|
|
status, data = self._post("/session/login", {"username": "alice"})
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertIn("session_token", data)
|
|
self.assertIn("expires_at", data)
|
|
self.assertEqual(data["auth_mode"], "card_presence_probe")
|
|
|
|
def test_login_unenrolled_user_returns_403(self):
|
|
status, data = self._post("/session/login", {"username": "nobody"})
|
|
self.assertEqual(status, 403)
|
|
self.assertFalse(data["ok"])
|
|
self.assertIn("not enrolled", data["error"])
|
|
|
|
def test_login_card_failure_returns_401(self):
|
|
self._enroll("alice")
|
|
with patch.object(self.state, "authenticate_with_card", return_value=(False, "No CTAP devices")):
|
|
status, data = self._post("/session/login", {"username": "alice"})
|
|
self.assertEqual(status, 401)
|
|
self.assertFalse(data["ok"])
|
|
self.assertIn("card auth failed", data["error"])
|
|
self.assertIn("No CTAP devices", data["details"])
|
|
|
|
def test_login_invalid_username_returns_400(self):
|
|
status, data = self._post("/session/login", {"username": "!bad!"})
|
|
self.assertEqual(status, 400)
|
|
|
|
def test_session_status_valid_token(self):
|
|
token = self._login()
|
|
status, data = self._post("/session/status", {}, token=token)
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["username"], "alice")
|
|
self.assertIn("expires_at", data)
|
|
self.assertGreaterEqual(data["seconds_remaining"], 0)
|
|
|
|
def test_session_status_no_token_returns_401(self):
|
|
status, data = self._post("/session/status", {})
|
|
self.assertEqual(status, 401)
|
|
|
|
def test_session_status_invalid_token_returns_401(self):
|
|
status, data = self._post("/session/status", {}, token="bad-token")
|
|
self.assertEqual(status, 401)
|
|
self.assertIn("invalid or expired", data["error"])
|
|
|
|
def test_logout_valid_token(self):
|
|
token = self._login()
|
|
status, data = self._post("/session/logout", {}, token=token)
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertTrue(data["invalidated"])
|
|
self.assertIsNone(self.state.get_session(token))
|
|
|
|
def test_logout_invalid_token_returns_200_not_invalidated(self):
|
|
status, data = self._post("/session/logout", {}, token="ghost")
|
|
self.assertEqual(status, 200)
|
|
self.assertFalse(data["invalidated"])
|
|
|
|
def test_logout_no_token_returns_401(self):
|
|
status, data = self._post("/session/logout", {})
|
|
self.assertEqual(status, 401)
|
|
|
|
def test_session_invalid_after_logout(self):
|
|
token = self._login()
|
|
self._post("/session/logout", {}, token=token)
|
|
status, data = self._post("/session/status", {}, token=token)
|
|
self.assertEqual(status, 401)
|
|
|
|
def test_multiple_sessions_independent(self):
|
|
t1 = self._login("alice")
|
|
t2 = self._login("bob")
|
|
# logout alice, bob's session still valid
|
|
self._post("/session/logout", {}, token=t1)
|
|
status, data = self._post("/session/status", {}, token=t2)
|
|
self.assertEqual(status, 200)
|
|
self.assertEqual(data["username"], "bob")
|
|
|
|
|
|
class TestHandlerResource(ServerFixture):
|
|
def test_counter_with_valid_session(self):
|
|
token = self._login()
|
|
with patch.object(self.state, "fetch_counter", return_value=(200, {"ok": True, "value": 5})):
|
|
status, data = self._post("/resource/counter", {}, token=token)
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["upstream"]["value"], 5)
|
|
self.assertEqual(data["username"], "alice")
|
|
self.assertTrue(data["session_reused"])
|
|
|
|
def test_counter_no_token_returns_401(self):
|
|
status, data = self._post("/resource/counter", {})
|
|
self.assertEqual(status, 401)
|
|
|
|
def test_counter_invalid_token_returns_401(self):
|
|
status, data = self._post("/resource/counter", {}, token="garbage")
|
|
self.assertEqual(status, 401)
|
|
|
|
def test_counter_upstream_failure_propagated(self):
|
|
token = self._login()
|
|
with patch.object(self.state, "fetch_counter", return_value=(502, {"ok": False, "error": "server unavailable"})):
|
|
status, data = self._post("/resource/counter", {}, token=token)
|
|
self.assertEqual(status, 502)
|
|
self.assertFalse(data["ok"])
|
|
self.assertIn("upstream failed", data["error"])
|
|
|
|
def test_counter_returns_upstream_non_200_as_error(self):
|
|
token = self._login()
|
|
with patch.object(self.state, "fetch_counter", return_value=(403, {"ok": False, "error": "forbidden"})):
|
|
status, data = self._post("/resource/counter", {}, token=token)
|
|
self.assertEqual(status, 403)
|
|
self.assertFalse(data["ok"])
|
|
|
|
def test_counter_session_still_valid_after_call(self):
|
|
token = self._login()
|
|
with patch.object(self.state, "fetch_counter", return_value=(200, {"ok": True, "value": 1})):
|
|
self._post("/resource/counter", {}, token=token)
|
|
status, _ = self._post("/session/status", {}, token=token)
|
|
self.assertEqual(status, 200)
|
|
|
|
|
|
# ── card emulator integration tests ──────────────────────────────────────────
|
|
|
|
from card_emulator import CardEmulator
|
|
|
|
|
|
def _make_direct_state(tmp_path):
|
|
return _make_state(tmp_path, auth_mode=AUTH_MODE_FIDO2_DIRECT)
|
|
|
|
|
|
def _patch_emulator(state, emulator):
|
|
"""Return a context manager that wires *emulator* into *state* as the card."""
|
|
return patch.multiple(
|
|
state,
|
|
_with_direct_ctap2=lambda fn: fn(emulator),
|
|
_drop_direct_device=lambda: None,
|
|
)
|
|
|
|
|
|
class TestCardEmulatorUnit(unittest.TestCase):
|
|
"""Direct calls to the emulator — no ProxyState involved."""
|
|
|
|
def setUp(self):
|
|
self.emulator = CardEmulator()
|
|
|
|
def _register(self, username="alice", rp_id="localhost"):
|
|
rp_id_hash = __import__("hashlib").sha256(rp_id.encode()).digest()
|
|
return self.emulator.make_credential(
|
|
client_data_hash=b"\x00" * 32,
|
|
rp={"id": rp_id, "name": "Test RP"},
|
|
user={"id": b"user-id", "name": username, "displayName": username},
|
|
key_params=[{"type": "public-key", "alg": -7}],
|
|
)
|
|
|
|
def test_make_credential_returns_none_attestation(self):
|
|
attest = self._register()
|
|
self.assertEqual(attest.fmt, "none")
|
|
self.assertEqual(attest.att_stmt, {})
|
|
|
|
def test_make_credential_stores_credential(self):
|
|
self._register()
|
|
self.assertEqual(self.emulator.credential_count(), 1)
|
|
|
|
def test_make_credential_auth_data_is_attested(self):
|
|
attest = self._register()
|
|
self.assertTrue(attest.auth_data.is_attested())
|
|
|
|
def test_make_credential_cred_id_is_32_bytes(self):
|
|
attest = self._register()
|
|
self.assertEqual(len(attest.auth_data.credential_data.credential_id), 32)
|
|
|
|
def test_make_credential_user_confirms_false_raises(self):
|
|
from fido2.ctap import CtapError
|
|
with self.assertRaises(CtapError) as ctx:
|
|
self._register() # first register so there's a credential
|
|
self.emulator.make_credential(
|
|
client_data_hash=b"\x00" * 32,
|
|
rp={"id": "localhost", "name": "Test RP"},
|
|
user={"id": b"user-id", "name": "bob", "displayName": "bob"},
|
|
key_params=[{"type": "public-key", "alg": -7}],
|
|
user_confirms=False,
|
|
)
|
|
self.assertEqual(ctx.exception.code, CtapError.ERR.OPERATION_DENIED)
|
|
|
|
def test_get_assertion_user_confirms_false_raises(self):
|
|
from fido2.ctap import CtapError
|
|
attest = self._register()
|
|
cred_id = attest.auth_data.credential_data.credential_id
|
|
with self.assertRaises(CtapError) as ctx:
|
|
self.emulator.get_assertion(
|
|
rp_id="localhost",
|
|
client_data_hash=b"\x01" * 32,
|
|
allow_list=[{"id": cred_id, "type": "public-key"}],
|
|
user_confirms=False,
|
|
)
|
|
self.assertEqual(ctx.exception.code, CtapError.ERR.OPERATION_DENIED)
|
|
|
|
def test_get_assertion_wrong_rp_raises(self):
|
|
from fido2.ctap import CtapError
|
|
attest = self._register(rp_id="localhost")
|
|
cred_id = attest.auth_data.credential_data.credential_id
|
|
with self.assertRaises(CtapError):
|
|
self.emulator.get_assertion(
|
|
rp_id="evil.example",
|
|
client_data_hash=b"\x01" * 32,
|
|
allow_list=[{"id": cred_id, "type": "public-key"}],
|
|
)
|
|
|
|
def test_get_assertion_empty_allow_list_raises(self):
|
|
from fido2.ctap import CtapError
|
|
self._register()
|
|
with self.assertRaises(CtapError):
|
|
self.emulator.get_assertion(
|
|
rp_id="localhost",
|
|
client_data_hash=b"\x01" * 32,
|
|
allow_list=None,
|
|
)
|
|
|
|
def test_sign_count_increments_across_assertions(self):
|
|
import struct
|
|
attest = self._register()
|
|
cred_id = attest.auth_data.credential_data.credential_id
|
|
|
|
def _count(assertion):
|
|
return struct.unpack(">I", bytes(assertion.auth_data)[33:37])[0]
|
|
|
|
a1 = self.emulator.get_assertion("localhost", b"\x01" * 32,
|
|
[{"id": cred_id, "type": "public-key"}])
|
|
a2 = self.emulator.get_assertion("localhost", b"\x02" * 32,
|
|
[{"id": cred_id, "type": "public-key"}])
|
|
self.assertGreater(_count(a2), _count(a1))
|
|
|
|
def test_forget_user_removes_credential(self):
|
|
self._register()
|
|
removed = self.emulator.forget_user("alice")
|
|
self.assertEqual(removed, 1)
|
|
self.assertEqual(self.emulator.credential_count(), 0)
|
|
|
|
def test_forget_unknown_user_returns_zero(self):
|
|
self._register()
|
|
self.assertEqual(self.emulator.forget_user("nobody"), 0)
|
|
self.assertEqual(self.emulator.credential_count(), 1)
|
|
|
|
def test_refusing_view_make_credential_raises(self):
|
|
from fido2.ctap import CtapError
|
|
with self.assertRaises(CtapError) as ctx:
|
|
self.emulator.refusing().make_credential(
|
|
client_data_hash=b"\x00" * 32,
|
|
rp={"id": "localhost", "name": "Test RP"},
|
|
user={"id": b"u", "name": "alice", "displayName": "Alice"},
|
|
key_params=[{"type": "public-key", "alg": -7}],
|
|
)
|
|
self.assertEqual(ctx.exception.code, CtapError.ERR.OPERATION_DENIED)
|
|
|
|
def test_refusing_view_get_assertion_raises(self):
|
|
from fido2.ctap import CtapError
|
|
attest = self._register()
|
|
cred_id = attest.auth_data.credential_data.credential_id
|
|
with self.assertRaises(CtapError) as ctx:
|
|
self.emulator.refusing().get_assertion(
|
|
rp_id="localhost",
|
|
client_data_hash=b"\x01" * 32,
|
|
allow_list=[{"id": cred_id, "type": "public-key"}],
|
|
)
|
|
self.assertEqual(ctx.exception.code, CtapError.ERR.OPERATION_DENIED)
|
|
|
|
|
|
class TestCardEmulatorIntegration(unittest.TestCase):
|
|
"""Full register → authenticate flow through ProxyState with the emulator."""
|
|
|
|
def setUp(self):
|
|
self._tmpdir = tempfile.TemporaryDirectory()
|
|
self.tmp_path = Path(self._tmpdir.name)
|
|
self.state = _make_direct_state(self.tmp_path)
|
|
self.emulator = CardEmulator()
|
|
|
|
def tearDown(self):
|
|
self._tmpdir.cleanup()
|
|
|
|
def _register(self, username="alice", display_name=None):
|
|
with _patch_emulator(self.state, self.emulator):
|
|
return self.state.register_enrollment(username, display_name)
|
|
|
|
def _authenticate(self, username="alice"):
|
|
with _patch_emulator(self.state, self.emulator):
|
|
return self.state.authenticate_with_card(username)
|
|
|
|
def _authenticate_refusing(self, username="alice"):
|
|
with _patch_emulator(self.state, self.emulator.refusing()):
|
|
return self.state.authenticate_with_card(username)
|
|
|
|
def test_register_produces_credential_data(self):
|
|
enrollment = self._register("alice", "Alice")
|
|
self.assertIsNotNone(enrollment.credential_data_b64)
|
|
self.assertEqual(enrollment.username, "alice")
|
|
|
|
def test_register_persists_to_disk(self):
|
|
self._register("alice")
|
|
state2 = _make_direct_state(self.tmp_path)
|
|
self.assertTrue(state2.has_enrollment("alice"))
|
|
self.assertIsNotNone(state2.get_enrollment("alice").credential_data_b64)
|
|
|
|
def test_authenticate_after_register_succeeds(self):
|
|
self._register("alice")
|
|
ok, msg = self._authenticate("alice")
|
|
self.assertTrue(ok)
|
|
self.assertEqual(msg, "assertion verified")
|
|
|
|
def test_authenticate_user_says_no_fails(self):
|
|
self._register("alice")
|
|
ok, msg = self._authenticate_refusing("alice")
|
|
self.assertFalse(ok)
|
|
self.assertIn("assertion verification failed", msg)
|
|
|
|
def test_register_user_says_no_fails(self):
|
|
with _patch_emulator(self.state, self.emulator.refusing()):
|
|
with self.assertRaises(RuntimeError) as ctx:
|
|
self.state.register_enrollment("alice", None)
|
|
self.assertIn("card registration failed", str(ctx.exception))
|
|
|
|
def test_authenticate_after_forget_fails(self):
|
|
self._register("alice")
|
|
self.emulator.forget_user("alice")
|
|
ok, msg = self._authenticate("alice")
|
|
self.assertFalse(ok)
|
|
|
|
def test_two_users_independent(self):
|
|
self._register("alice")
|
|
self._register("bob")
|
|
ok_a, _ = self._authenticate("alice")
|
|
ok_b, _ = self._authenticate("bob")
|
|
self.assertTrue(ok_a)
|
|
self.assertTrue(ok_b)
|
|
|
|
def test_forget_one_user_leaves_other_intact(self):
|
|
self._register("alice")
|
|
self._register("bob")
|
|
self.emulator.forget_user("alice")
|
|
ok_a, _ = self._authenticate("alice")
|
|
ok_b, _ = self._authenticate("bob")
|
|
self.assertFalse(ok_a)
|
|
self.assertTrue(ok_b)
|
|
|
|
def test_sign_count_increases_across_logins(self):
|
|
import struct
|
|
from k_proxy_app import AttestedCredentialData, b64u_decode
|
|
self._register("alice")
|
|
enrollment = self.state.get_enrollment("alice")
|
|
cred_data = AttestedCredentialData(b64u_decode(enrollment.credential_data_b64))
|
|
cred_id = cred_data.credential_id
|
|
|
|
sign_counts = []
|
|
for _ in range(3):
|
|
assertion = self.emulator.get_assertion(
|
|
rp_id=self.state.rp_id,
|
|
client_data_hash=b"\xAB" * 32,
|
|
allow_list=[{"id": cred_id, "type": "public-key"}],
|
|
)
|
|
sign_counts.append(struct.unpack(">I", bytes(assertion.auth_data)[33:37])[0])
|
|
|
|
self.assertLess(sign_counts[0], sign_counts[1])
|
|
self.assertLess(sign_counts[1], sign_counts[2])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|