Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_bundle_info,
get_bundle_versions,
init_bundle,
inspect_ckpt,
load,
onnx_export,
push_to_hf_hub,
Expand Down
1 change: 1 addition & 0 deletions monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
download,
download_large_files,
init_bundle,
inspect_ckpt,
onnx_export,
run,
run_workflow,
Expand Down
73 changes: 73 additions & 0 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,3 +2013,76 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str |
lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"])
lf_data.pop("path")
download_url(**lf_data)

def inspect_ckpt(
path: str,
print_all_vars: bool = True,
compute_hash: bool = False,
hash_type: str = "md5",
) -> dict:
"""
Inspect the variables and shapes saved in a checkpoint file.
Prints a human-readable summary of the tensor names, shapes, and dtypes
stored in the checkpoint, similar to TensorFlow's inspect_checkpoint tool.
Optionally also computes the hash value of the file (useful when creating
a ``large_files.yml`` for model-zoo bundles).

Typical usage examples:

.. code-block:: bash

# Display all tensor names, shapes, and dtypes:
python -m monai.bundle inspect_ckpt --path model.pt

# Suppress individual variable printing (only show file-level info):
python -m monai.bundle inspect_ckpt --path model.pt --print_all_vars false

# Also compute md5 hash of the checkpoint file:
python -m monai.bundle inspect_ckpt --path model.pt --compute_hash true

# Use sha256 hash instead of md5:
python -m monai.bundle inspect_ckpt --path model.pt --compute_hash true --hash_type sha256

Args:
path: path to the checkpoint file to inspect.
print_all_vars: whether to print individual variable names, shapes,
and dtypes. Default to ``True``.
compute_hash: whether to compute and print the hash value of the
checkpoint file. Default to ``False``.
hash_type: the hash type to use when ``compute_hash`` is ``True``.
Should be ``"md5"`` or ``"sha256"``. Default to ``"md5"``.

Returns:
A dictionary mapping variable names to a dict containing
``"shape"`` (tuple) and ``"dtype"`` (str) for each tensor.
"""
import hashlib

_log_input_summary(tag="inspect_ckpt", args={"path": path, "print_all_vars": print_all_vars, "compute_hash": compute_hash})

ckpt = torch.load(path, map_location="cpu", weights_only=True)
if not isinstance(ckpt, Mapping):
ckpt = get_state_dict(ckpt)

var_info: dict = {}
for name, val in ckpt.items():
if isinstance(val, torch.Tensor):
var_info[name] = {"shape": tuple(val.shape), "dtype": str(val.dtype)}
else:
var_info[name] = {"shape": None, "dtype": type(val).__name__}

logger.info(f"checkpoint file: {path}")
logger.info(f"total variables: {len(var_info)}")
if print_all_vars:
for name, info in var_info.items():
logger.info(f" {name}: shape={info['shape']}, dtype={info['dtype']}")

if compute_hash:
h = hashlib.new(hash_type)
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
digest = h.hexdigest()
logger.info(f"{hash_type} hash: {digest}")

return var_info
4 changes: 2 additions & 2 deletions monai/losses/spectral_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(
self.fft_norm = fft_norm

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
input_amplitude = self._get_fft_amplitude(target)
target_amplitude = self._get_fft_amplitude(input)
input_amplitude = self._get_fft_amplitude(input)
target_amplitude = self._get_fft_amplitude(target)

# Compute distance between amplitude of frequency components
# See Section 3.3 from https://arxiv.org/abs/2005.00341
Expand Down
6 changes: 3 additions & 3 deletions monai/losses/ssim_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# 2D data
x = torch.ones([1,1,10,10])/2
y = torch.ones([1,1,10,10])/2
print(1-SSIMLoss(spatial_dims=2)(x,y))
print(SSIMLoss(spatial_dims=2)(x,y))

# pseudo-3D data
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
y = torch.ones([1,5,10,10])/2
print(1-SSIMLoss(spatial_dims=2)(x,y))
print(SSIMLoss(spatial_dims=2)(x,y))

# 3D data
x = torch.ones([1,1,10,10,10])/2
y = torch.ones([1,1,10,10,10])/2
print(1-SSIMLoss(spatial_dims=3)(x,y))
print(SSIMLoss(spatial_dims=3)(x,y))
"""
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
loss: torch.Tensor = 1 - ssim_value
Expand Down
70 changes: 70 additions & 0 deletions tests/bundle/test_bundle_inspect_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
import tempfile
import unittest

import torch

from monai.bundle import inspect_ckpt


class TestInspectCkpt(unittest.TestCase):
def setUp(self):
# Create a temporary checkpoint file with a simple state dict
self.tmp_dir = tempfile.mkdtemp()
self.ckpt_path = os.path.join(self.tmp_dir, "model.pt")
state_dict = {
"layer1.weight": torch.randn(4, 3),
"layer1.bias": torch.zeros(4),
"layer2.weight": torch.randn(2, 4),
}
torch.save(state_dict, self.ckpt_path)

def test_returns_dict_with_correct_keys(self):
result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False)
self.assertIsInstance(result, dict)
self.assertIn("layer1.weight", result)
self.assertIn("layer1.bias", result)
self.assertIn("layer2.weight", result)

def test_shapes_are_correct(self):
result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False)
self.assertEqual(result["layer1.weight"]["shape"], (4, 3))
self.assertEqual(result["layer1.bias"]["shape"], (4,))
self.assertEqual(result["layer2.weight"]["shape"], (2, 4))

def test_dtype_is_reported(self):
result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False)
self.assertIn("dtype", result["layer1.weight"])
self.assertTrue(result["layer1.weight"]["dtype"].startswith("torch."))

def test_compute_hash_md5(self):
# Should not raise; hash value is logged but not returned in dict
result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False, compute_hash=True, hash_type="md5")
self.assertIsInstance(result, dict)

def test_compute_hash_sha256(self):
result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False, compute_hash=True, hash_type="sha256")
self.assertIsInstance(result, dict)

def test_print_all_vars_true_does_not_raise(self):
# Should log each variable without raising
try:
inspect_ckpt(path=self.ckpt_path, print_all_vars=True)
except Exception as e:
self.fail(f"inspect_ckpt raised an exception with print_all_vars=True: {e}")


if __name__ == "__main__":
unittest.main()
Loading