Skip to content

Commit 68f2e4b

Browse files
dan-onlineWilsontheWolf
authored andcommitted
fix(mongo): random sample when unique is false
1 parent f489b59 commit 68f2e4b

File tree

1 file changed

+52
-14
lines changed

1 file changed

+52
-14
lines changed

packages/mongo/src/lib/MongoProvider.ts

+52-14
Original file line numberDiff line numberDiff line change
@@ -599,43 +599,81 @@ export class MongoProvider<StoredValue = unknown> extends JoshProvider<StoredVal
599599

600600
// Due to the use of $sample, the output will never have duplicates
601601
public async [Method.Random](payload: Payload.Random<StoredValue>): Promise<Payload.Random<StoredValue>> {
602-
const docCount = await this.collection.countDocuments({});
602+
let { count, unique } = payload;
603+
const size = await this.collection.countDocuments({});
603604

604605
// TODO: @dan-online fix this yourself idk how this work
605606
// Basically just this:
606607
// if(unique && size < count) throw InvalidCount
607608
// if (size === 0) throw MissingData
608609
// Also try no to get an infinite loop with unique off and count > size
609610

610-
if (docCount === 0) return { ...payload, data: [] };
611-
if (docCount < payload.count) {
612-
payload.errors.push(this.error({ identifier: CommonIdentifiers.InvalidCount, method: Method.Random }));
611+
if (unique && size < count) {
612+
payload.errors.push(this.error({ identifier: CommonIdentifiers.InvalidCount, method: Method.Random }, { size }));
613613

614614
return payload;
615615
}
616616

617-
const aggr: Document[] = [{ $sample: { size: payload.count } }];
618-
const docs = (await this.collection.aggregate(aggr).toArray()) || [];
617+
if (size === 0) {
618+
payload.errors.push(this.error({ identifier: CommonIdentifiers.MissingData, method: Method.Random }, { unique, count }));
619619

620-
if (docs.length > 0) payload.data = docs.map((doc) => this.deserialize(doc.value));
620+
return payload;
621+
}
622+
623+
payload.data = [];
624+
625+
if (unique) {
626+
const aggr: Document[] = [{ $sample: { size: payload.count } }];
627+
const docs = (await this.collection.aggregate(aggr).toArray()) || [];
628+
629+
payload.data = docs.map((doc) => this.deserialize(doc.value));
630+
} else {
631+
while (count > 0) {
632+
const aggr: Document[] = [{ $sample: { size: 1 } }];
633+
const docs = (await this.collection.aggregate(aggr).toArray()) || [];
634+
635+
payload.data.push(this.deserialize(docs[0].value));
636+
637+
count--;
638+
}
639+
}
621640

622641
return payload;
623642
}
624643

625644
public async [Method.RandomKey](payload: Payload.RandomKey): Promise<Payload.RandomKey> {
626-
const docCount = await this.collection.countDocuments({});
645+
const size = await this.collection.countDocuments({});
646+
let { count, unique } = payload;
627647

628-
if (docCount === 0) return { ...payload, data: [] };
629-
if (docCount < payload.count) {
630-
payload.errors.push(this.error({ identifier: CommonIdentifiers.InvalidCount, method: Method.RandomKey }));
648+
if (unique && size < count) {
649+
payload.errors.push(this.error({ identifier: CommonIdentifiers.InvalidCount, method: Method.Random }, { size }));
631650

632651
return payload;
633652
}
634653

635-
const aggr: Document[] = [{ $sample: { size: payload.count } }];
636-
const docs = (await this.collection.aggregate(aggr).toArray()) || [];
654+
if (size === 0) {
655+
payload.errors.push(this.error({ identifier: CommonIdentifiers.MissingData, method: Method.Random }, { unique, count }));
637656

638-
if (docs.length > 0) payload.data = docs.map((doc) => doc.key);
657+
return payload;
658+
}
659+
660+
payload.data = [];
661+
662+
if (unique) {
663+
const aggr: Document[] = [{ $sample: { size: payload.count } }];
664+
const docs = (await this.collection.aggregate(aggr).toArray()) || [];
665+
666+
payload.data = docs.map((doc) => doc.key);
667+
} else {
668+
while (count > 0) {
669+
const aggr: Document[] = [{ $sample: { size: 1 } }];
670+
const docs = (await this.collection.aggregate(aggr).toArray()) || [];
671+
672+
payload.data.push(docs[0].key);
673+
674+
count--;
675+
}
676+
}
639677

640678
return payload;
641679
}

0 commit comments

Comments
 (0)