From b0aa389e05bb16376d65338082184ae2543cf2e2 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Mon, 13 Apr 2026 17:53:25 -0500 Subject: [PATCH] Validate monitoring_location_id format in waterdata functions Passing an integer (e.g. 5129115) or a bare string without an agency prefix (e.g. "dog") to any waterdata function silently wasted an API call and returned empty data. Now all ten public functions that accept monitoring_location_id raise before touching the network: - TypeError if the value is not a string or list of strings - ValueError if any string doesn't match the 'AGENCY-ID' format (e.g. 'USGS-01646500') Closes #188. Co-Authored-By: Claude Sonnet 4.6 --- dataretrieval/waterdata/api.py | 11 +++++++ dataretrieval/waterdata/utils.py | 51 ++++++++++++++++++++++++++++++++ tests/waterdata_test.py | 51 +++++++++++++++++++++++++++++++- 3 files changed, 112 insertions(+), 1 deletion(-) diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index b2310e7a..27dd5531 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -24,6 +24,7 @@ ) from dataretrieval.waterdata.utils import ( SAMPLES_URL, + _check_monitoring_location_id, _check_profiles, _default_headers, _get_args, @@ -205,6 +206,7 @@ def get_daily( ... approval_status = "Approved", ... time = "2024-01-01/.." """ + _check_monitoring_location_id(monitoring_location_id) service = "daily" output_id = "daily_id" @@ -371,6 +373,7 @@ def get_continuous( ... time="2021-01-01T00:00:00Z/2022-01-01T00:00:00Z", ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "continuous" output_id = "continuous_id" @@ -662,6 +665,7 @@ def get_monitoring_locations( ... properties=["monitoring_location_id", "state_name", "country_name"], ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "monitoring-locations" output_id = "monitoring_location_id" @@ -878,6 +882,7 @@ def get_time_series_metadata( ... begin="1990-01-01/..", ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "time-series-metadata" output_id = "time_series_id" @@ -1050,6 +1055,7 @@ def get_latest_continuous( ... monitoring_location_id=["USGS-05114000", "USGS-09423350"] ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "latest-continuous" output_id = "latest_continuous_id" @@ -1224,6 +1230,7 @@ def get_latest_daily( ... monitoring_location_id=["USGS-05114000", "USGS-09423350"] ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "latest-daily" output_id = "latest_daily_id" @@ -1397,6 +1404,7 @@ def get_field_measurements( ... time = "P20Y" ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "field-measurements" output_id = "field_measurement_id" @@ -1850,6 +1858,7 @@ def get_stats_por( ... ) """ # Build argument dictionary, omitting None values + _check_monitoring_location_id(monitoring_location_id) params = _get_args(locals(), exclude={"expand_percentiles"}) return get_stats_data( @@ -1979,6 +1988,7 @@ def get_stats_date_range( ... ) """ # Build argument dictionary, omitting None values + _check_monitoring_location_id(monitoring_location_id) params = _get_args(locals(), exclude={"expand_percentiles"}) return get_stats_data( @@ -2144,6 +2154,7 @@ def get_channel( ... monitoring_location_id="USGS-02238500", ... ) """ + _check_monitoring_location_id(monitoring_location_id) service = "channel-measurements" output_id = "channel_measurements_id" diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index c58148d5..56310e68 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -1109,6 +1109,57 @@ def _check_profiles( ) +_MONITORING_LOCATION_ID_RE = re.compile(r"^.+-.+$") + + +def _check_monitoring_location_id( + monitoring_location_id: str | list[str] | None, +) -> None: + """Validate the format of a monitoring_location_id value. + + Parameters + ---------- + monitoring_location_id : str, list of str, or None + One or more monitoring location identifiers. + + Raises + ------ + TypeError + If any identifier is not a string (e.g. an integer was passed). + ValueError + If any string identifier does not follow the required + ``'AGENCY-ID'`` format (e.g. ``'USGS-01646500'``). + """ + if monitoring_location_id is None: + return + + if not isinstance(monitoring_location_id, (str, list)): + raise TypeError( + f"monitoring_location_id must be a string or list of strings, " + f"not {type(monitoring_location_id).__name__}. " + f"Expected format: 'AGENCY-ID', e.g., 'USGS-{monitoring_location_id}'." + ) + + ids = ( + [monitoring_location_id] + if isinstance(monitoring_location_id, str) + else monitoring_location_id + ) + + for id_ in ids: + if not isinstance(id_, str): + raise TypeError( + f"monitoring_location_id must be a string or list of strings, " + f"not {type(id_).__name__}. " + f"Expected format: 'AGENCY-ID', e.g., 'USGS-{id_}'." + ) + if not _MONITORING_LOCATION_ID_RE.match(id_): + raise ValueError( + f"Invalid monitoring_location_id: {id_!r}. " + f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'." + ) + + def _get_args( local_vars: dict[str, Any], exclude: set[str] | None = None ) -> dict[str, Any]: diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index 195441e5..d7dd7562 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -21,7 +21,7 @@ get_stats_por, get_time_series_metadata, ) -from dataretrieval.waterdata.utils import _check_profiles +from dataretrieval.waterdata.utils import _check_monitoring_location_id, _check_profiles def mock_request(requests_mock, request_url, file_path): @@ -380,3 +380,52 @@ def test_get_channel(): assert df.shape[0] > 470 assert df.shape[1] == 27 # if geopandas installed, 21 columns if not assert "channel_measurements_id" in df.columns + + +class TestCheckMonitoringLocationId: + """Tests for _check_monitoring_location_id input validation. + + Regression tests for GitHub issue #188. + """ + + def test_valid_string(self): + """A correctly formatted string passes without error.""" + _check_monitoring_location_id("USGS-01646500") + + def test_valid_list(self): + """A list of correctly formatted strings passes without error.""" + _check_monitoring_location_id(["USGS-01646500", "USGS-02238500"]) + + def test_none_passes(self): + """None is allowed (optional parameter).""" + _check_monitoring_location_id(None) + + def test_integer_raises_type_error(self): + """An integer ID raises TypeError with a helpful message.""" + with pytest.raises(TypeError, match="not int"): + _check_monitoring_location_id(5129115) + + def test_integer_in_list_raises_type_error(self): + """An integer inside a list raises TypeError.""" + with pytest.raises(TypeError, match="not int"): + _check_monitoring_location_id(["USGS-01646500", 5129115]) + + def test_missing_agency_prefix_raises_value_error(self): + """A string without the AGENCY- prefix raises ValueError.""" + with pytest.raises(ValueError, match="Invalid monitoring_location_id"): + _check_monitoring_location_id("dog") + + def test_bare_site_number_raises_value_error(self): + """A bare site number string (no agency prefix) raises ValueError.""" + with pytest.raises(ValueError, match="Invalid monitoring_location_id"): + _check_monitoring_location_id("01646500") + + def test_get_daily_integer_id_raises(self): + """get_daily raises TypeError before making any network call.""" + with pytest.raises(TypeError): + get_daily(monitoring_location_id=5129115, parameter_code="00060") + + def test_get_daily_malformed_id_raises(self): + """get_daily raises ValueError for a malformed string ID.""" + with pytest.raises(ValueError): + get_daily(monitoring_location_id="dog", parameter_code="00060")