Skip to content

Commit 739c56d

Browse files
committed
Merge pull request astropy#2619 from eteq/fix-rot-matrix
Fix quirks in rotation_matrix
2 parents fb46015 + 26f0552 commit 739c56d

File tree

3 files changed

+67
-41
lines changed

3 files changed

+67
-41
lines changed

CHANGES.rst

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ New Features
4141
- The deprecated functions for pre-0.3 coordinate object names like
4242
``ICRSCoordinates`` have been removed. [#2422]
4343

44+
- The ``rotation_matrix`` and ``angle_axis`` functions in
45+
``astropy.coordinates.angles`` were made more numerically consistent and
46+
are now tested explicitly [#2619]
47+
4448
- ``astropy.cosmology``
4549

4650
- Added ``z_at_value`` function to find the redshift at which a cosmology

astropy/coordinates/angles.py

+19-41
Original file line numberDiff line numberDiff line change
@@ -701,54 +701,41 @@ def rotation_matrix(angle, axis='z', unit=None):
701701
rmat: `numpy.matrix`
702702
A unitary rotation matrix.
703703
"""
704-
# TODO: This doesn't handle arrays of angles
705-
706704
if unit is None:
707705
unit = u.degree
708706

709707
angle = Angle(angle, unit=unit)
710708

709+
s = np.sin(angle)
710+
c = np.cos(angle)
711+
712+
# use optimized implementations for x/y/z
711713
if axis == 'z':
712-
s = np.sin(angle)
713-
c = np.cos(angle)
714714
return np.matrix(((c, s, 0),
715715
(-s, c, 0),
716716
(0, 0, 1)))
717717
elif axis == 'y':
718-
s = np.sin(angle)
719-
c = np.cos(angle)
720718
return np.matrix(((c, 0, -s),
721719
(0, 1, 0),
722720
(s, 0, c)))
723721
elif axis == 'x':
724-
s = np.sin(angle)
725-
c = np.cos(angle)
726722
return np.matrix(((1, 0, 0),
727723
(0, c, s),
728724
(0, -s, c)))
729725
else:
730-
x, y, z = axis
731-
w = np.cos(angle / 2)
732-
733-
# normalize
734-
if w == 1:
735-
x = y = z = 0
736-
else:
737-
l = np.sqrt((x * x + y * y + z * z) / (1 - w * w))
738-
x /= l
739-
y /= l
740-
z /= l
726+
axis = np.asarray(axis)
727+
axis = axis / np.sqrt((axis * axis).sum())
741728

742-
wsq = w * w
743-
xsq = x * x
744-
ysq = y * y
745-
zsq = z * z
746-
return np.matrix(((wsq + xsq - ysq - zsq, 2 * x * y - 2 * w * z, 2 * x * z + 2 * w * y),
747-
(2 * x * y + 2 * w * z, wsq - xsq + ysq - zsq, 2 * y * z - 2 * w * x),
748-
(2 * x * z - 2 * w * y, 2 * y * z + 2 * w * x, wsq - xsq - ysq + zsq)))
729+
R = np.diag((c, c, c))
730+
R += np.outer(axis, axis) * (1. - c)
731+
axis *= s
732+
R += np.array([[0., axis[2], -axis[1]],
733+
[-axis[2], 0., axis[0]],
734+
[axis[1], -axis[0], 0.]])
735+
return R.view(np.matrix)
749736

750737

751-
def angle_axis(matrix, unit=None):
738+
def angle_axis(matrix):
752739
"""
753740
Computes the angle of rotation and the rotation axis for a given rotation
754741
matrix.
@@ -758,29 +745,20 @@ def angle_axis(matrix, unit=None):
758745
matrix : array-like
759746
A 3 x 3 unitary rotation matrix.
760747
761-
unit : UnitBase
762-
The output unit. If `None`, the output unit is degrees.
763-
764748
Returns
765749
-------
766750
angle : `Angle`
767751
The angle of rotation for this matrix.
768752
769753
axis : array (length 3)
770-
The axis of rotation for this matrix.
754+
The (normalized) axis of rotation for this matrix.
771755
"""
772-
# TODO: This doesn't handle arrays of angles
773-
774756
m = np.asmatrix(matrix)
775757
if m.shape != (3, 3):
776758
raise ValueError('matrix is not 3x3')
777759

778-
angle = np.acos((m[0, 0] + m[1, 1] + m[2, 2] - 1) / 2)
779-
denom = np.sqrt(2 * ((m[2, 1] - m[1, 2]) + (m[0, 2] - m[2, 0]) + (m[1, 0] - m[0, 1])))
780-
axis = np.array((m[2, 1] - m[1, 2], m[0, 2] - m[2, 0], m[1, 0] - m[0, 1])) / denom
781-
axis /= np.sqrt(np.sum(axis ** 2))
760+
axis = np.array((m[2, 1] - m[1, 2], m[0, 2] - m[2, 0], m[1, 0] - m[0, 1]))
761+
r = np.sqrt((axis * axis).sum())
762+
angle = np.arctan2(r, np.trace(m) - 1)
782763

783-
angle = Angle(angle, u.radian)
784-
if unit is None:
785-
unit = u.degree
786-
return angle.to(unit), axis
764+
return Angle(angle, u.radian), -axis / r

astropy/coordinates/tests/test_angles.py

+44
Original file line numberDiff line numberDiff line change
@@ -753,3 +753,47 @@ def test_mixed_string_and_quantity():
753753
a2 = Angle(['1d', 1 * u.rad * np.pi, '3d'])
754754
assert_array_equal(a2.value, [1., 180., 3.])
755755
assert a2.unit == u.deg
756+
757+
def test_rotation_matrix():
758+
from ..angles import rotation_matrix
759+
760+
assert_array_equal(rotation_matrix(0*u.deg, 'x'), np.eye(3))
761+
762+
assert_allclose(rotation_matrix(90*u.deg, 'y'), [[ 0, 0,-1],
763+
[ 0, 1, 0],
764+
[ 1, 0, 0]], atol=1e-12)
765+
766+
assert_allclose(rotation_matrix(-90*u.deg, 'z'), [[ 0,-1, 0],
767+
[ 1, 0, 0],
768+
[ 0, 0, 1]], atol=1e-12)
769+
770+
assert_allclose(rotation_matrix(45*u.deg, 'x'),
771+
rotation_matrix(45*u.deg, [1, 0, 0]))
772+
assert_allclose(rotation_matrix(125*u.deg, 'y'),
773+
rotation_matrix(125*u.deg, [0, 1, 0]))
774+
assert_allclose(rotation_matrix(-30*u.deg, 'z'),
775+
rotation_matrix(-30*u.deg, [0, 0, 1]))
776+
777+
assert_allclose(np.dot(rotation_matrix(180*u.deg, [1, 1, 0]).A, [1, 0, 0]),
778+
[0, 1, 0], atol=1e-12)
779+
780+
#make sure it also works for very small angles
781+
assert_allclose(rotation_matrix(0.000001*u.deg, 'x'),
782+
rotation_matrix(0.000001*u.deg, [1, 0, 0]))
783+
784+
def test_angle_axis():
785+
from ..angles import rotation_matrix, angle_axis
786+
787+
m1 = rotation_matrix(35*u.deg, 'x')
788+
an1, ax1 = angle_axis(m1)
789+
790+
assert an1 - 35*u.deg < 1e-10*u.deg
791+
assert_allclose(ax1, [1, 0, 0])
792+
793+
794+
m2 = rotation_matrix(-89*u.deg, [1, 1, 0])
795+
an2, ax2 = angle_axis(m2)
796+
797+
assert an2 - 89*u.deg < 1e-10*u.deg
798+
assert_allclose(ax2, [-2**-0.5, -2**-0.5, 0])
799+

0 commit comments

Comments
 (0)