Skip to content

Commit d1e7006

Browse files
committed
Implement conversion to explicit subword embeddings.
1 parent 6cd72a8 commit d1e7006

File tree

10 files changed

+241
-36
lines changed

10 files changed

+241
-36
lines changed

src/chunks/storage/array.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::mem::size_of;
55
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
66
use ndarray::{Array2, ArrayView2, ArrayViewMut2, Axis, CowArray, Ix1};
77

8-
use super::{Storage, StorageView, StorageViewMut};
8+
use super::{sealed::CloneFromMapping, Storage, StorageView, StorageViewMut};
99
use crate::chunks::io::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
1010
use crate::error::{Error, Result};
1111
use crate::util::padding;
@@ -27,14 +27,13 @@ mod mmap {
2727
use ndarray::{Array2, ArrayView2, Axis, CowArray, Ix1};
2828
use ndarray::{Dimension, Ix2};
2929

30-
#[cfg(target_endian = "little")]
3130
use super::NdArray;
3231
#[cfg(target_endian = "little")]
3332
use crate::chunks::io::WriteChunk;
3433
use crate::chunks::io::{ChunkIdentifier, TypeId};
35-
use crate::chunks::storage::Storage;
3634
#[cfg(target_endian = "little")]
3735
use crate::chunks::storage::StorageView;
36+
use crate::chunks::storage::{sealed::CloneFromMapping, Storage};
3837
use crate::error::{Error, Result};
3938
use crate::util::padding;
4039

@@ -103,6 +102,14 @@ mod mmap {
103102
}
104103
}
105104

105+
impl CloneFromMapping for MmapArray {
106+
type Result = NdArray;
107+
108+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result {
109+
NdArray::new(self.embeddings(mapping))
110+
}
111+
}
112+
106113
impl MmapChunk for MmapArray {
107114
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self> {
108115
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::NdArray)?;
@@ -287,6 +294,14 @@ impl StorageViewMut for NdArray {
287294
}
288295
}
289296

297+
impl CloneFromMapping for NdArray {
298+
type Result = NdArray;
299+
300+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result {
301+
NdArray::new(self.embeddings(mapping))
302+
}
303+
}
304+
290305
impl ReadChunk for NdArray {
291306
fn read_chunk<R>(read: &mut R) -> Result<Self>
292307
where

src/chunks/storage/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ pub trait Storage {
2929
fn shape(&self) -> (usize, usize);
3030
}
3131

32+
pub(crate) mod sealed {
33+
use crate::chunks::storage::Storage;
34+
35+
/// Return a new storage from an existing Storage based on a mapping.
36+
pub trait CloneFromMapping {
37+
type Result: Storage;
38+
39+
/// Construct a new Storage based on a mapping.
40+
///
41+
/// The `i`th entry in the returned storage is based on `self.embedding(mapping[i])`.
42+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result;
43+
}
44+
}
45+
3246
/// Storage that provide a view of the embedding matrix.
3347
pub trait StorageView: Storage {
3448
/// Get a view of the embedding matrix.

src/chunks/storage/quantized.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rand::{RngCore, SeedableRng};
88
use rand_xorshift::XorShiftRng;
99
use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
1010

11-
use super::{Storage, StorageView};
11+
use super::{sealed::CloneFromMapping, Storage, StorageView};
1212
use crate::chunks::io::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
1313
use crate::error::{Error, Result};
1414
use crate::storage::NdArray;
@@ -280,6 +280,20 @@ impl Storage for QuantizedArray {
280280
}
281281
}
282282

283+
impl CloneFromMapping for QuantizedArray {
284+
type Result = QuantizedArray;
285+
286+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result {
287+
let quantized_embeddings = self.quantized_embeddings.select(Axis(0), &mapping);
288+
let norms = self.norms.as_ref().map(|n| n.select(Axis(0), &mapping));
289+
QuantizedArray {
290+
quantizer: self.quantizer.clone(),
291+
quantized_embeddings,
292+
norms,
293+
}
294+
}
295+
}
296+
283297
impl ReadChunk for QuantizedArray {
284298
fn read_chunk<R>(read: &mut R) -> Result<Self>
285299
where
@@ -472,7 +486,7 @@ mod mmap {
472486
use super::{PQRead, QuantizedArray, Storage};
473487
use crate::chunks::io::MmapChunk;
474488
use crate::chunks::io::{ChunkIdentifier, WriteChunk};
475-
use crate::chunks::storage::NdArray;
489+
use crate::chunks::storage::{sealed::CloneFromMapping, NdArray};
476490
use crate::error::{Error, Result};
477491
use byteorder::{LittleEndian, ReadBytesExt};
478492

@@ -570,6 +584,21 @@ mod mmap {
570584
}
571585
}
572586

587+
impl CloneFromMapping for MmapQuantizedArray {
588+
type Result = QuantizedArray;
589+
590+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result {
591+
let quantized_embeddings =
592+
unsafe { self.quantized_embeddings() }.select(Axis(0), &mapping);
593+
let norms = self.norms.as_ref().map(|n| n.select(Axis(0), &mapping));
594+
QuantizedArray {
595+
quantizer: self.quantizer.clone(),
596+
quantized_embeddings,
597+
norms,
598+
}
599+
}
600+
}
601+
573602
impl MmapChunk for MmapQuantizedArray {
574603
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self> {
575604
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::QuantizedArray)?;

src/chunks/storage/wrappers.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use super::{NdArray, QuantizedArray, Storage, StorageView};
1414
#[cfg(feature = "memmap")]
1515
use crate::chunks::io::MmapChunk;
1616
use crate::chunks::io::{ChunkIdentifier, ReadChunk, WriteChunk};
17+
use crate::chunks::storage::sealed::CloneFromMapping;
1718
use crate::error::{Error, Result};
1819

1920
/// Storage types wrapper.
@@ -72,6 +73,27 @@ impl Storage for StorageWrap {
7273
}
7374
}
7475

76+
impl CloneFromMapping for StorageWrap {
77+
type Result = StorageWrap;
78+
79+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result {
80+
match self {
81+
StorageWrap::QuantizedArray(quant) => {
82+
StorageWrap::QuantizedArray(Box::new(quant.clone_from_mapping(mapping)))
83+
}
84+
#[cfg(feature = "memmap")]
85+
StorageWrap::MmapQuantizedArray(quant) => {
86+
StorageWrap::QuantizedArray(Box::new(quant.clone_from_mapping(mapping)))
87+
}
88+
#[cfg(feature = "memmap")]
89+
StorageWrap::MmapArray(mmapped) => {
90+
StorageWrap::NdArray(mmapped.clone_from_mapping(mapping))
91+
}
92+
StorageWrap::NdArray(array) => StorageWrap::NdArray(array.clone_from_mapping(mapping)),
93+
}
94+
}
95+
}
96+
7597
#[cfg(feature = "memmap")]
7698
impl From<MmapArray> for StorageWrap {
7799
fn from(s: MmapArray) -> Self {
@@ -236,6 +258,22 @@ impl StorageView for StorageViewWrap {
236258
}
237259
}
238260

261+
impl CloneFromMapping for StorageViewWrap {
262+
type Result = StorageViewWrap;
263+
264+
fn clone_from_mapping(&self, mapping: &[usize]) -> Self::Result {
265+
match self {
266+
#[cfg(all(feature = "memmap", target_endian = "little"))]
267+
StorageViewWrap::MmapArray(mmapped) => {
268+
StorageViewWrap::NdArray(mmapped.clone_from_mapping(mapping))
269+
}
270+
StorageViewWrap::NdArray(array) => {
271+
StorageViewWrap::NdArray(array.clone_from_mapping(mapping))
272+
}
273+
}
274+
}
275+
}
276+
239277
#[cfg(all(feature = "memmap", target_endian = "little"))]
240278
impl From<MmapArray> for StorageViewWrap {
241279
fn from(s: MmapArray) -> Self {

src/chunks/vocab/subword.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ where
357357
/// N-grams in the range `(self.min_n..self.max_n)` are extracted from the words in the
358358
/// vocabulary, each of these gets assigned an index from the `BucketIndexer` which is used to
359359
/// determine the index in the explicit subword vocab.
360-
pub fn to_explicit(&self) -> ExplicitSubwordVocab {
360+
///
361+
/// The second item in the returned tuple holds the `bucket -> explicit_index` mapping for
362+
/// all buckets hit by `self`.
363+
pub fn to_explicit(&self) -> (ExplicitSubwordVocab, HashMap<u64, usize>) {
361364
let mut ngram_index = HashMap::new();
362365
let SubwordVocab {
363366
words,
@@ -375,8 +378,11 @@ where
375378
ngram_index.entry(ngram.into()).or_insert(idx);
376379
}
377380
}
378-
let indexer = ExplicitIndexer::new_with_indices(ngram_index);
379-
ExplicitSubwordVocab::new(words.to_owned(), *min_n, *max_n, indexer)
381+
let (indexer, mapping) = ExplicitIndexer::new_with_indices(ngram_index);
382+
(
383+
ExplicitSubwordVocab::new(words.to_owned(), *min_n, *max_n, indexer),
384+
mapping,
385+
)
380386
}
381387
}
382388

@@ -407,7 +413,7 @@ impl SubwordVocab<ExplicitIndexer> {
407413

408414
let words = read_vocab_items(read, words_len as usize)?;
409415
let ngrams = read_ngrams_with_indices(read, ngrams_len as usize)?;
410-
let indexer = ExplicitIndexer::new_with_indices(ngrams);
416+
let (indexer, _) = ExplicitIndexer::new_with_indices(ngrams);
411417
Ok(SubwordVocab::new(words, min_n, max_n, indexer))
412418
}
413419

@@ -560,15 +566,15 @@ mod tests {
560566
("<t".to_owned(), 2),
561567
];
562568

563-
ExplicitSubwordVocab::new(words, 2, 3, ExplicitIndexer::new_with_indices(ngrams))
569+
ExplicitSubwordVocab::new(words, 2, 3, ExplicitIndexer::new_with_indices(ngrams).0)
564570
}
565571

566572
#[test]
567573
fn test_conversion() {
568574
let words = vec!["groß".to_owned(), "allerdings".to_owned()];
569575
let indexer = FinalfusionHashIndexer::new(21);
570576
let bucket_vocab = SubwordVocab::new(words, 3, 6, indexer);
571-
let explicit = bucket_vocab.to_explicit();
577+
let (explicit, _) = bucket_vocab.to_explicit();
572578
let dings = StrWithCharLen::new("dings");
573579
let gro = StrWithCharLen::new("<gro");
574580
let dings_expl_idx = explicit.indexer().index_ngram(&dings);

src/compat/text.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,11 @@ where
265265

266266
let matrix = Array2::from_shape_vec(shape, data).map_err(Error::Shape)?;
267267

268-
Ok(Embeddings::new_without_norms(
268+
Ok(Embeddings::new_with_maybe_norms(
269269
None,
270270
SimpleVocab::new(words),
271271
NdArray::new(matrix),
272+
None,
272273
))
273274
}
274275

src/compat/word2vec.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ where
107107
.map_err(|e| Error::io_error("Cannot read word embedding", e))?;
108108
}
109109

110-
Ok(Embeddings::new_without_norms(
110+
Ok(Embeddings::new_with_maybe_norms(
111111
None,
112112
SimpleVocab::new(words),
113113
NdArray::new(matrix),
114+
None,
114115
))
115116
}
116117
}

0 commit comments

Comments
 (0)