@@ -128,7 +128,6 @@ var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com")
128128// All of the methods on Reader are thread-safe. The struct may be safely
129129// shared across goroutines.
130130type Reader struct {
131- nodeReader nodeReader
132131 buffer []byte
133132 decoder decoder.ReflectionDecoder
134133 Metadata Metadata
@@ -312,25 +311,8 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) {
312311 buffer [searchTreeSize + dataSectionSeparatorSize : metadataStart - len (metadataStartMarker )],
313312 )
314313
315- nodeBuffer := buffer [:searchTreeSize ]
316- var nodeReader nodeReader
317- switch metadata .RecordSize {
318- case 24 :
319- nodeReader = nodeReader24 {buffer : nodeBuffer }
320- case 28 :
321- nodeReader = nodeReader28 {buffer : nodeBuffer }
322- case 32 :
323- nodeReader = nodeReader32 {buffer : nodeBuffer }
324- default :
325- return nil , mmdberrors .NewInvalidDatabaseError (
326- "unknown record size: %d" ,
327- metadata .RecordSize ,
328- )
329- }
330-
331314 reader := & Reader {
332315 buffer : buffer ,
333- nodeReader : nodeReader ,
334316 decoder : d ,
335317 Metadata : metadata ,
336318 ipv4Start : 0 ,
@@ -394,7 +376,7 @@ func (r *Reader) setIPv4Start() {
394376 node := uint (0 )
395377 i := 0
396378 for ; i < 96 && node < nodeCount ; i ++ {
397- node = r . nodeReader . readLeft ( node * r .nodeOffsetMult )
379+ node = readNodeBySize ( r . buffer , node * r .nodeOffsetMult , 0 , r . Metadata . RecordSize )
398380 }
399381 r .ipv4Start = node
400382 r .ipv4StartBitDepth = i
@@ -410,7 +392,10 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
410392 )
411393 }
412394
413- node , prefixLength := r .traverseTree (ip , 0 , 128 )
395+ node , prefixLength , err := r .traverseTree (ip , 0 , 128 )
396+ if err != nil {
397+ return 0 , 0 , err
398+ }
414399
415400 nodeCount := r .Metadata .NodeCount
416401 if node == nodeCount {
@@ -423,25 +408,134 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
423408 return 0 , prefixLength , mmdberrors .NewInvalidDatabaseError ("invalid node in search tree" )
424409}
425410
426- func (r * Reader ) traverseTree (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
411+ // readNodeBySize reads a node value from the buffer based on record size and bit.
412+ func readNodeBySize (buffer []byte , offset , bit , recordSize uint ) uint {
413+ switch recordSize {
414+ case 24 :
415+ offset += bit * 3
416+ return (uint (buffer [offset ]) << 16 ) |
417+ (uint (buffer [offset + 1 ]) << 8 ) |
418+ uint (buffer [offset + 2 ])
419+ case 28 :
420+ if bit == 0 {
421+ return ((uint (buffer [offset + 3 ]) & 0xF0 ) << 20 ) |
422+ (uint (buffer [offset ]) << 16 ) |
423+ (uint (buffer [offset + 1 ]) << 8 ) |
424+ uint (buffer [offset + 2 ])
425+ }
426+ return ((uint (buffer [offset + 3 ]) & 0x0F ) << 24 ) |
427+ (uint (buffer [offset + 4 ]) << 16 ) |
428+ (uint (buffer [offset + 5 ]) << 8 ) |
429+ uint (buffer [offset + 6 ])
430+ case 32 :
431+ offset += bit * 4
432+ return (uint (buffer [offset ]) << 24 ) |
433+ (uint (buffer [offset + 1 ]) << 16 ) |
434+ (uint (buffer [offset + 2 ]) << 8 ) |
435+ uint (buffer [offset + 3 ])
436+ default :
437+ return 0
438+ }
439+ }
440+
441+ func (r * Reader ) traverseTree (ip netip.Addr , node uint , stopBit int ) (uint , int , error ) {
442+ switch r .Metadata .RecordSize {
443+ case 24 :
444+ n , i := r .traverseTree24 (ip , node , stopBit )
445+ return n , i , nil
446+ case 28 :
447+ n , i := r .traverseTree28 (ip , node , stopBit )
448+ return n , i , nil
449+ case 32 :
450+ n , i := r .traverseTree32 (ip , node , stopBit )
451+ return n , i , nil
452+ default :
453+ return 0 , 0 , mmdberrors .NewInvalidDatabaseError (
454+ "unsupported record size: %d" ,
455+ r .Metadata .RecordSize ,
456+ )
457+ }
458+ }
459+
460+ func (r * Reader ) traverseTree24 (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
427461 i := 0
428462 if ip .Is4 () {
429463 i = r .ipv4StartBitDepth
430464 node = r .ipv4Start
431465 }
432466 nodeCount := r .Metadata .NodeCount
467+ buffer := r .buffer
468+ ip16 := ip .As16 ()
469+
470+ for ; i < stopBit && node < nodeCount ; i ++ {
471+ byteIdx := i >> 3
472+ bitPos := 7 - (i & 7 )
473+ bit := (uint (ip16 [byteIdx ]) >> bitPos ) & 1
433474
475+ baseOffset := node * 6
476+ offset := baseOffset + bit * 3
477+
478+ node = (uint (buffer [offset ]) << 16 ) |
479+ (uint (buffer [offset + 1 ]) << 8 ) |
480+ uint (buffer [offset + 2 ])
481+ }
482+
483+ return node , i
484+ }
485+
486+ func (r * Reader ) traverseTree28 (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
487+ i := 0
488+ if ip .Is4 () {
489+ i = r .ipv4StartBitDepth
490+ node = r .ipv4Start
491+ }
492+ nodeCount := r .Metadata .NodeCount
493+ buffer := r .buffer
434494 ip16 := ip .As16 ()
435495
436496 for ; i < stopBit && node < nodeCount ; i ++ {
437- bit := uint (1 ) & (uint (ip16 [i >> 3 ]) >> (7 - (i % 8 )))
497+ byteIdx := i >> 3
498+ bitPos := 7 - (i & 7 )
499+ bit := (uint (ip16 [byteIdx ]) >> bitPos ) & 1
500+
501+ baseOffset := node * 7
502+ sharedByte := uint (buffer [baseOffset + 3 ])
503+ mask := uint (0xF0 >> (bit * 4 ))
504+ shift := 20 + bit * 4
505+ nibble := ((sharedByte & mask ) << shift )
506+ offset := baseOffset + bit * 4
507+
508+ node = nibble |
509+ (uint (buffer [offset ]) << 16 ) |
510+ (uint (buffer [offset + 1 ]) << 8 ) |
511+ uint (buffer [offset + 2 ])
512+ }
438513
439- offset := node * r .nodeOffsetMult
440- if bit == 0 {
441- node = r .nodeReader .readLeft (offset )
442- } else {
443- node = r .nodeReader .readRight (offset )
444- }
514+ return node , i
515+ }
516+
517+ func (r * Reader ) traverseTree32 (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
518+ i := 0
519+ if ip .Is4 () {
520+ i = r .ipv4StartBitDepth
521+ node = r .ipv4Start
522+ }
523+ nodeCount := r .Metadata .NodeCount
524+ buffer := r .buffer
525+ ip16 := ip .As16 ()
526+
527+ for ; i < stopBit && node < nodeCount ; i ++ {
528+ byteIdx := i >> 3
529+ bitPos := 7 - (i & 7 )
530+ bit := (uint (ip16 [byteIdx ]) >> bitPos ) & 1
531+
532+ baseOffset := node * 8
533+ offset := baseOffset + bit * 4
534+
535+ node = (uint (buffer [offset ]) << 24 ) |
536+ (uint (buffer [offset + 1 ]) << 16 ) |
537+ (uint (buffer [offset + 2 ]) << 8 ) |
538+ uint (buffer [offset + 3 ])
445539 }
446540
447541 return node , i
0 commit comments