Skip to content

Commit 8db7db6

Browse files
committed
feat(edits): batch get descendents, make sanity checks optional
1 parent e74b9b8 commit 8db7db6

5 files changed

Lines changed: 124 additions & 73 deletions

File tree

pychunkedgraph/graph/cache.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Cache nodes, parents, children and cross edges.
44
"""
55
import traceback
6-
from collections import defaultdict
6+
from collections import defaultdict as defaultd
77
from sys import maxsize
88
from datetime import datetime
99

@@ -51,7 +51,7 @@ def __init__(self, cg):
5151
"cross_chunk_edges": {"hits": 0, "misses": 0, "calls": 0},
5252
}
5353
# Track where calls/misses come from
54-
self.call_sources = defaultdict(lambda: defaultdict(lambda: {"calls": 0, "misses": 0}))
54+
self.sources = defaultd(lambda: defaultd(lambda: {"calls": 0, "misses": 0}))
5555

5656
def _get_caller(self, skip_frames=2):
5757
"""Get caller info (filename:line:function)."""
@@ -65,8 +65,8 @@ def _get_caller(self, skip_frames=2):
6565
def _record_call(self, cache_type, misses=0):
6666
"""Record a call and its source."""
6767
caller = self._get_caller(skip_frames=3)
68-
self.call_sources[cache_type][caller]["calls"] += 1
69-
self.call_sources[cache_type][caller]["misses"] += misses
68+
self.sources[cache_type][caller]["calls"] += 1
69+
self.sources[cache_type][caller]["misses"] += misses
7070

7171
def __len__(self):
7272
return (
@@ -90,7 +90,7 @@ def get_stats(self):
9090
**s,
9191
"total": total,
9292
"hit_rate": f"{hit_rate:.1%}",
93-
"sources": dict(self.call_sources[name]),
93+
"sources": dict(self.sources[name]),
9494
}
9595
return result
9696

@@ -99,7 +99,7 @@ def reset_stats(self):
9999
s["hits"] = 0
100100
s["misses"] = 0
101101
s["calls"] = 0
102-
self.call_sources.clear()
102+
self.sources.clear()
103103

104104
def parent(self, node_id: np.uint64, *, time_stamp: datetime = None):
105105
self.stats["parents"]["calls"] += 1
@@ -154,7 +154,13 @@ def cross_edges_decorated(node_id):
154154

155155
return cross_edges_decorated(node_id)
156156

157-
def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None):
157+
def parents_multiple(
158+
self,
159+
node_ids: np.ndarray,
160+
*,
161+
time_stamp: datetime = None,
162+
fail_to_zero: bool = False,
163+
):
158164
node_ids = np.asarray(node_ids, dtype=NODE_ID)
159165
if not node_ids.size:
160166
return node_ids
@@ -168,7 +174,10 @@ def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None)
168174
parents = node_ids.copy()
169175
parents[mask] = self._parent_vec(node_ids[mask])
170176
parents[~mask] = self._cg.get_parents(
171-
node_ids[~mask], raw_only=True, time_stamp=time_stamp
177+
node_ids[~mask],
178+
raw_only=True,
179+
time_stamp=time_stamp,
180+
fail_to_zero=fail_to_zero,
172181
)
173182
update(self.parents_cache, node_ids[~mask], parents[~mask])
174183
return parents

pychunkedgraph/graph/chunkedgraph.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def get_parents(
241241
else:
242242
raise KeyError from exc
243243
return parents
244-
return self.cache.parents_multiple(node_ids, time_stamp=time_stamp)
244+
return self.cache.parents_multiple(
245+
node_ids, time_stamp=time_stamp, fail_to_zero=fail_to_zero
246+
)
245247

246248
def get_parent(
247249
self,
@@ -807,6 +809,7 @@ def add_edges(
807809
source_coords: typing.Sequence[int] = None,
808810
sink_coords: typing.Sequence[int] = None,
809811
allow_same_segment_merge: typing.Optional[bool] = False,
812+
do_sanity_check: typing.Optional[bool] = False,
810813
) -> operation.GraphEditOperation.Result:
811814
"""
812815
Adds an edge to the chunkedgraph
@@ -823,6 +826,7 @@ def add_edges(
823826
source_coords=source_coords,
824827
sink_coords=sink_coords,
825828
allow_same_segment_merge=allow_same_segment_merge,
829+
do_sanity_check=do_sanity_check,
826830
).execute()
827831

828832
def remove_edges(
@@ -838,6 +842,7 @@ def remove_edges(
838842
path_augment: bool = True,
839843
disallow_isolating_cut: bool = True,
840844
bb_offset: typing.Tuple[int, int, int] = (240, 240, 24),
845+
do_sanity_check: typing.Optional[bool] = False,
841846
) -> operation.GraphEditOperation.Result:
842847
"""
843848
Removes edges - either directly or after applying a mincut
@@ -862,6 +867,7 @@ def remove_edges(
862867
bbox_offset=bb_offset,
863868
path_augment=path_augment,
864869
disallow_isolating_cut=disallow_isolating_cut,
870+
do_sanity_check=do_sanity_check,
865871
).execute()
866872

867873
if not atomic_edges:

pychunkedgraph/graph/edges/__init__.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,29 @@ def flip_ids(id_map, node_ids):
207207
return np.concatenate(ids).astype(basetypes.NODE_ID)
208208

209209

210+
def _get_new_nodes(
211+
cg, nodes: np.ndarray, layer: int, parent_ts: datetime.datetime = None
212+
):
213+
unique_nodes, inverse = np.unique(nodes, return_inverse=True)
214+
node_root_map = {n: n for n in unique_nodes}
215+
lookup = np.ones(len(unique_nodes), dtype=unique_nodes.dtype)
216+
while np.any(lookup):
217+
roots = np.fromiter(node_root_map.values(), dtype=basetypes.NODE_ID)
218+
roots = cg.get_parents(roots, time_stamp=parent_ts, fail_to_zero=True)
219+
layers = cg.get_chunk_layers(roots)
220+
lookup[layers > layer] = 0
221+
lookup[roots == 0] = 0
222+
223+
layer_mask = layers <= layer
224+
non_zero_mask = roots != 0
225+
mask = layer_mask & non_zero_mask
226+
for node, root in zip(unique_nodes[mask], roots[mask]):
227+
node_root_map[node] = root
228+
229+
unique_results = np.fromiter(node_root_map.values(), dtype=basetypes.NODE_ID)
230+
return unique_results[inverse]
231+
232+
210233
def get_stale_nodes(
211234
cg, nodes: Iterable[basetypes.NODE_ID], parent_ts: datetime.datetime = None
212235
):
@@ -215,20 +238,17 @@ def get_stale_nodes(
215238
This is done by getting a supervoxel of a node and checking
216239
if it has a new parent at the same layer as the node.
217240
"""
218-
nodes = np.array(nodes, dtype=basetypes.NODE_ID)
241+
nodes = np.unique(np.array(nodes, dtype=basetypes.NODE_ID))
242+
new_ids = set() if cg.cache is None else cg.cache.new_ids
243+
nodes = nodes[~np.isin(nodes, new_ids)]
219244
supervoxels = cg.get_single_leaf_multiple(nodes)
220245
# nodes can be at different layers due to skip connections
221246
node_layers = cg.get_chunk_layers(nodes)
222247
stale_nodes = [types.empty_1d]
223248
for layer in np.unique(node_layers):
224249
_mask = node_layers == layer
225250
layer_nodes = nodes[_mask]
226-
_nodes = cg.get_roots(
227-
supervoxels[_mask],
228-
stop_layer=layer,
229-
ceil=False,
230-
time_stamp=parent_ts,
231-
)
251+
_nodes = _get_new_nodes(cg, supervoxels[_mask], layer, parent_ts)
232252
stale_mask = layer_nodes != _nodes
233253
stale_nodes.append(layer_nodes[stale_mask])
234254
return np.concatenate(stale_nodes)
@@ -544,10 +564,7 @@ def _get_new_edge(edge, edge_layer, parent_ts, padding, fallback: bool = False):
544564
if fallback:
545565
parents_b = _get_parents_b(_edges, parent_ts, edge_layer, True)
546566

547-
parents_b = np.unique(
548-
cg.get_roots(parents_b, stop_layer=mlayer, ceil=False, time_stamp=parent_ts)
549-
)
550-
567+
parents_b = np.unique(_get_new_nodes(cg, parents_b, mlayer, parent_ts))
551568
parents_a = np.array([node_a] * parents_b.size, dtype=basetypes.NODE_ID)
552569
return np.column_stack((parents_a, parents_b))
553570

@@ -607,8 +624,6 @@ def get_latest_edges_wrapper(
607624
stale_edge_layers,
608625
parent_ts=parent_ts,
609626
)
610-
stale_nodes = get_stale_nodes(cg, latest_edges.ravel(), parent_ts=parent_ts)
611-
assert stale_nodes.size == 0, f"latest_edges failed, stale: {stale_nodes}"
612627
logging.debug(f"{stale_edges} -> {latest_edges}; {parent_ts}")
613628
_new_cx_edges.append(latest_edges)
614629
new_cx_edges_d[layer] = np.concatenate(_new_cx_edges)

0 commit comments

Comments
 (0)