Guard against cross-server token replay in _verify_assertion_token

_verify_assertion_token now takes expected_host and rejects any token
whose bundle["host"] does not match — closing the cross-server replay
path where a token issued for server-a could have passed on server-b.

ServerState gains protected_host (default 127.0.0.1); k_server exposes
--protected-host CLI flag so operators declare which host they protect.

New abuse tests (unit + round-trip):
  test_cross_server_replay_rejected
  test_cross_server_replay_case_insensitive
  test_roundtrip_cross_server_replay_rejected
  test_roundtrip_cross_server_replay_accepted_on_correct_server

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Morten V. Christiansen 2026-05-09 23:57:28 +02:00
parent 4b719a0846
commit 592af0c314
2 changed files with 51 additions and 11 deletions

View File

@ -26,7 +26,7 @@ def _b64u_decode(s: str) -> bytes:
return base64.urlsafe_b64decode(padded) return base64.urlsafe_b64decode(padded)
def _verify_assertion_token(token: str) -> bool: def _verify_assertion_token(token: str, expected_host: str) -> bool:
"""Verify a base64url-encoded FIDO2 domain-level assertion bundle. """Verify a base64url-encoded FIDO2 domain-level assertion bundle.
Bundle fields (JSON, then base64url-encoded): Bundle fields (JSON, then base64url-encoded):
@ -38,6 +38,9 @@ def _verify_assertion_token(token: str) -> bool:
cdj base64url clientDataJson bytes cdj base64url clientDataJson bytes
cred base64url AttestedCredentialData (aaguid+credIdLen+credId+coseKey) cred base64url AttestedCredentialData (aaguid+credIdLen+credId+coseKey)
user enrolled username (informational) user enrolled username (informational)
expected_host must match bundle["host"] exactly (case-insensitive) to prevent
cross-server replay: a token issued for server-a must not pass on server-b.
""" """
try: try:
import cbor2 import cbor2
@ -55,6 +58,9 @@ def _verify_assertion_token(token: str) -> bool:
host = bundle["host"] host = bundle["host"]
nonce = bundle["nonce"] nonce = bundle["nonce"]
if host.lower() != expected_host.lower():
return False
# Verify challenge claim: challenge == b64u(SHA256(host|nonce)) # Verify challenge claim: challenge == b64u(SHA256(host|nonce))
binding = f"{host}|{nonce}".encode() binding = f"{host}|{nonce}".encode()
expected_challenge = base64.urlsafe_b64encode(hashlib.sha256(binding).digest()).rstrip(b"=").decode() expected_challenge = base64.urlsafe_b64encode(hashlib.sha256(binding).digest()).rstrip(b"=").decode()
@ -94,8 +100,9 @@ def _verify_assertion_token(token: str) -> bool:
class ServerState: class ServerState:
# All state is process-local; a restart resets the counter to zero. # All state is process-local; a restart resets the counter to zero.
def __init__(self, proxy_token: str): def __init__(self, proxy_token: str, protected_host: str = "127.0.0.1"):
self.proxy_token = proxy_token self.proxy_token = proxy_token
self.protected_host = protected_host
self.counter = 0 self.counter = 0
self.lock = threading.Lock() self.lock = threading.Lock()
@ -130,7 +137,7 @@ class Handler(BaseHTTPRequestHandler):
return True return True
auth = self.headers.get("Authorization", "") auth = self.headers.get("Authorization", "")
if auth.startswith("Bearer "): if auth.startswith("Bearer "):
return _verify_assertion_token(auth[7:].strip()) return _verify_assertion_token(auth[7:].strip(), self.state.protected_host)
return False return False
def do_GET(self) -> None: # noqa: N802 def do_GET(self) -> None: # noqa: N802
@ -180,6 +187,11 @@ def parse_args() -> argparse.Namespace:
default="dev-proxy-token", default="dev-proxy-token",
help="Shared token expected in X-Proxy-Token from k_proxy", help="Shared token expected in X-Proxy-Token from k_proxy",
) )
parser.add_argument(
"--protected-host",
default="127.0.0.1",
help="Hostname this server protects; Bearer tokens must be issued for this host",
)
return parser.parse_args() return parser.parse_args()
@ -188,7 +200,7 @@ def main() -> int:
if bool(args.tls_certfile) != bool(args.tls_keyfile): if bool(args.tls_certfile) != bool(args.tls_keyfile):
raise SystemExit("Both --tls-certfile and --tls-keyfile are required to enable HTTPS") raise SystemExit("Both --tls-certfile and --tls-keyfile are required to enable HTTPS")
state = ServerState(proxy_token=args.proxy_token) state = ServerState(proxy_token=args.proxy_token, protected_host=args.protected_host)
Handler.state = state Handler.state = state
server = ThreadingHTTPServer((args.host, args.port), Handler) server = ThreadingHTTPServer((args.host, args.port), Handler)
scheme = "http" scheme = "http"

View File

@ -140,9 +140,10 @@ class TestVerifyAssertionToken(unittest.TestCase):
self.nonce = "deadbeef01234567" self.nonce = "deadbeef01234567"
self.token, _ = _make_bundle(self.host, self.nonce) self.token, _ = _make_bundle(self.host, self.nonce)
def _check(self, token=None) -> bool: def _check(self, token=None, host=None) -> bool:
return k_server_app._verify_assertion_token( return k_server_app._verify_assertion_token(
self.token if token is None else token self.token if token is None else token,
host if host is not None else self.host,
) )
def test_valid_token_accepted(self): def test_valid_token_accepted(self):
@ -152,6 +153,15 @@ class TestVerifyAssertionToken(unittest.TestCase):
# Domain-level binding: the same token covers all paths on the host. # Domain-level binding: the same token covers all paths on the host.
self.assertTrue(self._check()) self.assertTrue(self._check())
def test_cross_server_replay_rejected(self):
# Token issued for self.host must not pass when a different server verifies it.
self.assertFalse(self._check(host="other-server.com"))
def test_cross_server_replay_case_insensitive(self):
# Case variation of the expected host still rejects a token for a different host.
token_b, _ = _make_bundle("BANK.com", self.nonce)
self.assertFalse(k_server_app._verify_assertion_token(token_b, "evil.com"))
def test_tampered_nonce_invalidates_challenge(self): def test_tampered_nonce_invalidates_challenge(self):
tampered = _tamper(self.token, "nonce", lambda _: "tampered00000000") tampered = _tamper(self.token, "nonce", lambda _: "tampered00000000")
self.assertFalse(self._check(tampered)) self.assertFalse(self._check(tampered))
@ -276,7 +286,7 @@ class TestVerifyAssertionTokenRoundTrip(unittest.TestCase):
emulator = CardEmulator() emulator = CardEmulator()
token = self._register_and_assert(emulator, "example.com", "cafebabe12345678") token = self._register_and_assert(emulator, "example.com", "cafebabe12345678")
self.assertTrue( self.assertTrue(
k_server_app._verify_assertion_token(token), k_server_app._verify_assertion_token(token, "example.com"),
"valid round-trip token must be accepted", "valid round-trip token must be accepted",
) )
@ -285,7 +295,7 @@ class TestVerifyAssertionTokenRoundTrip(unittest.TestCase):
emulator = CardEmulator() emulator = CardEmulator()
token = self._register_and_assert(emulator, "example.com", "aabbccdd11223344") token = self._register_and_assert(emulator, "example.com", "aabbccdd11223344")
self.assertTrue( self.assertTrue(
k_server_app._verify_assertion_token(token), k_server_app._verify_assertion_token(token, "example.com"),
"domain-level token must be accepted for any path on the host", "domain-level token must be accepted for any path on the host",
) )
@ -295,7 +305,7 @@ class TestVerifyAssertionTokenRoundTrip(unittest.TestCase):
token = self._register_and_assert(emulator, "example.com", "original00000000") token = self._register_and_assert(emulator, "example.com", "original00000000")
tampered = _tamper(token, "nonce", lambda _: "tampered11111111") tampered = _tamper(token, "nonce", lambda _: "tampered11111111")
self.assertFalse( self.assertFalse(
k_server_app._verify_assertion_token(tampered), k_server_app._verify_assertion_token(tampered, "example.com"),
"tampered nonce must break challenge verification", "tampered nonce must break challenge verification",
) )
@ -305,10 +315,28 @@ class TestVerifyAssertionTokenRoundTrip(unittest.TestCase):
token = self._register_and_assert(emulator, "example.com", "deadbeef00112233") token = self._register_and_assert(emulator, "example.com", "deadbeef00112233")
tampered = _tamper(token, "host", lambda _: "attacker.com") tampered = _tamper(token, "host", lambda _: "attacker.com")
self.assertFalse( self.assertFalse(
k_server_app._verify_assertion_token(tampered), k_server_app._verify_assertion_token(tampered, "example.com"),
"tampered host must break challenge verification", "tampered host must break challenge verification",
) )
def test_roundtrip_cross_server_replay_rejected(self):
"""Token issued for server-a must not validate on server-b."""
emulator = CardEmulator()
token = self._register_and_assert(emulator, "server-a.com", "1122334455667788")
self.assertFalse(
k_server_app._verify_assertion_token(token, "server-b.com"),
"cross-server replay: token for server-a must be rejected by server-b",
)
def test_roundtrip_cross_server_replay_accepted_on_correct_server(self):
"""Sanity: same token is accepted on the server it was issued for."""
emulator = CardEmulator()
token = self._register_and_assert(emulator, "server-a.com", "aabbccdd99887766")
self.assertTrue(
k_server_app._verify_assertion_token(token, "server-a.com"),
"token must still be valid on the correct server",
)
def test_roundtrip_replayed_for_different_user_rejected(self): def test_roundtrip_replayed_for_different_user_rejected(self):
"""Two users register separate credentials; each token is only valid for its own key.""" """Two users register separate credentials; each token is only valid for its own key."""
em_a = CardEmulator() em_a = CardEmulator()
@ -325,7 +353,7 @@ class TestVerifyAssertionTokenRoundTrip(unittest.TestCase):
cross = _b64u_encode(json.dumps(bundle_a, separators=(",", ":")).encode()) cross = _b64u_encode(json.dumps(bundle_a, separators=(",", ":")).encode())
self.assertFalse( self.assertFalse(
k_server_app._verify_assertion_token(cross), k_server_app._verify_assertion_token(cross, host),
"cross-user key swap must fail verification", "cross-user key swap must fail verification",
) )