Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dataretrieval/waterdata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from dataretrieval.waterdata.utils import (
SAMPLES_URL,
_check_monitoring_location_id,
_check_profiles,
_default_headers,
_get_args,
Expand Down Expand Up @@ -205,6 +206,7 @@ def get_daily(
... approval_status = "Approved",
... time = "2024-01-01/.."
"""
_check_monitoring_location_id(monitoring_location_id)
service = "daily"
output_id = "daily_id"

Expand Down Expand Up @@ -371,6 +373,7 @@ def get_continuous(
... time="2021-01-01T00:00:00Z/2022-01-01T00:00:00Z",
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "continuous"
output_id = "continuous_id"

Expand Down Expand Up @@ -662,6 +665,7 @@ def get_monitoring_locations(
... properties=["monitoring_location_id", "state_name", "country_name"],
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "monitoring-locations"
output_id = "monitoring_location_id"

Expand Down Expand Up @@ -878,6 +882,7 @@ def get_time_series_metadata(
... begin="1990-01-01/..",
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "time-series-metadata"
output_id = "time_series_id"

Expand Down Expand Up @@ -1050,6 +1055,7 @@ def get_latest_continuous(
... monitoring_location_id=["USGS-05114000", "USGS-09423350"]
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "latest-continuous"
output_id = "latest_continuous_id"

Expand Down Expand Up @@ -1224,6 +1230,7 @@ def get_latest_daily(
... monitoring_location_id=["USGS-05114000", "USGS-09423350"]
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "latest-daily"
output_id = "latest_daily_id"

Expand Down Expand Up @@ -1397,6 +1404,7 @@ def get_field_measurements(
... time = "P20Y"
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "field-measurements"
output_id = "field_measurement_id"

Expand Down Expand Up @@ -1850,6 +1858,7 @@ def get_stats_por(
... )
"""
# Build argument dictionary, omitting None values
_check_monitoring_location_id(monitoring_location_id)
params = _get_args(locals(), exclude={"expand_percentiles"})

return get_stats_data(
Expand Down Expand Up @@ -1979,6 +1988,7 @@ def get_stats_date_range(
... )
"""
# Build argument dictionary, omitting None values
_check_monitoring_location_id(monitoring_location_id)
params = _get_args(locals(), exclude={"expand_percentiles"})

return get_stats_data(
Expand Down Expand Up @@ -2144,6 +2154,7 @@ def get_channel(
... monitoring_location_id="USGS-02238500",
... )
"""
_check_monitoring_location_id(monitoring_location_id)
service = "channel-measurements"
output_id = "channel_measurements_id"

Expand Down
51 changes: 51 additions & 0 deletions dataretrieval/waterdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,57 @@ def _check_profiles(
)


_MONITORING_LOCATION_ID_RE = re.compile(r"^.+-.+$")


def _check_monitoring_location_id(
monitoring_location_id: str | list[str] | None,
) -> None:
"""Validate the format of a monitoring_location_id value.

Parameters
----------
monitoring_location_id : str, list of str, or None
One or more monitoring location identifiers.

Raises
------
TypeError
If any identifier is not a string (e.g. an integer was passed).
ValueError
If any string identifier does not follow the required
``'AGENCY-ID'`` format (e.g. ``'USGS-01646500'``).
"""
if monitoring_location_id is None:
return

if not isinstance(monitoring_location_id, (str, list)):
raise TypeError(
f"monitoring_location_id must be a string or list of strings, "
f"not {type(monitoring_location_id).__name__}. "
f"Expected format: 'AGENCY-ID', e.g., 'USGS-{monitoring_location_id}'."
)

ids = (
[monitoring_location_id]
if isinstance(monitoring_location_id, str)
else monitoring_location_id
)

for id_ in ids:
if not isinstance(id_, str):
raise TypeError(
f"monitoring_location_id must be a string or list of strings, "
f"not {type(id_).__name__}. "
f"Expected format: 'AGENCY-ID', e.g., 'USGS-{id_}'."
)
if not _MONITORING_LOCATION_ID_RE.match(id_):
raise ValueError(
f"Invalid monitoring_location_id: {id_!r}. "
f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
)


def _get_args(
local_vars: dict[str, Any], exclude: set[str] | None = None
) -> dict[str, Any]:
Expand Down
51 changes: 50 additions & 1 deletion tests/waterdata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_stats_por,
get_time_series_metadata,
)
from dataretrieval.waterdata.utils import _check_profiles
from dataretrieval.waterdata.utils import _check_monitoring_location_id, _check_profiles


def mock_request(requests_mock, request_url, file_path):
Expand Down Expand Up @@ -380,3 +380,52 @@ def test_get_channel():
assert df.shape[0] > 470
assert df.shape[1] == 27 # if geopandas installed, 21 columns if not
assert "channel_measurements_id" in df.columns


class TestCheckMonitoringLocationId:
"""Tests for _check_monitoring_location_id input validation.

Regression tests for GitHub issue #188.
"""

def test_valid_string(self):
"""A correctly formatted string passes without error."""
_check_monitoring_location_id("USGS-01646500")

def test_valid_list(self):
"""A list of correctly formatted strings passes without error."""
_check_monitoring_location_id(["USGS-01646500", "USGS-02238500"])

def test_none_passes(self):
"""None is allowed (optional parameter)."""
_check_monitoring_location_id(None)

def test_integer_raises_type_error(self):
"""An integer ID raises TypeError with a helpful message."""
with pytest.raises(TypeError, match="not int"):
_check_monitoring_location_id(5129115)

def test_integer_in_list_raises_type_error(self):
"""An integer inside a list raises TypeError."""
with pytest.raises(TypeError, match="not int"):
_check_monitoring_location_id(["USGS-01646500", 5129115])

def test_missing_agency_prefix_raises_value_error(self):
"""A string without the AGENCY- prefix raises ValueError."""
with pytest.raises(ValueError, match="Invalid monitoring_location_id"):
_check_monitoring_location_id("dog")

def test_bare_site_number_raises_value_error(self):
"""A bare site number string (no agency prefix) raises ValueError."""
with pytest.raises(ValueError, match="Invalid monitoring_location_id"):
_check_monitoring_location_id("01646500")

def test_get_daily_integer_id_raises(self):
"""get_daily raises TypeError before making any network call."""
with pytest.raises(TypeError):
get_daily(monitoring_location_id=5129115, parameter_code="00060")

def test_get_daily_malformed_id_raises(self):
"""get_daily raises ValueError for a malformed string ID."""
with pytest.raises(ValueError):
get_daily(monitoring_location_id="dog", parameter_code="00060")
Loading