Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
Binary file not shown.
Binary file not shown.
Binary file not shown.
126 changes: 126 additions & 0 deletions tests/test_hyperspherical_descent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
import torch
import math
from src.hyperspherical_descent import hyperspherical_descent


class TestHypersphericalDescent:
@pytest.fixture
def seed(self):
torch.manual_seed(42)
yield
torch.manual_seed(42)

def test_hyperspherical_descent_preserves_shape(self):
"""hyperspherical_descent should preserve input shape."""
for shape in [(4, 4), (4, 3), (3, 4), (5, 2), (2, 5), (10,)]:
W = torch.randn(*shape)
G = torch.randn(*shape)
result = hyperspherical_descent(W, G)
assert result.shape == shape

def test_hyperspherical_descent_1d_input(self):
"""hyperspherical_descent should work with 1D input."""
W = torch.randn(10)
G = torch.randn(10)
result = hyperspherical_descent(W, G)
assert result.shape == (10,)

def test_hyperspherical_descent_2d_input(self):
"""hyperspherical_descent should work with 2D input."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = hyperspherical_descent(W, G)
assert result.shape == (4, 4)

def test_hyperspherical_descent_no_nan(self):
"""hyperspherical_descent should not produce NaN values."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = hyperspherical_descent(W, G)
assert not torch.isnan(result).any()

def test_hyperspherical_descent_no_inf(self):
"""hyperspherical_descent should not produce Inf values."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = hyperspherical_descent(W, G)
assert not torch.isinf(result).any()

def test_hyperspherical_descent_on_unit_sphere(self):
"""Result should lie on unit sphere."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = hyperspherical_descent(W, G)
# Check that vector has unit norm (for all elements)
for row in result:
norm = row.norm()
assert math.isclose(norm, 1.0, rel_tol=1e-5)

def test_hyperspherical_descent_unit_vector(self):
"""hyperspherical_descent should keep unit vectors on unit sphere."""
W = torch.randn(10)
W = W / W.norm() # Make it a unit vector
G = torch.randn(10)
result = hyperspherical_descent(W, G)
# Result should also be a unit vector
assert math.isclose(result.norm().item(), 1.0, rel_tol=1e-5)

def test_hyperspherical_descent_custom_eta(self):
"""hyperspherical_descent should respect eta parameter."""
W = torch.randn(10)
G = torch.randn(10)
result1 = hyperspherical_descent(W, G, eta=0.01)
result2 = hyperspherical_descent(W, G, eta=1.0)
# Different eta should give different results
assert not torch.allclose(result1, result2)

def test_hyperspherical_descent_eta_zero(self):
"""hyperspherical_descent with eta=0 should return normalized input."""
W = torch.randn(10) * 5 # Non-unit vector
G = torch.randn(10)
result = hyperspherical_descent(W, G, eta=0.0)
# With eta=0, result should just be normalized W
expected = W / W.norm()
assert torch.allclose(result, expected, atol=1e-5)

def test_hyperspherical_descent_deterministic(self):
"""hyperspherical_descent should be deterministic."""
torch.manual_seed(42)
W = torch.randn(10)
G = torch.randn(10)
result1 = hyperspherical_descent(W, G)

torch.manual_seed(42)
W = torch.randn(10)
G = torch.randn(10)
result2 = hyperspherical_descent(W, G)

assert torch.allclose(result1, result2)

def test_hyperspherical_descent_gradient_descent(self):
"""Result should move in direction of negative gradient."""
W = torch.randn(10)
G = torch.randn(10)
result = hyperspherical_descent(W, G, eta=1.0)
# Direction should have negative correlation with gradient
diff = result - W
correlation = torch.dot(diff, G)
# If moving in gradient direction, correlation should be positive
# But we move in opposite direction of projection
assert correlation < 0 # Moving against gradient

def test_hyperspherical_descent_tensor_input(self):
"""hyperspherical_descent should accept torch tensors."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
assert isinstance(W, torch.Tensor) and isinstance(G, torch.Tensor)
result = hyperspherical_descent(W, G)
assert isinstance(result, torch.Tensor)

def test_hyperspherical_descent_large_eta(self):
"""hyperspherical_descent should handle large eta."""
W = torch.randn(10)
G = torch.randn(10)
result = hyperspherical_descent(W, G, eta=10.0)
assert not torch.isnan(result).any() and not torch.isinf(result).any()
42 changes: 42 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for main.py entry points."""

import pytest
import torch


class TestHypersphericalDescent:
"""Test hyperspherical descent optimizer."""

def test_hyperspherical_descent_preserves_shape(self):
"""Should preserve input shape."""
from src.hyperspherical_descent import hyperspherical_descent
for shape in [(4,), (8,), (16,)]:
W = torch.randn(*shape)
G = torch.randn(*shape)
result = hyperspherical_descent(W, G)
assert result.shape == shape

def test_hyperspherical_descent_unit_norm(self):
"""Result should have unit norm."""
from src.hyperspherical_descent import hyperspherical_descent
W = torch.randn(8)
G = torch.randn(8)
result = hyperspherical_descent(W, G)
assert torch.allclose(result.norm(), torch.tensor(1.0), atol=1e-5)

def test_hyperspherical_descent_no_nan(self):
"""Should not produce NaN."""
from src.hyperspherical_descent import hyperspherical_descent
W = torch.randn(8)
G = torch.randn(8)
result = hyperspherical_descent(W, G)
assert not torch.isnan(result).any()

def test_hyperspherical_descent_2d(self):
"""Should work with 2D tensors."""
from src.hyperspherical_descent import hyperspherical_descent
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = hyperspherical_descent(W, G)
assert result.shape == (4, 4)
assert not torch.isnan(result).any()
110 changes: 110 additions & 0 deletions tests/test_manifold_muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import torch
import math
from src.manifold_muon import manifold_muon


class TestManifoldMuon:
@pytest.fixture
def seed(self):
torch.manual_seed(42)
yield
torch.manual_seed(42)

def test_manifold_muon_preserves_shape(self):
"""manifold_muon should preserve input shape."""
for shape in [(4, 4), (4, 3), (3, 4), (5, 2), (2, 5)]:
W = torch.randn(*shape)
G = torch.randn(*shape)
result = manifold_muon(W, G)
assert result.shape == shape

def test_manifold_muon_wide_matrix(self):
"""manifold_muon should handle wide matrices."""
W = torch.randn(3, 5)
G = torch.randn(3, 5)
result = manifold_muon(W, G)
assert result.shape == (3, 5)

def test_manifold_muon_tall_matrix(self):
"""manifold_muon should handle tall matrices."""
W = torch.randn(5, 3)
G = torch.randn(5, 3)
result = manifold_muon(W, G)
assert result.shape == (5, 3)

def test_manifold_muon_no_nan(self):
"""manifold_muon should not produce NaN values."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = manifold_muon(W, G, steps=10)
assert not torch.isnan(result).any()

def test_manifold_muon_no_inf(self):
"""manifold_muon should not produce Inf values."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = manifold_muon(W, G, steps=10)
assert not torch.isinf(result).any()

def test_manifold_muon_convergence(self):
"""manifold_muon should converge to a stationary point."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = manifold_muon(W, G, steps=100, tol=1e-6)
# Check that final result is on manifold (W.T @ W = I)
W_result = result
metric = W_result.T @ W_result
identity = torch.eye(W_result.shape[1])
assert torch.allclose(metric, identity, atol=1e-3)

def test_manifold_muon_custom_eta(self):
"""manifold_muon should respect eta parameter."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result1 = manifold_muon(W, G, eta=0.01)
result2 = manifold_muon(W, G, eta=1.0)
# Different eta should give different results
assert not torch.allclose(result1, result2)

def test_manifold_muon_custom_alpha(self):
"""manifold_muon should respect alpha parameter."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result1 = manifold_muon(W, G, alpha=0.001)
result2 = manifold_muon(W, G, alpha=0.1)
# Different alpha should give different results
assert not torch.allclose(result1, result2)

def test_manifold_muon_custom_steps(self):
"""manifold_muon should respect steps parameter."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result1 = manifold_muon(W, G, steps=5)
result2 = manifold_muon(W, G, steps=50)
# More steps should give different results
assert not torch.allclose(result1, result2)

def test_manifold_muon_result_is_orthogonal(self):
"""Result should be orthogonal (columns are orthonormal)."""
W = torch.randn(4, 3)
G = torch.randn(4, 3)
result = manifold_muon(W, G, steps=50)
# Check that columns are orthonormal
metric = result.T @ result
assert torch.allclose(metric, torch.eye(3), atol=1e-3)

def test_manifold_muon_tensor_input(self):
"""manifold_muon should accept torch tensors."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
assert isinstance(W, torch.Tensor) and isinstance(G, torch.Tensor)
result = manifold_muon(W, G)
assert isinstance(result, torch.Tensor)

def test_manifold_muon_square_matrix(self):
"""manifold_muon should work with square matrices."""
W = torch.randn(4, 4)
G = torch.randn(4, 4)
result = manifold_muon(W, G)
assert result.shape == (4, 4)
43 changes: 43 additions & 0 deletions tests/test_manifold_muon_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Additional tests for manifold_muon."""

import pytest
import torch


class TestManifoldMuon:
"""Test manifold muon optimizer."""

def test_manifold_muon_preserves_shape(self):
"""Should preserve input shape."""
from src.manifold_muon import manifold_muon
# Use tall matrices (rows > cols)
W = torch.randn(8, 4)
G = torch.randn(8, 4)
result = manifold_muon(W, G)
assert result.shape == (8, 4)

def test_manifold_muon_wide_matrix(self):
"""Should handle wide matrices (cols > rows)."""
from src.manifold_muon import manifold_muon
W = torch.randn(4, 8)
G = torch.randn(4, 8)
result = manifold_muon(W, G)
assert result.shape == (4, 8)

def test_manifold_muon_no_nan(self):
"""Should not produce NaN."""
from src.manifold_muon import manifold_muon
W = torch.randn(8, 4)
G = torch.randn(8, 4)
result = manifold_muon(W, G)
assert not torch.isnan(result).any()

def test_manifold_muon_orthogonality(self):
"""Result columns should be approximately orthonormal."""
from src.manifold_muon import manifold_muon
W = torch.randn(8, 4)
G = torch.randn(8, 4)
result = manifold_muon(W, G, steps=50)
# Check orthonormal: Q^T @ Q should be close to identity
QtQ = result.T @ result
assert torch.allclose(QtQ, torch.eye(4), atol=0.1)
Loading