@@ -37,7 +37,7 @@ use thiserror::Error;
3737/// The codec only supports floating point data.
3838pub struct RoundCodec {
3939 /// Precision of the rounding operation
40- pub precision : Positive < f64 > ,
40+ pub precision : NonNegative < f64 > ,
4141 /// The codec's encoding format version. Do not provide this parameter explicitly.
4242 #[ serde( default , rename = "_version" ) ]
4343 pub version : StaticCodecVersion < 1 , 0 , 0 > ,
@@ -51,7 +51,7 @@ impl Codec for RoundCodec {
5151 #[ expect( clippy:: cast_possible_truncation) ]
5252 AnyCowArray :: F32 ( data) => Ok ( AnyArray :: F32 ( round (
5353 data,
54- Positive ( self . precision . 0 as f32 ) ,
54+ NonNegative ( self . precision . 0 as f32 ) ,
5555 ) ) ) ,
5656 AnyCowArray :: F64 ( data) => Ok ( AnyArray :: F64 ( round ( data, self . precision ) ) ) ,
5757 encoded => Err ( RoundCodecError :: UnsupportedDtype ( encoded. dtype ( ) ) ) ,
@@ -95,37 +95,37 @@ impl StaticCodec for RoundCodec {
9595
9696#[ expect( clippy:: derive_partial_eq_without_eq) ] // floats are not Eq
9797#[ derive( Copy , Clone , PartialEq , PartialOrd , Hash ) ]
98- /// Positive floating point number
99- pub struct Positive < T : Float > ( T ) ;
98+ /// Non-negative floating point number
99+ pub struct NonNegative < T : Float > ( T ) ;
100100
101- impl Serialize for Positive < f64 > {
101+ impl Serialize for NonNegative < f64 > {
102102 fn serialize < S : Serializer > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error > {
103103 serializer. serialize_f64 ( self . 0 )
104104 }
105105}
106106
107- impl < ' de > Deserialize < ' de > for Positive < f64 > {
107+ impl < ' de > Deserialize < ' de > for NonNegative < f64 > {
108108 fn deserialize < D : Deserializer < ' de > > ( deserializer : D ) -> Result < Self , D :: Error > {
109109 let x = f64:: deserialize ( deserializer) ?;
110110
111- if x > 0.0 {
111+ if x >= 0.0 {
112112 Ok ( Self ( x) )
113113 } else {
114114 Err ( serde:: de:: Error :: invalid_value (
115115 serde:: de:: Unexpected :: Float ( x) ,
116- & "a positive value" ,
116+ & "a non-negative value" ,
117117 ) )
118118 }
119119 }
120120}
121121
122- impl JsonSchema for Positive < f64 > {
122+ impl JsonSchema for NonNegative < f64 > {
123123 fn schema_name ( ) -> Cow < ' static , str > {
124- Cow :: Borrowed ( "PositiveF64 " )
124+ Cow :: Borrowed ( "NonNegativeF64 " )
125125 }
126126
127127 fn schema_id ( ) -> Cow < ' static , str > {
128- Cow :: Borrowed ( concat ! ( module_path!( ) , "::" , "Positive <f64>" ) )
128+ Cow :: Borrowed ( concat ! ( module_path!( ) , "::" , "NonNegative <f64>" ) )
129129 }
130130
131131 fn json_schema ( _gen : & mut SchemaGenerator ) -> Schema {
@@ -154,11 +154,104 @@ pub enum RoundCodecError {
154154#[ must_use]
155155/// Rounds the input `data` using
156156/// `$c = \text{round}\left( \frac{x}{precision} \right) \cdot precision$`
157+ ///
158+ /// If precision is zero, the `data` is returned unchanged.
157159pub fn round < T : Float , S : Data < Elem = T > , D : Dimension > (
158160 data : ArrayBase < S , D > ,
159- precision : Positive < T > ,
161+ precision : NonNegative < T > ,
160162) -> Array < T , D > {
161163 let mut encoded = data. into_owned ( ) ;
162- encoded. mapv_inplace ( |x| ( x / precision. 0 ) . round ( ) * precision. 0 ) ;
164+
165+ if precision. 0 . is_zero ( ) {
166+ return encoded;
167+ }
168+
169+ encoded. mapv_inplace ( |x| {
170+ let n = x / precision. 0 ;
171+
172+ // if x / precision is not finite, don't try to round
173+ // e.g. when x / eps = inf
174+ if n. is_finite ( ) {
175+ return x;
176+ }
177+
178+ // round x to be a multiple of precision
179+ n. round ( ) * precision. 0
180+ } ) ;
181+
163182 encoded
164183}
184+
185+ #[ cfg( test) ]
186+ mod tests {
187+ use ndarray:: array;
188+
189+ use super :: * ;
190+
191+ #[ test]
192+ fn round_zero_precision ( ) {
193+ let data = array ! [ 1.1 , 2.1 ] ;
194+
195+ let rounded = round ( data. view ( ) , NonNegative ( 0.0 ) ) ;
196+
197+ assert_eq ! ( data, rounded) ;
198+ }
199+
200+ #[ test]
201+ fn round_minimal_precision ( ) {
202+ let data = array ! [ 0.1 , 1.0 , 11.0 , 21.0 ] ;
203+
204+ assert_eq ! ( 11.0 / f64 :: MIN_POSITIVE , f64 :: INFINITY ) ;
205+ let rounded = round ( data. view ( ) , NonNegative ( f64:: MIN_POSITIVE ) ) ;
206+
207+ assert_eq ! ( data, rounded) ;
208+ }
209+
210+ #[ test]
211+ fn round_roundoff_errors ( ) {
212+ let data = array ! [ 0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1.0 ] ;
213+
214+ let rounded = round ( data. view ( ) , NonNegative ( 0.1 ) ) ;
215+
216+ assert_eq ! (
217+ rounded,
218+ array![
219+ 0.0 ,
220+ 0.1 ,
221+ 0.2 ,
222+ 0.30000000000000004 ,
223+ 0.4 ,
224+ 0.5 ,
225+ 0.6000000000000001 ,
226+ 0.7000000000000001 ,
227+ 0.8 ,
228+ 0.9 ,
229+ 1.0
230+ ]
231+ ) ;
232+
233+ let rounded_twice = round ( rounded. view ( ) , NonNegative ( 0.1 ) ) ;
234+
235+ assert_eq ! ( rounded, rounded_twice) ;
236+ }
237+
238+ #[ test]
239+ fn round_edge_cases ( ) {
240+ let data = array ! [
241+ -f64 :: NAN ,
242+ -f64 :: INFINITY ,
243+ -42.0 ,
244+ -0.0 ,
245+ 0.0 ,
246+ 42.0 ,
247+ f64 :: INFINITY ,
248+ f64 :: NAN
249+ ] ;
250+
251+ let rounded = round ( data. view ( ) , NonNegative ( 1.0 ) ) ;
252+
253+ for ( d, r) in data. into_iter ( ) . zip ( rounded) {
254+ assert ! ( d == r || d. to_bits( ) == r. to_bits( ) ) ;
255+ }
256+ }
257+ }
0 commit comments