Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
name: Tests

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
# -----------------------------------------------------------------------
# Downloader tests (~1 min)
# -----------------------------------------------------------------------
test-downloader:
runs-on: ubuntu-latest
defaults:
run:
working-directory: src/downloader
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pyyaml

- name: Run tests
run: pytest tests/ -v --tb=short

# -----------------------------------------------------------------------
# Simulator tests - pure Python logic only (~1 min)
# -----------------------------------------------------------------------
test-simulator:
runs-on: ubuntu-latest
defaults:
run:
working-directory: src/simulator
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
run: pip install numpy pytest pymatgen

- name: Run tests
run: pytest tests/ -v --tb=short

# -----------------------------------------------------------------------
# Trainer tests (~3 min)
# -----------------------------------------------------------------------
test-trainer:
runs-on: ubuntu-latest
defaults:
run:
working-directory: src/trainer
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install CPU-only PyTorch and dependencies
run: |
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install pytorch-lightning numpy pyyaml tqdm matplotlib scikit-learn pytest

- name: Run tests
run: pytest tests/ -v --tb=short

# -----------------------------------------------------------------------
# UI backend tests (~3 min)
# -----------------------------------------------------------------------
test-ui-backend:
runs-on: ubuntu-latest
defaults:
run:
working-directory: src/ui
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install CPU-only PyTorch and dependencies
run: |
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pytest httpx

- name: Run tests
run: pytest tests/ -v --tb=short

# -----------------------------------------------------------------------
# UI frontend tests (~2 min)
# -----------------------------------------------------------------------
test-ui-frontend:
runs-on: ubuntu-latest
defaults:
run:
working-directory: src/ui/frontend
steps:
- uses: actions/checkout@v4

- uses: actions/setup-node@v4
with:
node-version: "18"
cache: "npm"
cache-dependency-path: src/ui/frontend/package-lock.json

- name: Install dependencies
run: npm ci

- name: Install Vitest
run: npm install --save-dev vitest

- name: Run tests
run: npx vitest run --reporter=verbose
8 changes: 8 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[pytest]
# Root-level pytest config ensures rootdir is the repo root,
# preventing src/trainer/pyproject.toml from being used as config.
testpaths =
src/downloader/tests
src/simulator/tests
src/trainer/tests
src/ui/tests
226 changes: 226 additions & 0 deletions src/downloader/tests/test_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""Unit tests for the downloader module."""

import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock

import pytest
import yaml

# Add parent directory to path so we can import the module
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from downloader import (
load_config,
ensure_output_dir,
get_api_key,
passes_filters,
_get_crystal_system_from_sg,
is_spacegroup_stable,
_write_cif_from_struct,
)


# ---------------------------------------------------------------------------
# load_config
# ---------------------------------------------------------------------------
class TestLoadConfig:
def test_valid_yaml(self, tmp_path):
cfg_file = tmp_path / "config.yaml"
cfg_file.write_text("output_directory: /data/raw_cif\nfilters:\n max_atoms: 500\n")
cfg = load_config(cfg_file)
assert cfg["output_directory"] == "/data/raw_cif"
assert cfg["filters"]["max_atoms"] == 500

def test_empty_yaml_returns_empty_dict(self, tmp_path):
cfg_file = tmp_path / "empty.yaml"
cfg_file.write_text("")
cfg = load_config(cfg_file)
assert cfg == {}

def test_missing_file_raises(self, tmp_path):
with pytest.raises(FileNotFoundError):
load_config(tmp_path / "nonexistent.yaml")


# ---------------------------------------------------------------------------
# ensure_output_dir
# ---------------------------------------------------------------------------
class TestEnsureOutputDir:
def test_creates_directory(self, tmp_path):
new_dir = tmp_path / "a" / "b" / "c"
assert not new_dir.exists()
ensure_output_dir(new_dir)
assert new_dir.is_dir()

def test_existing_directory_no_error(self, tmp_path):
ensure_output_dir(tmp_path) # already exists


# ---------------------------------------------------------------------------
# get_api_key
# ---------------------------------------------------------------------------
class TestGetApiKey:
def test_returns_key_from_env(self, monkeypatch):
monkeypatch.setenv("MP_API_KEY", "test-key-123")
assert get_api_key() == "test-key-123"

def test_raises_when_missing(self, monkeypatch):
monkeypatch.delenv("MP_API_KEY", raising=False)
with pytest.raises(RuntimeError, match="MP_API_KEY"):
get_api_key()

def test_raises_when_empty(self, monkeypatch):
monkeypatch.setenv("MP_API_KEY", " ")
with pytest.raises(RuntimeError, match="MP_API_KEY"):
get_api_key()


# ---------------------------------------------------------------------------
# _get_crystal_system_from_sg
# ---------------------------------------------------------------------------
class TestGetCrystalSystemFromSg:
@pytest.mark.parametrize(
"sg_num, expected",
[
(1, 1), # Triclinic
(2, 1), # Triclinic boundary
(3, 2), # Monoclinic
(15, 2), # Monoclinic boundary
(16, 3), # Orthorhombic
(74, 3), # Orthorhombic boundary
(75, 4), # Tetragonal
(142, 4), # Tetragonal boundary
(143, 5), # Trigonal
(167, 5), # Trigonal boundary
(168, 6), # Hexagonal
(194, 6), # Hexagonal boundary
(195, 7), # Cubic
(230, 7), # Cubic boundary
],
)
def test_valid_space_groups(self, sg_num, expected):
assert _get_crystal_system_from_sg(sg_num) == expected

def test_invalid_returns_none(self):
assert _get_crystal_system_from_sg(231) is None
assert _get_crystal_system_from_sg(None) is None
assert _get_crystal_system_from_sg("abc") is None

def test_zero_and_negative_map_to_triclinic(self):
# Code treats sg_num <= 2 as Triclinic (no lower-bound guard)
assert _get_crystal_system_from_sg(0) == 1
assert _get_crystal_system_from_sg(-1) == 1


# ---------------------------------------------------------------------------
# passes_filters
# ---------------------------------------------------------------------------
class TestPassesFilters:
def _make_mock_structure(self, num_atoms=10, volume=100.0):
"""Create a mock structure with controllable atom count and volume."""
mock = MagicMock()
mock.__len__ = MagicMock(return_value=num_atoms)
mock.volume = volume
return mock

def test_no_filters_passes(self):
struct = self._make_mock_structure()
assert passes_filters(struct, {}) is True

def test_max_atoms_pass(self):
struct = self._make_mock_structure(num_atoms=100)
assert passes_filters(struct, {"max_atoms": 500}) is True

def test_max_atoms_fail(self):
struct = self._make_mock_structure(num_atoms=600)
assert passes_filters(struct, {"max_atoms": 500}) is False

def test_max_atoms_boundary(self):
struct = self._make_mock_structure(num_atoms=500)
assert passes_filters(struct, {"max_atoms": 500}) is True

def test_min_volume_pass(self):
struct = self._make_mock_structure(volume=200.0)
assert passes_filters(struct, {"min_volume": 100.0}) is True

def test_min_volume_fail(self):
struct = self._make_mock_structure(volume=50.0)
assert passes_filters(struct, {"min_volume": 100.0}) is False

def test_max_volume_pass(self):
struct = self._make_mock_structure(volume=500.0)
assert passes_filters(struct, {"max_volume": 1000.0}) is True

def test_max_volume_fail(self):
struct = self._make_mock_structure(volume=1500.0)
assert passes_filters(struct, {"max_volume": 1000.0}) is False

def test_combined_filters(self):
struct = self._make_mock_structure(num_atoms=100, volume=500.0)
filters = {"max_atoms": 200, "min_volume": 100.0, "max_volume": 1000.0}
assert passes_filters(struct, filters) is True

def test_combined_filters_fail_atoms(self):
struct = self._make_mock_structure(num_atoms=300, volume=500.0)
filters = {"max_atoms": 200, "min_volume": 100.0, "max_volume": 1000.0}
assert passes_filters(struct, filters) is False


# ---------------------------------------------------------------------------
# is_spacegroup_stable
# ---------------------------------------------------------------------------
class TestIsSpacegroupStable:
def test_all_same(self):
grid = [[225, 225], [225, 225]]
assert is_spacegroup_stable(grid) is True

def test_different_values(self):
grid = [[225, 225], [225, 226]]
assert is_spacegroup_stable(grid) is False

def test_single_value(self):
grid = [[225]]
assert is_spacegroup_stable(grid) is True

def test_empty_grid(self):
grid = [[]]
assert is_spacegroup_stable(grid) is False

def test_all_zeros(self):
grid = [[0, 0], [0, 0]]
assert is_spacegroup_stable(grid) is True

def test_zero_and_nonzero(self):
grid = [[0, 225], [225, 225]]
assert is_spacegroup_stable(grid) is False


# ---------------------------------------------------------------------------
# _write_cif_from_struct
# ---------------------------------------------------------------------------
class TestWriteCifFromStruct:
def test_writes_file_with_comment(self, tmp_path):
"""Test that CIF writing adds space group comment."""
mock_structure = MagicMock()
out_path = tmp_path / "test.cif"

# Mock CifWriter to write a minimal CIF file
with patch("downloader.CifWriter") as MockCifWriter:
mock_writer = MagicMock()
MockCifWriter.return_value = mock_writer

def write_side_effect(path):
with open(path, "w") as f:
f.write("data_test\n_cell_length_a 5.0\n")

mock_writer.write_file.side_effect = write_side_effect

_write_cif_from_struct(mock_structure, out_path, sg_num=225, sg_symbol="Fm-3m")

assert out_path.exists()
content = out_path.read_text()
assert "Fm-3m" in content
assert "_original_symmetry_space_group_name_H-M" in content
Loading