diff --git a/cuslines/cuda_python/cu_tractography.py b/cuslines/cuda_python/cu_tractography.py index 7064cc9..9b539be 100644 --- a/cuslines/cuda_python/cu_tractography.py +++ b/cuslines/cuda_python/cu_tractography.py @@ -49,7 +49,7 @@ def __init__( ngpus: int = 1, rng_seed: int = 0, rng_offset: int = 0, - chunk_size: int = 100000, + chunk_size: int = 25000, ): """ Initialize GPUTracker with necessary data. @@ -91,6 +91,9 @@ def __init__( rng_offset : int, optional Offset for random number generator default: 0 + chunk_size : int, optional + Number of seeds to process in each chunk per GPU + default: 25000 """ self.dataf = np.ascontiguousarray(dataf, dtype=REAL_DTYPE) self.metric_map = np.ascontiguousarray(stop_map, dtype=REAL_DTYPE) @@ -252,22 +255,26 @@ def generate_sft(self, seeds, ref_img): ) return StatefulTractogram(array_sequence, ref_img, Space.VOX) - # TODO: performance: consider a way to just output in VOX space directly def generate_trx(self, seeds, ref_img): global_chunk_sz, nchunks = self._divide_chunks(seeds) # Will resize by a factor of 2 if these are exceeded sl_len_guess = 100 - sl_per_seed_guess = 3 + sl_per_seed_guess = 4 n_sls_guess = sl_per_seed_guess * seeds.shape[0] # trx files use memory mapping + trx_reference = TrxFile( + reference=ref_img + ) + trx_reference.streamlines._data = trx_reference.streamlines._data.astype(np.float32) + trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(np.uint64) + trx_file = TrxFile( - reference=ref_img, nb_streamlines=n_sls_guess, nb_vertices=n_sls_guess * sl_len_guess, + init_as=trx_reference ) - trx_file.streamlines._offsets = trx_file.streamlines._offsets.astype(np.uint64) offsets_idx = 0 sls_data_idx = 0