k_card/phase65_concurrency_probe.py

189 lines
6.5 KiB
Python

#!/usr/bin/env python3
"""
Phase 6.5 concurrency probe for the direct browser-to-k_proxy path.
What it does:
- Creates a small batch of enrolled users.
- Logs each user in through k_proxy over TLS.
- Fires protected counter requests in parallel using the returned bearer tokens.
- Verifies that all calls succeed and that returned counter values are unique and contiguous.
"""
from __future__ import annotations
import argparse
import json
import ssl
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
@dataclass
class Session:
username: str
token: str
def request_json(
base_url: str,
path: str,
*,
method: str = "GET",
payload: dict[str, Any] | None = None,
token: str | None = None,
cafile: str | None = None,
timeout: int = 10,
) -> tuple[int, dict[str, Any]]:
req = Request(f"{base_url.rstrip('/')}{path}", method=method)
req.add_header("Content-Type", "application/json")
if token:
req.add_header("Authorization", f"Bearer {token}")
data = None if payload is None else json.dumps(payload).encode("utf-8")
context = ssl.create_default_context(cafile=cafile) if base_url.startswith("https://") else None
try:
with urlopen(req, data=data, timeout=timeout, context=context) as resp:
return resp.status, json.loads(resp.read().decode("utf-8"))
except HTTPError as exc:
try:
return exc.code, json.loads(exc.read().decode("utf-8"))
except Exception:
return exc.code, {"ok": False, "error": f"http error {exc.code}"}
except URLError as exc:
return 502, {"ok": False, "error": f"url error: {exc.reason}"}
except Exception as exc:
return 502, {"ok": False, "error": f"request failed: {exc}"}
def enroll_user(base_url: str, cafile: str, username: str, display_name: str) -> None:
status, data = request_json(
base_url,
"/enroll/register",
method="POST",
payload={"username": username, "display_name": display_name},
cafile=cafile,
)
if status == 200:
return
if status == 409 and data.get("error") == "user already enrolled":
return
raise RuntimeError(f"enroll failed for {username}: status={status} data={data}")
def login_user(base_url: str, cafile: str, username: str) -> Session:
status, data = request_json(
base_url,
"/session/login",
method="POST",
payload={"username": username},
cafile=cafile,
)
if status != 200 or not data.get("session_token"):
raise RuntimeError(f"login failed for {username}: status={status} data={data}")
return Session(username=username, token=data["session_token"])
def counter_call(base_url: str, cafile: str, session: Session, call_id: int) -> dict[str, Any]:
started = time.time()
status, data = request_json(
base_url,
"/resource/counter",
method="POST",
payload={},
token=session.token,
cafile=cafile,
)
finished = time.time()
return {
"call_id": call_id,
"username": session.username,
"status": status,
"data": data,
"latency_ms": int((finished - started) * 1000),
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run Phase 6.5 concurrency probe against k_proxy")
parser.add_argument("--base-url", default="https://127.0.0.1:9771")
parser.add_argument("--ca-file", required=True)
parser.add_argument("--users", type=int, default=3)
parser.add_argument("--requests-per-user", type=int, default=4)
parser.add_argument("--username-prefix", default="phase65")
parser.add_argument(
"--max-workers",
type=int,
help="Maximum number of in-flight protected calls; defaults to total requests",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
sessions: list[Session] = []
for idx in range(args.users):
username = f"{args.username_prefix}_{idx}"
enroll_user(args.base_url, args.ca_file, username, f"Phase65 User {idx}")
sessions.append(login_user(args.base_url, args.ca_file, username))
jobs: list[tuple[Session, int]] = []
call_id = 0
for session in sessions:
for _ in range(args.requests_per_user):
jobs.append((session, call_id))
call_id += 1
results: list[dict[str, Any]] = []
max_workers = args.max_workers or len(jobs)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_map = {
executor.submit(counter_call, args.base_url, args.ca_file, session, job_id): (session.username, job_id)
for session, job_id in jobs
}
for future in as_completed(future_map):
username, job_id = future_map[future]
try:
results.append(future.result())
except Exception as exc:
results.append(
{
"call_id": job_id,
"username": username,
"status": 599,
"data": {"ok": False, "error": str(exc)},
"latency_ms": -1,
}
)
results.sort(key=lambda item: item["call_id"])
ok_results = [item for item in results if item["status"] == 200 and item["data"].get("ok")]
values = [item["data"]["upstream"]["value"] for item in ok_results]
values_sorted = sorted(values)
contiguous = bool(values_sorted) and values_sorted == list(range(values_sorted[0], values_sorted[0] + len(values_sorted)))
summary = {
"ok": len(ok_results) == len(results) and len(set(values)) == len(values) and contiguous,
"users": args.users,
"requests_per_user": args.requests_per_user,
"total_requests": len(results),
"max_workers": max_workers,
"successful_requests": len(ok_results),
"unique_counter_values": len(set(values)),
"counter_min": min(values_sorted) if values_sorted else None,
"counter_max": max(values_sorted) if values_sorted else None,
"counter_contiguous": contiguous,
"max_latency_ms": max((item["latency_ms"] for item in results), default=None),
"results": results,
}
print(json.dumps(summary, indent=2))
return 0 if summary["ok"] else 1
if __name__ == "__main__":
raise SystemExit(main())