@@ -207,6 +207,29 @@ def flip_ids(id_map, node_ids):
207207 return np .concatenate (ids ).astype (basetypes .NODE_ID )
208208
209209
210+ def _get_new_nodes (
211+ cg , nodes : np .ndarray , layer : int , parent_ts : datetime .datetime = None
212+ ):
213+ unique_nodes , inverse = np .unique (nodes , return_inverse = True )
214+ node_root_map = {n : n for n in unique_nodes }
215+ lookup = np .ones (len (unique_nodes ), dtype = unique_nodes .dtype )
216+ while np .any (lookup ):
217+ roots = np .fromiter (node_root_map .values (), dtype = basetypes .NODE_ID )
218+ roots = cg .get_parents (roots , time_stamp = parent_ts , fail_to_zero = True )
219+ layers = cg .get_chunk_layers (roots )
220+ lookup [layers > layer ] = 0
221+ lookup [roots == 0 ] = 0
222+
223+ layer_mask = layers <= layer
224+ non_zero_mask = roots != 0
225+ mask = layer_mask & non_zero_mask
226+ for node , root in zip (unique_nodes [mask ], roots [mask ]):
227+ node_root_map [node ] = root
228+
229+ unique_results = np .fromiter (node_root_map .values (), dtype = basetypes .NODE_ID )
230+ return unique_results [inverse ]
231+
232+
210233def get_stale_nodes (
211234 cg , nodes : Iterable [basetypes .NODE_ID ], parent_ts : datetime .datetime = None
212235):
@@ -215,20 +238,17 @@ def get_stale_nodes(
215238 This is done by getting a supervoxel of a node and checking
216239 if it has a new parent at the same layer as the node.
217240 """
218- nodes = np .array (nodes , dtype = basetypes .NODE_ID )
241+ nodes = np .unique (np .array (nodes , dtype = basetypes .NODE_ID ))
242+ new_ids = set () if cg .cache is None else cg .cache .new_ids
243+ nodes = nodes [~ np .isin (nodes , new_ids )]
219244 supervoxels = cg .get_single_leaf_multiple (nodes )
220245 # nodes can be at different layers due to skip connections
221246 node_layers = cg .get_chunk_layers (nodes )
222247 stale_nodes = [types .empty_1d ]
223248 for layer in np .unique (node_layers ):
224249 _mask = node_layers == layer
225250 layer_nodes = nodes [_mask ]
226- _nodes = cg .get_roots (
227- supervoxels [_mask ],
228- stop_layer = layer ,
229- ceil = False ,
230- time_stamp = parent_ts ,
231- )
251+ _nodes = _get_new_nodes (cg , supervoxels [_mask ], layer , parent_ts )
232252 stale_mask = layer_nodes != _nodes
233253 stale_nodes .append (layer_nodes [stale_mask ])
234254 return np .concatenate (stale_nodes )
@@ -544,10 +564,7 @@ def _get_new_edge(edge, edge_layer, parent_ts, padding, fallback: bool = False):
544564 if fallback :
545565 parents_b = _get_parents_b (_edges , parent_ts , edge_layer , True )
546566
547- parents_b = np .unique (
548- cg .get_roots (parents_b , stop_layer = mlayer , ceil = False , time_stamp = parent_ts )
549- )
550-
567+ parents_b = np .unique (_get_new_nodes (cg , parents_b , mlayer , parent_ts ))
551568 parents_a = np .array ([node_a ] * parents_b .size , dtype = basetypes .NODE_ID )
552569 return np .column_stack ((parents_a , parents_b ))
553570
@@ -607,8 +624,6 @@ def get_latest_edges_wrapper(
607624 stale_edge_layers ,
608625 parent_ts = parent_ts ,
609626 )
610- stale_nodes = get_stale_nodes (cg , latest_edges .ravel (), parent_ts = parent_ts )
611- assert stale_nodes .size == 0 , f"latest_edges failed, stale: { stale_nodes } "
612627 logging .debug (f"{ stale_edges } -> { latest_edges } ; { parent_ts } " )
613628 _new_cx_edges .append (latest_edges )
614629 new_cx_edges_d [layer ] = np .concatenate (_new_cx_edges )
0 commit comments