@@ -797,6 +797,7 @@ class exponential(BaseDist_Mixin):
797797 numpy.random.exponential
798798
799799 """
800+
800801 dist = stats .expon
801802 param_template = namedtuple ('params' , ['lamda' , 'loc' ])
802803 name = 'exponential'
@@ -833,7 +834,7 @@ class rice(BaseDist_Mixin):
833834 R : float
834835 The shape parameter of the distribution.
835836 sigma : float
836- The standard deviate of the distribution.
837+ The standard deviation of the distribution.
837838 loc : float, optional
838839 Location parameter of the distribution. This defaults to, and
839840 should probably be left at, 0.
@@ -879,6 +880,7 @@ class rice(BaseDist_Mixin):
879880 numpy.random.exponential
880881
881882 """
883+
882884 dist = stats .rice
883885 param_template = namedtuple ('params' , ['R' , 'sigma' , 'loc' ])
884886 name = 'rice'
@@ -904,6 +906,114 @@ def fit(cls, data, **guesses):
904906 return cls .param_template (R = b * sigma , loc = loc , sigma = sigma )
905907
906908
909+ class truncated_normal (BaseDist_Mixin ):
910+ """
911+ Create and fit data to a truncated normal distribution.
912+
913+ Methods
914+ -------
915+ fit
916+ Use scipy's maximum likelihood estimation methods to estimate
917+ the parameters of the data's distribution.
918+ from_params
919+ Create a new distribution instances from the ``namedtuple``
920+ result of the :meth:`~fit` method.
921+
922+ Parameters
923+ ----------
924+ lower, upper : float
925+ The lower and upper limits of the distribution that serve as its
926+ shape parameters.
927+ mu : float, optional (default = 0)
928+ The expected value (mean) of the underlying normal distribution.
929+ Acts as the location parameter of the distribution.
930+ sigma : float, optional (default = 1)
931+ The standard deviation of the underlying normal distribution.
932+ Also acts as the scale parameter of distribution.
933+
934+ Examples
935+ --------
936+ >>> import numpy
937+ >>> import paramnormal as pn
938+ >>> numpy.random.seed(0)
939+ >>> pn.truncated_normal(lower=-0.5, upper=0.5).rvs(size=3)
940+ array([ 0.04687082, 0.20804061, 0.09879796])
941+
942+ >>> # you can also use greek letters
943+ >>> numpy.random.seed(0)
944+ >>> pn.truncated_normal(lower=-0.5, upper=2.5, σ=2).rvs(size=3)
945+ array([ 0.8902748 , 1.37377049, 1.04012565])
946+
947+ >>> # silly fake data
948+ >>> numpy.random.seed(0)
949+ >>> data = pn.truncated_normal(lower=-0.5, upper=2.5, mu=0, sigma=2).rvs(size=37)
950+ >>> # pretend `data` is unknown and we want to fit a dist. to it
951+ >>> pn.truncated_normal.fit(data)
952+ params(lower=1.040124, upper=1.082447, mu=-8.097877e-06, sigma=1.033405)
953+
954+ In scipy, the distribution is defined as
955+ ``stats.truncnorm(a, b, loc, scale)`` where
956+
957+ .. math::
958+
959+ a = \f rac{\mathrm{lower bound}} - \mu}{\sigma}
960+
961+ and
962+
963+ .. math::
964+
965+ b = \f rac{x_{\mathrm{upper bound}} - \mu}{\sigma}
966+
967+ Since ``a`` and ``b`` are directly linked to the location and scale
968+ of the distribution as well as the lower and upper limits,
969+ respectively, it's difficult to use the ``fit`` method of this
970+ distirbution without either knowing a lot about it `a priori` or
971+ assuming just as much.
972+
973+ References
974+ ----------
975+ http://scipy.github.io/devdocs/generated/scipy.stats.truncnorm
976+ https://en.wikipedia.org/wiki/Rice_distribution
977+
978+ See Also
979+ --------
980+ scipy.stats.rice
981+ numpy.random.exponential
982+
983+ """
984+
985+ dist = stats .truncnorm
986+ param_template = namedtuple ('params' , ['lower' , 'upper' , 'mu' , 'sigma' ])
987+ name = 'truncated normal'
988+
989+ @staticmethod
990+ @utils .greco_deco
991+ def _process_args (lower = None , upper = None , mu = None , sigma = None , fit = False ):
992+ a = None
993+ b = None
994+ if lower is not None and mu is not None and sigma is not None :
995+ a = (lower - mu ) / sigma
996+
997+ if upper is not None and mu is not None and sigma is not None :
998+ b = (upper - mu ) / sigma
999+
1000+ loc_key , scale_key = utils ._get_loc_scale_keys (fit = fit )
1001+ if fit :
1002+ akey = 'f0'
1003+ bkey = 'f1'
1004+ else :
1005+ akey = 'a'
1006+ bkey = 'b'
1007+ return {akey : a , bkey : b , loc_key : mu , scale_key : sigma }
1008+
1009+ @classmethod
1010+ def fit (cls , data , ** guesses ):
1011+ a , b , mu , sigma = cls ._fit (data , ** guesses )
1012+ lower = a * sigma + mu
1013+ upper = b * sigma + mu
1014+ return cls .param_template (lower = lower , upper = upper , mu = mu , sigma = sigma )
1015+
1016+
9071017__all__ = [
9081018 'normal' ,
9091019 'lognormal' ,
@@ -915,4 +1025,5 @@ def fit(cls, data, **guesses):
9151025 'pareto' ,
9161026 'exponential' ,
9171027 'rice' ,
1028+ 'truncated_normal' ,
9181029]
0 commit comments