Skip to content

Commit

Permalink
Allow overweight items in AChao sampler (stanford-futuredata#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbailis authored Mar 15, 2017
1 parent f6f24ea commit 485162a
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 3 deletions.
70 changes: 68 additions & 2 deletions core/src/main/java/macrobase/analysis/sample/AChao.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
package macrobase.analysis.sample;

import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.stream.Collectors;

/**
* See http://arxiv.org/pdf/1012.0256.pdf
*/
class AChao<T> {
private class OverweightItem<T> implements Comparable<OverweightItem> {
public Double weight;
public T item;

public OverweightItem(T item,
double weight) {
this.item = item;
this.weight = weight;
}

@Override
public int compareTo(OverweightItem o) {
return weight.compareTo(o.weight);
}
}

private final List<T> reservoir;
double runningCount;
private final int reservoirCapacity;
private final Random random;
private final PriorityQueue<OverweightItem<T>> overweightItems = new PriorityQueue<>();

public AChao(int capacity) {
this(capacity, new Random());
Expand All @@ -24,17 +51,56 @@ public AChao(int capacity, Random random) {
this.random = random;
}

private void updateOverweightItems() {
while(!overweightItems.isEmpty()) {
OverweightItem<T> ow = overweightItems.peek();
if(reservoirCapacity * ow.weight / runningCount <= 1) {
overweightItems.poll();
insert(ow.item, ow.weight);
} else {
break;
}
}
}

public final List<T> getReservoir() {
updateOverweightItems();

if(!overweightItems.isEmpty()) {
// overweight items always make it in the sample
List<T> ret = overweightItems.stream().map(i -> i.item).collect(Collectors.toList());

assert (ret.size() <= reservoirCapacity);

// fill the return value with a sample of non-overweight elements
Collections.shuffle(reservoir, random);
ret.addAll(reservoir.subList(0, reservoirCapacity-ret.size()));
return ret;
}

return reservoir;
}

protected void decayWeights(double decay) {
runningCount *= decay;
overweightItems.forEach(i -> i.weight *= decay);
}

public void insert(T ele, double weight) {
runningCount += weight;

updateOverweightItems();

if (reservoir.size() < reservoirCapacity) {
reservoir.add(ele);
} else if (random.nextDouble() < weight / runningCount) {
reservoir.set(random.nextInt(reservoirCapacity), ele);
} else {
double pInsertion = reservoirCapacity * weight / runningCount;

if(pInsertion > 1) {
overweightItems.add(new OverweightItem(ele, weight));
} else if (random.nextDouble() < pInsertion) {
reservoir.set(random.nextInt(reservoirCapacity), ele);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void advancePeriod() {
}

public void advancePeriod(int numPeriods) {
runningCount *= Math.pow(1 - bias, numPeriods);
decayWeights(Math.pow(1 - bias, numPeriods));
}

public void insert(T ele) {
Expand Down
86 changes: 86 additions & 0 deletions core/src/test/java/macrobase/analysis/sample/AChaoTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package macrobase.analysis.sample;

import com.google.common.collect.Lists;
import macrobase.datamodel.Datum;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import static org.junit.Assert.*;


public class AChaoTest {
@Test
public void simpleTest() throws Exception {
int[] sample = {1, 2, 3, 4, 5, 6, 7};

Random r = new Random(0);
AChao<Integer> ac = new AChao<>(2, r);

for(int i : sample) {
ac.insert(i, 1);
}

assertEquals((Integer) 2, (Integer) ac.getReservoir().size());
assertEquals((Integer) 5, ac.getReservoir().get(0));
assertEquals((Integer) 4, ac.getReservoir().get(1));
}

@Test
public void testOverweightItems() throws Exception {
int[] sample = {1, 2, 3, 4, 5, 6, 7};

Random r = new Random(0);
AChao<Integer> ac = new AChao<>(2, r);

for(int i : sample) {
ac.insert(i, 1);
}

assertEquals((Integer) 2, (Integer) ac.getReservoir().size());
assertEquals((Integer) 5, ac.getReservoir().get(0));
assertEquals((Integer) 4, ac.getReservoir().get(1));

ac.decayWeights(.1);
ac.insert(100, 1000);

assertEquals((Integer) 2, (Integer) ac.getReservoir().size());
assertTrue(ac.getReservoir().contains(100));

ac.decayWeights(.00001);

ac.insert(200, 1000);
assertTrue(ac.getReservoir().contains(200));
}

@Test
public void testOverweightItemSequential() throws Exception {
int[] sample = {1, 2, 3, 4, 5, 6, 7};

Random r = new Random(0);
AChao<Integer> ac = new AChao<>(100, r);

for(int j = 0; j < 100; ++j) {
for (int i : sample) {
ac.insert(i, 1);
}
}

ac.decayWeights(.00001);
ac.insert(100, 1);
ac.insert(200, 1);
ac.insert(300, 1);

assertEquals((Integer) 100, (Integer) ac.getReservoir().size());
assertTrue(ac.getReservoir().contains(100));

ac.decayWeights(.0000001);
ac.insert(400, 1);

assertTrue(ac.getReservoir().contains(400));
}
}

0 comments on commit 485162a

Please sign in to comment.