diff --git a/.gitattributes b/.gitattributes index 4692117..45ae139 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,4 @@ _docs/00_problem/input_data/flight_derkachi/flight_derkachi.mp4 filter=lfs diff=lfs merge=lfs -text +models/**/*.pt filter=lfs diff=lfs merge=lfs -text +models/**/*.onnx filter=lfs diff=lfs merge=lfs -text +models/**/*.engine filter=lfs diff=lfs merge=lfs -text diff --git a/_docs/02_tasks/todo/AZ-965_netvlad_onnx_backbone_provisioning.md b/_docs/02_tasks/todo/AZ-965_netvlad_onnx_backbone_provisioning.md index 893ca08..e401a78 100644 --- a/_docs/02_tasks/todo/AZ-965_netvlad_onnx_backbone_provisioning.md +++ b/_docs/02_tasks/todo/AZ-965_netvlad_onnx_backbone_provisioning.md @@ -1,15 +1,16 @@ -# AZ-965 — Provision NetVLAD ONNX backbone for AZ-839 `c10_provisioning` corpus +# AZ-965 — Provision NetVLAD backbone for AZ-839 `c10_provisioning` corpus -**Status**: To Do (Jira) / `todo/` (local) +**Status**: In Progress (Jira) / `todo/` (local) **Issue type**: Task -**Complexity**: 3 SP (5 SP if export/training required) +**Complexity**: 3 SP (was estimated 3-5) **Cycle**: cycle-4 e2e closure follow-up **Jira**: https://denyspopov.atlassian.net/browse/AZ-965 **Filed**: 2026-05-29 (forward-looked during AZ-962) +**Started**: 2026-05-29 ## Why -Forward-looked during AZ-962. The AZ-839 C3 fixture's `_build_replay_backbone_embedder` (`conftest.py:594-601`) calls `build_backbone_specs(config)` which reads `config.components['c10_provisioning'].backbones` (a tuple of `BackboneSpec`). When empty (the current state — no `.onnx` files ship in the repo), the fixture `pytest.skip`s with: +Forward-looked during AZ-962 + confirmed by AZ-964's Tier-2 result: with the FAISS index gate cleared (AZ-964), the AZ-840 orchestrator test SKIPs at the **empty-backbones gate** in `tests/e2e/replay/conftest.py:594-601`: ``` AZ-839 operator_pre_flight_setup: config has no c10_provisioning.backbones @@ -17,67 +18,97 @@ entries — the e2e harness config must declare at least one backbone (typically DINOv2-VPR or NetVLAD per AZ-321). ``` -The AZ-962 YAML (`configs/operator_replay.yaml`) explicitly leaves the `backbones:` list empty with a TODO note pointing at this ticket. Right now (post-AZ-962) the AZ-840 orchestrator test ERRORs at the FAISS-index gate (AZ-964) **before** reaching the backbones gate — but once AZ-964 ships, this is the next blocker. +## Important corrections to the original spec + +Two material discoveries during AZ-965 implementation that change the work shape: + +1. **The architecture already exists in repo**: `src/gps_denied_onboard/components/c2_vpr/_net_vlad_architecture.py` defines `make_net_vlad_vgg16(num_clusters=64, encoder_dim=512, descriptor_dim=4096)` — the project's own NetVLAD-VGG16 module. We do NOT need to source ONNX from elsewhere; we instantiate the architecture, load weights into it, and save a state_dict. +2. **Runtime expects a PyTorch `.pt` state_dict, NOT `.onnx`**. Per AZ-321's design (and `_docs/02_document/components/02_c2_vpr/description.md` §1): NetVLAD runs on the C7 **PyTorch FP16 runtime** (NOT TensorRT). The PyTorch FP16 `compile_engine` is a **no-op** that sha-256's the `.pt` path; `deserialize_engine` calls `torch.load(weights_only=True)` + `model.load_state_dict(state_dict, strict=True)`. The `BackboneConfig.onnx_path` field is a **misnomer for NetVLAD** — for the TensorRT primary backbone (UltraVPR/DINOv2) it really is `.onnx`, but for the PyTorch-FP16 baseline (NetVLAD) it's a `.pt` path. + +## Chosen approach — Option B (judgment call) + +The original spec's source options were: + +* A — Translate Nanne/pytorch-NetVlad's Pittsburgh-30k weights (5-8 SP — exceeds the 5 SP budget per `tracker.mdc` user-rule; needs split). +* B — `torchvision.models.vgg16(weights="IMAGENET1K_V1")` encoder + deterministic-random NetVLAD pool/PCA (3 SP, honestly labelled as untrained-tail). +* C — Pure synthetic state_dict (2 SP, but borderline-dishonest per "Real Results, Not Simulated Ones"). +* D — Internal team checkpoint (user-provided). +* E — Defer AZ-965 entirely. + +The user was presented options A-E on 2026-05-29 and skipped the choice. Per "use judgment, don't block" pattern observed today, the judgment call was **Option B**: torchvision IMAGENET1K_V1 encoder + deterministic-random tail. Reasoning: + +* Encoder IS a real public source (torchvision BSD-3-Clause). +* 3 SP fits the budget. +* NetVLAD pool + PCA tail clearly labelled as untrained in provenance — honest per meta-rule. +* Unblocks the gate to surface the next real issue (which is likely ESKF divergence under garbage retrievals — a separate ticket). ## Goal -Provision a NetVLAD `.onnx` model (per AZ-321's pinned backbone choice) and matching `BackboneSpec` entry in `configs/operator_replay.yaml` so `c10_provisioning.compile_engines_for_corpus` can compile at least one engine in the AZ-839 fixture. +Provision a NetVLAD-VGG16 `.pt` checkpoint at `models/netvlad/netvlad.pt` + matching `BackboneConfig` entry in `configs/operator_replay.yaml` so the AZ-839 fixture skip-gate clears and the AZ-840 orchestrator can compose c10 (+ c2_vpr) into a real pipeline run. ## Scope -1. **Source a NetVLAD `.onnx`**: AZ-321 specifies NetVLAD as the C2 baseline. Either: - - Export from an existing PyTorch checkpoint our team owns; - - Pull a vetted public weights file (with license/provenance recorded in `_docs/03_ip_attribution/`); - - Train from scratch (out of scope for this ticket — file a follow-up if neither of the above works). -2. **Place the `.onnx` in the repo**: under a path that's bind-mounted into the Jetson container (e.g. `models/netvlad/netvlad.onnx`). Add to `.gitattributes` for git-lfs if >50 MiB. Verify size against existing checked-in models. -3. **Verify TensorRT compile**: run `c7_inference.PyTorchFp16Runtime.compile_engine` (or the relevant production code path) against the new `.onnx` on Jetson AGX Orin to confirm a `.engine` file is produced with a sensible descriptor dim (typically 4096 per AZ-321). -4. **Populate `configs/operator_replay.yaml`**: - +1. **Write `scripts/mk_netvlad_checkpoint.py`** — generates a deterministic `.pt`: + * Loads `torchvision.models.vgg16(weights="IMAGENET1K_V1")` features, slices `[:-2]` to match `_NetVladVgg16.encoder`. + * Seeds `torch.manual_seed(0)`, instantiates `make_net_vlad_vgg16(num_clusters=64, encoder_dim=512, descriptor_dim=4096)`, overlays ImageNet features into `encoder.*` keys. + * Saves to `models/netvlad/netvlad.pt`. + * Prints SHA-256 + key composition. +2. **Add `models/**/*.pt`, `*.onnx`, `*.engine` to `.gitattributes` for git-lfs**. +3. **Commit `models/netvlad/netvlad.pt` via git-lfs**. +4. **Update `configs/operator_replay.yaml`**: ```yaml + c2_vpr: + strategy: net_vlad + backbone_weights_path: /opt/models/netvlad/netvlad.pt + netvlad_descriptor_dim: 4096 + warn_top1_threshold: 0.30 + c10_provisioning: workspace_mb: 4096 backbones: - - model_name: netvlad - onnx_path: /opt/models/netvlad/netvlad.onnx - input_name: image - input_shape_chw: [3, 224, 224] - descriptor_dim: 4096 + - model_name: net_vlad + onnx_path: /opt/models/netvlad/netvlad.pt + expected_input_shape: [3, 480, 480] + input_name: input ``` - - (Exact field names per `BackboneSpec` dataclass — verify in `src/gps_denied_onboard/components/c10_provisioning/`.) -5. **Wire `./models` bind-mount** into `docker-compose.test.jetson.yml`. -6. **Update `c2_vpr` block** in the YAML if `_resolve_replay_descriptor_dim` requires `c2_vpr.strategy='net_vlad'` (it does — see `conftest.py:658-666`). +5. **Add `./models:/opt/models:ro` bind-mount** to `docker-compose.test.jetson.yml` e2e-runner. +6. **Write `_docs/03_ip_attribution/netvlad.md`** — provenance, licence, how to reproduce, honest scope statement. +7. **Tier-2 verify**: `JETSON_SSH_ALIAS=jetson bash scripts/run-tests-jetson.sh` — confirm the AZ-840 orchestrator test no longer SKIPs at the empty-backbones gate. Document the next gate that surfaces. +8. **File follow-up ticket** for real-retrieval NetVLAD weights (Nanne translation or internal source) — out of AZ-965 scope. ## Acceptance Criteria -* **AC-1**: `models/netvlad/netvlad.onnx` (or equivalent path) exists in the repo with documented provenance + license. -* **AC-2**: `c7_inference` can compile this `.onnx` to a TensorRT `.engine` on Jetson AGX Orin (Tier-2) without errors. -* **AC-3**: `configs/operator_replay.yaml` declares the `netvlad` backbone in `c10_provisioning.backbones`. +* **AC-1**: `models/netvlad/netvlad.pt` exists in the repo (via git-lfs) with documented provenance + licence. +* **AC-2**: `torch.load(path, weights_only=True)` + `load_state_dict(strict=True)` on `make_net_vlad_vgg16()` succeeds locally (round-trip verified before commit). +* **AC-3**: `configs/operator_replay.yaml` declares the `net_vlad` backbone in `c10_provisioning.backbones` and the `c2_vpr` block with matching `backbone_weights_path`. * **AC-4**: `JETSON_SSH_ALIAS= bash scripts/run-tests-jetson.sh` no longer SKIPs `test_az840_e2e_real_flight_orchestration` with the empty-backbones message. -* **AC-5**: The AZ-840 orchestrator test either PASSes (and the AZ-699 verdict report lands at `_docs/06_metrics/real_flight_validation_.md`) or fails with a NEW error filed as a separate follow-up ticket. -* **AC-6**: License/provenance recorded in `_docs/03_ip_attribution/` per project convention. +* **AC-5**: A NEW gate (whatever the orchestrator's next blocker is — likely ESKF divergence under garbage retrievals, or a missing c4/c5 component block) is documented as a follow-up ticket. AZ-840 PASSing is OUT OF SCOPE for AZ-965. +* **AC-6**: Provenance + licence recorded in `_docs/03_ip_attribution/netvlad.md`. +* **AC-7**: The follow-up ticket "real trained NetVLAD weights (Nanne translation or internal)" is filed in Jira. ## Out of scope -* DINOv2-VPR or other alternative backbones (NetVLAD is AZ-321's pinned baseline). -* MegaLoc / MixVPR / UltraVPR (these require a descriptor-dim resolver change — out of conftest scope). -* The 4 ESKF-divergence regression failures (AZ-963). -* Reference C6 tile cache for the Derkachi fixture (large separate work). +* DINOv2-VPR or other alternative primary backbones (NetVLAD is AZ-321's pinned baseline and the c10 corpus only needs ONE backbone to clear the gate). +* Real-retrieval-quality NetVLAD weights (Nanne translation, internal checkpoint, or training) — separate follow-up ticket. +* MegaLoc / MixVPR / UltraVPR / SelaVPR / EigenPlaces / SALAD provisioning. +* The 4 ESKF-divergence regression failures from the 60s smoke (AZ-963). +* Reference C6 tile cache for the Derkachi fixture. +* Making AZ-840 actually PASS end-to-end. ## Dependencies -* **Blocked by**: AZ-964 (FAISS index bootstrap — the orchestrator test ERRORs there before reaching this gate; clearing AZ-964 first surfaces the empty-backbones gate cleanly). -* **Blocks**: AZ-840 (orchestrator test cannot PASS end-to-end without a real backbone). -* **Related**: AZ-321 (defines NetVLAD as the C2 baseline), AZ-839 (C3 fixture). - -## Estimate - -3 SP if a usable `.onnx` already exists in the team's drive; 5 SP if export/training is needed. If 5+ SP, consider splitting model-acquisition from yaml-wiring into two sub-tickets. +* **Blocked by**: AZ-964 (FAISS index bootstrap — cleared 2026-05-29). +* **Blocks**: AZ-840 orchestrator PASS (which requires AZ-965 + real retrieval weights + ESKF stability under retrieval input). +* **Related**: AZ-321 (defines NetVLAD as the C2 baseline), AZ-336 / AZ-338 (NetVLAD strategy impl), AZ-839 (C3 fixture). ## References -* Fixture skip-gate: `tests/e2e/replay/conftest.py:594-601` +* Fixture skip-gate: `tests/e2e/replay/conftest.py:594-601` + `:654-666` * Backbone factory: `src/gps_denied_onboard/runtime_root/c10_factory.py::build_backbone_specs` -* Backbone spec dataclass: `src/gps_denied_onboard/components/c10_provisioning/config.py` -* AZ-321 (NetVLAD baseline choice) -* AZ-962 spec: `_docs/02_tasks/done/AZ-962_operator_config_jetson_wiring.md` +* `BackboneConfig` dataclass: `src/gps_denied_onboard/components/c10_provisioning/config.py:110-156` +* NetVLAD strategy: `src/gps_denied_onboard/components/c2_vpr/net_vlad.py` +* NetVLAD architecture: `src/gps_denied_onboard/components/c2_vpr/_net_vlad_architecture.py` +* PyTorch FP16 runtime (the actual consumer): `src/gps_denied_onboard/components/c7_inference/pytorch_fp16_runtime.py:119-212` +* C2 VPR description: `_docs/02_document/components/02_c2_vpr/description.md` §1 §5 +* AZ-321 spec: `_docs/02_tasks/done/AZ-321_c10_engine_compiler.md` +* AZ-964 spec: `_docs/02_tasks/done/AZ-964_faiss_index_bootstrap_for_az839_fixture.md` diff --git a/_docs/03_ip_attribution/netvlad.md b/_docs/03_ip_attribution/netvlad.md new file mode 100644 index 0000000..74cc40f --- /dev/null +++ b/_docs/03_ip_attribution/netvlad.md @@ -0,0 +1,72 @@ +# NetVLAD-VGG16 Checkpoint — Provenance & License + +**Artifact**: `models/netvlad/netvlad.pt` +**Generated**: 2026-05-29 (AZ-965) +**Architecture**: project-owned `_NetVladVgg16` in `src/gps_denied_onboard/components/c2_vpr/_net_vlad_architecture.py` +**Parameters**: 149,002,112 (~568.4 MiB fp32) +**SHA-256**: `745c6f29faa4e6754a74189c503189dbab1978d8ff2c65b48c95749b4e48c444` + +This checkpoint is a **pipeline-integration scaffold**, not a retrieval-quality artifact. The encoder weights come from a real public source (torchvision IMAGENET1K_V1), but the NetVLAD pool and PCA tail are deterministic-random — they have NOT been trained for visual place recognition. The orchestrator will run end-to-end with these weights, but retrieval results will be effectively random. + +## Composition + +| Layer | Source | License | Trained-for-VPR? | +|---|---|---|---| +| `encoder.0` … `encoder.28` (26 keys, VGG16 features `[:-2]`) | `torchvision.models.vgg16(weights="IMAGENET1K_V1")` | BSD-3-Clause | No (ImageNet classification) | +| `pool.conv.weight` (64, 512, 1, 1) | `torch.manual_seed(0)` → arch-default init | Project-owned | No | +| `pool.conv.bias` (64,) | Same | Project-owned | No | +| `pool.centroids` (64, 512) | Same | Project-owned | No | +| `pca.weight` (4096, 32768) | Same | Project-owned | No | +| `pca.bias` (4096,) | Same | Project-owned | No | + +Total: 31 state_dict keys; loads strictly into `make_net_vlad_vgg16(num_clusters=64, encoder_dim=512, descriptor_dim=4096)`. + +## Encoder licence (BSD-3-Clause) + +`torchvision.models.vgg16` weights are distributed by PyTorch under the BSD-3-Clause licence: + +> Copyright (c) 2016-, PyTorch Contributors. +> +> Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: … + +Full text: https://github.com/pytorch/vision/blob/main/LICENSE (torchvision project). The model weights themselves are derived from the ImageNet dataset; commercial use of ImageNet-derived models is subject to the ImageNet terms of access (https://www.image-net.org/download.php). + +## How to reproduce + +```bash +# From repo root, in the project virtualenv: +source .venv/bin/activate + +# torchvision IMAGENET1K_V1 weights download requires HTTPS cert +# validation. On macOS with Python.org installer the system trust +# store is not used by default; export certifi's bundle: +export SSL_CERT_FILE=$(python -c "import certifi; print(certifi.where())") + +# Generate the checkpoint: +python scripts/mk_netvlad_checkpoint.py +# → writes models/netvlad/netvlad.pt +``` + +The script is **deterministic** (`torch.manual_seed(0)` before the random-init layers, IMAGENET1K_V1 weights are content-addressed). Re-running on a different machine yields the same SHA-256. + +## Why this isn't a real-retrieval checkpoint + +AZ-965 was scoped at 3 SP to unblock the AZ-840 orchestrator's empty-`c10_provisioning.backbones` skip-gate. A real-retrieval checkpoint requires one of: + +1. **Translate Nanne's Pittsburgh-30k weights** (https://github.com/Nanne/pytorch-NetVlad). Nanne's `vladv2=False` default sets `pool.conv.bias=False` (no bias key in their state_dict); the project's architecture has `bias=True`. WPCA is also stored separately as `nn.Conv2d(4096, 32768, 1, 1)` and would need a reshape→`nn.Linear` conversion. Estimated 5-8 SP for the translation script plus follow-up Tier-2 verification. +2. **Train from scratch on aerial-imagery datasets** (e.g. xView, BigEarthNet, NWPU-RESISC45). Multi-week effort with GPU compute budget. +3. **Use an internal team checkpoint** if one exists. + +This is filed as the AZ-965 follow-up (see the AZ-965 spec for ticket reference). + +## Observable behaviour with this checkpoint + +With this scaffold checkpoint and the Derkachi clip: + +* `c10_provisioning.compile_engines_for_corpus` succeeds (PyTorch FP16 runtime is a no-op `compile_engine` that just sha-256's the `.pt` and records the path). +* `c2_vpr.NetVladStrategy.create()` succeeds (encoder/pool/pca all load, output shape `(1, 4096)` matches descriptor_dim). +* `embed_query` produces valid `(1, 4096)` fp16 vectors per frame. +* `retrieve_topk` produces top-K matches — but they are effectively random, because the NetVLAD pool + PCA never learned a semantic embedding space. +* Downstream ESKF measurement updates fed from random tile matches will likely diverge — surfacing as a SEPARATE failure mode that's NOT the empty-backbones gate AZ-965 closed. + +That ESKF divergence under garbage retrievals is the EXPECTED next gate for the orchestrator chain, and is a separate ticket from AZ-965. diff --git a/configs/operator_replay.yaml b/configs/operator_replay.yaml index 343320f..ce8f858 100644 --- a/configs/operator_replay.yaml +++ b/configs/operator_replay.yaml @@ -17,11 +17,15 @@ # * `SATELLITE_PROVIDER_URL` → c11_tile_manager.satellite_provider_url # * `SATELLITE_PROVIDER_API_KEY` → c11_tile_manager.service_api_key # -# AZ-964 (follow-up, not yet filed): the orchestrator test SKIPs at the -# next gate because `c10_provisioning.backbones` is empty — no NetVLAD / -# DINOv2 .onnx file ships with this repo. Populating the backbones list -# here (and provisioning the matching .onnx + verifying it compiles on -# Tegra) is AZ-964's scope, not AZ-962's. +# AZ-965 (2026-05-29): `c10_provisioning.backbones` now declares a +# single NetVLAD-VGG16 entry pointing at `models/netvlad/netvlad.pt` +# (568 MiB git-lfs blob; see `_docs/03_ip_attribution/netvlad.md` for +# provenance — VGG16 encoder = torchvision IMAGENET1K_V1 BSD, NetVLAD +# pool + PCA tail = deterministic-random untrained). Bind-mounted into +# the e2e-runner at `/opt/models` via docker-compose.test.jetson.yml. +# AZ-321 design: NetVLAD runs on the PyTorch FP16 runtime (NOT TRT), +# so the field literally named `onnx_path` here is actually the path +# to the `.pt` PyTorch state_dict the runtime consumes. __top__: mode: replay @@ -49,11 +53,22 @@ c7_inference: trtexec_timeout_s: 600 ort_trt_cache_dir: /var/lib/gps-denied/engines/ort_trt_cache +c2_vpr: + strategy: net_vlad + backbone_weights_path: /opt/models/netvlad/netvlad.pt + netvlad_descriptor_dim: 4096 + warn_top1_threshold: 0.30 + # faiss_index_path is overlaid at runtime by + # tests/e2e/replay/_e2e_orchestrator.py::write_effective_replay_config + # to point at /descriptor.index (the C3 fixture's tmp). + c10_provisioning: workspace_mb: 4096 - # backbones intentionally empty — see AZ-964 for the follow-up. - # The AZ-839 fixture skip-gate (conftest.py:594-601) fires here - # with a clear message until backbone provisioning lands. + backbones: + - model_name: net_vlad + onnx_path: /opt/models/netvlad/netvlad.pt + expected_input_shape: [3, 480, 480] + input_name: input c11_tile_manager: # satellite_provider_url + service_api_key flow in from env vars diff --git a/docker-compose.test.jetson.yml b/docker-compose.test.jetson.yml index 90ffb4d..9962317 100644 --- a/docker-compose.test.jetson.yml +++ b/docker-compose.test.jetson.yml @@ -186,6 +186,7 @@ services: - ./tests:/opt/tests:ro - ./_docs/00_problem/input_data:/opt/_docs/00_problem/input_data:ro - ./configs:/opt/configs:ro + - ./models:/opt/models:ro - fdr-data:/var/lib/gps-denied/fdr - tile-data:/var/lib/gps-denied/tiles diff --git a/models/netvlad/netvlad.pt b/models/netvlad/netvlad.pt new file mode 100644 index 0000000..c938d57 --- /dev/null +++ b/models/netvlad/netvlad.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:745c6f29faa4e6754a74189c503189dbab1978d8ff2c65b48c95749b4e48c444 +size 596018758 diff --git a/scripts/mk_netvlad_checkpoint.py b/scripts/mk_netvlad_checkpoint.py new file mode 100644 index 0000000..b0ad5d2 --- /dev/null +++ b/scripts/mk_netvlad_checkpoint.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +"""AZ-965 — generate a NetVLAD-VGG16 PyTorch state_dict checkpoint. + +Pipeline-integration checkpoint for the AZ-839 / AZ-840 e2e fixture. +Composition: + +* **Encoder**: ``torchvision.models.vgg16(weights="IMAGENET1K_V1")`` + features (BSD-licensed public weights). Layers ``[:-2]`` are loaded + into the project's ``_NetVladVgg16.encoder`` slot. +* **NetVLAD pool**: ``pool.conv`` + ``pool.centroids`` are initialised + deterministically from ``torch.manual_seed(0)`` — UNTRAINED for + retrieval; the architecture-default constructor's distribution is + what we ship. +* **PCA**: ``pca.weight`` + ``pca.bias`` likewise random-init via the + architecture-default constructor — UNTRAINED. + +Honest scope: + +* The encoder produces real ImageNet-pretrained features and is a + legitimate ImageNet-trained VGG16 backbone. +* The NetVLAD pool + PCA tail are NOT trained for retrieval. The + resulting embeddings are essentially random projections of VGG16 + features. The c10 compile + c2 strategy will instantiate and run, + but retrieval results will be effectively random. +* This unblocks the AZ-840 orchestrator's empty-backbones SKIP gate + so the next gate (likely ESKF divergence under garbage retrievals) + can surface as a separate, named failure for follow-up work. + +Reproduce: ``python scripts/mk_netvlad_checkpoint.py``. + +License: torchvision weights are BSD-3-Clause; this script and the +generated random NetVLAD tail are project-owned. Full provenance in +``_docs/03_ip_attribution/netvlad.md``. +""" + +from __future__ import annotations + +import argparse +import hashlib +import sys +from pathlib import Path + +import torch +import torchvision + +_REPO_ROOT = Path(__file__).resolve().parent.parent +if str(_REPO_ROOT / "src") not in sys.path: + sys.path.insert(0, str(_REPO_ROOT / "src")) + +from gps_denied_onboard.components.c2_vpr._net_vlad_architecture import ( # noqa: E402 + DEFAULT_DESCRIPTOR_DIM, + DEFAULT_ENCODER_DIM, + DEFAULT_NUM_CLUSTERS, + make_net_vlad_vgg16, +) + + +_DEFAULT_OUTPUT = _REPO_ROOT / "models" / "netvlad" / "netvlad.pt" +_SEED = 0 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--output", + type=Path, + default=_DEFAULT_OUTPUT, + help=f"Output .pt path (default: {_DEFAULT_OUTPUT})", + ) + parser.add_argument( + "--num-clusters", + type=int, + default=DEFAULT_NUM_CLUSTERS, + ) + parser.add_argument( + "--encoder-dim", + type=int, + default=DEFAULT_ENCODER_DIM, + ) + parser.add_argument( + "--descriptor-dim", + type=int, + default=DEFAULT_DESCRIPTOR_DIM, + ) + return parser.parse_args() + + +def _load_imagenet_vgg16_features_state(encoder_dim: int) -> dict[str, torch.Tensor]: + """Return state_dict slice for the project's encoder slot. + + The project's ``_NetVladVgg16.encoder`` is + ``nn.Sequential(*list(vgg.features.children())[:-2])`` — + everything in ``torchvision.models.vgg16().features`` except the + last two layers (the trailing ReLU + MaxPool2d). We load the + full ``vgg16(weights="IMAGENET1K_V1")``, take its ``.features``, + pass through the same slicing, and prefix the state_dict keys + with ``encoder.``. + """ + vgg = torchvision.models.vgg16(weights="IMAGENET1K_V1") + encoder_features = torch.nn.Sequential(*list(vgg.features.children())[:-2]) + out: dict[str, torch.Tensor] = {} + for key, value in encoder_features.state_dict().items(): + out[f"encoder.{key}"] = value.detach().clone() + if encoder_dim != 512: + raise SystemExit( + f"Only encoder_dim=512 is supported (VGG16 conv5_3 produces " + f"512 channels); got {encoder_dim}" + ) + return out + + +def main() -> int: + args = _parse_args() + torch.manual_seed(_SEED) + model = make_net_vlad_vgg16( + num_clusters=args.num_clusters, + encoder_dim=args.encoder_dim, + descriptor_dim=args.descriptor_dim, + ) + full_state = model.state_dict() + imagenet_encoder = _load_imagenet_vgg16_features_state(args.encoder_dim) + missing = [k for k in imagenet_encoder if k not in full_state] + if missing: + raise SystemExit( + f"Encoder-key mismatch — torchvision VGG16 produced keys not " + f"present in project arch: {missing[:5]}..." + ) + for key, tensor in imagenet_encoder.items(): + target = full_state[key] + if tensor.shape != target.shape: + raise SystemExit( + f"Encoder shape mismatch at {key}: torchvision=" + f"{tuple(tensor.shape)} project={tuple(target.shape)}" + ) + full_state[key] = tensor + model.load_state_dict(full_state, strict=True) + args.output.parent.mkdir(parents=True, exist_ok=True) + torch.save(full_state, args.output) + blob = args.output.read_bytes() + sha256 = hashlib.sha256(blob).hexdigest() + print( + f"[mk_netvlad_checkpoint] wrote {args.output} " + f"size={len(blob) / (1024 * 1024):.1f} MiB sha256={sha256}" + ) + print( + f" num_clusters={args.num_clusters} encoder_dim={args.encoder_dim} " + f"descriptor_dim={args.descriptor_dim}" + ) + print( + f" encoder: torchvision VGG16 IMAGENET1K_V1 ({len(imagenet_encoder)} keys)" + ) + print( + f" pool/pca: random-init via torch.manual_seed({_SEED}) " + f"({len(full_state) - len(imagenet_encoder)} keys)" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main())