1919from torch .distributions .exp_family import ExponentialFamily
2020from torch .distributions .utils import broadcast_all , lazy_property
2121from torch .types import _size
22- from torch .distributions .distribution import Distribution
2322
2423default_size = torch .Size ()
2524
@@ -1041,11 +1040,11 @@ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor:
10411040def _eval_poly (y : torch .Tensor , coef : torch .Tensor ) -> torch .Tensor :
10421041 """
10431042 Evaluate a polynomial at given points.
1044-
1043+
10451044 Args:
10461045 y: Input tensor.
10471046 coeffs: Polynomial coefficients.
1048-
1047+
10491048 Returns:
10501049 Evaluated polynomial tensor.
10511050 """
@@ -1108,7 +1107,7 @@ def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
11081107 Args:
11091108 x: Input tensor, must be positive.
11101109 order: Order of the Bessel function (0 or 1).
1111-
1110+
11121111 Returns:
11131112 Logarithm of the Bessel function.
11141113 """
@@ -1133,20 +1132,17 @@ def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
11331132
11341133@torch .jit .script_if_tracing
11351134def _rejection_sample (
1136- loc : torch .Tensor ,
1137- concentration : torch .Tensor ,
1138- proposal_r : torch .Tensor ,
1139- x : torch .Tensor
1135+ loc : torch .Tensor , concentration : torch .Tensor , proposal_r : torch .Tensor , x : torch .Tensor
11401136) -> torch .Tensor :
11411137 """
11421138 Perform rejection sampling for the von Mises distribution.
1143-
1139+
11441140 Args:
11451141 loc: Location parameter.
11461142 concentration: Concentration parameter.
11471143 proposal_r: Precomputed proposal parameter.
11481144 x: Tensor to fill with samples.
1149-
1145+
11501146 Returns:
11511147 Tensor of samples.
11521148 """
@@ -1165,9 +1161,7 @@ def _rejection_sample(
11651161
11661162
11671163class VonMises (Distribution ):
1168- """
1169- Von Mises distribution class for circular data.
1170- """
1164+ """Von Mises distribution class for circular data."""
11711165
11721166 arg_constraints = {
11731167 "loc" : constraints .real ,
@@ -1181,33 +1175,37 @@ def __init__(
11811175 loc : torch .Tensor ,
11821176 concentration : torch .Tensor ,
11831177 validate_args : bool = None ,
1184- ):
1178+ ) -> None :
1179+ """
1180+ Args:
1181+ loc: loc parameter of the distribution.
1182+ concentration: concentration parameter of the distribution.
1183+ validate_args: If True, checks the distribution parameters for validity.
1184+ """
11851185 self .loc , self .concentration = broadcast_all (loc , concentration )
11861186 batch_shape = self .loc .shape
11871187 super ().__init__ (batch_shape , torch .Size (), validate_args )
1188-
1188+
11891189 @lazy_property
11901190 @torch .no_grad ()
11911191 def _proposal_r (self ) -> torch .Tensor :
1192- """
1193- Compute the proposal parameter for sampling.
1194- """
1192+ """Compute the proposal parameter for sampling."""
11951193 kappa = self ._concentration
11961194 tau = 1 + (1 + 4 * kappa ** 2 ).sqrt ()
11971195 rho = (tau - (2 * tau ).sqrt ()) / (2 * kappa )
11981196 _proposal_r = (1 + rho ** 2 ) / (2 * rho )
1199-
1197+
12001198 # second order Taylor expansion around 0 for small kappa
12011199 _proposal_r_taylor = 1 / kappa + kappa
12021200 return torch .where (kappa < 1e-5 , _proposal_r_taylor , _proposal_r )
12031201
1204- def log_prob (self , value ) :
1202+ def log_prob (self , value : torch . Tensor ) -> torch . Tensor :
12051203 """
12061204 Compute the log probability of the given value.
12071205
12081206 Args:
12091207 value: Tensor of values.
1210-
1208+
12111209 Returns:
12121210 Tensor of log probabilities.
12131211 """
@@ -1218,15 +1216,15 @@ def log_prob(self, value):
12181216 return log_prob
12191217
12201218 @lazy_property
1221- def _loc (self ):
1219+ def _loc (self ) -> torch . Tensor :
12221220 return self .loc .to (torch .double )
12231221
12241222 @lazy_property
1225- def _concentration (self ):
1223+ def _concentration (self ) -> torch . Tensor :
12261224 return self .concentration .to (torch .double )
1227-
1225+
12281226 @torch .no_grad ()
1229- def sample (self , sample_shape = torch .Size ()) :
1227+ def sample (self , sample_shape : _size = default_size ) -> torch .Tensor :
12301228 """
12311229 The sampling algorithm for the von Mises distribution is based on the
12321230 following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
@@ -1238,33 +1236,33 @@ def sample(self, sample_shape=torch.Size()):
12381236 """
12391237 shape = self ._extended_shape (sample_shape )
12401238 x = torch .empty (shape , dtype = self ._loc .dtype , device = self .loc .device )
1241- return _rejection_sample (
1242- self ._loc , self ._concentration , self ._proposal_r , x
1243- ).to (self .loc .dtype )
1239+ return _rejection_sample (self ._loc , self ._concentration , self ._proposal_r , x ).to (self .loc .dtype )
12441240
1245- def rsample (self , sample_shape = torch .Size ()):
1246- """
1247- Generate reparameterized samples from the distribution.
1248- """
1241+ def rsample (self , sample_shape : _size = default_size ) -> torch .Tensor :
1242+ """Generate reparameterized samples from the distribution"""
12491243 shape = self ._extended_shape (sample_shape )
12501244 samples = _VonMisesSampler .apply (self .concentration , self ._proposal_r , shape )
12511245 samples = samples + self .loc
1252-
1246+
12531247 # Map the samples to [-pi, pi].
1254- return samples - 2. * torch .pi * torch .round (samples / (2. * torch .pi ))
1248+ return samples - 2.0 * torch .pi * torch .round (samples / (2.0 * torch .pi ))
12551249
12561250 @property
1257- def mean (self ):
1251+ def mean (self ) -> torch . Tensor :
12581252 """Mean of the distribution."""
12591253 return self .loc
12601254
12611255 @property
1262- def variance (self ):
1256+ def variance (self ) -> torch . Tensor :
12631257 """Variance of the distribution."""
1264- return 1 - (
1265- _log_modified_bessel_fn (self .concentration , order = 1 )
1266- - _log_modified_bessel_fn (self .concentration , order = 0 )
1267- ).exp ()
1258+ return (
1259+ 1
1260+ - (
1261+ _log_modified_bessel_fn (self .concentration , order = 1 )
1262+ - _log_modified_bessel_fn (self .concentration , order = 0 )
1263+ ).exp ()
1264+ )
1265+
12681266
12691267@torch .jit .script_if_tracing
12701268@torch .no_grad ()
@@ -1282,7 +1280,7 @@ def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, sh
12821280 """
12831281 x = torch .empty (shape , dtype = concentration .dtype , device = concentration .device )
12841282 done = torch .zeros (x .shape , dtype = torch .bool , device = concentration .device )
1285-
1283+
12861284 while not done .all ():
12871285 u = torch .rand ((3 ,) + x .shape , dtype = concentration .dtype , device = concentration .device )
12881286 u1 , u2 , u3 = u .unbind ()
@@ -1295,6 +1293,7 @@ def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, sh
12951293 done = done | accept
12961294 return x
12971295
1296+
12981297def cosxm1 (x : torch .Tensor ) -> torch .Tensor :
12991298 """
13001299 Compute cos(x) - 1 using a numerically stable formula.
@@ -1307,6 +1306,7 @@ def cosxm1(x: torch.Tensor) -> torch.Tensor:
13071306 """
13081307 return - 2 * torch .square (torch .sin (x / 2.0 ))
13091308
1309+
13101310class _VonMisesSampler (torch .autograd .Function ):
13111311 @staticmethod
13121312 def forward (
@@ -1329,7 +1329,7 @@ def forward(
13291329 """
13301330 samples = _rejection_rsample (concentration , proposal_r , shape )
13311331 ctx .save_for_backward (concentration , proposal_r , samples )
1332-
1332+
13331333 return samples
13341334
13351335 @staticmethod
@@ -1348,29 +1348,27 @@ def backward(
13481348 Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors.
13491349 """
13501350 concentration , proposal_r , samples = ctx .saved_tensors
1351-
1352- num_periods = torch .round (samples / (2. * torch .pi ))
1353- x_mapped = samples - (2. * torch .pi ) * num_periods
1354-
1355- ## Parameters from the paper
1351+
1352+ num_periods = torch .round (samples / (2.0 * torch .pi ))
1353+ x_mapped = samples - (2.0 * torch .pi ) * num_periods
1354+
1355+ # Parameters from the paper
13561356 ck = 10.5
13571357 num_terms = 20
1358-
1359- ## Compute series and normal approximation
1358+
1359+ # Compute series and normal approximation
13601360 cdf_series , dcdf_dconcentration_series = von_mises_cdf_series (x_mapped , concentration , num_terms )
13611361 cdf_normal , dcdf_dconcentration_normal = von_mises_cdf_normal (x_mapped , concentration )
13621362 use_series = concentration < ck
1363- cdf = torch .where (use_series , cdf_series , cdf_normal ) + num_periods
1363+ # cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
13641364 dcdf_dconcentration = torch .where (use_series , dcdf_dconcentration_series , dcdf_dconcentration_normal )
1365-
1366- ## Compute CDF gradient terms
1367- inv_prob = torch .exp (concentration * cosxm1 (samples )) / (
1368- 2 * math .pi * torch .special .i0e (concentration )
1369- )
1370- grad_concentration = grad_output * (- dcdf_dconcentration / inv_prob )
1371-
1365+
1366+ # Compute CDF gradient terms
1367+ inv_prob = torch .exp (concentration * cosxm1 (samples )) / (2 * math .pi * torch .special .i0e (concentration ))
1368+ grad_concentration = grad_output * (- dcdf_dconcentration / inv_prob )
1369+
13721370 return grad_concentration , None , None
1373-
1371+
13741372
13751373def von_mises_cdf_series (
13761374 x : torch .Tensor , concentration : torch .Tensor , num_terms : int
@@ -1394,25 +1392,26 @@ def von_mises_cdf_series(
13941392 drn_dconcentration = torch .zeros_like (x )
13951393
13961394 while n > 0 :
1397- denominator = 2. * n / concentration + rn
1398- ddenominator_dk = - 2. * n / concentration ** 2 + drn_dconcentration
1399- rn = 1. / denominator
1400- drn_dconcentration = - ddenominator_dk / denominator ** 2
1395+ denominator = 2.0 * n / concentration + rn
1396+ ddenominator_dk = - 2.0 * n / concentration ** 2 + drn_dconcentration
1397+ rn = 1.0 / denominator
1398+ drn_dconcentration = - ddenominator_dk / denominator ** 2
14011399
14021400 multiplier = torch .sin (n * x ) / n + vn
14031401 vn = rn * multiplier
1404- dvn_dconcentration = ( drn_dconcentration * multiplier + rn * dvn_dconcentration )
1405-
1402+ dvn_dconcentration = drn_dconcentration * multiplier + rn * dvn_dconcentration
1403+
14061404 n -= 1
14071405
1408- cdf = 0.5 + x / (2. * torch .pi ) + vn / torch .pi
1406+ cdf = 0.5 + x / (2.0 * torch .pi ) + vn / torch .pi
14091407 dcdf_dconcentration = dvn_dconcentration / torch .pi
14101408
1411- cdf_clipped = torch .clamp (cdf , 0. , 1. )
1412- dcdf_dconcentration *= (cdf >= 0. ) & (cdf <= 1. )
1409+ cdf_clipped = torch .clamp (cdf , 0.0 , 1.0 )
1410+ dcdf_dconcentration *= (cdf >= 0.0 ) & (cdf <= 1.0 )
14131411
14141412 return cdf_clipped , dcdf_dconcentration
1415-
1413+
1414+
14161415def cdf_func (concentration : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
14171416 """
14181417 Approximate the CDF of the von Mises distribution.
@@ -1424,32 +1423,26 @@ def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
14241423 Returns:
14251424 torch.Tensor: Approximate CDF values.
14261425 """
1427-
14281426 # Calculate the z value based on the approximation
1429- z = (torch .sqrt (torch .tensor (2. / torch .pi )) / torch .special .i0e (concentration )) * torch .sin (0.5 * x )
1427+ z = (torch .sqrt (torch .tensor (2.0 / torch .pi )) / torch .special .i0e (concentration )) * torch .sin (0.5 * x )
14301428 # Apply corrections to z to improve the approximation
1431- z2 = z ** 2
1429+ z2 = z ** 2
14321430 z3 = z2 * z
1433- z4 = z2 ** 2
1434- c = 24. * concentration
1435- c1 = 56.
1431+ z4 = z2 ** 2
1432+ c = 24.0 * concentration
1433+ c1 = 56.0
14361434
1437- xi = z - z3 / (
1438- ((c - 2. * z2 - 16. ) / 3. ) -
1439- (z4 + (7. / 4. ) * z2 + 167. / 2. ) / (c - c1 - z2 + 3. )
1440- ) ** 2
1435+ xi = z - z3 / (((c - 2.0 * z2 - 16.0 ) / 3.0 ) - (z4 + (7.0 / 4.0 ) * z2 + 167.0 / 2.0 ) / (c - c1 - z2 + 3.0 )) ** 2
14411436
14421437 # Use the standard normal distribution for the approximation
14431438 distrib = torch .distributions .Normal (
1444- torch .tensor (0. , dtype = x .dtype , device = x .device ),
1445- torch .tensor (1. , dtype = x .dtype , device = x .device )
1439+ torch .tensor (0.0 , dtype = x .dtype , device = x .device ), torch .tensor (1.0 , dtype = x .dtype , device = x .device )
14461440 )
1447-
1441+
14481442 return distrib .cdf (xi )
14491443
1450- def von_mises_cdf_normal (
1451- x : torch .Tensor , concentration : torch .Tensor
1452- ) -> Tuple [torch .Tensor , torch .Tensor ]:
1444+
1445+ def von_mises_cdf_normal (x : torch .Tensor , concentration : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
14531446 """
14541447 Compute the CDF of the von Mises distribution using a normal approximation.
14551448
@@ -1467,4 +1460,4 @@ def von_mises_cdf_normal(
14671460 dcdf_dconcentration = concentration_ .grad .clone () # Copy the gradient
14681461 # Detach gradients to prevent further autograd tracking
14691462 concentration_ .grad = None
1470- return cdf , dcdf_dconcentration
1463+ return cdf , dcdf_dconcentration
0 commit comments