-
Notifications
You must be signed in to change notification settings - Fork 86
Expand file tree
/
Copy pathneuron_liconn_mit_x2.yaml
More file actions
403 lines (372 loc) · 18.6 KB
/
neuron_liconn_mit_x2.yaml
File metadata and controls
403 lines (372 loc) · 18.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
experiment_name: liconn_affinity_instance_seg
description: >
LiConn neuron instance segmentation using long-range affinity model (distances 1, 3, 9)
with ABISS-based decoding for large-volume EM data.
Data: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled.zarr/
img: shape (795, 4870, 3825), dtype uint8, chunks (64,64,64), ZYX order
label: shape (795, 4870, 3825), dtype uint64, chunks (64,64,64), ZYX order
resolution: 25 nm (z) × 18 nm (y) × 18 nm (x)
Split strategy (spatial Z-axis, 80/20):
train: Z slices 0–635 (636 slices, ~15.9 µm z-depth)
val: Z slices 636–794 (159 slices, ~4.0 µm z-depth)
test: full volume (run inference on the whole zarr after training)
Label erosion (erosion: 1):
Applies Kisuk Lee's instance erosion — any voxel in a 3×3 XY window
touching >1 segment ID is set to background before affinity computation.
Essential for tightly-packed axons/neurons to create clean boundary gaps.
Imperfect label robustness strategy:
1. erosion: 1 — removes boundary ambiguity from imperfect proofreading
2. FocalLoss (gamma=2) — down-weights easy background voxels; handles class
imbalance without needing pos_weight tuning
3. TverskyLoss (beta=0.7) — penalises false negatives (missed boundaries) more
than false positives; reduces false merges in packed axons
4. aug_em_neuron — heavy elastic + misalignment + missing sections to
prevent overfitting to label artifacts
5. EMA (decay=0.999) — smooths weight updates, reduces sensitivity to noisy labels
6. foreground_threshold: 0.05 — skips near-empty patches (background-heavy
patches amplify label noise)
7. accumulate_grad_batches: 4 — larger effective batch reduces gradient noise
from individual mislabeled patches
Loss design (FocalLoss + TverskyLoss, per channel group):
Short-range (ch 0-2, offset 1): Focal weight=1.0 + Tversky weight=1.0
— primary connectivity signal; full weight on both losses
Mid-range (ch 3-5, offset 3): Focal weight=0.8 + Tversky weight=0.8
— slightly downweighted; bridges gaps but noisier than short-range
Long-range (ch 6-8, offset 9): Focal weight=0.5 + Tversky weight=0.5
— sparser and noisier; half weight to avoid dominating gradients
FocalLoss kwargs: gamma=2.0 (standard), alpha=0.25 (foreground weight)
TverskyLoss kwargs: alpha=0.3 (FP penalty), beta=0.7 (FN penalty), sigmoid=true
— beta > alpha means we penalise missed boundaries (false merges) more than
spurious boundaries (false splits); appropriate for tightly-packed axons
Affinity channels (9 total):
- Ch 0-2: short-range (offset 1) for z, y, x
- Ch 3-5: mid-range (offset 3) for z, y, x
- Ch 6-8: long-range (offset 9) for z, y, x
Decoding strategy:
- Primary: ABISS chunked watershed+agglomeration pipeline (run_abiss_single.py)
- Fallback: decode_affinity_cc (connected components on short-range affinities)
- For the full 795×4870×3825 volume, run ABISS externally on saved predictions.
See .claude/repos_other/ABISS_USAGE_SUMMARY.md for ABISS build/run instructions.
# ── Inherit all shared profile libraries ──────────────────────────────────────
_base_:
- ../connectomics/config/all_profiles.yaml
# ── Profile selectors ─────────────────────────────────────────────────────────
default:
# Use 'binary' pipeline profile (not 'aff12') because:
# - aff12 injects model.loss.overrides which is not in LossConfig schema
# - We define our own losses list inline (FocalLoss + TverskyLoss per channel group)
# - We override out_channels, label_transform, and augmentation ourselves below
pipeline_profile: binary
system:
profile: all-gpu-cpu
model:
arch:
profile: mednext_m # MedNeXt-M (17.6M params); better capacity for 9-ch affinity
# Override out_channels: 9 affinities (3 distances × 3 axes)
out_channels: 9
loss:
profile:
# ── Custom loss: FocalLoss + TverskyLoss per channel group ────────────────
# We bypass the loss_binary profile and define losses inline.
# Each entry maps 1:1 to a loss function (by index in the losses list).
# pred_slice / target_slice select which affinity channels each term applies to.
#
# FocalLoss (MONAI):
# gamma=2.0 — standard focusing parameter; down-weights easy negatives
# alpha=0.25 — foreground class weight (boundary voxels are the minority)
# sigmoid=true — applies sigmoid before computing loss (raw logits input)
#
# TverskyLoss (MONAI):
# alpha=0.3 — false positive penalty (spurious boundaries)
# beta=0.7 — false negative penalty (missed boundaries / false merges)
# sigmoid=true — applies sigmoid before computing loss
# smooth_nr/smooth_dr=1e-5 — numerical stability
#
# NOTE: FocalLoss does NOT support pos_weight (no spatial_weight_arg).
# Class imbalance is handled by alpha + gamma instead.
losses:
# ── Short-range channels (ch 0-2, offset 1) ──────────────────────────
# FocalLoss: no 'sigmoid' param — applies sigmoid by default (use_softmax=False)
# TverskyLoss: 'sigmoid=true' is valid
- function: FocalLoss
weight: 1.0
pred_slice: [0, 3]
target_slice: [0, 3]
kwargs:
gamma: 2.0
alpha: 0.25
- function: TverskyLoss
weight: 1.0
pred_slice: [0, 3]
target_slice: [0, 3]
kwargs:
alpha: 0.3
beta: 0.7
sigmoid: true
smooth_nr: 1.0e-5
smooth_dr: 1.0e-5
# ── Mid-range channels (ch 3-5, offset 3) ────────────────────────────
- function: FocalLoss
weight: 0.8
pred_slice: [3, 6]
target_slice: [3, 6]
kwargs:
gamma: 2.0
alpha: 0.25
- function: TverskyLoss
weight: 0.8
pred_slice: [3, 6]
target_slice: [3, 6]
kwargs:
alpha: 0.3
beta: 0.7
sigmoid: true
smooth_nr: 1.0e-5
smooth_dr: 1.0e-5
# ── Long-range channels (ch 6-8, offset 9) ───────────────────────────
- function: FocalLoss
weight: 0.5
pred_slice: [6, 9]
target_slice: [6, 9]
kwargs:
gamma: 2.0
alpha: 0.25
- function: TverskyLoss
weight: 0.5
pred_slice: [6, 9]
target_slice: [6, 9]
kwargs:
alpha: 0.3
beta: 0.7
sigmoid: true
smooth_nr: 1.0e-5
smooth_dr: 1.0e-5
mednext:
size: M # MedNeXt-M (17.6M params); was B (10.5M) — more capacity for 9-ch affinity
kernel_size: 3
# Preset `mednext` already uses GroupNorm internally in MedNeXt blocks.
# `model.mednext.norm` is only configurable when using `arch.type: mednext_custom`.
dim: 3d
checkpoint_style: outside_block # gradient checkpointing — required for large patches (≥192³)
data:
# ── Image normalization ───────────────────────────────────────────────────
# Normalize uint8 [0,255] → float [0,1] before augmentation.
# clip_percentile_low/high=0.0/1.0 means no clipping (use full range).
# Learned from neuron_nisb/9nm_liconn.yaml — missing this causes training instability.
image_transform:
normalize: "0-1"
clip_percentile_low: 0.0
clip_percentile_high: 1.0
data_transform:
# Context padding in the upsampled coordinate space (applied after 2x resize).
# [16,64,64] pads 16/64/64 on each side → total [32,128,128] = half of
# model input [64,256,256].
pad_size: [16, 64, 64]
# Sample half-size raw patches/volumes and upsample them 2x before the model.
resize: [64, 256, 256]
dataloader:
profile: cached
patch_size: [32, 128, 128]
use_cache: false # disable in-memory cache — volume is 795×4870×3825
use_lazy_zarr: true # stream patches directly from Zarr (no RAM preload)
use_preloaded_cache_train: false
use_preloaded_cache_val: false
# Skip patches with <10% foreground — avoids amplifying label noise in
# background-heavy crops (common with tightly-packed axons at edges).
# Raised from 0.05 → 0.10 after cropping out unlabeled black regions,
# since remaining volume should have denser label coverage.
cached_sampling_foreground_threshold: 0.10
augmentation:
# aug_em_neuron: heavy elastic deformation, wide contrast range, aggressive
# EM artifacts (misalignment, missing sections, motion blur, missing parts).
# This is the most important robustness tool for imperfect labels — the model
# learns to be invariant to the kinds of artifacts that cause proofreading errors.
profile: aug_em_neuron
inference:
# Only use short-range affinities for TTA ensemble (channels 0-2)
select_channel: [0, 1, 2]
sliding_window:
lazy_load: true
# Large overlap for smooth predictions on big volumes
overlap: 0.5
blending: gaussian
keep_input_on_cpu: true # keep raw volume on CPU to save GPU memory
sw_device: cuda
output_device: cpu
test_time_augmentation:
enabled: false
ensemble_mode: min # min-pooling is conservative for affinities
# activation_profile is not in TestTimeAugmentationConfig — removed
# crop_pad + selected-channel affinity_crop[(1,0),(1,0),(1,0)] = pad_size [16,64,64]:
# Z: 15+1=16, 16+0=16 | Y: 63+1=64, 64+0=64 | X: 63+1=64, 64+0=64
crop_pad: [15, 16, 63, 64, 63, 64]
decoding:
steps:
- name: decode_abiss
kwargs:
command: >
env LD_LIBRARY_PATH=/orcd/software/community/001/rocky8/intel/2024.2.1/tbb/2021.13/lib:$LD_LIBRARY_PATH
{python_exe} scripts/run_abiss_single.py
--input {input_h5}
--output {output_h5}
--abiss-home /home/mansour8/pytc_dev/pytorch_connectomics/abiss
--ws-high-threshold 0.9
--ws-low-threshold 0.3
--ws-size-threshold 100
input_dataset: main
output_dataset: main
timeout_sec: 3600
# Option B: Lightweight connected-components (no ABISS required).
# Only uses short-range affinities (ch 0-2); good for quick inspection.
# Uncomment and comment out Option A above to use.
# decoding:
# - name: decode_affinity_cc
# kwargs:
# threshold: 0.5
# backend: auto
# ── Training stage ─────────────────────────────────────────────────────────────
postprocessing:
enabled: true
# ── Decoding ──────────────────────────────────────────────────────────────
# Option A: ABISS external pipeline — best quality, requires ABISS built.
# Falls back to connected-components automatically if ABISS fails.
#
# ABISS binary: abiss/build/ws (built from abiss/ source via abiss/build_ws.sh)
# Intel TBB runtime: /orcd/software/community/001/rocky8/intel/2024.2.1/tbb/2021.13/lib
# must be on LD_LIBRARY_PATH at runtime (ws binary links against libtbb.so.12).
#
# Thresholds for axon segmentation at 18 nm:
# ws_high_threshold=0.9 — only seed watershed at very high affinity (confident boundaries)
# ws_low_threshold=0.3 — extend seeds down to moderate affinity
# ws_size_threshold=100 — discard fragments smaller than 100 voxels
train:
data:
label_transform:
# ── Label erosion ────────────────────────────────────────────────────────
# erosion: N applies SegErosionInstanced (Kisuk Lee's SNEMI3D preprocessing):
# Any voxel in a (2N+1)×(2N+1) XY window that touches >1 segment ID is
# set to background BEFORE affinity computation.
#
# For tightly-packed axons/neurons at 18 nm, erosion=1 (3×3 window) is
# the standard choice — it creates a 1-voxel gap at every instance boundary,
# making the affinity signal cleaner and reducing false merges.
# Increase to erosion=2 if boundaries are still ambiguous after training.
erosion: 1
# 9-channel affinity: distances 1, 3, 9 along each axis (z, y, x)
targets:
- name: affinity
kwargs:
offsets:
# Distance 1 (short-range) — primary connectivity signal
- "1-0-0" # ch 0: z
- "0-1-0" # ch 1: y
- "0-0-1" # ch 2: x
# Distance 3 (mid-range) — bridges small gaps
- "3-0-0" # ch 3: z
- "0-3-0" # ch 4: y
- "0-0-3" # ch 5: x
# Distance 9 (long-range) — long-range context for agglomeration
- "9-0-0" # ch 6: z
- "0-9-0" # ch 7: y
- "0-0-9" # ch 8: x
affinity_mode: deepem
# ── Data paths ───────────────────────────────────────────────────────────
# NOTE: split_enabled only works with HDF5/TIFF, not Zarr (raises ValueError).
# LazyZarrVolumeDataset samples random patches from the full volume extent.
# Both train and val use the same Zarr — val patches are drawn with a different
# random seed (set_epoch is called per epoch). For a strict spatial split,
# create separate Zarr sub-stores (e.g. zarr.open(...)[0:636]) offline.
train:
image: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/img
label: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/label
resolution: [25, 18, 18] # [z_nm, y_nm, x_nm] — anisotropic: 25 nm z, 18 nm xy
val:
image: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/img
label: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/label
resolution: [25, 18, 18]
# Data is already ZYX — no transpose needed
# data_transform:
# train_transpose: [2, 1, 0] # only needed if Zarr is stored XYZ
dataloader:
# batch_size=1 halves GPU memory vs batch_size=2; compensate with accumulate_grad_batches=4
# to keep effective batch = 4 (same as batch_size=2 × accumulate=2).
batch_size: 1
model:
input_size: [64, 256, 256]
output_size: [64, 256, 256]
optimization:
profile: warmup_cosine_lr
max_epochs: 500
n_steps_per_epoch: 200
val_check_interval: 10
# Gradient accumulation: effective batch = batch_size × accumulate_grad_batches.
# batch_size=1, accumulate=4, num_gpus=1 → effective batch = 4.
# Larger effective batch reduces gradient noise from individual mislabeled patches.
accumulate_grad_batches: 4
# EMA smooths weight updates over time — reduces sensitivity to noisy/imperfect
# labels by averaging model weights across recent steps (decay=0.999 from profile).
# Already enabled in warmup_cosine_lr profile (ema.enabled: true, decay: 0.999).
# Validation uses EMA weights (validate_with_ema: true) for best generalization.
system:
num_gpus: 2
num_workers: 16 # num_cpus is not in SystemConfig; use num_workers for dataloader workers
seed: 42
monitor:
logging:
scalar:
loss: [train_loss_total_epoch, val_loss_total]
loss_every_n_steps: 100
images:
# dtype cast fixed: _log_multi_channel_viz and _log_single_channel_viz now
# explicitly cast all tensors to float32 before make_grid (uint8 label channels
# and bfloat16 model outputs both caused RuntimeError previously).
enabled: true
max_images: 8
num_slices: 4
log_every_n_epochs: 5
channel_mode: all
checkpoint:
dirpath: /orcd/scratch/bcs/002/mansour/train_liconn/outputs/checkpoints/
monitor: val_loss_total
mode: min
save_top_k: 3
# ── Test / inference stage ─────────────────────────────────────────────────────
test:
system:
num_gpus: 1
num_workers: 4
data:
dataloader:
batch_size: 4
# Test volume: BA_5AA_proteintest_1 FFN-sharpened, uint8.
# Shape: (59, 1024, 1024) ZYX — Z=59 is smaller than training patch (64).
# Use window_size: [32, 256, 256] so the sliding window fits within Z=59.
# No label available — inference only (evaluation.enabled: false).
test:
# image: /orcd/data/edboyden/002/mansour8/deb_data/BA_5AA_proteintest_1_2026-02-18_17.25.37_channel1_ffn_sharp.zarr
image:
/orcd/data/edboyden/002/dleible/DL288B/DL288B_251222S_cond5_40x_12tiles_round1_fused_488_crop512x1024x1024.tif
# label: omitted — no ground truth available for this volume
resolution: [25, 18, 18] # same voxel size as training data
inference:
sliding_window:
# Lazy 2x upsampling: _resolve_scale_factors() computes resize/patch_size
# = [64,256,256]/[32,128,128] = [2.0,2.0,2.0]. LazyVolumeAccessor reads
# ~[32,128,128] raw patches from disk and upsamples them 2x via grid_sample,
# so the model sees [64,256,256] — same transform as training's Resized().
lazy_load: true
window_size: [64, 256, 256]
overlap: 0.5
blending: gaussian
keep_input_on_cpu: true
sw_device: cuda
output_device: cpu
save_prediction:
enabled: true
output_formats: [h5]
output_path: /orcd/scratch/bcs/002/mansour/train_liconn/outputs/results_axons/
# float32 for full precision affinities; use float16 to halve storage
storage_dtype: float32
compression: gzip
evaluation:
enabled: false # no ground truth — cannot compute metrics