@@ -360,17 +360,22 @@ impl<'a> VQSKMeansBuilder<'a> {
360
360
VQSKMeansBuilderError :: UninitializedFieldError ( "shallow_levels" . to_string ( ) )
361
361
} ) ?;
362
362
Self :: validate_shallow_levels ( shallow_levels) ?;
363
- let max_levels = self . max_levels . ok_or_else ( || {
364
- VQSKMeansBuilderError :: UninitializedFieldError ( "max_levels" . to_string ( ) )
365
- } ) ?;
366
- Self :: validate_max_levels ( shallow_levels, max_levels) ?;
363
+ // let max_levels = self.max_levels.ok_or_else(|| {
364
+ // VQSKMeansBuilderError::UninitializedFieldError("max_levels".to_string())
365
+ // })?;
366
+ let max_levels = match self . max_levels {
367
+ Some ( max_levels) => * max_levels,
368
+ None => Self :: calculate_max_levels ( shallow_levels, nclusters) ,
369
+ } ;
370
+ Self :: validate_max_levels ( shallow_levels, & max_levels) ?;
371
+
367
372
let dz_bottom_min = self . dz_bottom_min . ok_or_else ( || {
368
373
VQSKMeansBuilderError :: UninitializedFieldError ( "dz_bottom_min" . to_string ( ) )
369
374
} ) ?;
370
375
let mut hsm = kmeans_hsm ( hgrid, nclusters, etal) ?;
371
376
hsm. iter_mut ( ) . for_each ( |depth| * depth = depth. abs ( ) ) ;
372
377
let mut nlevels = Vec :: < usize > :: with_capacity ( * nclusters) ;
373
- let levels = Array :: linspace ( * shallow_levels as f64 , * max_levels as f64 , * nclusters) ;
378
+ let levels = Array :: linspace ( * shallow_levels as f64 , max_levels as f64 , * nclusters) ;
374
379
for level in levels. iter ( ) {
375
380
let mut level = level. round ( ) as usize ;
376
381
if level < * shallow_levels {
@@ -421,9 +426,13 @@ impl<'a> VQSKMeansBuilder<'a> {
421
426
}
422
427
Ok ( ( ) )
423
428
}
429
+
430
+ fn calculate_max_levels ( shallow_levels : & usize , clusters : & usize ) -> usize {
431
+ shallow_levels + clusters - 1
432
+ }
424
433
fn validate_max_levels (
425
- shallow_levels : & ' a usize ,
426
- max_levels : & ' a usize ,
434
+ shallow_levels : & usize ,
435
+ max_levels : & usize ,
427
436
) -> Result < ( ) , VQSKMeansBuilderError > {
428
437
if * shallow_levels > * max_levels {
429
438
return Err ( VQSKMeansBuilderError :: InvalidMaxLevels (
@@ -484,12 +493,17 @@ impl<'a> VQSAutoBuilder<'a> {
484
493
VQSAutoBuilderError :: UninitializedFieldError ( "shallow_levels" . to_string ( ) )
485
494
} ) ?;
486
495
Self :: validate_shallow_levels ( shallow_levels) ?;
487
- let max_levels = self . max_levels . ok_or_else ( || {
488
- VQSAutoBuilderError :: UninitializedFieldError ( "max_levels" . to_string ( ) )
489
- } ) ?;
490
- Self :: validate_max_levels ( shallow_levels, max_levels) ?;
496
+ let max_levels = match self . max_levels {
497
+ Some ( max_levels) => * max_levels,
498
+ None => Self :: calculate_max_levels ( shallow_levels, ngrids) ,
499
+ } ;
500
+ Self :: validate_max_levels ( shallow_levels, & max_levels) ?;
501
+ // let max_levels = self.max_levels.ok_or_else(|| {
502
+ // VQSAutoBuilderError::UninitializedFieldError("max_levels".to_string())
503
+ // })?;
504
+ // Self::validate_max_levels(shallow_levels, max_levels)?;
491
505
let ( hsm, nlevels) =
492
- Self :: build_hsm_and_nlevels ( hgrid, ngrids, initial_depth, shallow_levels, max_levels) ?;
506
+ Self :: build_hsm_and_nlevels ( hgrid, ngrids, initial_depth, shallow_levels, & max_levels) ?;
493
507
Ok ( VQSBuilder :: default ( )
494
508
. hgrid ( & hgrid)
495
509
. depths ( & hsm)
@@ -525,6 +539,10 @@ impl<'a> VQSAutoBuilder<'a> {
525
539
Ok ( ( ) )
526
540
}
527
541
542
+ fn calculate_max_levels ( shallow_levels : & usize , clusters : & usize ) -> usize {
543
+ shallow_levels + clusters - 1
544
+ }
545
+
528
546
fn exponential_samples ( start : f64 , end : f64 , steps : usize ) -> Vec < f64 > {
529
547
let mut samples = Vec :: with_capacity ( steps) ;
530
548
let scale = ( end / start) . powf ( 1.0 / ( steps as f64 - 1.0 ) ) ;
0 commit comments