k_card/k_server_app.py

221 lines
7.4 KiB
Python

#!/usr/bin/env python3
"""
k_server — protected resource backend.
Exposes a monotonic counter behind a shared proxy token. Only k_proxy
is expected to reach this service; k_client should have no direct path.
All state is process-local and resets on restart.
"""
from __future__ import annotations
import argparse
import base64
import hashlib
import json
import ssl
import threading
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import Any
from urllib.parse import urlparse
def _b64u_decode(s: str) -> bytes:
padded = s + "=" * ((4 - len(s) % 4) % 4)
return base64.urlsafe_b64decode(padded)
def _verify_assertion_token(token: str, request_path: str, request_method: str) -> bool:
"""Verify a base64url-encoded FIDO2 per-request assertion bundle.
Bundle fields (JSON, then base64url-encoded):
v version (1)
url full URL used to derive the challenge
method HTTP method used to derive the challenge
nonce random hex nonce used to derive the challenge
authData base64url authenticator data
sig base64url ECDSA signature
cdj base64url clientDataJson bytes
cred base64url AttestedCredentialData (aaguid+credIdLen+credId+coseKey)
user enrolled username (informational)
"""
try:
import cbor2
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
EllipticCurvePublicNumbers,
SECP256R1,
)
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.exceptions import InvalidSignature
bundle = json.loads(_b64u_decode(token).decode("utf-8"))
# Path and method must match the actual request.
bundle_path = urlparse(bundle["url"]).path
if bundle_path != request_path:
return False
if bundle["method"].upper() != request_method.upper():
return False
url = bundle["url"]
method = bundle["method"]
nonce = bundle["nonce"]
# Verify challenge claim: challenge == b64u(SHA256(url|method|nonce))
binding = f"{url}|{method}|{nonce}".encode()
expected_challenge = base64.urlsafe_b64encode(hashlib.sha256(binding).digest()).rstrip(b"=").decode()
cdj_bytes = _b64u_decode(bundle["cdj"])
cdj = json.loads(cdj_bytes)
if cdj.get("type") != "webauthn.get":
return False
if cdj.get("challenge") != expected_challenge:
return False
# Verify ECDSA-P256 signature over authData || SHA256(clientDataJson).
auth_data = _b64u_decode(bundle["authData"])
signature = _b64u_decode(bundle["sig"])
client_data_hash = hashlib.sha256(cdj_bytes).digest()
message = auth_data + client_data_hash
# Extract P-256 public key from AttestedCredentialData.
cred_data = _b64u_decode(bundle["cred"])
cred_id_len = (cred_data[16] << 8) | cred_data[17]
cose_bytes = cred_data[18 + cred_id_len:]
cose_key = cbor2.loads(cose_bytes)
x = cose_key[-2]
y = cose_key[-3]
pub_key = EllipticCurvePublicNumbers(
x=int.from_bytes(x, "big"),
y=int.from_bytes(y, "big"),
curve=SECP256R1(),
).public_key(default_backend())
pub_key.verify(signature, message, ECDSA(SHA256()))
return True
except (InvalidSignature, Exception):
return False
class ServerState:
# All state is process-local; a restart resets the counter to zero.
def __init__(self, proxy_token: str):
self.proxy_token = proxy_token
self.counter = 0
self.lock = threading.Lock()
def next_counter(self) -> int:
with self.lock:
self.counter += 1
return self.counter
class Handler(BaseHTTPRequestHandler):
state: ServerState
protocol_version = "HTTP/1.1"
def _json(self, status: int, payload: dict[str, Any]) -> None:
body = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def _discard_request_body(self) -> None:
# HTTP/1.1 keep-alive: the connection is reused, so the body must be fully
# consumed before we send the response, even for endpoints that ignore it.
length = int(self.headers.get("Content-Length", "0"))
if length > 0:
self.rfile.read(length)
def _is_proxy_authorized(self) -> bool:
# Accept legacy X-Proxy-Token (k_proxy_app.py) or FIDO2 assertion Bearer.
if self.headers.get("X-Proxy-Token") == self.state.proxy_token:
return True
auth = self.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return _verify_assertion_token(
auth[7:].strip(),
request_path=urlparse(self.path).path,
request_method=self.command,
)
return False
def do_GET(self) -> None: # noqa: N802
path = urlparse(self.path).path
if path == "/health":
self._json(
200,
{
"ok": True,
"service": "k_server",
"time": int(time.time()),
},
)
return
self.send_error(404)
def do_POST(self) -> None: # noqa: N802
path = urlparse(self.path).path
if path != "/resource/counter":
self.send_error(404)
return
self._discard_request_body()
if not self._is_proxy_authorized():
self._json(401, {"ok": False, "error": "unauthorized proxy"})
return
value = self.state.next_counter()
self._json(
200,
{
"ok": True,
"resource": "counter",
"value": value,
"time": int(time.time()),
},
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run k_server counter service")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8780)
parser.add_argument("--tls-certfile", help="PEM certificate chain for HTTPS listener")
parser.add_argument("--tls-keyfile", help="PEM private key for HTTPS listener")
parser.add_argument(
"--proxy-token",
default="dev-proxy-token",
help="Shared token expected in X-Proxy-Token from k_proxy",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
if bool(args.tls_certfile) != bool(args.tls_keyfile):
raise SystemExit("Both --tls-certfile and --tls-keyfile are required to enable HTTPS")
state = ServerState(proxy_token=args.proxy_token)
Handler.state = state
server = ThreadingHTTPServer((args.host, args.port), Handler)
scheme = "http"
if args.tls_certfile:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(certfile=args.tls_certfile, keyfile=args.tls_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
scheme = "https"
print(f"k_server listening on {scheme}://{args.host}:{args.port}")
server.serve_forever()
return 0
if __name__ == "__main__":
raise SystemExit(main())