220 lines
7.5 KiB
Python
220 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
k_server — protected resource backend.
|
|
|
|
Exposes a monotonic counter behind a shared proxy token. Only k_proxy
|
|
is expected to reach this service; k_client should have no direct path.
|
|
All state is process-local and resets on restart.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import ssl
|
|
import threading
|
|
import time
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from typing import Any
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
def _b64u_decode(s: str) -> bytes:
|
|
padded = s + "=" * ((4 - len(s) % 4) % 4)
|
|
return base64.urlsafe_b64decode(padded)
|
|
|
|
|
|
def _verify_assertion_token(token: str, expected_host: str) -> bool:
|
|
"""Verify a base64url-encoded FIDO2 domain-level assertion bundle.
|
|
|
|
Bundle fields (JSON, then base64url-encoded):
|
|
v version (1)
|
|
host hostname used to derive the challenge
|
|
nonce random hex nonce used to derive the challenge
|
|
authData base64url authenticator data
|
|
sig base64url ECDSA signature
|
|
cdj base64url clientDataJson bytes
|
|
cred base64url AttestedCredentialData (aaguid+credIdLen+credId+coseKey)
|
|
user enrolled username (informational)
|
|
|
|
expected_host must match bundle["host"] exactly (case-insensitive) to prevent
|
|
cross-server replay: a token issued for server-a must not pass on server-b.
|
|
"""
|
|
try:
|
|
import cbor2
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.asymmetric.ec import (
|
|
ECDSA,
|
|
EllipticCurvePublicNumbers,
|
|
SECP256R1,
|
|
)
|
|
from cryptography.hazmat.primitives.hashes import SHA256
|
|
from cryptography.exceptions import InvalidSignature
|
|
|
|
bundle = json.loads(_b64u_decode(token).decode("utf-8"))
|
|
|
|
host = bundle["host"]
|
|
nonce = bundle["nonce"]
|
|
|
|
if host.lower() != expected_host.lower():
|
|
return False
|
|
|
|
# Verify challenge claim: challenge == b64u(SHA256(host|nonce))
|
|
binding = f"{host}|{nonce}".encode()
|
|
expected_challenge = base64.urlsafe_b64encode(hashlib.sha256(binding).digest()).rstrip(b"=").decode()
|
|
|
|
cdj_bytes = _b64u_decode(bundle["cdj"])
|
|
cdj = json.loads(cdj_bytes)
|
|
if cdj.get("type") != "webauthn.get":
|
|
return False
|
|
if cdj.get("challenge") != expected_challenge:
|
|
return False
|
|
|
|
# Verify ECDSA-P256 signature over authData || SHA256(clientDataJson).
|
|
auth_data = _b64u_decode(bundle["authData"])
|
|
signature = _b64u_decode(bundle["sig"])
|
|
client_data_hash = hashlib.sha256(cdj_bytes).digest()
|
|
message = auth_data + client_data_hash
|
|
|
|
# Extract P-256 public key from AttestedCredentialData.
|
|
cred_data = _b64u_decode(bundle["cred"])
|
|
cred_id_len = (cred_data[16] << 8) | cred_data[17]
|
|
cose_bytes = cred_data[18 + cred_id_len:]
|
|
cose_key = cbor2.loads(cose_bytes)
|
|
x = cose_key[-2]
|
|
y = cose_key[-3]
|
|
|
|
pub_key = EllipticCurvePublicNumbers(
|
|
x=int.from_bytes(x, "big"),
|
|
y=int.from_bytes(y, "big"),
|
|
curve=SECP256R1(),
|
|
).public_key(default_backend())
|
|
|
|
pub_key.verify(signature, message, ECDSA(SHA256()))
|
|
return True
|
|
except (InvalidSignature, Exception):
|
|
return False
|
|
|
|
|
|
class ServerState:
|
|
# All state is process-local; a restart resets the counter to zero.
|
|
def __init__(self, proxy_token: str, protected_host: str = "127.0.0.1"):
|
|
self.proxy_token = proxy_token
|
|
self.protected_host = protected_host
|
|
self.counter = 0
|
|
self.lock = threading.Lock()
|
|
|
|
def next_counter(self) -> int:
|
|
with self.lock:
|
|
self.counter += 1
|
|
return self.counter
|
|
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
state: ServerState
|
|
protocol_version = "HTTP/1.1"
|
|
|
|
def _json(self, status: int, payload: dict[str, Any]) -> None:
|
|
body = json.dumps(payload).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
def _discard_request_body(self) -> None:
|
|
# HTTP/1.1 keep-alive: the connection is reused, so the body must be fully
|
|
# consumed before we send the response, even for endpoints that ignore it.
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length > 0:
|
|
self.rfile.read(length)
|
|
|
|
def _is_proxy_authorized(self) -> bool:
|
|
# Accept legacy X-Proxy-Token (k_proxy_app.py) or FIDO2 assertion Bearer.
|
|
if self.headers.get("X-Proxy-Token") == self.state.proxy_token:
|
|
return True
|
|
auth = self.headers.get("Authorization", "")
|
|
if auth.startswith("Bearer "):
|
|
return _verify_assertion_token(auth[7:].strip(), self.state.protected_host)
|
|
return False
|
|
|
|
def do_GET(self) -> None: # noqa: N802
|
|
path = urlparse(self.path).path
|
|
if path == "/health":
|
|
self._json(
|
|
200,
|
|
{
|
|
"ok": True,
|
|
"service": "k_server",
|
|
"time": int(time.time()),
|
|
},
|
|
)
|
|
return
|
|
self.send_error(404)
|
|
|
|
def do_POST(self) -> None: # noqa: N802
|
|
path = urlparse(self.path).path
|
|
if path != "/resource/counter":
|
|
self.send_error(404)
|
|
return
|
|
self._discard_request_body()
|
|
if not self._is_proxy_authorized():
|
|
self._json(401, {"ok": False, "error": "unauthorized proxy"})
|
|
return
|
|
|
|
value = self.state.next_counter()
|
|
self._json(
|
|
200,
|
|
{
|
|
"ok": True,
|
|
"resource": "counter",
|
|
"value": value,
|
|
"time": int(time.time()),
|
|
},
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Run k_server counter service")
|
|
parser.add_argument("--host", default="127.0.0.1")
|
|
parser.add_argument("--port", type=int, default=8780)
|
|
parser.add_argument("--tls-certfile", help="PEM certificate chain for HTTPS listener")
|
|
parser.add_argument("--tls-keyfile", help="PEM private key for HTTPS listener")
|
|
parser.add_argument(
|
|
"--proxy-token",
|
|
default="dev-proxy-token",
|
|
help="Shared token expected in X-Proxy-Token from k_proxy",
|
|
)
|
|
parser.add_argument(
|
|
"--protected-host",
|
|
default="127.0.0.1",
|
|
help="Hostname this server protects; Bearer tokens must be issued for this host",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if bool(args.tls_certfile) != bool(args.tls_keyfile):
|
|
raise SystemExit("Both --tls-certfile and --tls-keyfile are required to enable HTTPS")
|
|
|
|
state = ServerState(proxy_token=args.proxy_token, protected_host=args.protected_host)
|
|
Handler.state = state
|
|
server = ThreadingHTTPServer((args.host, args.port), Handler)
|
|
scheme = "http"
|
|
if args.tls_certfile:
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
context.load_cert_chain(certfile=args.tls_certfile, keyfile=args.tls_keyfile)
|
|
server.socket = context.wrap_socket(server.socket, server_side=True)
|
|
scheme = "https"
|
|
|
|
print(f"k_server listening on {scheme}://{args.host}:{args.port}")
|
|
server.serve_forever()
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|