Unofficial Python implementation of Can a Fruit Fly Learn Word Embeddings? with a PyTorch flavored API.
pip install git+https://github.com/Ramos-Ramos/fruit-fly-net
Check out our Colab demo!
import numpy as xp
from scipy.special import softmax
from fruit_fly_net import FruitFlyNet
model = FruitFlyNet(
input_dim=40000, # input dimension size (vocab_size * 2)
output_dim=600, # output dimension size
k=16, # top k cells to be left active in output layer
lr=1e-4 # learning rate (learning is performed internally)
)
x = xp.concatenate([xp.argsort(xp.random.rand(2000, 20000)) < i for i in (15, 1)], axis=1)
probs = xp.tile(softmax(xp.random.rand(20000)), 2)
output = model(x, probs)
Learning is performed internally as long as the model is in train mode. No need to call .backward()
or instantiate optimizers. To set the mode, use .train()
and .eval()
.
model.train() # will update weights on forward pass
model.eval() # will not update weights on forward pass
To get the loss, use bio_hash_loss
.
from FruitFlyNet import bio_hash_loss
loss = bio_hash_loss(model.weights, x, probs)
To enable gpu learning, move the model to the gpu via .to
and use cupy instead of numpy.
import cupy as xp
model = FruitFlyNet(
input_size=40000,
output_size=600,
k=16,
lr=1e-4
)
model.to('gpu')
@misc{liang2021fruit,
title={Can a Fruit Fly Learn Word Embeddings?},
author={Yuchen Liang and Chaitanya K. Ryali and Benjamin Hoover and Leopold Grinberg and Saket Navlakha and Mohammed J. Zaki and Dmitry Krotov},
year={2021},
eprint={2101.06887},
archivePrefix={arXiv},
primaryClass={cs.CL}
}