diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..d6643b7 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,125 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + # ----------------------------------------------------------------------- + # Downloader tests (~1 min) + # ----------------------------------------------------------------------- + test-downloader: + runs-on: ubuntu-latest + defaults: + run: + working-directory: src/downloader + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + pip install -r requirements.txt + pip install pytest pyyaml + + - name: Run tests + run: pytest tests/ -v --tb=short + + # ----------------------------------------------------------------------- + # Simulator tests - pure Python logic only (~1 min) + # ----------------------------------------------------------------------- + test-simulator: + runs-on: ubuntu-latest + defaults: + run: + working-directory: src/simulator + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: pip install numpy pytest pymatgen + + - name: Run tests + run: pytest tests/ -v --tb=short + + # ----------------------------------------------------------------------- + # Trainer tests (~3 min) + # ----------------------------------------------------------------------- + test-trainer: + runs-on: ubuntu-latest + defaults: + run: + working-directory: src/trainer + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install CPU-only PyTorch and dependencies + run: | + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install pytorch-lightning numpy pyyaml tqdm matplotlib scikit-learn pytest + + - name: Run tests + run: pytest tests/ -v --tb=short + + # ----------------------------------------------------------------------- + # UI backend tests (~3 min) + # ----------------------------------------------------------------------- + test-ui-backend: + runs-on: ubuntu-latest + defaults: + run: + working-directory: src/ui + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install CPU-only PyTorch and dependencies + run: | + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install pytest httpx + + - name: Run tests + run: pytest tests/ -v --tb=short + + # ----------------------------------------------------------------------- + # UI frontend tests (~2 min) + # ----------------------------------------------------------------------- + test-ui-frontend: + runs-on: ubuntu-latest + defaults: + run: + working-directory: src/ui/frontend + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-node@v4 + with: + node-version: "18" + cache: "npm" + cache-dependency-path: src/ui/frontend/package-lock.json + + - name: Install dependencies + run: npm ci + + - name: Install Vitest + run: npm install --save-dev vitest + + - name: Run tests + run: npx vitest run --reporter=verbose diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8c20255 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +# Root-level pytest config ensures rootdir is the repo root, +# preventing src/trainer/pyproject.toml from being used as config. +testpaths = + src/downloader/tests + src/simulator/tests + src/trainer/tests + src/ui/tests diff --git a/src/downloader/tests/test_downloader.py b/src/downloader/tests/test_downloader.py new file mode 100644 index 0000000..0675876 --- /dev/null +++ b/src/downloader/tests/test_downloader.py @@ -0,0 +1,226 @@ +"""Unit tests for the downloader module.""" + +import os +import sys +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest +import yaml + +# Add parent directory to path so we can import the module +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from downloader import ( + load_config, + ensure_output_dir, + get_api_key, + passes_filters, + _get_crystal_system_from_sg, + is_spacegroup_stable, + _write_cif_from_struct, +) + + +# --------------------------------------------------------------------------- +# load_config +# --------------------------------------------------------------------------- +class TestLoadConfig: + def test_valid_yaml(self, tmp_path): + cfg_file = tmp_path / "config.yaml" + cfg_file.write_text("output_directory: /data/raw_cif\nfilters:\n max_atoms: 500\n") + cfg = load_config(cfg_file) + assert cfg["output_directory"] == "/data/raw_cif" + assert cfg["filters"]["max_atoms"] == 500 + + def test_empty_yaml_returns_empty_dict(self, tmp_path): + cfg_file = tmp_path / "empty.yaml" + cfg_file.write_text("") + cfg = load_config(cfg_file) + assert cfg == {} + + def test_missing_file_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + load_config(tmp_path / "nonexistent.yaml") + + +# --------------------------------------------------------------------------- +# ensure_output_dir +# --------------------------------------------------------------------------- +class TestEnsureOutputDir: + def test_creates_directory(self, tmp_path): + new_dir = tmp_path / "a" / "b" / "c" + assert not new_dir.exists() + ensure_output_dir(new_dir) + assert new_dir.is_dir() + + def test_existing_directory_no_error(self, tmp_path): + ensure_output_dir(tmp_path) # already exists + + +# --------------------------------------------------------------------------- +# get_api_key +# --------------------------------------------------------------------------- +class TestGetApiKey: + def test_returns_key_from_env(self, monkeypatch): + monkeypatch.setenv("MP_API_KEY", "test-key-123") + assert get_api_key() == "test-key-123" + + def test_raises_when_missing(self, monkeypatch): + monkeypatch.delenv("MP_API_KEY", raising=False) + with pytest.raises(RuntimeError, match="MP_API_KEY"): + get_api_key() + + def test_raises_when_empty(self, monkeypatch): + monkeypatch.setenv("MP_API_KEY", " ") + with pytest.raises(RuntimeError, match="MP_API_KEY"): + get_api_key() + + +# --------------------------------------------------------------------------- +# _get_crystal_system_from_sg +# --------------------------------------------------------------------------- +class TestGetCrystalSystemFromSg: + @pytest.mark.parametrize( + "sg_num, expected", + [ + (1, 1), # Triclinic + (2, 1), # Triclinic boundary + (3, 2), # Monoclinic + (15, 2), # Monoclinic boundary + (16, 3), # Orthorhombic + (74, 3), # Orthorhombic boundary + (75, 4), # Tetragonal + (142, 4), # Tetragonal boundary + (143, 5), # Trigonal + (167, 5), # Trigonal boundary + (168, 6), # Hexagonal + (194, 6), # Hexagonal boundary + (195, 7), # Cubic + (230, 7), # Cubic boundary + ], + ) + def test_valid_space_groups(self, sg_num, expected): + assert _get_crystal_system_from_sg(sg_num) == expected + + def test_invalid_returns_none(self): + assert _get_crystal_system_from_sg(231) is None + assert _get_crystal_system_from_sg(None) is None + assert _get_crystal_system_from_sg("abc") is None + + def test_zero_and_negative_map_to_triclinic(self): + # Code treats sg_num <= 2 as Triclinic (no lower-bound guard) + assert _get_crystal_system_from_sg(0) == 1 + assert _get_crystal_system_from_sg(-1) == 1 + + +# --------------------------------------------------------------------------- +# passes_filters +# --------------------------------------------------------------------------- +class TestPassesFilters: + def _make_mock_structure(self, num_atoms=10, volume=100.0): + """Create a mock structure with controllable atom count and volume.""" + mock = MagicMock() + mock.__len__ = MagicMock(return_value=num_atoms) + mock.volume = volume + return mock + + def test_no_filters_passes(self): + struct = self._make_mock_structure() + assert passes_filters(struct, {}) is True + + def test_max_atoms_pass(self): + struct = self._make_mock_structure(num_atoms=100) + assert passes_filters(struct, {"max_atoms": 500}) is True + + def test_max_atoms_fail(self): + struct = self._make_mock_structure(num_atoms=600) + assert passes_filters(struct, {"max_atoms": 500}) is False + + def test_max_atoms_boundary(self): + struct = self._make_mock_structure(num_atoms=500) + assert passes_filters(struct, {"max_atoms": 500}) is True + + def test_min_volume_pass(self): + struct = self._make_mock_structure(volume=200.0) + assert passes_filters(struct, {"min_volume": 100.0}) is True + + def test_min_volume_fail(self): + struct = self._make_mock_structure(volume=50.0) + assert passes_filters(struct, {"min_volume": 100.0}) is False + + def test_max_volume_pass(self): + struct = self._make_mock_structure(volume=500.0) + assert passes_filters(struct, {"max_volume": 1000.0}) is True + + def test_max_volume_fail(self): + struct = self._make_mock_structure(volume=1500.0) + assert passes_filters(struct, {"max_volume": 1000.0}) is False + + def test_combined_filters(self): + struct = self._make_mock_structure(num_atoms=100, volume=500.0) + filters = {"max_atoms": 200, "min_volume": 100.0, "max_volume": 1000.0} + assert passes_filters(struct, filters) is True + + def test_combined_filters_fail_atoms(self): + struct = self._make_mock_structure(num_atoms=300, volume=500.0) + filters = {"max_atoms": 200, "min_volume": 100.0, "max_volume": 1000.0} + assert passes_filters(struct, filters) is False + + +# --------------------------------------------------------------------------- +# is_spacegroup_stable +# --------------------------------------------------------------------------- +class TestIsSpacegroupStable: + def test_all_same(self): + grid = [[225, 225], [225, 225]] + assert is_spacegroup_stable(grid) is True + + def test_different_values(self): + grid = [[225, 225], [225, 226]] + assert is_spacegroup_stable(grid) is False + + def test_single_value(self): + grid = [[225]] + assert is_spacegroup_stable(grid) is True + + def test_empty_grid(self): + grid = [[]] + assert is_spacegroup_stable(grid) is False + + def test_all_zeros(self): + grid = [[0, 0], [0, 0]] + assert is_spacegroup_stable(grid) is True + + def test_zero_and_nonzero(self): + grid = [[0, 225], [225, 225]] + assert is_spacegroup_stable(grid) is False + + +# --------------------------------------------------------------------------- +# _write_cif_from_struct +# --------------------------------------------------------------------------- +class TestWriteCifFromStruct: + def test_writes_file_with_comment(self, tmp_path): + """Test that CIF writing adds space group comment.""" + mock_structure = MagicMock() + out_path = tmp_path / "test.cif" + + # Mock CifWriter to write a minimal CIF file + with patch("downloader.CifWriter") as MockCifWriter: + mock_writer = MagicMock() + MockCifWriter.return_value = mock_writer + + def write_side_effect(path): + with open(path, "w") as f: + f.write("data_test\n_cell_length_a 5.0\n") + + mock_writer.write_file.side_effect = write_side_effect + + _write_cif_from_struct(mock_structure, out_path, sg_num=225, sg_symbol="Fm-3m") + + assert out_path.exists() + content = out_path.read_text() + assert "Fm-3m" in content + assert "_original_symmetry_space_group_name_H-M" in content diff --git a/src/simulator/tests/test_simulation_worker.py b/src/simulator/tests/test_simulation_worker.py new file mode 100644 index 0000000..0f48fa2 --- /dev/null +++ b/src/simulator/tests/test_simulation_worker.py @@ -0,0 +1,245 @@ +"""Unit tests for simulator pure-Python logic (no GSAS-II dependency).""" + +import sys +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest + +# We need to mock the GSAS-II import before importing the module +# since it's a hard dependency that won't be available in CI. +sys.modules["GSASII"] = MagicMock() +sys.modules["GSASII.GSASIIscriptable"] = MagicMock() + +# Mock pymatgen.symmetry.groups before import +mock_sg = MagicMock() +mock_sg.SpaceGroup.SYMM_OPS = [ + {"short_h_m": "P1", "number": 1}, + {"short_h_m": "P-1", "number": 2}, + {"short_h_m": "Fm-3m", "number": 225}, + {"short_h_m": "Im-3m", "number": 229}, + {"short_h_m": "Pn-3m", "number": 224}, + {"short_h_m": "Ia-3d", "number": 230}, + {"short_h_m": "Cmce", "number": 64}, +] +mock_sg.SpaceGroup.sg_encoding = {} +sys.modules["pymatgen"] = MagicMock() +sys.modules["pymatgen.symmetry"] = MagicMock() +sys.modules["pymatgen.symmetry.groups"] = mock_sg + +# Now import the worker module +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from simulation_worker import ( + get_crystal_system, + _compute_invariants, + _parse_cif_lattice_params, + _parse_cif_file, + convert_space_group_for_gsas, + no_extra_chars, + extra_chars, + get_sg_num_from_symbol, +) + + +# --------------------------------------------------------------------------- +# get_crystal_system +# --------------------------------------------------------------------------- +class TestGetCrystalSystem: + @pytest.mark.parametrize( + "sg_num, expected", + [ + (1, 1), # Triclinic + (2, 1), + (3, 2), # Monoclinic + (15, 2), + (16, 3), # Orthorhombic + (74, 3), + (75, 4), # Tetragonal + (142, 4), + (143, 5), # Trigonal + (167, 5), + (168, 6), # Hexagonal + (194, 6), + (195, 7), # Cubic + (230, 7), + ], + ) + def test_valid_space_groups(self, sg_num, expected): + assert get_crystal_system(sg_num) == expected + + def test_float_input(self): + assert get_crystal_system(225.0) == 7 + + def test_invalid_input(self): + assert get_crystal_system(None) is None + assert get_crystal_system("abc") is None + assert get_crystal_system(231) is None + + +# --------------------------------------------------------------------------- +# _compute_invariants +# --------------------------------------------------------------------------- +class TestComputeInvariants: + def test_cubic_cell(self): + """Cubic cell: a=b=c, alpha=beta=gamma=90.""" + a = b = c = 5.0 + alpha = beta = gamma = 90.0 + I1, I2 = _compute_invariants(a, b, c, alpha, beta, gamma) + + # I1 = sqrt(3 * 5^2) = sqrt(75) = 5*sqrt(3) + expected_I1 = np.sqrt(3 * 25.0) + np.testing.assert_allclose(I1, expected_I1, rtol=1e-10) + + # For 90-degree angles: sin(90)=1, so + # I2 = sqrt(b^2*c^2 + a^2*c^2 + a^2*b^2) = sqrt(3*25^2) = 5^2*sqrt(3) + expected_I2 = np.sqrt(3 * 625.0) + np.testing.assert_allclose(I2, expected_I2, rtol=1e-10) + + def test_orthorhombic_cell(self): + """Orthorhombic: a!=b!=c, alpha=beta=gamma=90.""" + a, b, c = 3.0, 4.0, 5.0 + alpha = beta = gamma = 90.0 + I1, I2 = _compute_invariants(a, b, c, alpha, beta, gamma) + + expected_I1 = np.sqrt(9 + 16 + 25) + np.testing.assert_allclose(I1, expected_I1, rtol=1e-10) + + # sin(90)=1 + expected_I2 = np.sqrt(16 * 25 + 9 * 25 + 9 * 16) + np.testing.assert_allclose(I2, expected_I2, rtol=1e-10) + + def test_returns_numpy_float(self): + I1, I2 = _compute_invariants(5.0, 5.0, 5.0, 90.0, 90.0, 90.0) + assert isinstance(I1, (np.floating, float)) + assert isinstance(I2, (np.floating, float)) + + +# --------------------------------------------------------------------------- +# _parse_cif_lattice_params +# --------------------------------------------------------------------------- +class TestParseCifLatticeParams: + def test_standard_cif(self, tmp_path): + cif_file = tmp_path / "test.cif" + cif_file.write_text( + "data_test\n" + "_cell_length_a 5.431\n" + "_cell_length_b 5.431\n" + "_cell_length_c 5.431\n" + "_cell_angle_alpha 90.0\n" + "_cell_angle_beta 90.0\n" + "_cell_angle_gamma 90.0\n" + ) + result = _parse_cif_lattice_params(str(cif_file)) + assert result is not None + a, b, c, alpha, beta, gamma = result + assert abs(a - 5.431) < 1e-6 + assert abs(alpha - 90.0) < 1e-6 + + def test_cif_with_uncertainty(self, tmp_path): + """Test that uncertainty parentheses (e.g., 5.431(2)) are handled.""" + cif_file = tmp_path / "test.cif" + cif_file.write_text( + "data_test\n" + "_cell_length_a 5.431(2)\n" + "_cell_length_b 5.431(3)\n" + "_cell_length_c 7.123(1)\n" + "_cell_angle_alpha 90.0\n" + "_cell_angle_beta 90.0\n" + "_cell_angle_gamma 120.0(5)\n" + ) + result = _parse_cif_lattice_params(str(cif_file)) + assert result is not None + a, b, c, alpha, beta, gamma = result + assert abs(a - 5.431) < 1e-6 + assert abs(gamma - 120.0) < 1e-6 + + def test_missing_fields_returns_none(self, tmp_path): + cif_file = tmp_path / "incomplete.cif" + cif_file.write_text("data_test\n_cell_length_a 5.431\n") + result = _parse_cif_lattice_params(str(cif_file)) + assert result is None + + def test_nonexistent_file_returns_none(self): + result = _parse_cif_lattice_params("/nonexistent/path.cif") + assert result is None + + +# --------------------------------------------------------------------------- +# _parse_cif_file (space group parsing) +# --------------------------------------------------------------------------- +class TestParseCifFile: + def test_parse_from_int_tables_number(self, tmp_path): + cif_file = tmp_path / "test.cif" + cif_file.write_text( + "data_test\n" + "_symmetry_Int_Tables_number 225\n" + ) + sg_num, cs = _parse_cif_file(str(cif_file)) + assert sg_num == 225 + assert cs == 7 # Cubic + + def test_parse_from_comment(self, tmp_path): + cif_file = tmp_path / "test.cif" + cif_file.write_text( + "data_test\n" + "# _original_symmetry_space_group_name_H-M 'P1'\n" + "_cell_length_a 5.0\n" + ) + sg_num, cs = _parse_cif_file(str(cif_file), parse_from_comment=True) + assert sg_num == 1 + assert cs == 1 # Triclinic + + def test_no_spacegroup_returns_none(self, tmp_path): + cif_file = tmp_path / "test.cif" + cif_file.write_text("data_test\n_cell_length_a 5.0\n") + sg_num, cs = _parse_cif_file(str(cif_file)) + assert sg_num is None + assert cs is None + + +# --------------------------------------------------------------------------- +# Space group conversion functions +# --------------------------------------------------------------------------- +class TestSpaceGroupConversion: + def test_no_extra_chars(self): + assert no_extra_chars("P1") == "P 1" + assert no_extra_chars("Fm") == "F m" + + def test_extra_chars_with_subscript(self): + result = extra_chars("P2_1") + assert "21" in result + + def test_extra_chars_with_negative(self): + result = extra_chars("P-1") + assert "-1" in result + + def test_extra_chars_with_division(self): + result = extra_chars("P6/m") + assert "6/m" in result + + def test_convert_space_group_for_gsas_simple(self): + result = convert_space_group_for_gsas("P1") + assert result == "P 1" + + def test_convert_space_group_for_gsas_with_subscript(self): + result = convert_space_group_for_gsas("P2_1") + assert "21" in result + + def test_strips_h_r_suffix(self): + result = convert_space_group_for_gsas("R3-H") + # Should strip the -H suffix + assert "H" not in result + + def test_get_sg_num_from_symbol_known(self): + assert get_sg_num_from_symbol("P1") == 1 + + def test_get_sg_num_from_symbol_alt(self): + """Test alternative symbol mapping (e.g., Fm3m -> Fm-3m).""" + result = get_sg_num_from_symbol("Fm3m") + assert result == 225 + + def test_get_sg_num_from_symbol_unknown(self): + result = get_sg_num_from_symbol("ZZZZZ") + assert result is None diff --git a/src/trainer/tests/test_dataset.py b/src/trainer/tests/test_dataset.py new file mode 100644 index 0000000..92a623b --- /dev/null +++ b/src/trainer/tests/test_dataset.py @@ -0,0 +1,427 @@ +"""Unit tests for the trainer dataset and manifest utilities.""" + +import json +import os +import sys +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from dataset.dataset import ( + NpyManifestDataset, + _read_jsonl_manifest, + make_poisson_gaussian_noise_transform, + default_manifest_paths, +) +from dataset.manifest_utils import ( + scan_dataset_root, + split_materials, + write_jsonl_manifest, + generate_manifests, + ManifestStats, + _is_material_dir, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +def _make_npy_sample(path, dp_len=100, cs=7, sg=225, a=5.0, b=5.0, c=5.0, + alpha=90.0, beta=90.0, gamma=90.0): + """Create a synthetic .npy file matching the simulator's output format.""" + payload = { + "dp": np.random.rand(dp_len).astype(np.float32), + "cs": cs, + "sg": sg, + "_cell_length_a": a, + "_cell_length_b": b, + "_cell_length_c": c, + "_cell_angle_alpha": alpha, + "_cell_angle_beta": beta, + "_cell_angle_gamma": gamma, + } + np.save(path, payload, allow_pickle=True) + + +def _make_dataset_tree(root, materials_count=5, files_per_material=3, dp_len=100): + """Create a dataset tree with material directories and .npy files.""" + root = Path(root) + root.mkdir(parents=True, exist_ok=True) + for i in range(materials_count): + mid = f"mp-{1000 + i}" + mat_dir = root / mid + mat_dir.mkdir(exist_ok=True) + for j in range(files_per_material): + _make_npy_sample(mat_dir / f"{mid}-{j + 1}.npy", dp_len=dp_len, cs=(i % 7) + 1, sg=(i % 230) + 1) + return root + + +def _make_manifest(path, materials, base_dir=None): + """Write a JSONL manifest file.""" + with open(path, "w") as f: + if base_dir: + f.write(json.dumps({"__meta__": {"version": 1, "base_dir": base_dir}}) + "\n") + for mid, files in materials.items(): + f.write(json.dumps({"material_id": mid, "files": files}) + "\n") + + +# --------------------------------------------------------------------------- +# _read_jsonl_manifest +# --------------------------------------------------------------------------- +class TestReadJsonlManifest: + def test_basic_manifest(self, tmp_path): + manifest = tmp_path / "test.jsonl" + _make_manifest(str(manifest), { + "mp-1000": ["mp-1000/mp-1000-1.npy"], + "mp-1001": ["mp-1001/mp-1001-1.npy"], + }) + records, meta = _read_jsonl_manifest(str(manifest)) + assert len(records) == 2 + assert meta is None + + def test_manifest_with_meta(self, tmp_path): + manifest = tmp_path / "test.jsonl" + _make_manifest(str(manifest), {"mp-1000": ["mp-1000-1.npy"]}, base_dir="../dataset") + records, meta = _read_jsonl_manifest(str(manifest)) + assert len(records) == 1 + assert meta is not None + assert meta["base_dir"] == "../dataset" + + def test_missing_file_raises(self): + with pytest.raises(FileNotFoundError): + _read_jsonl_manifest("/nonexistent/manifest.jsonl") + + def test_invalid_record_raises(self, tmp_path): + manifest = tmp_path / "bad.jsonl" + manifest.write_text('{"bad_key": "bad_value"}\n') + with pytest.raises(ValueError, match="material_id"): + _read_jsonl_manifest(str(manifest)) + + def test_empty_manifest_raises(self, tmp_path): + manifest = tmp_path / "empty.jsonl" + manifest.write_text("") + with pytest.raises(RuntimeError, match="No records"): + _read_jsonl_manifest(str(manifest)) + + +# --------------------------------------------------------------------------- +# NpyManifestDataset +# --------------------------------------------------------------------------- +class TestNpyManifestDataset: + @pytest.fixture + def dataset_dir(self, tmp_path): + return _make_dataset_tree(tmp_path / "dataset", materials_count=3, files_per_material=2) + + @pytest.fixture + def manifest_file(self, dataset_dir, tmp_path): + """Create a manifest pointing to the dataset.""" + manifest_path = tmp_path / "manifests" / "test.jsonl" + manifest_path.parent.mkdir(parents=True, exist_ok=True) + + materials = {} + for mat_dir in sorted(dataset_dir.iterdir()): + if mat_dir.is_dir(): + mid = mat_dir.name + files = [str(f) for f in sorted(mat_dir.glob("*.npy"))] + materials[mid] = files + + _make_manifest(str(manifest_path), materials) + return str(manifest_path) + + def test_len(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=False, validate_paths=True, extract_labels=False, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=False, + ) + assert len(ds) == 6 # 3 materials * 2 files each + + def test_getitem_returns_tensor(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=False, validate_paths=True, extract_labels=False, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=False, + ) + sample = ds[0] + assert isinstance(sample, torch.Tensor) + assert sample.dtype == torch.float32 + assert sample.shape == (100,) + + def test_getitem_with_meta(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=True, validate_paths=True, extract_labels=True, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=False, + ) + sample = ds[0] + assert isinstance(sample, dict) + assert "x" in sample + assert "material_id" in sample + assert "path" in sample + assert "cs" in sample + assert "sg" in sample + + def test_extract_labels(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=True, validate_paths=True, extract_labels=True, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=False, + ) + sample = ds[0] + assert "cs" in sample + assert "sg" in sample + assert "lattice_params" in sample + assert sample["lattice_params"].shape == (6,) + + def test_shift_labels(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=True, validate_paths=True, extract_labels=True, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=True, + ) + sample = ds[0] + # Labels should be shifted by -1 + assert sample["cs"].item() >= 0 # originally 1-indexed, now 0-indexed + + def test_floor_at_zero(self, dataset_dir, tmp_path): + """Test that floor_at_zero clamps negative values.""" + # Create a sample with negative values + mat_dir = dataset_dir / "mp-neg" + mat_dir.mkdir(exist_ok=True) + payload = { + "dp": np.array([-1.0, 0.5, 2.0, -0.5], dtype=np.float32), + "cs": 1, "sg": 1, + "_cell_length_a": 5.0, "_cell_length_b": 5.0, "_cell_length_c": 5.0, + "_cell_angle_alpha": 90.0, "_cell_angle_beta": 90.0, "_cell_angle_gamma": 90.0, + } + np.save(mat_dir / "mp-neg-1.npy", payload, allow_pickle=True) + + manifest_path = tmp_path / "neg_manifest.jsonl" + _make_manifest(str(manifest_path), {"mp-neg": [str(mat_dir / "mp-neg-1.npy")]}) + + ds = NpyManifestDataset( + manifest_path=str(manifest_path), dtype=torch.float32, mmap_mode=None, + return_meta=False, validate_paths=True, extract_labels=False, + allow_pickle=True, floor_at_zero=True, normalize_log1p=False, + shift_labels=False, + ) + sample = ds[0] + assert (sample >= 0).all() + + def test_normalize_log1p(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=False, validate_paths=True, extract_labels=False, + allow_pickle=True, floor_at_zero=True, normalize_log1p=True, + shift_labels=False, + ) + sample = ds[0] + # log1p of non-negative values should be non-negative + assert (sample >= 0).all() + + def test_materials_property(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=False, validate_paths=True, extract_labels=False, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=False, + ) + mats = ds.materials + assert len(mats) == 3 + assert all(m.startswith("mp-") for m in mats) + + def test_count_by_material(self, manifest_file): + ds = NpyManifestDataset( + manifest_path=manifest_file, dtype=torch.float32, mmap_mode=None, + return_meta=False, validate_paths=True, extract_labels=False, + allow_pickle=True, floor_at_zero=False, normalize_log1p=False, + shift_labels=False, + ) + counts = ds.count_by_material() + assert all(c == 2 for c in counts.values()) + + +# --------------------------------------------------------------------------- +# make_poisson_gaussian_noise_transform +# --------------------------------------------------------------------------- +class TestNoiseTransform: + def test_output_shape_preserved(self): + transform = make_poisson_gaussian_noise_transform((50, 200), (0.01, 0.05)) + x = torch.rand(100) * 1000 + out = transform(x) + assert out.shape == x.shape + + def test_output_non_negative(self): + transform = make_poisson_gaussian_noise_transform((50, 200), (0.01, 0.05)) + x = torch.rand(100) * 1000 + out = transform(x) + assert (out >= 0).all() + + def test_zero_input(self): + transform = make_poisson_gaussian_noise_transform((50, 200), (0.01, 0.05)) + x = torch.zeros(50) + out = transform(x) + assert (out >= 0).all() + + def test_deterministic_with_same_seed(self): + """Noise transform uses torch random state; setting seed should make it reproducible.""" + transform = make_poisson_gaussian_noise_transform((100, 100), (0.0, 0.0)) + x = torch.ones(50) * 100 + torch.manual_seed(42) + out1 = transform(x) + torch.manual_seed(42) + out2 = transform(x) + torch.testing.assert_close(out1, out2) + + +# --------------------------------------------------------------------------- +# scan_dataset_root +# --------------------------------------------------------------------------- +class TestScanDatasetRoot: + def test_scans_materials(self, tmp_path): + _make_dataset_tree(tmp_path / "dataset", materials_count=3) + materials = scan_dataset_root(str(tmp_path / "dataset")) + assert len(materials) == 3 + + def test_nonexistent_root_raises(self): + with pytest.raises(FileNotFoundError): + scan_dataset_root("/nonexistent/path") + + def test_empty_root_raises(self, tmp_path): + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + with pytest.raises(RuntimeError, match="No .npy files"): + scan_dataset_root(str(empty_dir)) + + +# --------------------------------------------------------------------------- +# split_materials +# --------------------------------------------------------------------------- +class TestSplitMaterials: + def test_split_ratios(self): + materials = {f"mp-{i}": [f"f{i}.npy"] for i in range(100)} + train, val, test = split_materials(materials, 0.8, 0.1, 0.1, seed=42) + assert len(train) == 80 + assert len(val) == 10 + assert len(test) == 10 + + def test_no_overlap(self): + materials = {f"mp-{i}": [f"f{i}.npy"] for i in range(50)} + train, val, test = split_materials(materials, 0.8, 0.1, 0.1, seed=42) + train_ids = set(train.keys()) + val_ids = set(val.keys()) + test_ids = set(test.keys()) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) + assert val_ids.isdisjoint(test_ids) + + def test_all_materials_assigned(self): + materials = {f"mp-{i}": [f"f{i}.npy"] for i in range(50)} + train, val, test = split_materials(materials, 0.8, 0.1, 0.1, seed=42) + total = len(train) + len(val) + len(test) + assert total == 50 + + def test_invalid_ratios_raises(self): + materials = {f"mp-{i}": [f"f{i}.npy"] for i in range(10)} + with pytest.raises(ValueError, match="sum to 1.0"): + split_materials(materials, 0.5, 0.5, 0.5) + + def test_reproducible_with_same_seed(self): + materials = {f"mp-{i}": [f"f{i}.npy"] for i in range(50)} + t1, v1, te1 = split_materials(materials, 0.8, 0.1, 0.1, seed=42) + t2, v2, te2 = split_materials(materials, 0.8, 0.1, 0.1, seed=42) + assert set(t1.keys()) == set(t2.keys()) + assert set(v1.keys()) == set(v2.keys()) + assert set(te1.keys()) == set(te2.keys()) + + +# --------------------------------------------------------------------------- +# write_jsonl_manifest / generate_manifests +# --------------------------------------------------------------------------- +class TestManifestWriting: + def test_write_jsonl_manifest(self, tmp_path): + manifest_path = str(tmp_path / "test.jsonl") + materials = {"mp-1": ["/data/mp-1/f1.npy", "/data/mp-1/f2.npy"]} + count = write_jsonl_manifest(manifest_path, materials) + assert count == 1 + assert os.path.isfile(manifest_path) + + records, meta = _read_jsonl_manifest(manifest_path) + assert len(records) == 1 + assert records[0]["material_id"] == "mp-1" + + def test_write_with_base_dir(self, tmp_path): + manifest_path = str(tmp_path / "manifests" / "test.jsonl") + materials = {"mp-1": ["/data/dataset/mp-1/f1.npy"]} + write_jsonl_manifest(manifest_path, materials, base_dir="../dataset") + + records, meta = _read_jsonl_manifest(manifest_path) + assert meta is not None + assert meta["base_dir"] == "../dataset" + + def test_generate_manifests(self, tmp_path): + dataset_root = _make_dataset_tree(tmp_path / "dataset", materials_count=10) + manifest_dir = str(tmp_path / "manifests") + + stats = generate_manifests( + dataset_root=str(dataset_root), + manifest_dir=manifest_dir, + train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, + seed=42, + ) + + assert isinstance(stats, ManifestStats) + assert stats.train_materials + stats.val_materials + stats.test_materials == 10 + assert os.path.isfile(os.path.join(manifest_dir, "train.jsonl")) + assert os.path.isfile(os.path.join(manifest_dir, "val.jsonl")) + assert os.path.isfile(os.path.join(manifest_dir, "test.jsonl")) + + +# --------------------------------------------------------------------------- +# _is_material_dir +# --------------------------------------------------------------------------- +class TestIsMaterialDir: + def test_dir_with_npy(self, tmp_path): + mat_dir = tmp_path / "mp-1" + mat_dir.mkdir() + (mat_dir / "test.npy").write_bytes(b"data") + assert _is_material_dir(str(mat_dir)) is True + + def test_empty_dir(self, tmp_path): + mat_dir = tmp_path / "mp-1" + mat_dir.mkdir() + assert _is_material_dir(str(mat_dir)) is False + + def test_nonexistent(self, tmp_path): + assert _is_material_dir(str(tmp_path / "nope")) is False + + def test_file_not_dir(self, tmp_path): + f = tmp_path / "file.txt" + f.write_text("data") + assert _is_material_dir(str(f)) is False + + +# --------------------------------------------------------------------------- +# default_manifest_paths +# --------------------------------------------------------------------------- +class TestDefaultManifestPaths: + def test_default_paths(self): + paths = default_manifest_paths() + assert "train" in paths + assert "val" in paths + assert "test" in paths + assert paths["train"].endswith("train.jsonl") + + def test_custom_dir(self): + paths = default_manifest_paths("/custom/dir") + assert paths["train"] == "/custom/dir/train.jsonl" diff --git a/src/trainer/tests/test_model.py b/src/trainer/tests/test_model.py new file mode 100644 index 0000000..01f9be8 --- /dev/null +++ b/src/trainer/tests/test_model.py @@ -0,0 +1,357 @@ +"""Unit tests for the trainer model architecture.""" + +import sys +from pathlib import Path + +import pytest +import torch +import torch.nn as nn + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from model.model import ( + drop_path, + DropPath, + ConvNeXtBlock1D, + ConvNextBlock1DAdaptor, + MultiscaleCNNBackbone1D, + AlphaDiffractMultiscaleLightning, + make_mlp, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _default_model_kwargs(): + """Return a minimal set of kwargs to instantiate AlphaDiffractMultiscaleLightning.""" + return dict( + dim_in=8192, + channels=(32, 32), + kernel_sizes=(7, 7), + strides=(2, 2), + dropout_rate=0.0, + ramped_dropout_rate=False, + block_type="convnext", + pooling_type="average", + final_pool=True, + use_batchnorm=False, + activation=nn.GELU, + output_type="gap", + layer_scale_init_value=1e-6, + drop_path_rate=0.0, + head_dropout=0.0, + cs_hidden=(64,), + sg_hidden=(64,), + lp_hidden=(64,), + num_cs_classes=7, + num_sg_classes=230, + num_lp_outputs=6, + lp_bounds_min=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), + lp_bounds_max=(100.0, 100.0, 100.0, 180.0, 180.0, 180.0), + bound_lp_with_sigmoid=True, + lambda_cs=1.0, + lambda_sg=1.0, + lambda_lp=1.0, + lr=0.001, + weight_decay=0.01, + use_adamw=True, + steps_per_epoch=100, + ) + + +# --------------------------------------------------------------------------- +# drop_path utility +# --------------------------------------------------------------------------- +class TestDropPath: + def test_no_drop(self): + x = torch.randn(2, 4, 8) + out = drop_path(x, drop_prob=0.0, training=True) + assert torch.equal(out, x) + + def test_no_drop_eval(self): + x = torch.randn(2, 4, 8) + out = drop_path(x, drop_prob=0.5, training=False) + assert torch.equal(out, x) + + def test_drop_path_module(self): + dp = DropPath(drop_prob=0.0) + x = torch.randn(2, 4, 8) + assert torch.equal(dp(x), x) + + +# --------------------------------------------------------------------------- +# ConvNeXtBlock1D +# --------------------------------------------------------------------------- +class TestConvNeXtBlock1D: + def test_output_shape(self): + block = ConvNeXtBlock1D( + dim=32, kernel_size=7, drop_path=0.0, + layer_scale_init_value=1e-6, activation=nn.GELU, + ) + x = torch.randn(2, 32, 64) + out = block(x) + assert out.shape == x.shape + + def test_residual_connection(self): + """With zero init gamma, output should be close to input.""" + block = ConvNeXtBlock1D( + dim=16, kernel_size=3, drop_path=0.0, + layer_scale_init_value=0.0, activation=nn.GELU, + ) + # layer_scale_init_value=0 means gamma is None, so no scaling + # but with layer_scale_init_value=0, gamma is not created (> 0 check) + # Actually checking the code: gamma is only created if value > 0 + # So with 0, gamma is None and the block output before residual is unscaled + x = torch.randn(2, 16, 32) + out = block(x) + assert out.shape == x.shape + + +# --------------------------------------------------------------------------- +# ConvNextBlock1DAdaptor +# --------------------------------------------------------------------------- +class TestConvNextBlock1DAdaptor: + def test_same_channels_no_stride(self): + adaptor = ConvNextBlock1DAdaptor( + in_channels=32, out_channels=32, kernel_size=7, stride=1, + dropout=0.0, use_batchnorm=False, activation=nn.GELU, + layer_scale_init_value=1e-6, drop_path_rate=0.0, block_type="convnext", + ) + x = torch.randn(2, 32, 64) + out = adaptor(x) + assert out.shape == (2, 32, 64) + + def test_channel_change(self): + adaptor = ConvNextBlock1DAdaptor( + in_channels=16, out_channels=32, kernel_size=7, stride=1, + dropout=0.0, use_batchnorm=False, activation=nn.GELU, + layer_scale_init_value=1e-6, drop_path_rate=0.0, block_type="convnext", + ) + x = torch.randn(2, 16, 64) + out = adaptor(x) + assert out.shape == (2, 32, 64) + + def test_stride_downsamples(self): + adaptor = ConvNextBlock1DAdaptor( + in_channels=32, out_channels=32, kernel_size=7, stride=4, + dropout=0.0, use_batchnorm=False, activation=nn.GELU, + layer_scale_init_value=1e-6, drop_path_rate=0.0, block_type="convnext", + ) + x = torch.randn(2, 32, 64) + out = adaptor(x) + assert out.shape == (2, 32, 16) # 64 / 4 = 16 + + def test_non_convnext_block_type(self): + adaptor = ConvNextBlock1DAdaptor( + in_channels=32, out_channels=32, kernel_size=7, stride=1, + dropout=0.0, use_batchnorm=False, activation=nn.GELU, + layer_scale_init_value=1e-6, drop_path_rate=0.0, block_type="none", + ) + assert adaptor.block is None + x = torch.randn(2, 32, 64) + out = adaptor(x) + assert out.shape == (2, 32, 64) + + +# --------------------------------------------------------------------------- +# MultiscaleCNNBackbone1D +# --------------------------------------------------------------------------- +class TestMultiscaleCNNBackbone1D: + def test_gap_output(self): + backbone = MultiscaleCNNBackbone1D( + dim_in=512, channels=(16, 32), kernel_sizes=(7, 5), + strides=(2, 2), dropout_rate=0.0, ramped_dropout_rate=False, + block_type="convnext", pooling_type="average", final_pool=True, + use_batchnorm=False, activation=nn.GELU, output_type="gap", + layer_scale_init_value=1e-6, drop_path_rate=0.0, + ) + x = torch.randn(4, 1, 512) + out = backbone(x) + assert out.shape == (4, 32) # GAP over last channel dim + assert backbone.dim_output == 32 + + def test_accepts_2d_input(self): + backbone = MultiscaleCNNBackbone1D( + dim_in=256, channels=(16,), kernel_sizes=(7,), + strides=(1,), dropout_rate=0.0, ramped_dropout_rate=False, + block_type="convnext", pooling_type="average", final_pool=True, + use_batchnorm=False, activation=nn.GELU, output_type="gap", + layer_scale_init_value=1e-6, drop_path_rate=0.0, + ) + x = torch.randn(4, 256) # 2D input, no channel dim + out = backbone(x) + assert out.ndim == 2 + assert out.shape[0] == 4 + + def test_flatten_output(self): + backbone = MultiscaleCNNBackbone1D( + dim_in=256, channels=(16,), kernel_sizes=(7,), + strides=(1,), dropout_rate=0.0, ramped_dropout_rate=False, + block_type="convnext", pooling_type="average", final_pool=True, + use_batchnorm=False, activation=nn.GELU, output_type="flatten", + layer_scale_init_value=1e-6, drop_path_rate=0.0, + ) + x = torch.randn(4, 1, 256) + out = backbone(x) + assert out.ndim == 2 + assert out.shape[0] == 4 + assert out.shape[1] == backbone.dim_output + + def test_max_pooling(self): + backbone = MultiscaleCNNBackbone1D( + dim_in=256, channels=(16,), kernel_sizes=(7,), + strides=(1,), dropout_rate=0.0, ramped_dropout_rate=False, + block_type="convnext", pooling_type="max", final_pool=True, + use_batchnorm=False, activation=nn.GELU, output_type="gap", + layer_scale_init_value=1e-6, drop_path_rate=0.0, + ) + x = torch.randn(2, 1, 256) + out = backbone(x) + assert out.shape == (2, 16) + + def test_mismatched_lengths_raises(self): + with pytest.raises(AssertionError): + MultiscaleCNNBackbone1D( + dim_in=256, channels=(16, 32), kernel_sizes=(7,), # mismatch + strides=(1, 1), dropout_rate=0.0, ramped_dropout_rate=False, + block_type="convnext", pooling_type="average", final_pool=True, + use_batchnorm=False, activation=nn.GELU, output_type="gap", + layer_scale_init_value=1e-6, drop_path_rate=0.0, + ) + + +# --------------------------------------------------------------------------- +# make_mlp +# --------------------------------------------------------------------------- +class TestMakeMlp: + def test_simple_mlp(self): + mlp = make_mlp(64, (32, 16), 7, dropout=0.1) + x = torch.randn(4, 64) + out = mlp(x) + assert out.shape == (4, 7) + + def test_no_hidden(self): + mlp = make_mlp(64, None, 7, dropout=0.0) + x = torch.randn(4, 64) + out = mlp(x) + assert out.shape == (4, 7) + + def test_with_output_activation(self): + mlp = make_mlp(64, (32,), 7, dropout=0.0, output_activation=nn.Softmax(dim=-1)) + x = torch.randn(4, 64) + out = mlp(x) + assert out.shape == (4, 7) + # Softmax means rows should sum to 1 + torch.testing.assert_close(out.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=1e-5) + + +# --------------------------------------------------------------------------- +# AlphaDiffractMultiscaleLightning +# --------------------------------------------------------------------------- +class TestAlphaDiffractMultiscaleLightning: + def test_forward_output_keys(self): + model = AlphaDiffractMultiscaleLightning(**_default_model_kwargs()) + x = torch.randn(2, 8192) + out = model(x) + assert isinstance(out, dict) + assert "features" in out + assert "cs_logits" in out + assert "sg_logits" in out + assert "lp" in out + + def test_forward_output_shapes(self): + model = AlphaDiffractMultiscaleLightning(**_default_model_kwargs()) + x = torch.randn(2, 8192) + out = model(x) + assert out["cs_logits"].shape == (2, 7) + assert out["sg_logits"].shape == (2, 230) + assert out["lp"].shape == (2, 6) + + def test_lp_bounded_with_sigmoid(self): + kwargs = _default_model_kwargs() + kwargs["bound_lp_with_sigmoid"] = True + model = AlphaDiffractMultiscaleLightning(**kwargs) + x = torch.randn(2, 8192) + out = model(x) + lp = out["lp"] + # All LP values should be within bounds + assert (lp >= 0.0).all() + assert (lp[:, :3] <= 100.0).all() # lengths bounded by max + assert (lp[:, 3:] <= 180.0).all() # angles bounded by max + + def test_lp_unbounded(self): + kwargs = _default_model_kwargs() + kwargs["bound_lp_with_sigmoid"] = False + model = AlphaDiffractMultiscaleLightning(**kwargs) + x = torch.randn(2, 8192) + out = model(x) + # Should still produce valid tensor + assert out["lp"].shape == (2, 6) + + def test_training_step(self): + model = AlphaDiffractMultiscaleLightning(**_default_model_kwargs()) + batch = ( + torch.randn(4, 8192), # x + torch.randint(0, 7, (4,)), # cs labels + torch.randint(0, 230, (4,)), # sg labels + torch.randn(4, 6), # lp labels + ) + loss = model.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 # scalar + assert loss.item() > 0 + + def test_training_step_dict_batch(self): + model = AlphaDiffractMultiscaleLightning(**_default_model_kwargs()) + batch = { + "x": torch.randn(4, 8192), + "cs": torch.randint(0, 7, (4,)), + "sg": torch.randint(0, 230, (4,)), + "lattice_params": torch.randn(4, 6), + } + loss = model.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + + def test_training_step_no_lp(self): + model = AlphaDiffractMultiscaleLightning(**_default_model_kwargs()) + batch = ( + torch.randn(4, 8192), + torch.randint(0, 7, (4,)), + torch.randint(0, 230, (4,)), + ) + loss = model.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + + def test_configure_optimizers(self): + model = AlphaDiffractMultiscaleLightning(**_default_model_kwargs()) + opt_config = model.configure_optimizers() + assert "optimizer" in opt_config + assert "lr_scheduler" in opt_config + assert isinstance(opt_config["optimizer"], torch.optim.AdamW) + + def test_configure_optimizers_adam(self): + kwargs = _default_model_kwargs() + kwargs["use_adamw"] = False + model = AlphaDiffractMultiscaleLightning(**kwargs) + opt_config = model.configure_optimizers() + assert isinstance(opt_config["optimizer"], torch.optim.Adam) + + def test_to_index_scalar_labels(self): + idx = AlphaDiffractMultiscaleLightning._to_index(torch.tensor([2, 5, 0]), 7) + assert idx.dtype == torch.long + assert torch.equal(idx, torch.tensor([2, 5, 0])) + + def test_to_index_onehot_labels(self): + onehot = torch.zeros(3, 7) + onehot[0, 2] = 1.0 + onehot[1, 5] = 1.0 + onehot[2, 0] = 1.0 + idx = AlphaDiffractMultiscaleLightning._to_index(onehot, 7) + assert torch.equal(idx, torch.tensor([2, 5, 0])) + + def test_to_index_clamps(self): + """Labels exceeding num_classes should be clamped.""" + idx = AlphaDiffractMultiscaleLightning._to_index(torch.tensor([300]), 230) + assert idx.item() == 229 # clamped to max valid index diff --git a/src/ui/frontend/src/utils/xrd-processing.js b/src/ui/frontend/src/utils/xrd-processing.js new file mode 100644 index 0000000..bff9464 --- /dev/null +++ b/src/ui/frontend/src/utils/xrd-processing.js @@ -0,0 +1,225 @@ +/** + * Pure utility functions extracted from XRDContext for testability. + * These functions contain no React state or side effects. + */ + +/** + * Convert wavelength using Bragg's law: lambda = 2d*sin(theta) + * For same d-spacing: sin(theta2) = (lambda2/lambda1) * sin(theta1) + * + * @param {number} theta_deg - 2theta angle in degrees + * @param {number} sourceWavelength - Source wavelength in Angstroms + * @param {number} targetWavelength - Target wavelength in Angstroms + * @returns {number|null} Converted 2theta angle, or null if physically impossible + */ +export function convertWavelength(theta_deg, sourceWavelength, targetWavelength) { + if (Math.abs(sourceWavelength - targetWavelength) < 0.0001) { + return theta_deg + } + + const theta_rad = (theta_deg * Math.PI) / 180 + const sin_theta2 = (targetWavelength / sourceWavelength) * Math.sin(theta_rad) + + if (Math.abs(sin_theta2) > 1) { + return null + } + + const theta2_rad = Math.asin(sin_theta2) + return (theta2_rad * 180) / Math.PI +} + +/** + * Interpolate data to fixed size for model input. + * + * @param {number[]} x - Input x values (2theta) + * @param {number[]} y - Input y values (intensity) + * @param {number} targetSize - Desired output length + * @param {number} [xMin] - Minimum x value for output grid + * @param {number} [xMax] - Maximum x value for output grid + * @param {string} [strategy='linear'] - Interpolation strategy: 'linear' or 'cubic' + * @returns {{x: number[], y: number[]}} Interpolated data + */ +export function interpolateData(x, y, targetSize, xMin, xMax, strategy = 'linear') { + if (x.length === targetSize && xMin === undefined) { + return { x, y } + } + + const minX = xMin !== undefined ? xMin : Math.min(...x) + const maxX = xMax !== undefined ? xMax : Math.max(...x) + const step = (maxX - minX) / (targetSize - 1) + + const newX = Array.from({ length: targetSize }, (_, i) => minX + i * step) + const newY = new Array(targetSize) + + const dataMinX = Math.min(...x) + const dataMaxX = Math.max(...x) + + if (strategy === 'linear') { + for (let i = 0; i < targetSize; i++) { + const targetX = newX[i] + + if (targetX < dataMinX || targetX > dataMaxX) { + newY[i] = 0 + continue + } + + let idx = x.findIndex(val => val >= targetX) + if (idx === -1) idx = x.length - 1 + if (idx === 0) idx = 1 + + const x0 = x[idx - 1] + const x1 = x[idx] + const y0 = y[idx - 1] + const y1 = y[idx] + + newY[i] = y0 + ((targetX - x0) * (y1 - y0)) / (x1 - x0) + } + } else if (strategy === 'cubic') { + for (let i = 0; i < targetSize; i++) { + const targetX = newX[i] + + if (targetX < dataMinX || targetX > dataMaxX) { + newY[i] = 0 + continue + } + + let idx = x.findIndex(val => val >= targetX) + if (idx === -1) idx = x.length - 1 + if (idx === 0) idx = 1 + + const i0 = Math.max(0, idx - 2) + const i1 = Math.max(0, idx - 1) + const i2 = Math.min(x.length - 1, idx) + const i3 = Math.min(x.length - 1, idx + 1) + + if (i2 === i1) { + newY[i] = y[i1] + } else { + const t = (targetX - x[i1]) / (x[i2] - x[i1]) + const t2 = t * t + const t3 = t2 * t + + const v0 = y[i0] + const v1 = y[i1] + const v2 = y[i2] + const v3 = y[i3] + + newY[i] = 0.5 * ( + 2 * v1 + + (-v0 + v2) * t + + (2 * v0 - 5 * v1 + 4 * v2 - v3) * t2 + + (-v0 + 3 * v1 - 3 * v2 + v3) * t3 + ) + } + } + } + + return { x: newX, y: newY } +} + +/** + * Parse DIF or XY format (space-separated 2theta intensity). + * + * @param {string} text - Raw file content + * @returns {{x: number[], y: number[]}} Parsed data points + */ +export function parseDIF(text) { + const lines = text.split('\n') + const x = [] + const y = [] + + for (const line of lines) { + const trimmed = line.trim() + + if (!trimmed || + trimmed.startsWith('#') || + trimmed.startsWith('_') || + trimmed.startsWith('CELL') || + trimmed.startsWith('SPACE') || + /^[a-zA-Z]/.test(trimmed)) { + continue + } + + const parts = trimmed.split(/\s+/) + if (parts.length >= 2) { + const xVal = parseFloat(parts[0]) + const yVal = parseFloat(parts[1]) + + if (!isNaN(xVal) && !isNaN(yVal)) { + x.push(xVal) + y.push(yVal) + } + } + } + + return { x, y } +} + +/** + * Extract metadata from CIF/DIF file text. + * + * @param {string} text - Raw file content + * @returns {{wavelength: number|null, cellParams: string|null, spaceGroup: string|null, crystalSystem: string|null}} + */ +export function extractMetadata(text) { + const metadata = { + wavelength: null, + cellParams: null, + spaceGroup: null, + crystalSystem: null + } + + const lines = text.split('\n') + + const wavelengthPatterns = [ + /wavelength[:\s=]+([0-9.]+)/i, + /lambda[:\s=]+([0-9.]+)/i, + /wave[:\s=]+([0-9.]+)/i, + /_pd_wavelength[:\s]+([0-9.]+)/i, + /_diffrn_radiation_wavelength[:\s]+([0-9.]+)/i, + /radiation.*?([0-9.]+)\s*[AÅ]/i, + ] + + for (const line of lines) { + if (!metadata.wavelength) { + for (const pattern of wavelengthPatterns) { + const match = line.match(pattern) + if (match && match[1]) { + const wavelength = parseFloat(match[1]) + if (wavelength > 0.1 && wavelength < 3.0) { + metadata.wavelength = wavelength + break + } + } + } + + if (/Cu\s*K[αa]/i.test(line)) metadata.wavelength = 1.5406 + else if (/Mo\s*K[αa]/i.test(line)) metadata.wavelength = 0.7107 + else if (/Co\s*K[αa]/i.test(line)) metadata.wavelength = 1.7889 + else if (/Cr\s*K[αa]/i.test(line)) metadata.wavelength = 2.2897 + } + + if (/CELL PARAMETERS:/i.test(line)) { + const match = line.match(/CELL PARAMETERS:\s*([\d.\s]+)/) + if (match) { + metadata.cellParams = match[1].trim() + } + } + + if (/SPACE GROUP:/i.test(line) || /_symmetry_Int_Tables_number/i.test(line)) { + const match = line.match(/(?:SPACE GROUP:|_symmetry_Int_Tables_number)[:\s]+(\d+)/) + if (match) { + metadata.spaceGroup = match[1] + } + } + + if (/Crystal System:/i.test(line)) { + const match = line.match(/Crystal System:\s*(\d+)/) + if (match) { + metadata.crystalSystem = match[1] + } + } + } + + return metadata +} diff --git a/src/ui/frontend/src/utils/xrd-processing.test.js b/src/ui/frontend/src/utils/xrd-processing.test.js new file mode 100644 index 0000000..258a44f --- /dev/null +++ b/src/ui/frontend/src/utils/xrd-processing.test.js @@ -0,0 +1,197 @@ +import { describe, it, expect } from 'vitest' +import { + convertWavelength, + interpolateData, + parseDIF, + extractMetadata, +} from './xrd-processing' + +// --------------------------------------------------------------------------- +// convertWavelength +// --------------------------------------------------------------------------- +describe('convertWavelength', () => { + it('returns same angle when wavelengths match', () => { + const result = convertWavelength(10.0, 0.6199, 0.6199) + expect(result).toBeCloseTo(10.0, 5) + }) + + it('returns null for physically impossible conversion', () => { + // Large angle with large wavelength ratio can exceed sin > 1 + const result = convertWavelength(80.0, 0.5, 2.0) + expect(result).toBeNull() + }) + + it('converts Cu Ka to synchrotron wavelength', () => { + const cuKa = 1.5406 + const synchrotron = 0.6199 + const theta = 20.0 + + const result = convertWavelength(theta, cuKa, synchrotron) + expect(result).not.toBeNull() + // Shorter wavelength -> smaller 2theta + expect(result).toBeLessThan(theta) + expect(result).toBeGreaterThan(0) + }) + + it('converts synchrotron to Cu Ka wavelength', () => { + const cuKa = 1.5406 + const synchrotron = 0.6199 + const theta = 10.0 + + const result = convertWavelength(theta, synchrotron, cuKa) + expect(result).not.toBeNull() + // Longer wavelength -> larger 2theta + expect(result).toBeGreaterThan(theta) + }) + + it('handles zero angle', () => { + const result = convertWavelength(0, 1.0, 2.0) + expect(result).toBeCloseTo(0, 5) + }) +}) + +// --------------------------------------------------------------------------- +// interpolateData +// --------------------------------------------------------------------------- +describe('interpolateData', () => { + it('returns same data when target size matches', () => { + const x = [1, 2, 3] + const y = [10, 20, 30] + const result = interpolateData(x, y, 3) + expect(result.x).toEqual(x) + expect(result.y).toEqual(y) + }) + + it('interpolates to larger size with linear strategy', () => { + const x = [0, 10] + const y = [0, 100] + const result = interpolateData(x, y, 11, 0, 10, 'linear') + expect(result.x).toHaveLength(11) + expect(result.y).toHaveLength(11) + // Midpoint should be ~50 + expect(result.y[5]).toBeCloseTo(50, 0) + }) + + it('sets out-of-range values to zero', () => { + const x = [5, 10, 15] + const y = [100, 200, 300] + const result = interpolateData(x, y, 20, 0, 20, 'linear') + // Points before x=5 should be 0 + expect(result.y[0]).toBe(0) + // Points after x=15 should be 0 + expect(result.y[19]).toBe(0) + }) + + it('supports cubic interpolation', () => { + const x = [0, 2, 4, 6, 8, 10] + const y = [0, 4, 16, 36, 64, 100] + const result = interpolateData(x, y, 11, 0, 10, 'cubic') + expect(result.x).toHaveLength(11) + expect(result.y).toHaveLength(11) + }) + + it('handles single point input', () => { + const x = [5] + const y = [100] + const result = interpolateData(x, y, 10, 0, 10, 'linear') + expect(result.x).toHaveLength(10) + }) +}) + +// --------------------------------------------------------------------------- +// parseDIF +// --------------------------------------------------------------------------- +describe('parseDIF', () => { + it('parses space-separated data', () => { + const text = '5.0 100.0\n10.0 200.0\n15.0 300.0\n' + const result = parseDIF(text) + expect(result.x).toEqual([5.0, 10.0, 15.0]) + expect(result.y).toEqual([100.0, 200.0, 300.0]) + }) + + it('skips comment lines', () => { + const text = '# This is a comment\n5.0 100.0\n10.0 200.0\n' + const result = parseDIF(text) + expect(result.x).toHaveLength(2) + }) + + it('skips metadata lines', () => { + const text = 'CELL PARAMETERS: 5.0 5.0 5.0 90 90 90\nSPACE GROUP: Fm-3m\n5.0 100.0\n' + const result = parseDIF(text) + expect(result.x).toEqual([5.0]) + }) + + it('skips lines starting with underscore', () => { + const text = '_cell_length_a 5.431\n5.0 100.0\n' + const result = parseDIF(text) + expect(result.x).toEqual([5.0]) + }) + + it('handles empty input', () => { + const result = parseDIF('') + expect(result.x).toEqual([]) + expect(result.y).toEqual([]) + }) + + it('handles tab-separated data', () => { + const text = '5.0\t100.0\n10.0\t200.0\n' + const result = parseDIF(text) + expect(result.x).toEqual([5.0, 10.0]) + expect(result.y).toEqual([100.0, 200.0]) + }) + + it('skips lines starting with letters', () => { + const text = 'Header line\n5.0 100.0\nAnother header\n10.0 200.0\n' + const result = parseDIF(text) + expect(result.x).toEqual([5.0, 10.0]) + }) +}) + +// --------------------------------------------------------------------------- +// extractMetadata +// --------------------------------------------------------------------------- +describe('extractMetadata', () => { + it('detects wavelength from numeric pattern', () => { + const text = '_diffrn_radiation_wavelength 0.6199\n' + const result = extractMetadata(text) + expect(result.wavelength).toBeCloseTo(0.6199, 4) + }) + + it('detects Cu Ka radiation', () => { + const text = 'Radiation: Cu Ka\n' + const result = extractMetadata(text) + expect(result.wavelength).toBeCloseTo(1.5406, 4) + }) + + it('detects Mo Ka radiation', () => { + const text = 'Radiation: Mo Ka\n' + const result = extractMetadata(text) + expect(result.wavelength).toBeCloseTo(0.7107, 4) + }) + + it('extracts space group number', () => { + const text = '_symmetry_Int_Tables_number 225\n' + const result = extractMetadata(text) + expect(result.spaceGroup).toBe('225') + }) + + it('extracts cell parameters', () => { + const text = 'CELL PARAMETERS: 5.431 5.431 5.431 90.0 90.0 90.0\n' + const result = extractMetadata(text) + expect(result.cellParams).not.toBeNull() + expect(result.cellParams).toContain('5.431') + }) + + it('returns null for missing data', () => { + const result = extractMetadata('Some random text without metadata\n') + expect(result.wavelength).toBeNull() + expect(result.spaceGroup).toBeNull() + expect(result.cellParams).toBeNull() + }) + + it('rejects wavelengths outside X-ray range', () => { + const text = 'wavelength: 0.001\n' + const result = extractMetadata(text) + expect(result.wavelength).toBeNull() + }) +}) diff --git a/src/ui/requirements.txt b/src/ui/requirements.txt index 8e64b2a..2dd13b9 100644 --- a/src/ui/requirements.txt +++ b/src/ui/requirements.txt @@ -1,8 +1,8 @@ -fastapi==0.104.1 -uvicorn[standard]==0.24.0 -python-multipart==0.0.6 -numpy==1.24.3 -torch==2.1.0 -pytorch-lightning==2.5.5 -pydantic==2.5.0 -spglib==2.4.0 +fastapi>=0.115.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 +numpy>=1.24.3 +torch>=2.1.0 +pytorch-lightning>=2.5.5 +pydantic>=2.5.0 +spglib>=2.4.0 diff --git a/src/ui/tests/test_api.py b/src/ui/tests/test_api.py new file mode 100644 index 0000000..f9e6f86 --- /dev/null +++ b/src/ui/tests/test_api.py @@ -0,0 +1,257 @@ +"""Unit tests for the UI FastAPI backend.""" + +import sys +from pathlib import Path +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from app.model_inference import XRDModelInference + + +# --------------------------------------------------------------------------- +# XRDModelInference.preprocess_data +# --------------------------------------------------------------------------- +class TestPreprocessData: + def setup_method(self): + self.inference = XRDModelInference() + # Force CPU for tests so .numpy() works regardless of GPU availability + self.inference.device = torch.device("cpu") + + def test_output_shape(self): + x = list(range(100)) + y = [float(i) for i in range(100)] + tensor = self.inference.preprocess_data(x, y) + assert tensor.shape == (1, 100) + + def test_floor_at_zero(self): + x = [1.0, 2.0, 3.0] + y = [-10.0, 5.0, 100.0] + tensor = self.inference.preprocess_data(x, y) + # After flooring at zero: [0, 5, 100] + # After rescaling: [0, 5, 100] -> min=0, max=100 -> [0, 5, 100] + assert (tensor >= 0).all() + + def test_rescale_to_0_100(self): + x = [1.0, 2.0, 3.0] + y = [0.0, 50.0, 200.0] + tensor = self.inference.preprocess_data(x, y) + # Should normalize to [0, 100] + values = tensor.squeeze().numpy() + np.testing.assert_allclose(values[0], 0.0, atol=1e-5) + np.testing.assert_allclose(values[-1], 100.0, atol=1e-5) + + def test_constant_intensity(self): + x = [1.0, 2.0, 3.0] + y = [5.0, 5.0, 5.0] + tensor = self.inference.preprocess_data(x, y) + # Constant intensity -> zeros + values = tensor.squeeze().numpy() + np.testing.assert_allclose(values, 0.0, atol=1e-5) + + def test_all_zeros(self): + x = [1.0, 2.0] + y = [0.0, 0.0] + tensor = self.inference.preprocess_data(x, y) + values = tensor.squeeze().numpy() + np.testing.assert_allclose(values, 0.0, atol=1e-5) + + def test_negative_all_floor(self): + x = [1.0, 2.0, 3.0] + y = [-10.0, -5.0, -1.0] + tensor = self.inference.preprocess_data(x, y) + # All negatives -> floored to 0 -> constant -> zeros + values = tensor.squeeze().numpy() + np.testing.assert_allclose(values, 0.0, atol=1e-5) + + +# --------------------------------------------------------------------------- +# XRDModelInference._process_model_output +# --------------------------------------------------------------------------- +class TestProcessModelOutput: + def setup_method(self): + self.inference = XRDModelInference() + + def test_dict_output_with_all_heads(self): + output = { + "cs_logits": torch.randn(1, 7), + "sg_logits": torch.randn(1, 230), + "lp": torch.tensor([[5.0, 5.0, 5.0, 90.0, 90.0, 90.0]]), + "features": torch.randn(1, 64), + } + result = self.inference._process_model_output(output) + assert "phase_predictions" in result + preds = result["phase_predictions"] + assert len(preds) == 3 # cs, sg, lp + + def test_cs_prediction_structure(self): + output = { + "cs_logits": torch.zeros(1, 7), + "sg_logits": torch.randn(1, 230), + "lp": torch.randn(1, 6), + "features": torch.randn(1, 64), + } + result = self.inference._process_model_output(output) + cs_pred = result["phase_predictions"][0] + assert cs_pred["phase"] == "Crystal System" + assert "predicted_class" in cs_pred + assert "confidence" in cs_pred + assert "all_probabilities" in cs_pred + assert len(cs_pred["all_probabilities"]) == 7 + + def test_cs_probabilities_sum_to_one(self): + output = { + "cs_logits": torch.randn(1, 7), + "sg_logits": torch.randn(1, 230), + "lp": torch.randn(1, 6), + "features": torch.randn(1, 64), + } + result = self.inference._process_model_output(output) + cs_pred = result["phase_predictions"][0] + total_prob = sum(p["probability"] for p in cs_pred["all_probabilities"]) + np.testing.assert_allclose(total_prob, 1.0, atol=1e-5) + + def test_sg_prediction_structure(self): + output = { + "cs_logits": torch.randn(1, 7), + "sg_logits": torch.randn(1, 230), + "lp": torch.randn(1, 6), + "features": torch.randn(1, 64), + } + result = self.inference._process_model_output(output) + sg_pred = result["phase_predictions"][1] + assert sg_pred["phase"] == "Space Group" + assert "predicted_class" in sg_pred + assert "confidence" in sg_pred + assert "top_probabilities" in sg_pred + assert len(sg_pred["top_probabilities"]) == 10 + + def test_lp_prediction_structure(self): + lp_values = torch.tensor([[5.43, 5.43, 5.43, 90.0, 90.0, 90.0]]) + output = { + "cs_logits": torch.randn(1, 7), + "sg_logits": torch.randn(1, 230), + "lp": lp_values, + "features": torch.randn(1, 64), + } + result = self.inference._process_model_output(output) + lp_pred = result["phase_predictions"][2] + assert lp_pred["phase"] == "Lattice Parameters" + assert lp_pred["is_lattice"] is True + params = lp_pred["lattice_params"] + assert len(params) == 6 + np.testing.assert_allclose(params["a"], 5.43, atol=1e-5) + + def test_tensor_output_fallback(self): + output = torch.randn(1, 7) + result = self.inference._process_model_output(output) + assert "phase_predictions" in result + + def test_unknown_output_type(self): + result = self.inference._process_model_output("unexpected") + assert result["phase_predictions"] == [] + + +# --------------------------------------------------------------------------- +# XRDModelInference._compute_overall_confidence +# --------------------------------------------------------------------------- +class TestComputeOverallConfidence: + def setup_method(self): + self.inference = XRDModelInference() + + def test_averages_confidences(self): + processed = { + "phase_predictions": [ + {"confidence": 0.8}, + {"confidence": 0.6}, + ] + } + result = self.inference._compute_overall_confidence(processed) + np.testing.assert_allclose(result, 0.7, atol=1e-5) + + def test_no_confidences_returns_none(self): + processed = {"phase_predictions": [{"phase": "test"}]} + result = self.inference._compute_overall_confidence(processed) + assert result is None + + def test_empty_predictions(self): + result = self.inference._compute_overall_confidence({"phase_predictions": []}) + assert result is None + + +# --------------------------------------------------------------------------- +# XRDModelInference.predict (model not loaded) +# --------------------------------------------------------------------------- +class TestPredictNoModel: + def test_returns_error_when_no_model(self): + inference = XRDModelInference() + result = inference.predict([1.0, 2.0], [10.0, 20.0]) + assert result["status"] == "error" + assert "not loaded" in result["error"].lower() or "not loaded" in result["error"] + + +# --------------------------------------------------------------------------- +# XRDModelInference.is_loaded +# --------------------------------------------------------------------------- +class TestIsLoaded: + def test_false_initially(self): + inference = XRDModelInference() + assert inference.is_loaded() is False + + def test_true_after_mock_load(self): + inference = XRDModelInference() + inference.model = MagicMock() + assert inference.is_loaded() is True + + +# --------------------------------------------------------------------------- +# FastAPI endpoint tests +# --------------------------------------------------------------------------- +class TestFastAPIEndpoints: + """Test the FastAPI endpoints using TestClient.""" + + @pytest.fixture + def client(self): + """Create a test client with mocked model.""" + # We need to patch the model inference before importing the app + with patch("app.main.model_inference") as mock_inference: + mock_inference.is_loaded.return_value = True + mock_inference.predict.return_value = { + "status": "success", + "predictions": {"phase_predictions": []}, + } + + from fastapi.testclient import TestClient + from app.main import app + with TestClient(app) as c: + yield c, mock_inference + + def test_health_check(self, client): + c, mock_inference = client + response = c.get("/api/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + def test_predict_success(self, client): + c, mock_inference = client + response = c.post("/api/predict", json={ + "x": [1.0, 2.0, 3.0], + "y": [10.0, 20.0, 30.0], + }) + assert response.status_code == 200 + mock_inference.predict.assert_called_once() + + def test_predict_missing_data(self, client): + c, mock_inference = client + response = c.post("/api/predict", json={"x": [], "y": []}) + assert response.status_code == 400 + + def test_predict_mismatched_lengths(self, client): + c, mock_inference = client + response = c.post("/api/predict", json={"x": [1.0, 2.0], "y": [10.0]}) + assert response.status_code == 400