Architecture
Patch Extraction
Patches are extracted using MONAI's RandWeightedCropd (when heatmaps are available) or RandCropByPosNegLabeld (without heatmaps):
- With heatmaps: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map — regions with high DWI and low ADC are sampled more frequently
- Without heatmaps: Crops are sampled from positive (prostate) regions based on the binary mask
Each scan yields N patches (default 24) of size tile_size x tile_size x depth (default 64x64x3).
Tensor Shape Convention
Throughout the pipeline, tensors follow the shape [B, N, C, D, H, W]:
| Dim | Meaning | Typical Value |
|---|---|---|
| B | Batch size | 4–8 |
| N | Number of patches (instances) | 24 |
| C | Channels (T2 + DWI + ADC) | 3 |
| D | Depth (slices per patch) | 3 |
| H | Patch height | 64 |
| W | Patch width | 64 |
MILModel3D
The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
flowchart TD
A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone]
B --> C[Transformer Encoder<br/>4 layers, 8 heads]
C --> D[Attention Pooling<br/>512 → 2048 → 1]
D --> E["Weighted Sum [B, 512]"]
E --> F["FC Head [B, num_classes]"]
Forward Pass
-
Backbone: Input is reshaped from
[B, N, C, D, H, W]to[B*N, C, D, H, W]and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch. -
Transformer: Output from the ResNet 18 encoder is forwarded to the transformer encoder.
-
Attention: A two-layer attention network (
512 → 2048 → 1with Tanh) computes a scalar weight per patch, normalized via softmax. -
Classification: The attention-weighted sum of patch features produces a single
[B, 512]vector per scan, which is projected to class logits by a linear layer.
MIL Modes
| Mode | Aggregation Strategy |
|---|---|
mean |
Average logits across patches |
max |
Max logits across patches |
att |
Attention-weighted feature pooling |
att_trans |
Transformer encoder + attention pooling (primary mode) |
att_trans_pyramid |
Pyramid transformer on intermediate ResNet layers + attention |
The default and primary mode is att_trans.
csPCa_Model
Wraps a frozen MILModel_3D backbone and replaces the classification head:
flowchart TD
A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone<br/>(ResNet18 + Transformer)"]
B --> C["Pooled Features [B, 512]"]
C --> D["SimpleNN Head<br/>512 → 256 → 128 → 1"]
D --> E["Sigmoid → csPCa Probability"]
SimpleNN
During csPCa training, the backbone's net (ResNet18), transformer are frozen. The attention module and SimpleNN head remain trainable.
Attention Loss
During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective:
- Classification loss: Standard CrossEntropy on PI-RADS labels
- Attention loss:
1 - cosine_similarity(predicted_attention, heatmap_attention)- Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing
- PI-RADS 2 samples get uniform attention (no expected lesion)
lambda_attwarms up linearly from 0 to 2.0 over the first 25 epochs- The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification