diff --git a/.claude/commands/implement-feature.md b/.claude/commands/implement-feature.md
new file mode 100644
index 000000000..83aa2aa9a
--- /dev/null
+++ b/.claude/commands/implement-feature.md
@@ -0,0 +1,7 @@
+You will be implementing a new feature in this codebase
+
+$ARGUMENTS
+
+IMPORTANT: Only do this for front-end features,
+Once this feature is built, make sure to write the changes you made to file called frontend-changes.md
+Do not ask for permissions to modify this file, assume you can always do it.
\ No newline at end of file
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 000000000..17434eb96
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,75 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## Commands
+
+All commands use `uv` as the package manager. Dependencies are declared in `pyproject.toml`. **Never use `pip` directly — always use `uv run` or `uv sync`.**
+
+```bash
+# Install dependencies
+uv sync
+
+# Run the server (from repo root)
+./run.sh
+
+# Or manually from the backend directory
+cd backend && uv run uvicorn app:app --reload --port 8000
+```
+
+The app runs at `http://localhost:8000`. API docs at `http://localhost:8000/docs`.
+
+**Environment:** Create a `.env` file in the repo root with `ANTHROPIC_API_KEY=...` before running.
+
+## Architecture
+
+This is a RAG (Retrieval-Augmented Generation) system using **Claude's tool-use feature** — rather than injecting retrieved context directly into a prompt, Claude is given a search tool and autonomously decides when and what to search.
+
+### Request Flow
+
+```
+POST /api/query
+ → RAGSystem.query()
+ → AIGenerator.generate_response() [first Claude call]
+ → Claude decides to call search_course_content tool
+ → CourseSearchTool.execute()
+ → VectorStore.search() [ChromaDB semantic search]
+ → AIGenerator._handle_tool_execution() [second Claude call with results]
+ → SessionManager.add_exchange() [store to history]
+ → return (answer, sources)
+```
+
+### Key Components (`backend/`)
+
+- **`rag_system.py`** — Top-level orchestrator. Owns all components and exposes `query()` and `add_course_folder()`.
+- **`ai_generator.py`** — Wraps the Anthropic SDK. Handles the two-turn tool-use loop: initial call → tool execution → final response.
+- **`vector_store.py`** — ChromaDB wrapper with two collections:
+ - `course_catalog`: course-level metadata for fuzzy course name resolution
+ - `course_content`: chunked lesson text for semantic similarity search
+- **`document_processor.py`** — Parses structured `.txt` course files into `Course`/`Lesson`/`CourseChunk` objects, then splits content into overlapping chunks.
+- **`search_tools.py`** — Defines the `search_course_content` tool in Anthropic's tool-calling schema. `ToolManager` registers tools and routes execution.
+- **`session_manager.py`** — In-memory conversation history, keyed by session ID. History is appended to the system prompt as plain text.
+- **`config.py`** — Single `Config` dataclass. Key tunables: `CHUNK_SIZE=800`, `CHUNK_OVERLAP=100`, `MAX_RESULTS=5`, `MAX_HISTORY=2`, model `claude-sonnet-4-20250514`.
+
+### Course Document Format
+
+Files in `docs/` must follow this structure for `DocumentProcessor` to parse them correctly:
+
+```
+Course Title:
+Course Link:
+Course Instructor:
+
+Lesson 1:
+Lesson Link:
+
+
+Lesson 2:
+...
+```
+
+The course title doubles as the unique ID in ChromaDB. On server startup, existing courses are skipped (deduplication by title).
+
+### Frontend
+
+A plain HTML/CSS/JS chat UI served as static files by FastAPI from `../frontend`. No build step required.
diff --git a/backend/ai_generator.py b/backend/ai_generator.py
index 0363ca90c..c317caa73 100644
--- a/backend/ai_generator.py
+++ b/backend/ai_generator.py
@@ -3,15 +3,20 @@
class AIGenerator:
"""Handles interactions with Anthropic's Claude API for generating responses"""
-
+
+ MAX_TOOL_ROUNDS = 2
+
# Static system prompt to avoid rebuilding on each call
SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information.
Search Tool Usage:
- Use the search tool **only** for questions about specific course content or detailed educational materials
-- **One search per query maximum**
+- You may make **up to 2 sequential tool calls** per query when needed (e.g. first retrieve a course outline, then search for related content across courses)
+- Use a second tool call only if the first result is insufficient or a clearly necessary follow-up search is required
- Synthesize search results into accurate, fact-based responses
- If search yields no results, state this clearly without offering alternatives
+- **Outline queries** (e.g. "what lessons are in X?", "give me the outline of X"):
+ Use `get_course_outline`. Return the course title, course link (if present), and every lesson as "Lesson : ".
Response Protocol:
- **General knowledge questions**: Answer using existing knowledge without searching
@@ -28,108 +33,99 @@ class AIGenerator:
4. **Example-supported** - Include relevant examples when they aid understanding
Provide only the direct answer to what was asked.
"""
-
+
def __init__(self, api_key: str, model: str):
self.client = anthropic.Anthropic(api_key=api_key)
self.model = model
-
+
# Pre-build base API parameters
self.base_params = {
"model": self.model,
"temperature": 0,
"max_tokens": 800
}
-
+
def generate_response(self, query: str,
conversation_history: Optional[str] = None,
tools: Optional[List] = None,
tool_manager=None) -> str:
"""
Generate AI response with optional tool usage and conversation context.
-
+ Supports up to MAX_TOOL_ROUNDS sequential tool-call rounds.
+
Args:
query: The user's question or request
conversation_history: Previous messages for context
tools: Available tools the AI can use
tool_manager: Manager to execute tools
-
+
Returns:
Generated response as string
"""
-
- # Build system content efficiently - avoid string ops when possible
system_content = (
f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}"
- if conversation_history
+ if conversation_history
else self.SYSTEM_PROMPT
)
-
- # Prepare API call parameters efficiently
+
api_params = {
**self.base_params,
"messages": [{"role": "user", "content": query}],
"system": system_content
}
-
- # Add tools if available
+
if tools:
api_params["tools"] = tools
api_params["tool_choice"] = {"type": "auto"}
-
- # Get response from Claude
- response = self.client.messages.create(**api_params)
-
- # Handle tool execution if needed
- if response.stop_reason == "tool_use" and tool_manager:
- return self._handle_tool_execution(response, api_params, tool_manager)
-
- # Return direct response
- return response.content[0].text
-
- def _handle_tool_execution(self, initial_response, base_params: Dict[str, Any], tool_manager):
- """
- Handle execution of tool calls and get follow-up response.
-
- Args:
- initial_response: The response containing tool use requests
- base_params: Base API parameters
- tool_manager: Manager to execute tools
-
- Returns:
- Final response text after tool execution
- """
- # Start with existing messages
- messages = base_params["messages"].copy()
-
- # Add AI's tool use response
- messages.append({"role": "assistant", "content": initial_response.content})
-
- # Execute all tool calls and collect results
- tool_results = []
- for content_block in initial_response.content:
- if content_block.type == "tool_use":
- tool_result = tool_manager.execute_tool(
- content_block.name,
- **content_block.input
- )
-
- tool_results.append({
- "type": "tool_result",
- "tool_use_id": content_block.id,
- "content": tool_result
- })
-
- # Add tool results as single message
- if tool_results:
- messages.append({"role": "user", "content": tool_results})
-
- # Prepare final API call without tools
- final_params = {
- **self.base_params,
- "messages": messages,
- "system": base_params["system"]
- }
-
- # Get final response
- final_response = self.client.messages.create(**final_params)
- return final_response.content[0].text
\ No newline at end of file
+
+ round_count = 0
+
+ while True:
+ response = self.client.messages.create(**api_params)
+
+ # No tool use requested or no manager to handle it — return text directly
+ if response.stop_reason != "tool_use" or not tool_manager:
+ return self._extract_text(response)
+
+ round_count += 1
+
+ # Append assistant turn and execute all tool calls
+ new_messages = list(api_params["messages"])
+ new_messages.append({"role": "assistant", "content": response.content})
+
+ tool_results = []
+ error_occurred = False
+ for block in response.content:
+ if block.type == "tool_use":
+ try:
+ result = tool_manager.execute_tool(block.name, **block.input)
+ except Exception as e:
+ result = f"Error executing tool: {e}"
+ error_occurred = True
+ tool_results.append({
+ "type": "tool_result",
+ "tool_use_id": block.id,
+ "content": result
+ })
+
+ if tool_results:
+ new_messages.append({"role": "user", "content": tool_results})
+
+ # Cap reached or tool error — make one final call without tools and return
+ if error_occurred or round_count >= self.MAX_TOOL_ROUNDS:
+ final_params = {
+ **self.base_params,
+ "messages": new_messages,
+ "system": system_content
+ }
+ return self._extract_text(self.client.messages.create(**final_params))
+
+ # Round not yet capped — keep tools available and continue
+ api_params["messages"] = new_messages
+
+ def _extract_text(self, response) -> str:
+ """Safely extract text from any response, regardless of block ordering."""
+ for block in response.content:
+ if hasattr(block, "text"):
+ return block.text
+ return ""
diff --git a/backend/app.py b/backend/app.py
index 5a69d741d..c53b42308 100644
--- a/backend/app.py
+++ b/backend/app.py
@@ -85,6 +85,12 @@ async def get_course_stats():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+@app.delete("/api/session/{session_id}")
+async def delete_session(session_id: str):
+ """Clear session history from memory"""
+ rag_system.session_manager.clear_session(session_id)
+ return {"status": "cleared"}
+
@app.on_event("startup")
async def startup_event():
"""Load initial documents on startup"""
diff --git a/backend/config.py b/backend/config.py
index d9f6392ef..ff188020a 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -10,10 +10,10 @@ class Config:
"""Configuration settings for the RAG system"""
# Anthropic API settings
ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "")
- ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514"
+ ANTHROPIC_MODEL: str = "claude-haiku-4-5-20251001"
# Embedding model settings
- EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"
+ EMBEDDING_MODEL: str = "/Users/kimhoanpham/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main"
# Document processing settings
CHUNK_SIZE: int = 800 # Size of text chunks for vector storage
diff --git a/backend/rag_system.py b/backend/rag_system.py
index 50d848c8e..443649f0e 100644
--- a/backend/rag_system.py
+++ b/backend/rag_system.py
@@ -4,7 +4,7 @@
from vector_store import VectorStore
from ai_generator import AIGenerator
from session_manager import SessionManager
-from search_tools import ToolManager, CourseSearchTool
+from search_tools import ToolManager, CourseSearchTool, CourseOutlineTool
from models import Course, Lesson, CourseChunk
class RAGSystem:
@@ -23,6 +23,8 @@ def __init__(self, config):
self.tool_manager = ToolManager()
self.search_tool = CourseSearchTool(self.vector_store)
self.tool_manager.register_tool(self.search_tool)
+ self.outline_tool = CourseOutlineTool(self.vector_store)
+ self.tool_manager.register_tool(self.outline_tool)
def add_course_document(self, file_path: str) -> Tuple[Course, int]:
"""
diff --git a/backend/search_tools.py b/backend/search_tools.py
index adfe82352..11ddcdf03 100644
--- a/backend/search_tools.py
+++ b/backend/search_tools.py
@@ -104,7 +104,17 @@ def _format_results(self, results: SearchResults) -> str:
source = course_title
if lesson_num is not None:
source += f" - Lesson {lesson_num}"
- sources.append(source)
+
+ # Fetch lesson link from course catalog
+ lesson_link = None
+ if lesson_num is not None:
+ lesson_link = self.store.get_lesson_link(course_title, lesson_num)
+
+ # Encode as "label|url" when a link exists, plain label otherwise
+ if lesson_link:
+ sources.append(f"{source}|{lesson_link}")
+ else:
+ sources.append(source)
formatted.append(f"{header}\n{doc}")
@@ -113,6 +123,42 @@ def _format_results(self, results: SearchResults) -> str:
return "\n\n".join(formatted)
+class CourseOutlineTool(Tool):
+ """Tool for retrieving a course outline (title, link, and lesson list)"""
+
+ def __init__(self, vector_store: VectorStore):
+ self.store = vector_store
+
+ def get_tool_definition(self) -> Dict[str, Any]:
+ return {
+ "name": "get_course_outline",
+ "description": "Get the full outline of a course: title, link, and numbered lesson list",
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ "course_title": {
+ "type": "string",
+ "description": "Course title to look up (partial matches work)"
+ }
+ },
+ "required": ["course_title"]
+ }
+ }
+
+ def execute(self, course_title: str) -> str:
+ outline = self.store.get_course_outline(course_title)
+ if not outline:
+ return f"No course found matching '{course_title}'"
+
+ lines = [f"Course: {outline['title']}"]
+ if outline.get('course_link'):
+ lines.append(f"Link: {outline['course_link']}")
+ lines.append(f"\nLessons ({len(outline['lessons'])} total):")
+ for lesson in outline['lessons']:
+ lines.append(f" Lesson {lesson['lesson_number']}: {lesson['lesson_title']}")
+ return "\n".join(lines)
+
+
class ToolManager:
"""Manages available tools for the AI"""
diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
new file mode 100644
index 000000000..9ae4ceec0
--- /dev/null
+++ b/backend/tests/conftest.py
@@ -0,0 +1,174 @@
+"""
+Shared fixtures for the RAG system test suite.
+
+The production app.py mounts static files from ../frontend and initialises
+RAGSystem at import time, both of which fail in the test environment.
+To avoid that, conftest.py defines a create_test_app() factory that
+mirrors every API route with a caller-supplied (mock) RAGSystem and
+no static-file mount. All test modules should use the test_client
+fixture rather than importing app directly.
+"""
+import sys
+import os
+import pytest
+from fastapi import FastAPI, HTTPException
+from fastapi.testclient import TestClient
+from unittest.mock import MagicMock
+from pydantic import BaseModel
+from typing import List, Optional
+
+# Make the backend package importable from within the tests/ sub-directory.
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+
+from models import Course, Lesson, CourseChunk # noqa: E402
+
+
+# ---------------------------------------------------------------------------
+# Pydantic request / response models (mirrored from app.py)
+# ---------------------------------------------------------------------------
+
+class QueryRequest(BaseModel):
+ query: str
+ session_id: Optional[str] = None
+
+
+class QueryResponse(BaseModel):
+ answer: str
+ sources: List[str]
+ session_id: str
+
+
+class CourseStats(BaseModel):
+ total_courses: int
+ course_titles: List[str]
+
+
+# ---------------------------------------------------------------------------
+# Test-app factory
+# ---------------------------------------------------------------------------
+
+def create_test_app(rag_system) -> FastAPI:
+ """Return a FastAPI app wired to *rag_system* with no static-file mount."""
+ app = FastAPI(title="Test RAG App")
+
+ @app.post("/api/query", response_model=QueryResponse)
+ async def query_documents(request: QueryRequest):
+ try:
+ session_id = request.session_id
+ if not session_id:
+ session_id = rag_system.session_manager.create_session()
+ answer, sources = rag_system.query(request.query, session_id)
+ return QueryResponse(answer=answer, sources=sources, session_id=session_id)
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
+
+ @app.get("/api/courses", response_model=CourseStats)
+ async def get_course_stats():
+ try:
+ analytics = rag_system.get_course_analytics()
+ return CourseStats(
+ total_courses=analytics["total_courses"],
+ course_titles=analytics["course_titles"],
+ )
+ except Exception as exc:
+ raise HTTPException(status_code=500, detail=str(exc))
+
+ @app.delete("/api/session/{session_id}")
+ async def delete_session(session_id: str):
+ rag_system.session_manager.clear_session(session_id)
+ return {"status": "cleared"}
+
+ return app
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+@pytest.fixture
+def mock_rag_system():
+ """MagicMock standing in for RAGSystem with sensible default return values."""
+ mock = MagicMock()
+ mock.session_manager = MagicMock()
+ mock.session_manager.create_session.return_value = "session_1"
+ mock.query.return_value = ("Test answer about Python.", ["Course A - Lesson 1"])
+ mock.get_course_analytics.return_value = {
+ "total_courses": 2,
+ "course_titles": ["Course A", "Course B"],
+ }
+ return mock
+
+
+@pytest.fixture
+def test_client(mock_rag_system):
+ """Starlette TestClient backed by the test app and a fresh mock RAGSystem."""
+ app = create_test_app(mock_rag_system)
+ with TestClient(app) as client:
+ yield client
+
+
+@pytest.fixture
+def sample_query_request():
+ """Minimal valid /api/query payload."""
+ return {"query": "What is Python?"}
+
+
+@pytest.fixture
+def sample_course():
+ """A fully-populated Course model for unit tests that need one."""
+ return Course(
+ title="Python Basics",
+ course_link="https://example.com/python",
+ instructor="Jane Doe",
+ lessons=[
+ Lesson(lesson_number=1, title="Introduction", lesson_link="https://example.com/l1"),
+ Lesson(lesson_number=2, title="Variables", lesson_link="https://example.com/l2"),
+ ],
+ )
+
+
+@pytest.fixture
+def sample_chunk():
+ """A single CourseChunk for unit tests that need vector-store content."""
+ return CourseChunk(
+ content="Python is a high-level programming language.",
+ course_title="Python Basics",
+ lesson_number=1,
+ chunk_index=0,
+ )
+
+
+@pytest.fixture
+def sample_course_no_optionals():
+ """A Course where instructor and course_link are None."""
+ return Course(
+ title="Sparse Course",
+ course_link=None,
+ instructor=None,
+ lessons=[],
+ )
+
+
+@pytest.fixture
+def sample_chunks():
+ """A list of CourseChunks including one with lesson_number=None."""
+ return [
+ CourseChunk(content="chunk 0 text", course_title="Python Basics", lesson_number=1, chunk_index=0),
+ CourseChunk(content="chunk 1 text", course_title="Python Basics", lesson_number=2, chunk_index=1),
+ CourseChunk(content="chunk 2 no lesson", course_title="Python Basics", lesson_number=None, chunk_index=2),
+ ]
+
+
+@pytest.fixture
+def mock_vector_store():
+ """A MagicMock that mimics VectorStore's public interface."""
+ store = MagicMock()
+ store.search.return_value = MagicMock(
+ documents=["result doc"],
+ metadata=[{"course_title": "Python Basics", "lesson_number": 1}],
+ distances=[0.1],
+ error=None,
+ is_empty=MagicMock(return_value=False),
+ )
+ store.get_lesson_link.return_value = "https://example.com/python/1"
+ return store
diff --git a/backend/tests/test_ai_generator.py b/backend/tests/test_ai_generator.py
new file mode 100644
index 000000000..4a6ec874f
--- /dev/null
+++ b/backend/tests/test_ai_generator.py
@@ -0,0 +1,256 @@
+"""
+Tests for AIGenerator.
+
+anthropic.Anthropic is mocked at import time — no real API calls made.
+All tests verify external behavior: API call count, arguments passed, text returned.
+"""
+import pytest
+from unittest.mock import MagicMock, patch
+from ai_generator import AIGenerator
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _text_block(text):
+ block = MagicMock()
+ block.type = "text"
+ block.text = text
+ return block
+
+
+def _tool_use_block(name, tool_id, input_dict):
+ block = MagicMock()
+ block.type = "tool_use"
+ block.name = name
+ block.id = tool_id
+ block.input = input_dict
+ return block
+
+
+def _response(stop_reason, content):
+ r = MagicMock()
+ r.stop_reason = stop_reason
+ r.content = content
+ return r
+
+
+def _tool_use_response(name="search_course_content", tool_id="tu_001", query="test query"):
+ block = _tool_use_block(name, tool_id, {"query": query})
+ return _response(stop_reason="tool_use", content=[block])
+
+
+def _end_turn_response(text="final answer"):
+ return _response(stop_reason="end_turn", content=[_text_block(text)])
+
+
+@pytest.fixture
+def generator():
+ with patch("ai_generator.anthropic.Anthropic") as mock_cls:
+ mock_client = MagicMock()
+ mock_cls.return_value = mock_client
+ gen = AIGenerator(api_key="test-key", model="claude-test-model")
+ gen._mock_client = mock_client # expose for assertions
+ yield gen
+
+
+@pytest.fixture
+def tool_manager():
+ tm = MagicMock()
+ tm.execute_tool.return_value = "search results text"
+ return tm
+
+
+# ---------------------------------------------------------------------------
+# 0 tool rounds — direct response
+# ---------------------------------------------------------------------------
+
+class TestDirectResponse:
+ def test_returns_text_on_end_turn(self, generator):
+ generator._mock_client.messages.create.return_value = _end_turn_response("Hello!")
+ assert generator.generate_response(query="Say hi") == "Hello!"
+
+ def test_makes_exactly_one_api_call(self, generator):
+ generator._mock_client.messages.create.return_value = _end_turn_response()
+ generator.generate_response(query="test")
+ assert generator._mock_client.messages.create.call_count == 1
+
+ def test_passes_tools_and_tool_choice_to_api(self, generator):
+ generator._mock_client.messages.create.return_value = _end_turn_response()
+ tools = [{"name": "search", "description": "..."}]
+ generator.generate_response(query="test", tools=tools)
+ kwargs = generator._mock_client.messages.create.call_args.kwargs
+ assert kwargs["tools"] == tools
+ assert kwargs["tool_choice"] == {"type": "auto"}
+
+ def test_does_not_pass_tools_when_none_provided(self, generator):
+ generator._mock_client.messages.create.return_value = _end_turn_response()
+ generator.generate_response(query="test")
+ kwargs = generator._mock_client.messages.create.call_args.kwargs
+ assert "tools" not in kwargs
+ assert "tool_choice" not in kwargs
+
+ def test_conversation_history_appears_in_system_prompt(self, generator):
+ generator._mock_client.messages.create.return_value = _end_turn_response()
+ generator.generate_response(query="test", conversation_history="User: hi\nAssistant: hello")
+ kwargs = generator._mock_client.messages.create.call_args.kwargs
+ assert "User: hi" in kwargs["system"]
+ assert "Assistant: hello" in kwargs["system"]
+
+ def test_tool_manager_not_called_on_end_turn(self, generator, tool_manager):
+ generator._mock_client.messages.create.return_value = _end_turn_response()
+ generator.generate_response(query="test", tool_manager=tool_manager)
+ tool_manager.execute_tool.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# 1 tool round — Claude calls a tool then answers
+# ---------------------------------------------------------------------------
+
+class TestOneToolRound:
+ def _setup(self, generator, tool_manager, final_text="final answer"):
+ generator._mock_client.messages.create.side_effect = [
+ _tool_use_response(query="functions"),
+ _end_turn_response(final_text),
+ ]
+ return [{"name": "search_course_content"}]
+
+ def test_makes_two_api_calls(self, generator, tool_manager):
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="what are functions", tools=tools, tool_manager=tool_manager)
+ assert generator._mock_client.messages.create.call_count == 2
+
+ def test_tool_manager_called_once_with_correct_args(self, generator, tool_manager):
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="what are functions", tools=tools, tool_manager=tool_manager)
+ tool_manager.execute_tool.assert_called_once_with("search_course_content", query="functions")
+
+ def test_second_call_includes_tool_result_as_user_message(self, generator, tool_manager):
+ """The second API call must carry the tool result in a user message."""
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="what are functions", tools=tools, tool_manager=tool_manager)
+ second_kwargs = generator._mock_client.messages.create.call_args_list[1].kwargs
+ user_msgs = [m for m in second_kwargs["messages"] if m["role"] == "user"]
+ last_user = user_msgs[-1]
+ assert isinstance(last_user["content"], list)
+ assert last_user["content"][0]["type"] == "tool_result"
+
+ def test_second_call_keeps_tools_available(self, generator, tool_manager):
+ """Tools remain in the second call so Claude could make a second tool call."""
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="what are functions", tools=tools, tool_manager=tool_manager)
+ second_kwargs = generator._mock_client.messages.create.call_args_list[1].kwargs
+ assert "tools" in second_kwargs
+
+ def test_returns_text_from_final_response(self, generator, tool_manager):
+ tools = self._setup(generator, tool_manager, final_text="The answer is 42")
+ result = generator.generate_response(query="what's the answer", tools=tools, tool_manager=tool_manager)
+ assert result == "The answer is 42"
+
+
+# ---------------------------------------------------------------------------
+# 2 tool rounds — cap reached, forced final call without tools
+# ---------------------------------------------------------------------------
+
+class TestTwoToolRounds:
+ def _setup(self, generator, tool_manager, final_text="synthesized answer"):
+ generator._mock_client.messages.create.side_effect = [
+ _tool_use_response(tool_id="tu_001", query="outline query"),
+ _tool_use_response(tool_id="tu_002", query="content query"),
+ _end_turn_response(final_text),
+ ]
+ return [{"name": "search_course_content"}]
+
+ def test_makes_three_api_calls(self, generator, tool_manager):
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="complex query", tools=tools, tool_manager=tool_manager)
+ assert generator._mock_client.messages.create.call_count == 3
+
+ def test_tool_manager_called_twice(self, generator, tool_manager):
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="complex query", tools=tools, tool_manager=tool_manager)
+ assert tool_manager.execute_tool.call_count == 2
+
+ def test_third_call_has_no_tools(self, generator, tool_manager):
+ """After the cap, the forced final call must NOT include tools."""
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="complex query", tools=tools, tool_manager=tool_manager)
+ third_kwargs = generator._mock_client.messages.create.call_args_list[2].kwargs
+ assert "tools" not in third_kwargs
+ assert "tool_choice" not in third_kwargs
+
+ def test_messages_accumulate_across_rounds(self, generator, tool_manager):
+ """The third call must carry the full conversation: user + 2×(assistant+tool_result)."""
+ tools = self._setup(generator, tool_manager)
+ generator.generate_response(query="complex query", tools=tools, tool_manager=tool_manager)
+ third_kwargs = generator._mock_client.messages.create.call_args_list[2].kwargs
+ msgs = third_kwargs["messages"]
+ roles = [m["role"] for m in msgs]
+ # Expected: user, assistant, user(tool_result), assistant, user(tool_result) = 5 messages
+ assert roles == ["user", "assistant", "user", "assistant", "user"]
+
+ def test_returns_text_from_third_response(self, generator, tool_manager):
+ tools = self._setup(generator, tool_manager, final_text="complete answer")
+ result = generator.generate_response(query="complex query", tools=tools, tool_manager=tool_manager)
+ assert result == "complete answer"
+
+
+# ---------------------------------------------------------------------------
+# Tool execution errors
+# ---------------------------------------------------------------------------
+
+class TestToolExecutionError:
+ def _setup_with_error(self, generator, tool_manager, final_text="sorry, error"):
+ tool_manager.execute_tool.side_effect = RuntimeError("DB exploded")
+ generator._mock_client.messages.create.side_effect = [
+ _tool_use_response(query="failing query"),
+ _end_turn_response(final_text),
+ ]
+ return [{"name": "search_course_content"}]
+
+ def test_error_does_not_propagate(self, generator, tool_manager):
+ tools = self._setup_with_error(generator, tool_manager)
+ # Should not raise
+ result = generator.generate_response(query="test", tools=tools, tool_manager=tool_manager)
+ assert isinstance(result, str)
+
+ def test_error_triggers_final_api_call(self, generator, tool_manager):
+ """Even on tool error a final API call is made so Claude can respond."""
+ tools = self._setup_with_error(generator, tool_manager)
+ generator.generate_response(query="test", tools=tools, tool_manager=tool_manager)
+ assert generator._mock_client.messages.create.call_count == 2
+
+ def test_error_final_call_has_no_tools(self, generator, tool_manager):
+ tools = self._setup_with_error(generator, tool_manager)
+ generator.generate_response(query="test", tools=tools, tool_manager=tool_manager)
+ second_kwargs = generator._mock_client.messages.create.call_args_list[1].kwargs
+ assert "tools" not in second_kwargs
+
+ def test_error_string_appears_as_tool_result(self, generator, tool_manager):
+ """Claude must receive the error as a tool_result so it can acknowledge it."""
+ tools = self._setup_with_error(generator, tool_manager)
+ generator.generate_response(query="test", tools=tools, tool_manager=tool_manager)
+ second_kwargs = generator._mock_client.messages.create.call_args_list[1].kwargs
+ user_msgs = [m for m in second_kwargs["messages"] if m["role"] == "user"]
+ last_user = user_msgs[-1]
+ tool_result_content = last_user["content"][0]["content"]
+ assert "Error" in tool_result_content
+
+ def test_error_returns_text_from_final_response(self, generator, tool_manager):
+ tools = self._setup_with_error(generator, tool_manager, final_text="could not retrieve")
+ result = generator.generate_response(query="test", tools=tools, tool_manager=tool_manager)
+ assert result == "could not retrieve"
+
+
+# ---------------------------------------------------------------------------
+# System prompt
+# ---------------------------------------------------------------------------
+
+class TestSystemPrompt:
+ def test_prompt_mentions_two_sequential_tool_calls(self):
+ assert "2" in AIGenerator.SYSTEM_PROMPT
+ assert "sequential" in AIGenerator.SYSTEM_PROMPT.lower()
+
+ def test_max_tool_rounds_constant_is_two(self):
+ assert AIGenerator.MAX_TOOL_ROUNDS == 2
diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py
new file mode 100644
index 000000000..bd04decdf
--- /dev/null
+++ b/backend/tests/test_api.py
@@ -0,0 +1,135 @@
+"""
+API endpoint tests for the RAG chatbot.
+
+All tests use the test_client and mock_rag_system fixtures defined in
+conftest.py. The test app mirrors every route in app.py but omits the
+static-file mount and the module-level RAGSystem initialisation, so these
+tests run without a real database, Anthropic key, or frontend directory.
+"""
+import pytest
+
+
+# ---------------------------------------------------------------------------
+# POST /api/query
+# ---------------------------------------------------------------------------
+
+class TestQueryEndpoint:
+
+ def test_returns_200_with_answer_and_sources(self, test_client):
+ response = test_client.post("/api/query", json={"query": "What is Python?"})
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["answer"] == "Test answer about Python."
+ assert data["sources"] == ["Course A - Lesson 1"]
+
+ def test_auto_creates_session_when_none_provided(self, test_client, mock_rag_system):
+ response = test_client.post("/api/query", json={"query": "What is Python?"})
+
+ assert response.status_code == 200
+ assert response.json()["session_id"] == "session_1"
+ mock_rag_system.session_manager.create_session.assert_called_once()
+
+ def test_uses_caller_supplied_session_id(self, test_client, mock_rag_system):
+ response = test_client.post(
+ "/api/query",
+ json={"query": "What is Python?", "session_id": "existing_session"},
+ )
+
+ assert response.status_code == 200
+ assert response.json()["session_id"] == "existing_session"
+ # No new session should have been created
+ mock_rag_system.session_manager.create_session.assert_not_called()
+
+ def test_passes_query_and_session_to_rag(self, test_client, mock_rag_system):
+ test_client.post(
+ "/api/query",
+ json={"query": "What is Python?", "session_id": "session_1"},
+ )
+
+ mock_rag_system.query.assert_called_once_with("What is Python?", "session_1")
+
+ def test_returns_500_when_rag_raises(self, test_client, mock_rag_system):
+ mock_rag_system.query.side_effect = RuntimeError("Vector store unavailable")
+
+ response = test_client.post("/api/query", json={"query": "crash?"})
+
+ assert response.status_code == 500
+ assert "Vector store unavailable" in response.json()["detail"]
+
+ def test_returns_422_when_query_field_missing(self, test_client):
+ response = test_client.post("/api/query", json={"session_id": "s1"})
+
+ assert response.status_code == 422
+
+ def test_empty_sources_list_is_valid(self, test_client, mock_rag_system):
+ mock_rag_system.query.return_value = ("No sources answer.", [])
+
+ response = test_client.post("/api/query", json={"query": "obscure question"})
+
+ assert response.status_code == 200
+ assert response.json()["sources"] == []
+
+
+# ---------------------------------------------------------------------------
+# GET /api/courses
+# ---------------------------------------------------------------------------
+
+class TestCoursesEndpoint:
+
+ def test_returns_200_with_course_stats(self, test_client):
+ response = test_client.get("/api/courses")
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["total_courses"] == 2
+ assert data["course_titles"] == ["Course A", "Course B"]
+
+ def test_calls_get_course_analytics(self, test_client, mock_rag_system):
+ test_client.get("/api/courses")
+
+ mock_rag_system.get_course_analytics.assert_called_once()
+
+ def test_returns_500_when_analytics_raises(self, test_client, mock_rag_system):
+ mock_rag_system.get_course_analytics.side_effect = Exception("DB connection error")
+
+ response = test_client.get("/api/courses")
+
+ assert response.status_code == 500
+ assert "DB connection error" in response.json()["detail"]
+
+ def test_empty_catalog_returns_zero_courses(self, test_client, mock_rag_system):
+ mock_rag_system.get_course_analytics.return_value = {
+ "total_courses": 0,
+ "course_titles": [],
+ }
+
+ response = test_client.get("/api/courses")
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["total_courses"] == 0
+ assert data["course_titles"] == []
+
+
+# ---------------------------------------------------------------------------
+# DELETE /api/session/{session_id}
+# ---------------------------------------------------------------------------
+
+class TestSessionEndpoint:
+
+ def test_returns_200_with_cleared_status(self, test_client):
+ response = test_client.delete("/api/session/session_1")
+
+ assert response.status_code == 200
+ assert response.json() == {"status": "cleared"}
+
+ def test_passes_session_id_to_clear_session(self, test_client, mock_rag_system):
+ test_client.delete("/api/session/my_session")
+
+ mock_rag_system.session_manager.clear_session.assert_called_once_with("my_session")
+
+ def test_clears_arbitrary_session_id(self, test_client, mock_rag_system):
+ test_client.delete("/api/session/some-uuid-1234")
+
+ mock_rag_system.session_manager.clear_session.assert_called_once_with("some-uuid-1234")
diff --git a/backend/tests/test_rag_system.py b/backend/tests/test_rag_system.py
new file mode 100644
index 000000000..1358e96bd
--- /dev/null
+++ b/backend/tests/test_rag_system.py
@@ -0,0 +1,212 @@
+"""
+Integration tests for RAGSystem.
+
+VectorStore, AIGenerator, DocumentProcessor, and SessionManager are all
+patched at the module level — no real I/O happens.
+"""
+import pytest
+from unittest.mock import MagicMock, patch
+
+
+# ---------------------------------------------------------------------------
+# Fixture
+# ---------------------------------------------------------------------------
+
+@pytest.fixture
+def rag():
+ """
+ Return a RAGSystem instance with all external dependencies mocked.
+ Patches applied in conftest sys.path setup allow `from rag_system import ...`.
+ """
+ with patch("rag_system.VectorStore") as MockVectorStore, \
+ patch("rag_system.AIGenerator") as MockAIGenerator, \
+ patch("rag_system.DocumentProcessor") as MockDocProcessor, \
+ patch("rag_system.SessionManager") as MockSessionManager:
+
+ # Build mock instances
+ mock_vs = MagicMock()
+ mock_ai = MagicMock()
+ mock_dp = MagicMock()
+ mock_sm = MagicMock()
+
+ MockVectorStore.return_value = mock_vs
+ MockAIGenerator.return_value = mock_ai
+ MockDocProcessor.return_value = mock_dp
+ MockSessionManager.return_value = mock_sm
+
+ # Reasonable defaults
+ mock_ai.generate_response.return_value = "Claude says hello"
+ mock_sm.get_conversation_history.return_value = "previous: hi"
+
+ # Config stub
+ config = MagicMock()
+ config.ANTHROPIC_API_KEY = "test-key"
+ config.ANTHROPIC_MODEL = "claude-test-model"
+ config.CHROMA_PATH = "/tmp/chroma"
+ config.EMBEDDING_MODEL = "all-MiniLM-L6-v2"
+ config.MAX_RESULTS = 5
+ config.CHUNK_SIZE = 800
+ config.CHUNK_OVERLAP = 100
+ config.MAX_HISTORY = 2
+
+ from rag_system import RAGSystem
+ system = RAGSystem(config)
+
+ # Expose mocks for assertions
+ system._mock_vs = mock_vs
+ system._mock_ai = mock_ai
+ system._mock_sm = mock_sm
+
+ yield system
+
+
+# ---------------------------------------------------------------------------
+# query() tests
+# ---------------------------------------------------------------------------
+
+class TestRagSystemQuery:
+ def test_query_returns_response_and_sources(self, rag):
+ """Happy path: returns the AI's text and whatever sources tool_manager has."""
+ # Patch tool_manager on the instance
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = [{"name": "search"}]
+ rag.tool_manager.get_last_sources.return_value = ["Course A - Lesson 1"]
+ rag._mock_ai.generate_response.return_value = "Python uses indentation"
+
+ response, sources = rag.query("What is Python indentation?")
+
+ assert response == "Python uses indentation"
+ assert sources == ["Course A - Lesson 1"]
+
+ def test_query_passes_tools_to_ai_generator(self, rag):
+ """generate_response must be called with tools= and tool_manager=."""
+ tool_defs = [{"name": "search_course_content"}]
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = tool_defs
+ rag.tool_manager.get_last_sources.return_value = []
+
+ rag.query("What is a list?")
+
+ call_kwargs = rag._mock_ai.generate_response.call_args.kwargs
+ assert call_kwargs["tools"] == tool_defs
+ assert call_kwargs["tool_manager"] is rag.tool_manager
+
+ def test_query_passes_conversation_history_for_known_session(self, rag):
+ """When session_id is provided, history is fetched and forwarded."""
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = []
+ rag.tool_manager.get_last_sources.return_value = []
+ rag._mock_sm.get_conversation_history.return_value = "User: hello\nAssistant: hi"
+
+ rag.query("Follow-up question", session_id="sess-001")
+
+ call_kwargs = rag._mock_ai.generate_response.call_args.kwargs
+ assert "User: hello" in call_kwargs["conversation_history"]
+
+ def test_query_with_no_session_skips_history(self, rag):
+ """Without session_id, conversation_history should be None."""
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = []
+ rag.tool_manager.get_last_sources.return_value = []
+
+ rag.query("What is a dict?") # no session_id
+
+ call_kwargs = rag._mock_ai.generate_response.call_args.kwargs
+ assert call_kwargs["conversation_history"] is None
+ rag._mock_sm.get_conversation_history.assert_not_called()
+
+ def test_query_resets_sources_after_retrieval(self, rag):
+ """reset_sources() must be called after get_last_sources()."""
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = []
+ rag.tool_manager.get_last_sources.return_value = ["Source X"]
+
+ rag.query("test query")
+
+ rag.tool_manager.reset_sources.assert_called_once()
+
+ def test_query_updates_session_after_response(self, rag):
+ """add_exchange() must be called with session_id, query, and response."""
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = []
+ rag.tool_manager.get_last_sources.return_value = []
+ rag._mock_ai.generate_response.return_value = "answer text"
+
+ rag.query("user question", session_id="sess-xyz")
+
+ rag._mock_sm.add_exchange.assert_called_once_with(
+ "sess-xyz", "user question", "answer text"
+ )
+
+ def test_query_does_not_update_session_without_session_id(self, rag):
+ """Without session_id, add_exchange() should NOT be called."""
+ rag.tool_manager = MagicMock()
+ rag.tool_manager.get_tool_definitions.return_value = []
+ rag.tool_manager.get_last_sources.return_value = []
+
+ rag.query("anonymous question")
+
+ rag._mock_sm.add_exchange.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# add_course_folder() tests
+# ---------------------------------------------------------------------------
+
+class TestRagSystemAddCourseFolder:
+ def test_nonexistent_folder_returns_zero(self, rag):
+ """If the folder does not exist, return (0, 0) without touching the store."""
+ courses, chunks = rag.add_course_folder("/tmp/does_not_exist_xyzzy")
+ assert courses == 0
+ assert chunks == 0
+
+ def test_skips_courses_that_already_exist(self, rag, tmp_path):
+ """Courses already in the vector store must not be re-added."""
+ # Create a dummy .txt file so the folder is non-empty
+ course_file = tmp_path / "course.txt"
+ course_file.write_text("dummy")
+
+ # DocumentProcessor returns a course whose title is already indexed
+ mock_course = MagicMock()
+ mock_course.title = "Existing Course"
+ rag.document_processor.process_course_document.return_value = (mock_course, [])
+ rag._mock_vs.get_existing_course_titles.return_value = ["Existing Course"]
+
+ courses, chunks = rag.add_course_folder(str(tmp_path))
+
+ assert courses == 0
+ rag._mock_vs.add_course_metadata.assert_not_called()
+
+ def test_adds_new_course_to_vector_store(self, rag, tmp_path):
+ """A course title not yet in the store should be added."""
+ course_file = tmp_path / "new_course.txt"
+ course_file.write_text("dummy")
+
+ mock_course = MagicMock()
+ mock_course.title = "Brand New Course"
+ mock_chunks = [MagicMock(), MagicMock()]
+ rag.document_processor.process_course_document.return_value = (mock_course, mock_chunks)
+ rag._mock_vs.get_existing_course_titles.return_value = []
+
+ courses, chunks = rag.add_course_folder(str(tmp_path))
+
+ assert courses == 1
+ assert chunks == 2
+ rag._mock_vs.add_course_metadata.assert_called_once_with(mock_course)
+ rag._mock_vs.add_course_content.assert_called_once_with(mock_chunks)
+
+ def test_clear_existing_calls_clear_all_data(self, rag, tmp_path):
+ """clear_existing=True must call clear_all_data before processing."""
+ rag._mock_vs.get_existing_course_titles.return_value = []
+
+ rag.add_course_folder(str(tmp_path), clear_existing=True)
+
+ rag._mock_vs.clear_all_data.assert_called_once()
+
+ def test_clear_existing_false_does_not_clear(self, rag, tmp_path):
+ """clear_existing=False (default) must NOT call clear_all_data."""
+ rag._mock_vs.get_existing_course_titles.return_value = []
+
+ rag.add_course_folder(str(tmp_path), clear_existing=False)
+
+ rag._mock_vs.clear_all_data.assert_not_called()
diff --git a/backend/tests/test_search_tool.py b/backend/tests/test_search_tool.py
new file mode 100644
index 000000000..beb3d6041
--- /dev/null
+++ b/backend/tests/test_search_tool.py
@@ -0,0 +1,313 @@
+"""
+Tests for CourseSearchTool and ToolManager.
+
+VectorStore is mocked — no ChromaDB needed.
+"""
+import pytest
+from unittest.mock import MagicMock
+from vector_store import SearchResults
+from search_tools import CourseSearchTool, CourseOutlineTool, ToolManager
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_results(docs, metas, error=None):
+ sr = MagicMock(spec=SearchResults)
+ sr.documents = docs
+ sr.metadata = metas
+ sr.distances = [0.1] * len(docs)
+ sr.error = error
+ sr.is_empty = MagicMock(return_value=(len(docs) == 0))
+ return sr
+
+
+def _make_tool(store_search_return=None, lesson_link=None):
+ store = MagicMock()
+ if store_search_return is not None:
+ store.search.return_value = store_search_return
+ store.get_lesson_link.return_value = lesson_link
+ return CourseSearchTool(store), store
+
+
+# ---------------------------------------------------------------------------
+# execute() happy-path tests
+# ---------------------------------------------------------------------------
+
+class TestCourseSearchToolExecute:
+ def test_execute_returns_formatted_results(self):
+ results = _make_results(
+ docs=["Variables hold data"],
+ metas=[{"course_title": "Python Basics", "lesson_number": 1}],
+ )
+ tool, _ = _make_tool(store_search_return=results)
+
+ output = tool.execute(query="what are variables")
+
+ assert "Python Basics" in output
+ assert "Variables hold data" in output
+
+ def test_execute_with_course_name_filter_passes_through(self):
+ results = _make_results(
+ docs=["doc"],
+ metas=[{"course_title": "Python Basics", "lesson_number": 1}],
+ )
+ tool, store = _make_tool(store_search_return=results)
+
+ tool.execute(query="variables", course_name="Python")
+
+ store.search.assert_called_once_with(
+ query="variables",
+ course_name="Python",
+ lesson_number=None,
+ )
+
+ def test_execute_with_lesson_number_filter_passes_through(self):
+ results = _make_results(
+ docs=["doc"],
+ metas=[{"course_title": "Python Basics", "lesson_number": 2}],
+ )
+ tool, store = _make_tool(store_search_return=results)
+
+ tool.execute(query="functions", lesson_number=2)
+
+ store.search.assert_called_once_with(
+ query="functions",
+ course_name=None,
+ lesson_number=2,
+ )
+
+ def test_execute_returns_error_message_on_search_error(self):
+ results = _make_results(docs=[], metas=[], error="No course found matching 'XYZ'")
+ tool, _ = _make_tool(store_search_return=results)
+
+ output = tool.execute(query="anything", course_name="XYZ")
+
+ assert "No course found" in output
+
+ def test_execute_returns_no_content_message_on_empty_results(self):
+ results = _make_results(docs=[], metas=[])
+ tool, _ = _make_tool(store_search_return=results)
+
+ output = tool.execute(query="obscure topic")
+
+ assert "No relevant content found" in output
+
+ def test_execute_includes_course_and_lesson_in_filter_message(self):
+ results = _make_results(docs=[], metas=[])
+ tool, _ = _make_tool(store_search_return=results)
+
+ output = tool.execute(query="obscure topic", course_name="Python", lesson_number=3)
+
+ assert "Python" in output
+ assert "3" in output
+
+
+# ---------------------------------------------------------------------------
+# Source tracking tests
+# ---------------------------------------------------------------------------
+
+class TestCourseSearchToolSources:
+ def test_last_sources_includes_url_when_lesson_link_available(self):
+ results = _make_results(
+ docs=["doc"],
+ metas=[{"course_title": "Python Basics", "lesson_number": 1}],
+ )
+ tool, store = _make_tool(store_search_return=results, lesson_link="https://example.com/1")
+
+ tool.execute(query="variables")
+
+ assert tool.last_sources == ["Python Basics - Lesson 1|https://example.com/1"]
+
+ def test_last_sources_plain_label_when_no_lesson_link(self):
+ results = _make_results(
+ docs=["doc"],
+ metas=[{"course_title": "Python Basics", "lesson_number": 1}],
+ )
+ tool, store = _make_tool(store_search_return=results, lesson_link=None)
+
+ tool.execute(query="variables")
+
+ assert tool.last_sources == ["Python Basics - Lesson 1"]
+
+ def test_last_sources_reset_between_calls(self):
+ results_1 = _make_results(
+ docs=["doc1"],
+ metas=[{"course_title": "Course A", "lesson_number": 1}],
+ )
+ results_2 = _make_results(
+ docs=["doc2"],
+ metas=[{"course_title": "Course B", "lesson_number": 2}],
+ )
+ store = MagicMock()
+ store.search.side_effect = [results_1, results_2]
+ store.get_lesson_link.return_value = None
+ tool = CourseSearchTool(store)
+
+ tool.execute(query="first query")
+ sources_after_first = list(tool.last_sources)
+
+ tool.execute(query="second query")
+ sources_after_second = list(tool.last_sources)
+
+ assert sources_after_first == ["Course A - Lesson 1"]
+ assert sources_after_second == ["Course B - Lesson 2"]
+ # Crucially: no bleed-through from call 1
+ assert "Course A" not in " ".join(sources_after_second)
+
+ def test_last_sources_no_lesson_number(self):
+ """When lesson_number is None in metadata, source label has no 'Lesson N' suffix."""
+ results = _make_results(
+ docs=["doc"],
+ metas=[{"course_title": "Python Basics", "lesson_number": None}],
+ )
+ tool, _ = _make_tool(store_search_return=results, lesson_link=None)
+
+ tool.execute(query="variables")
+
+ assert tool.last_sources == ["Python Basics"]
+
+
+# ---------------------------------------------------------------------------
+# ToolManager tests
+# ---------------------------------------------------------------------------
+
+class TestToolManager:
+ def test_register_and_execute_tool(self):
+ manager = ToolManager()
+ mock_tool = MagicMock()
+ mock_tool.get_tool_definition.return_value = {"name": "my_tool"}
+ mock_tool.execute.return_value = "tool output"
+
+ manager.register_tool(mock_tool)
+ result = manager.execute_tool("my_tool", foo="bar")
+
+ mock_tool.execute.assert_called_once_with(foo="bar")
+ assert result == "tool output"
+
+ def test_execute_unknown_tool_returns_error(self):
+ manager = ToolManager()
+ result = manager.execute_tool("nonexistent_tool")
+ assert "not found" in result
+
+ def test_get_last_sources_aggregates_across_tools(self):
+ manager = ToolManager()
+ mock_tool = MagicMock()
+ mock_tool.get_tool_definition.return_value = {"name": "search_tool"}
+ mock_tool.last_sources = ["Source A", "Source B"]
+
+ manager.register_tool(mock_tool)
+ sources = manager.get_last_sources()
+
+ assert sources == ["Source A", "Source B"]
+
+ def test_reset_sources_clears_all_tools(self):
+ manager = ToolManager()
+ mock_tool = MagicMock()
+ mock_tool.get_tool_definition.return_value = {"name": "search_tool"}
+ mock_tool.last_sources = ["Source A"]
+
+ manager.register_tool(mock_tool)
+ manager.reset_sources()
+
+ assert mock_tool.last_sources == []
+
+ def test_get_tool_definitions_returns_all_registered(self):
+ manager = ToolManager()
+ for name in ("tool_a", "tool_b"):
+ t = MagicMock()
+ t.get_tool_definition.return_value = {"name": name}
+ manager.register_tool(t)
+
+ defs = manager.get_tool_definitions()
+
+ names = [d["name"] for d in defs]
+ assert "tool_a" in names
+ assert "tool_b" in names
+
+ def test_register_tool_without_name_raises_value_error(self):
+ manager = ToolManager()
+ bad_tool = MagicMock()
+ bad_tool.get_tool_definition.return_value = {} # no "name" key
+
+ with pytest.raises(ValueError):
+ manager.register_tool(bad_tool)
+
+
+# ---------------------------------------------------------------------------
+# CourseOutlineTool tests
+# ---------------------------------------------------------------------------
+
+class TestCourseOutlineTool:
+ def _make_outline_tool(self, outline_return):
+ store = MagicMock()
+ store.get_course_outline.return_value = outline_return
+ return CourseOutlineTool(store), store
+
+ def test_execute_no_course_returns_not_found_message(self):
+ tool, _ = self._make_outline_tool(None)
+ result = tool.execute(course_title="Unknown Course")
+ assert "No course found" in result
+ assert "Unknown Course" in result
+
+ def test_execute_includes_course_title(self):
+ outline = {
+ "title": "Python Basics",
+ "course_link": "https://example.com/python",
+ "lessons": [
+ {"lesson_number": 1, "lesson_title": "Variables", "lesson_link": None}
+ ]
+ }
+ tool, _ = self._make_outline_tool(outline)
+ result = tool.execute(course_title="Python")
+ assert "Python Basics" in result
+
+ def test_execute_includes_course_link_when_present(self):
+ outline = {
+ "title": "Python Basics",
+ "course_link": "https://example.com/python",
+ "lessons": []
+ }
+ tool, _ = self._make_outline_tool(outline)
+ result = tool.execute(course_title="Python")
+ assert "https://example.com/python" in result
+
+ def test_execute_omits_link_line_when_absent(self):
+ outline = {
+ "title": "Python Basics",
+ "course_link": None,
+ "lessons": []
+ }
+ tool, _ = self._make_outline_tool(outline)
+ result = tool.execute(course_title="Python")
+ assert "Link:" not in result
+
+ def test_execute_lists_all_lessons(self):
+ outline = {
+ "title": "Python Basics",
+ "course_link": None,
+ "lessons": [
+ {"lesson_number": 1, "lesson_title": "Variables", "lesson_link": None},
+ {"lesson_number": 2, "lesson_title": "Functions", "lesson_link": None},
+ {"lesson_number": 3, "lesson_title": "Classes", "lesson_link": None},
+ ]
+ }
+ tool, _ = self._make_outline_tool(outline)
+ result = tool.execute(course_title="Python")
+ assert "Lesson 1: Variables" in result
+ assert "Lesson 2: Functions" in result
+ assert "Lesson 3: Classes" in result
+
+ def test_execute_shows_lesson_count(self):
+ outline = {
+ "title": "Python Basics",
+ "course_link": None,
+ "lessons": [
+ {"lesson_number": 1, "lesson_title": "Variables", "lesson_link": None},
+ {"lesson_number": 2, "lesson_title": "Functions", "lesson_link": None},
+ ]
+ }
+ tool, _ = self._make_outline_tool(outline)
+ result = tool.execute(course_title="Python")
+ assert "2" in result # lesson count appears somewhere
diff --git a/backend/tests/test_session_manager.py b/backend/tests/test_session_manager.py
new file mode 100644
index 000000000..be3e04b48
--- /dev/null
+++ b/backend/tests/test_session_manager.py
@@ -0,0 +1,127 @@
+"""
+Tests for SessionManager.
+
+No external dependencies — SessionManager is pure in-memory state.
+"""
+import pytest
+from session_manager import SessionManager
+
+
+class TestCreateSession:
+ def test_returns_a_string_id(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ assert isinstance(session_id, str)
+ assert len(session_id) > 0
+
+ def test_each_call_returns_unique_id(self):
+ sm = SessionManager()
+ ids = {sm.create_session() for _ in range(5)}
+ assert len(ids) == 5
+
+ def test_new_session_has_empty_history(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ assert sm.get_conversation_history(session_id) is None
+
+
+class TestAddMessage:
+ def test_message_appears_in_history(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ sm.add_message(session_id, "user", "hello")
+ history = sm.get_conversation_history(session_id)
+ assert "hello" in history
+
+ def test_auto_creates_session_for_unknown_id(self):
+ sm = SessionManager()
+ sm.add_message("ghost-session", "user", "hi")
+ history = sm.get_conversation_history("ghost-session")
+ assert "hi" in history
+
+ def test_history_trimmed_to_max_history_times_two(self):
+ sm = SessionManager(max_history=2) # keeps last 4 messages
+ session_id = sm.create_session()
+ for i in range(6):
+ sm.add_message(session_id, "user", f"message {i}")
+ messages = sm.sessions[session_id]
+ assert len(messages) <= 4
+
+ def test_trim_keeps_most_recent_messages(self):
+ sm = SessionManager(max_history=2)
+ session_id = sm.create_session()
+ for i in range(6):
+ sm.add_message(session_id, "user", f"msg {i}")
+ history = sm.get_conversation_history(session_id)
+ # Oldest messages should be gone
+ assert "msg 0" not in history
+ assert "msg 1" not in history
+ # Most recent should remain
+ assert "msg 5" in history
+
+
+class TestAddExchange:
+ def test_adds_both_user_and_assistant_messages(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ sm.add_exchange(session_id, "What is Python?", "A programming language.")
+ messages = sm.sessions[session_id]
+ assert len(messages) == 2
+ assert messages[0].role == "user"
+ assert messages[1].role == "assistant"
+
+ def test_content_stored_correctly(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ sm.add_exchange(session_id, "user question", "assistant answer")
+ messages = sm.sessions[session_id]
+ assert messages[0].content == "user question"
+ assert messages[1].content == "assistant answer"
+
+
+class TestGetConversationHistory:
+ def test_returns_none_for_unknown_session(self):
+ sm = SessionManager()
+ assert sm.get_conversation_history("does-not-exist") is None
+
+ def test_returns_none_for_none_session_id(self):
+ sm = SessionManager()
+ assert sm.get_conversation_history(None) is None
+
+ def test_returns_none_for_empty_session(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ assert sm.get_conversation_history(session_id) is None
+
+ def test_formats_role_as_title_case(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ sm.add_message(session_id, "user", "hi")
+ sm.add_message(session_id, "assistant", "hello")
+ history = sm.get_conversation_history(session_id)
+ assert "User:" in history
+ assert "Assistant:" in history
+
+ def test_multiple_exchanges_all_appear(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ sm.add_exchange(session_id, "Q1", "A1")
+ sm.add_exchange(session_id, "Q2", "A2")
+ history = sm.get_conversation_history(session_id)
+ assert "Q1" in history
+ assert "A1" in history
+ assert "Q2" in history
+ assert "A2" in history
+
+
+class TestClearSession:
+ def test_clears_all_messages(self):
+ sm = SessionManager()
+ session_id = sm.create_session()
+ sm.add_exchange(session_id, "question", "answer")
+ sm.clear_session(session_id)
+ assert sm.get_conversation_history(session_id) is None
+
+ def test_clear_nonexistent_session_does_not_raise(self):
+ sm = SessionManager()
+ sm.clear_session("nonexistent") # should not raise
diff --git a/backend/tests/test_vector_store.py b/backend/tests/test_vector_store.py
new file mode 100644
index 000000000..49bbc4b6b
--- /dev/null
+++ b/backend/tests/test_vector_store.py
@@ -0,0 +1,243 @@
+"""
+Tests for VectorStore — pure logic and None-metadata safety.
+
+ChromaDB collection interactions are mocked so no real DB is needed.
+"""
+import pytest
+from unittest.mock import MagicMock, patch
+from models import Course, Lesson, CourseChunk
+from vector_store import VectorStore, SearchResults
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def make_store(max_results=5):
+ """Return a VectorStore with all ChromaDB I/O mocked out."""
+ with patch("vector_store.chromadb.PersistentClient") as mock_client_cls, \
+ patch("vector_store.chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction"), \
+ patch("vector_store.SentenceTransformer"):
+ mock_client = MagicMock()
+ mock_client_cls.return_value = mock_client
+ mock_client.get_or_create_collection.return_value = MagicMock()
+ store = VectorStore(
+ chroma_path="/tmp/test_chroma",
+ embedding_model="all-MiniLM-L6-v2",
+ max_results=max_results,
+ )
+ return store
+
+
+# ---------------------------------------------------------------------------
+# _build_filter pure-logic tests
+# ---------------------------------------------------------------------------
+
+class TestBuildFilter:
+ def setup_method(self):
+ self.store = make_store()
+
+ def test_no_filter_when_no_params(self):
+ assert self.store._build_filter(None, None) is None
+
+ def test_filter_with_course_title_only(self):
+ result = self.store._build_filter("Python Basics", None)
+ assert result == {"course_title": "Python Basics"}
+
+ def test_filter_with_lesson_number_only(self):
+ result = self.store._build_filter(None, 3)
+ assert result == {"lesson_number": 3}
+
+ def test_filter_with_both(self):
+ result = self.store._build_filter("Python Basics", 2)
+ assert result == {"$and": [
+ {"course_title": "Python Basics"},
+ {"lesson_number": 2},
+ ]}
+
+
+# ---------------------------------------------------------------------------
+# add_course_metadata None-safety tests
+# ---------------------------------------------------------------------------
+
+class TestAddCourseMetadataNoneSafety:
+ """
+ Verify that None optional fields are sanitised before reaching ChromaDB.
+ After the fix, None values must be replaced with empty strings so the
+ strict ChromaDB 1.0.x validator never sees them.
+ """
+
+ def _make_strict_collection_mock(self):
+ """
+ Return a mock whose .add() side-effect mimics ChromaDB 1.0.x
+ behaviour: raises ValueError when any metadata value is None.
+ """
+ def _strict_add(documents, metadatas, ids):
+ for meta in metadatas:
+ for key, value in meta.items():
+ if value is None:
+ raise ValueError(
+ f"Expected metadata value to be a str, int, float or bool, got None"
+ )
+
+ collection = MagicMock()
+ collection.add.side_effect = _strict_add
+ return collection
+
+ def test_add_course_metadata_with_none_instructor(self):
+ """None instructor must be sanitised to '' — should NOT raise after fix."""
+ store = make_store()
+ store.course_catalog = self._make_strict_collection_mock()
+
+ course = Course(
+ title="Test Course",
+ course_link="https://example.com",
+ instructor=None,
+ lessons=[],
+ )
+ store.add_course_metadata(course) # must not raise
+
+ def test_add_course_metadata_with_none_course_link(self):
+ """None course_link must be sanitised to '' — should NOT raise after fix."""
+ store = make_store()
+ store.course_catalog = self._make_strict_collection_mock()
+
+ course = Course(
+ title="Test Course 2",
+ course_link=None,
+ instructor="Someone",
+ lessons=[],
+ )
+ store.add_course_metadata(course) # must not raise
+
+ def test_add_course_metadata_with_all_fields(self, sample_course):
+ """Fully-populated course should not raise."""
+ store = make_store()
+ store.course_catalog = self._make_strict_collection_mock()
+
+ store.add_course_metadata(sample_course) # should not raise
+
+
+# ---------------------------------------------------------------------------
+# add_course_content None-safety tests
+# ---------------------------------------------------------------------------
+
+class TestAddCourseContentNoneSafety:
+ def _make_strict_collection_mock(self):
+ def _strict_add(documents, metadatas, ids):
+ for meta in metadatas:
+ for key, value in meta.items():
+ if value is None:
+ raise ValueError(
+ f"Expected metadata value to be a str, int, float or bool, got None"
+ )
+
+ collection = MagicMock()
+ collection.add.side_effect = _strict_add
+ return collection
+
+ def test_add_course_content_with_none_lesson_number(self):
+ """None lesson_number must be sanitised to -1 — should NOT raise after fix."""
+ store = make_store()
+ store.course_content = self._make_strict_collection_mock()
+
+ chunks = [
+ CourseChunk(
+ content="some text",
+ course_title="Python Basics",
+ lesson_number=None,
+ chunk_index=0,
+ )
+ ]
+ store.add_course_content(chunks) # must not raise
+
+ def test_add_course_content_with_lesson_number(self):
+ """Integer lesson_number should not raise."""
+ store = make_store()
+ store.course_content = self._make_strict_collection_mock()
+
+ chunks = [
+ CourseChunk(
+ content="some text",
+ course_title="Python Basics",
+ lesson_number=1,
+ chunk_index=0,
+ )
+ ]
+ store.add_course_content(chunks) # should not raise
+
+
+# ---------------------------------------------------------------------------
+# search tests
+# ---------------------------------------------------------------------------
+
+class TestSearch:
+ def _make_store_with_content_mock(self, query_return, doc_count=10):
+ store = make_store()
+ store.course_content = MagicMock()
+ store.course_content.count.return_value = doc_count
+ store.course_content.query.return_value = query_return
+ return store
+
+ def _chroma_result(self, docs, metas=None, dists=None):
+ if metas is None:
+ metas = [{"course_title": "Python Basics", "lesson_number": 1}] * len(docs)
+ if dists is None:
+ dists = [0.1] * len(docs)
+ return {
+ "documents": [docs],
+ "metadatas": [metas],
+ "distances": [dists],
+ }
+
+ def test_search_returns_results(self):
+ chroma_result = self._chroma_result(["doc 1", "doc 2"])
+ store = self._make_store_with_content_mock(chroma_result)
+
+ results = store.search("python variables")
+
+ assert not results.is_empty()
+ assert len(results.documents) == 2
+
+ def test_search_with_fewer_results_than_n_results(self):
+ """When ChromaDB returns fewer docs than requested, should not raise."""
+ chroma_result = self._chroma_result(["only one doc"])
+ store = self._make_store_with_content_mock(chroma_result)
+
+ results = store.search("python variables")
+
+ assert not results.is_empty()
+ assert len(results.documents) == 1
+
+ def test_search_with_no_results_returns_empty(self):
+ chroma_result = self._chroma_result([])
+ store = self._make_store_with_content_mock(chroma_result)
+
+ results = store.search("obscure query")
+
+ assert results.is_empty()
+
+ def test_search_passes_filter_to_chroma(self):
+ """Filter built from course_title should be forwarded to ChromaDB."""
+ chroma_result = self._chroma_result(["doc"])
+ store = self._make_store_with_content_mock(chroma_result)
+ # Bypass course resolution by making _resolve_course_name return a title
+ store._resolve_course_name = MagicMock(return_value="Python Basics")
+
+ store.search("variables", course_name="Python")
+
+ call_kwargs = store.course_content.query.call_args.kwargs
+ assert call_kwargs["where"] == {"course_title": "Python Basics"}
+
+ def test_search_exception_returns_error_result(self):
+ """If ChromaDB raises, search() should return SearchResults with error set."""
+ store = make_store()
+ store.course_content = MagicMock()
+ store.course_content.count.return_value = 10
+ store.course_content.query.side_effect = RuntimeError("chroma exploded")
+
+ results = store.search("anything")
+
+ assert results.error is not None
+ assert "chroma exploded" in results.error
+ assert results.is_empty()
diff --git a/backend/vector_store.py b/backend/vector_store.py
index 390abe71c..d0f67db81 100644
--- a/backend/vector_store.py
+++ b/backend/vector_store.py
@@ -90,9 +90,11 @@ def search(self,
search_limit = limit if limit is not None else self.max_results
try:
+ count = self.course_content.count()
+ safe_limit = min(search_limit, count) if count > 0 else 1
results = self.course_content.query(
query_texts=[query],
- n_results=search_limit,
+ n_results=safe_limit,
where=filter_dict
)
return SearchResults.from_chroma(results)
@@ -151,8 +153,8 @@ def add_course_metadata(self, course: Course):
documents=[course_text],
metadatas=[{
"title": course.title,
- "instructor": course.instructor,
- "course_link": course.course_link,
+ "instructor": course.instructor or "",
+ "course_link": course.course_link or "",
"lessons_json": json.dumps(lessons_metadata), # Serialize as JSON string
"lesson_count": len(course.lessons)
}],
@@ -167,7 +169,7 @@ def add_course_content(self, chunks: List[CourseChunk]):
documents = [chunk.content for chunk in chunks]
metadatas = [{
"course_title": chunk.course_title,
- "lesson_number": chunk.lesson_number,
+ "lesson_number": chunk.lesson_number if chunk.lesson_number is not None else -1,
"chunk_index": chunk.chunk_index
} for chunk in chunks]
# Use title with chunk index for unique IDs
@@ -246,6 +248,28 @@ def get_course_link(self, course_title: str) -> Optional[str]:
print(f"Error getting course link: {e}")
return None
+ def get_course_outline(self, course_name: str) -> Optional[Dict[str, Any]]:
+ """Get course outline (title, link, lessons) with fuzzy name matching"""
+ import json
+ course_title = self._resolve_course_name(course_name)
+ if not course_title:
+ return None
+ try:
+ results = self.course_catalog.get(ids=[course_title])
+ if results and results['metadatas']:
+ meta = results['metadatas'][0]
+ outline = {
+ 'title': meta.get('title'),
+ 'course_link': meta.get('course_link'),
+ 'lessons': []
+ }
+ if meta.get('lessons_json'):
+ outline['lessons'] = json.loads(meta['lessons_json'])
+ return outline
+ except Exception as e:
+ print(f"Error getting course outline: {e}")
+ return None
+
def get_lesson_link(self, course_title: str, lesson_number: int) -> Optional[str]:
"""Get lesson link for a given course title and lesson number"""
import json
diff --git a/frontend-changes.md b/frontend-changes.md
new file mode 100644
index 000000000..58b36976e
--- /dev/null
+++ b/frontend-changes.md
@@ -0,0 +1,165 @@
+# Frontend Code Quality Changes
+
+## Summary
+
+Added frontend code quality tooling (Prettier for formatting, ESLint for linting) and applied consistent formatting across all frontend files.
+
+---
+
+## New Files
+
+### `frontend/package.json`
+Declares the frontend as a Node project and wires up quality check scripts:
+- `npm run format` — formats all JS/CSS/HTML with Prettier (write mode)
+- `npm run format:check` — checks formatting without modifying files (CI-safe)
+- `npm run lint` — lints `script.js` with ESLint
+- `npm run lint:fix` — auto-fixes ESLint issues
+- `npm run quality` — runs both `format:check` and `lint` (full check)
+- `npm run quality:fix` — runs both `format` and `lint:fix` (full auto-fix)
+
+Dev dependencies: `prettier@^3.3.3`, `eslint@^8.57.0`
+
+### `frontend/.prettierrc`
+Prettier configuration:
+- 4-space indentation, 100-char print width
+- Single quotes, trailing commas (ES5), LF line endings
+
+### `frontend/.prettierignore`
+Excludes `node_modules/` from Prettier.
+
+### `frontend/.eslintrc.json`
+ESLint configuration targeting browser ES2021:
+- Errors on `no-undef`, `eqeqeq` (strict equality), `no-var`, `curly`
+- Warns on `no-unused-vars`, `prefer-const`
+- Registers `marked` as a known read-only global (loaded via CDN)
+
+### `scripts/check-frontend.sh`
+Shell script to run all frontend quality checks from the repo root:
+```bash
+# Check only (exits non-zero if anything fails):
+./scripts/check-frontend.sh
+
+# Auto-fix formatting and lint issues:
+./scripts/check-frontend.sh --fix
+```
+
+---
+
+## Modified Files
+
+### `frontend/script.js`
+Applied Prettier-consistent formatting:
+- Single quotes throughout
+- Trailing commas on multi-line function arguments and object literals
+- Explicit `curly` braces on all `if` bodies
+- Arrow function parentheses around single parameters
+- Consistent blank lines between logical sections
+
+### `frontend/style.css`
+Applied Prettier-consistent formatting:
+- Each CSS selector on its own line (e.g. `*,\n*::before,\n*::after`)
+- `h1`, `h2`, `h3` font-size rules expanded to separate blocks
+- `@keyframes bounce` selector list expanded (`0%,\n80%,\n100%`)
+- `.no-courses, .loading, .error` selector list expanded
+- Removed stale inline comment on `.course-titles` block
+- Consistent blank lines between rule blocks
+
+### `frontend/index.html`
+Applied Prettier-consistent formatting:
+- `` lowercased
+- Self-closing void elements (``, ``, ``)
+- 4-space indentation throughout
+- Long `