Skip to content

Commit 3e509b5

Browse files
committed
pre-commit fixes
1 parent 83397bc commit 3e509b5

16 files changed

Lines changed: 443 additions & 270 deletions

gateway/sds_gateway/api_methods/helpers/temporal_filtering.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import re
22

33
from django.db.models import QuerySet
4-
4+
from loguru import logger as log
55
from opensearchpy.exceptions import NotFoundError as OpenSearchNotFoundError
6-
from sds_gateway.api_methods.models import CaptureType, Capture, File, DRF_RF_FILENAME_REGEX_STR
6+
7+
from sds_gateway.api_methods.models import DRF_RF_FILENAME_REGEX_STR
8+
from sds_gateway.api_methods.models import Capture
9+
from sds_gateway.api_methods.models import CaptureType
10+
from sds_gateway.api_methods.models import File
711
from sds_gateway.api_methods.utils.opensearch_client import get_opensearch_client
812
from sds_gateway.api_methods.utils.relationship_utils import get_capture_files
9-
from loguru import logger as log
1013

1114
# Digital RF spec: rf@SECONDS.MILLISECONDS.h5 (e.g. rf@1396379502.000.h5)
1215
# https://github.com/MITHaystack/digital_rf
@@ -55,14 +58,12 @@ def get_capture_bounds(capture_type: CaptureType, capture_uuid: str) -> tuple[in
5558
try:
5659
response = client.get(index=index, id=capture_uuid)
5760
except OpenSearchNotFoundError as e:
58-
raise ValueError(
59-
f"Capture {capture_uuid} not found in OpenSearch index {index}"
60-
) from e
61+
msg = f"Capture {capture_uuid} not found in OpenSearch index {index}"
62+
raise ValueError(msg) from e
6163

6264
if not response.get("found"):
63-
raise ValueError(
64-
f"Capture {capture_uuid} not found in OpenSearch index {index}"
65-
)
65+
msg = f"Capture {capture_uuid} not found in OpenSearch index {index}"
66+
raise ValueError(msg)
6667

6768
source = response.get("_source", {})
6869
search_props = source.get("search_props", {})
@@ -90,7 +91,7 @@ def filter_capture_data_files_selection_bounds(
9091
capture_type: CaptureType,
9192
capture: Capture,
9293
start_time: int, # relative ms from start of capture (from UI)
93-
end_time: int, # relative ms from start of capture (from UI)
94+
end_time: int, # relative ms from start of capture (from UI)
9495
) -> QuerySet[File]:
9596
"""Filter the capture file selection bounds to the given start and end times."""
9697
_catch_capture_type_error(capture_type)
@@ -108,26 +109,32 @@ def filter_capture_data_files_selection_bounds(
108109
name__lte=end_file_name,
109110
).order_by("name")
110111

112+
111113
def get_capture_files_with_temporal_filter(
112114
capture_type: CaptureType,
113115
capture: Capture,
114-
start_time: int | None = None, # milliseconds since start of capture
116+
start_time: int | None = None, # milliseconds since start of capture
115117
end_time: int | None = None,
116118
) -> QuerySet[File]:
117119
"""Get the capture files with temporal filtering."""
118120
_catch_capture_type_error(capture_type)
119121

120122
if start_time is None or end_time is None:
121-
log.warning("Start or end time is None, returning all capture files without temporal filtering")
123+
log.warning(
124+
"Start or end time is None; returning all capture files without "
125+
"temporal filtering"
126+
)
122127
return get_capture_files(capture)
123128

124129
# get non-data files
125-
non_data_files = get_capture_files(capture).exclude(name__regex=DRF_RF_FILENAME_REGEX_STR)
130+
non_data_files = get_capture_files(capture).exclude(
131+
name__regex=DRF_RF_FILENAME_REGEX_STR
132+
)
126133

127134
# get data files with temporal filtering
128135
data_files = filter_capture_data_files_selection_bounds(
129136
capture_type, capture, start_time, end_time
130137
)
131138

132139
# return all files
133-
return non_data_files.union(data_files)
140+
return non_data_files.union(data_files)

gateway/sds_gateway/api_methods/models.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from blake3 import blake3 as Blake3 # noqa: N812
1414
from django.conf import settings
1515
from django.db import models
16-
from django.db.models import Sum
1716
from django.db.models import Count
1817
from django.db.models import ProtectedError
1918
from django.db.models import QuerySet
19+
from django.db.models import Sum
2020
from django.db.models.signals import post_save
2121
from django.db.models.signals import pre_delete
2222
from django.dispatch import receiver
@@ -423,22 +423,27 @@ def get_capture(self) -> dict[str, Any]:
423423
"owner": self.owner,
424424
}
425425

426-
427426
def get_drf_data_files_queryset(self) -> QuerySet[File]:
428427
"""DRF data files (rf@*.h5) for this capture (M2M + FK)."""
429428
if self.capture_type != CaptureType.DigitalRF:
430429
log.warning("Capture %s is not a DigitalRF capture", self.uuid)
431430
return File.objects.none()
432431

433432
# Local import avoids circular import (relationship_utils imports Capture).
434-
from sds_gateway.api_methods.utils.relationship_utils import get_capture_files
433+
from sds_gateway.api_methods.utils.relationship_utils import ( # noqa: PLC0415
434+
get_capture_files,
435+
)
435436

436437
return get_capture_files(self, include_deleted=False).filter(
437438
name__regex=DRF_RF_FILENAME_REGEX_STR,
438439
)
439440

440441
def get_drf_data_files_stats(self) -> dict[str, int]:
441-
"""Count + total size in one query; cached per instance. File PK is ``uuid`` — use ``pk``."""
442+
"""
443+
Count + total size in one query; cached per instance.
444+
445+
File primary key is ``uuid``; use ``pk`` in aggregates.
446+
"""
442447
if hasattr(self, "_drf_data_files_stats_cache"):
443448
return self._drf_data_files_stats_cache
444449

gateway/sds_gateway/api_methods/serializers/capture_serializers.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_files(self, capture: Capture) -> ReturnList[File]:
9999
@extend_schema_field(serializers.IntegerField(allow_null=True))
100100
def get_total_file_size(self, capture: Capture) -> int | None:
101101
"""Get the total file size of all files associated with this capture."""
102-
102+
103103
if capture.capture_type != CaptureType.DigitalRF:
104104
return None
105105

@@ -109,11 +109,16 @@ def get_total_file_size(self, capture: Capture) -> int | None:
109109
data_total = self.get_data_files_info(capture).get("total_size", 0)
110110
if total < data_total:
111111
logging.getLogger(__name__).warning(
112-
"Capture %s: total_file_size (%s) < data_files_total_size (%s); using data total.",
113-
str(capture.uuid), total, data_total,
112+
(
113+
"Capture %s: total_file_size (%s) < data_files_total_size (%s); "
114+
"using data total."
115+
),
116+
str(capture.uuid),
117+
total,
118+
data_total,
114119
)
115120
total = data_total
116-
121+
117122
return total
118123

119124
@extend_schema_field(serializers.DictField(allow_null=True))
@@ -135,17 +140,19 @@ def get_data_files_info(self, capture: Capture) -> dict[str, Any]:
135140
def get_center_frequency_ghz(self, capture: Capture) -> float | None:
136141
"""Get the center frequency in GHz from the capture model property."""
137142
return capture.center_frequency_ghz
138-
143+
139144
@extend_schema_field(serializers.FloatField(allow_null=True))
140145
def get_sample_rate_mhz(self, capture: Capture) -> float | None:
141-
"""Get the sample rate in MHz from the capture model property. None if not indexed in OpenSearch."""
146+
"""Sample rate in MHz from the model. None if not indexed in OpenSearch."""
142147
return capture.sample_rate_mhz
143148

144149
@extend_schema_field(serializers.IntegerField(allow_null=True))
145150
def get_length_of_capture_ms(self, capture: Capture) -> int | None:
146-
"""Get the length of the capture in milliseconds. OpenSearch bounds are in seconds."""
151+
"""Capture length in milliseconds (OpenSearch bounds are seconds)."""
147152
try:
148-
start_time, end_time = get_capture_bounds(capture.capture_type, str(capture.uuid))
153+
start_time, end_time = get_capture_bounds(
154+
capture.capture_type, str(capture.uuid)
155+
)
149156
return (end_time - start_time) * 1000
150157
except (ValueError, IndexError, KeyError):
151158
return None
@@ -160,13 +167,14 @@ def get_file_cadence_ms(self, capture: Capture) -> int | None:
160167

161168
@extend_schema_field(serializers.IntegerField(allow_null=True))
162169
def get_capture_start_epoch_sec(self, capture: Capture) -> int | None:
163-
"""Get the capture start time as Unix epoch seconds. None if not indexed in OpenSearch."""
170+
"""Capture start as Unix epoch seconds. None if not in OpenSearch."""
164171
try:
165172
start_time, _ = get_capture_bounds(capture.capture_type, str(capture.uuid))
166-
return start_time
167173
except (ValueError, IndexError, KeyError):
168174
return None
169-
175+
else:
176+
return start_time
177+
170178
@extend_schema_field(serializers.DictField)
171179
def get_capture_props(self, capture: Capture) -> dict[str, Any]:
172180
"""Retrieve the indexed metadata for the capture."""
@@ -382,21 +390,25 @@ def get_total_file_size(self, obj: dict[str, Any]) -> int | None:
382390
"""Get the total file size across all channels."""
383391
if obj["capture_type"] != CaptureType.DigitalRF:
384392
return None
385-
393+
386394
total_size = 0
387395
for channel_data in obj["channels"]:
388396
capture_uuid = channel_data["uuid"]
389397
capture = Capture.objects.get(uuid=capture_uuid)
390398
all_files = get_capture_files(capture, include_deleted=False)
391399
result = all_files.aggregate(total_size=Sum("size"))
392400
total_size += result["total_size"] or 0
393-
401+
394402
data_total = self.get_data_files_info(obj).get("total_size", 0)
395-
403+
396404
if total_size < data_total:
397405
logging.getLogger(__name__).warning(
398-
"Composite capture: total_file_size (%s) < data_files_total_size (%s); using data total.",
399-
total_size, data_total,
406+
(
407+
"Composite capture: total_file_size (%s) < "
408+
"data_files_total_size (%s); using data total."
409+
),
410+
total_size,
411+
data_total,
400412
)
401413
total_size = data_total
402414
return total_size
@@ -405,7 +417,7 @@ def get_data_files_info(self, obj: dict[str, Any]) -> dict[str, Any]:
405417
"""Get the data files info for the composite capture."""
406418
if obj["capture_type"] != CaptureType.DigitalRF:
407419
return {}
408-
420+
409421
total_count = 0
410422
total_size = 0
411423
for channel_data in obj["channels"]:
@@ -418,7 +430,9 @@ def get_data_files_info(self, obj: dict[str, Any]) -> dict[str, Any]:
418430
return {
419431
"count": total_count,
420432
"total_size": total_size,
421-
"per_data_file_size": (float(total_size) / total_count) if total_count else None,
433+
"per_data_file_size": (float(total_size) / total_count)
434+
if total_count
435+
else None,
422436
}
423437

424438
@extend_schema_field(serializers.CharField)
@@ -464,12 +478,11 @@ def get_capture_start_epoch_sec(self, obj: dict[str, Any]) -> int | None:
464478
return None
465479
try:
466480
capture = Capture.objects.get(uuid=channels[0]["uuid"])
467-
start_time, _ = get_capture_bounds(
468-
capture.capture_type, str(capture.uuid)
469-
)
470-
return start_time
481+
start_time, _ = get_capture_bounds(capture.capture_type, str(capture.uuid))
471482
except (ValueError, IndexError, KeyError):
472483
return None
484+
else:
485+
return start_time
473486

474487

475488
def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]:

gateway/sds_gateway/api_methods/tests/factories.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
from unittest.mock import patch
1414

1515
from django.core.files.base import ContentFile
16-
from faker import Faker as FakerInstance
1716
from factory import Faker as FactoryFaker
1817
from factory import LazyAttribute
1918
from factory import LazyFunction
20-
from factory import post_generation
2119
from factory import Sequence
20+
from factory import post_generation
2221
from factory.django import DjangoModelFactory
22+
from faker import Faker as FakerInstance
2323

2424
from sds_gateway.api_methods.helpers.temporal_filtering import drf_rf_filename_from_ms
2525
from sds_gateway.api_methods.models import Capture
@@ -30,9 +30,10 @@
3030
from sds_gateway.api_methods.models import UserSharePermission
3131
from sds_gateway.users.tests.factories import UserFactory
3232

33-
# Standalone Faker for LazyFunction callbacks (not factory_boy's FactoryFaker declaration)
33+
# Standalone Faker for LazyFunction callbacks (not FactoryFaker from factory_boy)
3434
_faker = FakerInstance()
3535

36+
3637
class DatasetFactory(DjangoModelFactory):
3738
"""Factory for creating Dataset instances for testing.
3839
@@ -237,26 +238,23 @@ class Meta:
237238

238239
channel = FactoryFaker("word")
239240
capture_type = "drf"
240-
top_level_dir = LazyFunction(
241-
lambda: _faker.file_path(depth=2).replace("/", "_")
242-
)
241+
top_level_dir = LazyFunction(lambda: _faker.file_path(depth=2).replace("/", "_"))
243242
owner = FactoryFaker("subfactory", factory=UserFactory)
244243
name = FactoryFaker("slug")
245244
index_name = "captures-drf"
246245

247246

248247
class DRFDataFileFactory(DjangoModelFactory):
249-
"""Factory for creating DRF data file instances for testing.
250-
251-
This factory creates realistic DRF data file objects that represent files stored in the system.
252-
It generates test data for file metadata and creates a Django ContentFile for the actual file content.
253-
254-
The factory creates files with realistic metadata including size, checksums, and proper file extensions.
255-
It also handles the creation of the Django file field with test content.
248+
"""Factory for DRF data file instances used in tests.
249+
250+
Creates file metadata plus a Django ContentFile for content. Includes
251+
checksums, sizes, and extensions; wires the file field for uploads.
256252
"""
257253

258254
uuid = FactoryFaker("uuid4")
259-
directory = LazyAttribute(lambda obj: f"/files/{obj.owner.email}/{obj.capture.top_level_dir}/")
255+
directory = LazyAttribute(
256+
lambda obj: f"/files/{obj.owner.email}/{obj.capture.top_level_dir}/"
257+
)
260258
name = Sequence(lambda n: drf_rf_filename_from_ms(1000 + n * 1000))
261259
media_type = "application/x-hdf5"
262260
permissions = "rw-r----"
@@ -275,8 +273,6 @@ def file(self, create, extracted, **kwargs):
275273
else:
276274
content = b"test drf file content"
277275
self.file = ContentFile(content, name=self.name)
278-
279-
280276

281277
class Meta:
282278
model = File
@@ -316,7 +312,9 @@ class UserSharePermissionFactory(DjangoModelFactory):
316312

317313
owner = FactoryFaker("subfactory", factory=UserFactory)
318314
shared_with = FactoryFaker("subfactory", factory=UserFactory)
319-
item_type = FactoryFaker("random_element", elements=[ItemType.DATASET, ItemType.CAPTURE])
315+
item_type = FactoryFaker(
316+
"random_element", elements=[ItemType.DATASET, ItemType.CAPTURE]
317+
)
320318
item_uuid = FactoryFaker("uuid4")
321319
is_enabled = True
322320
message = FactoryFaker("sentence", nb_words=5)

gateway/sds_gateway/api_methods/tests/test_celery_tasks.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sds_gateway.api_methods.models import ItemType
2626
from sds_gateway.api_methods.models import TemporaryZipFile
2727
from sds_gateway.api_methods.models import ZipFileStatus
28+
from sds_gateway.api_methods.tasks import _get_item_files
2829
from sds_gateway.api_methods.tasks import acquire_user_lock
2930
from sds_gateway.api_methods.tasks import check_celery_task
3031
from sds_gateway.api_methods.tasks import check_disk_space_available
@@ -36,7 +37,6 @@
3637
from sds_gateway.api_methods.tasks import get_user_task_status
3738
from sds_gateway.api_methods.tasks import is_user_locked
3839
from sds_gateway.api_methods.tasks import release_user_lock
39-
from sds_gateway.api_methods.tasks import _get_item_files
4040
from sds_gateway.api_methods.tasks import send_item_files_email
4141
from sds_gateway.api_methods.utils.disk_utils import estimate_disk_size
4242

@@ -1270,7 +1270,7 @@ def test_get_item_files_with_temporal_bounds_returns_expected_rf_subset(self):
12701270
"sds_gateway.api_methods.helpers.temporal_filtering.get_opensearch_client"
12711271
) as m:
12721272
m.return_value.get.return_value = mock_response
1273-
# Relative ms: 10004000 from capture startabsolute 2s5s filenames
1273+
# Relative ms: 1000-4000 from capture start; absolute 2s-5s filenames
12741274
result = _get_item_files(
12751275
self.user,
12761276
self.capture,
@@ -1279,10 +1279,12 @@ def test_get_item_files_with_temporal_bounds_returns_expected_rf_subset(self):
12791279
end_time=4000,
12801280
)
12811281
names = [f.name for f in result]
1282-
# DRF files in [2s, 5s] inclusive (see filter_capture_data_files_selection_bounds)
1282+
# DRF files [2s,5s] inclusive; see filter_capture_data_files_selection_bounds
12831283
expected_rf = [f"rf@{i}.000.h5" for i in range(2, 6)]
12841284
rf_names = sorted(n for n in names if n.startswith("rf@"))
1285-
assert rf_names == expected_rf, f"Expected RF files {expected_rf}, got {rf_names}"
1286-
# Metadata / non-DRF capture files from setUp are still included in the download set
1285+
assert rf_names == expected_rf, (
1286+
f"Expected RF files {expected_rf}, got {rf_names}"
1287+
)
1288+
# Metadata and non-DRF files from setUp stay in the download set
12871289
assert "test_file1.txt" in names
12881290
assert "test_file2.txt" in names

0 commit comments

Comments
 (0)