Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run pre-commit on all files, convert to ruff #29

Merged
merged 5 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ env.cache_id }}
Expand All @@ -22,4 +22,3 @@ jobs:
mkdocs-material-
- run: pip install mkdocs-material mkdocstrings mkdocs-literate-nav mkdocs-section-index mkdocs-gen-files mkdocstrings-python
- run: mkdocs gh-deploy --force

8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Set up python
id: setup-python
uses: actions/setup-python@v5
Expand All @@ -27,7 +27,7 @@ jobs:

- name: Install Poetry
uses: snok/install-poetry@v1

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v3
Expand All @@ -49,5 +49,5 @@ jobs:
fail_ci_if_error: false
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.json
slug: probcomp/hfppl
files: ./coverage.json
slug: probcomp/hfppl
19 changes: 11 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.13.2
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: isort
args: [--profile, black, --force-single-line-imports]
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
- id: check-yaml
args: [--unsafe]
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.9
hooks:
- id: black
language_version: python3.10
- id: ruff-format
types_or: [ python, pyi, jupyter ]
42 changes: 26 additions & 16 deletions benchmark/benchmark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,40 @@
Example usage: pytest benchmark/benchmark_backend.py --benchmark-only --benchmark-group-by=func -v
"""

import torch
import pytest
import asyncio
from hfppl.llms import CachedCausalLM

import pytest
import torch

from examples.haiku import run_example as run_haiku
from examples.hard_constraints import run_example as run_hard_constraints
from hfppl.llms import CachedCausalLM

backends = [
'hf',
"hf",
pytest.param(
'vllm',
"vllm",
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason="vLLM backend requires CUDA"
)
)
not torch.cuda.is_available(), reason="vLLM backend requires CUDA"
),
),
]


@pytest.fixture
def LLM(backend):
# Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU
kwargs = {'engine_opts' : {'gpu_memory_utilization' : 0.45}, 'cache_size' : 100} if backend == 'vllm' else {}
return CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", backend=backend, **kwargs)
kwargs = (
{"engine_opts": {"gpu_memory_utilization": 0.45}, "cache_size": 100}
if backend == "vllm"
else {}
)
return CachedCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B", backend=backend, **kwargs
)

@pytest.mark.parametrize('backend', backends)

@pytest.mark.parametrize("backend", backends)
def test_hard_constraints_benchmark(LLM, benchmark, n_particles=20, max_tokens=50):
def run_with_clear_cache():
LLM.clear_cache()
Expand All @@ -38,24 +47,25 @@ def run_with_clear_cache():

# warmup
run_with_clear_cache()

benchmark.pedantic(
run_with_clear_cache,
iterations=1,
rounds=3,
)

@pytest.mark.parametrize('backend', backends)

@pytest.mark.parametrize("backend", backends)
def test_haiku_benchmark(LLM, benchmark, n_particles=20):
def run_with_clear_cache():
LLM.clear_cache()
return asyncio.run(
run_haiku(LLM, poem_title='The beauty of testing', n_particles=n_particles)
run_haiku(LLM, poem_title="The beauty of testing", n_particles=n_particles)
)

# warmup
run_with_clear_cache()

benchmark.pedantic(
run_with_clear_cache,
iterations=1,
Expand Down
2 changes: 1 addition & 1 deletion docs/anatomy.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# Anatomy of a LLaMPPL model
# Anatomy of a LLaMPPL model
8 changes: 4 additions & 4 deletions docs/batching.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

If running in a GPU-accelerated environment, LLaMPPL supports **auto-batching**.

The `step` method of a LLaMPPL model describes how to advance a *single* particle one step of generation.
But inference methods must maintain many particles at once.
The `step` method of a LLaMPPL model describes how to advance a *single* particle one step of generation.
But inference methods must maintain many particles at once.

With auto-batching, LLaMPPL will execute particles' `step` methods concurrently, and automatically batch calls
to large language models. This batching is handled by the `CachedCausalLM` object, and its behavior is controlled by two parameters:

* `lm.batch_size`: the maximum number of requests to batch. The default value is 20.
* `lm.timeout`: if `lm.timeout` seconds pass with no new request, the current batch is processed even if not full. The default value is 0.02.
You may want to set the batch size (`#!python lm.batch_size`) to the number of particles you are using (if the number of particles is not too large).

You may want to set the batch size (`#!python lm.batch_size`) to the number of particles you are using (if the number of particles is not too large).
12 changes: 6 additions & 6 deletions docs/caching.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Next-token log probabilities are always cached, whenever they are computed.
This way, if different particles make exactly the same log probability queries,
the Transformer is run only once. This is primarily beneficial when:

* particles are cloned during resampling: when each particle is
* particles are cloned during resampling: when each particle is

* cloned particles happen to sample the same next token: if the next-token distribution is concentrated,
it is likely that multiple copies of a particle will sample the same next token. Log probability caching
Expand All @@ -30,15 +30,15 @@ In principle, key-value caching is most useful when:
cost to cache *different* key-value sequences for *each* particle, to speed up future next-token
queries.

Currently, only the first use case is well-supported by the LLaMPPL library, via the
Currently, only the first use case is well-supported by the LLaMPPL library, via the
[`lm.cache_kv(prompt)`][hfppl.llms.CachedCausalLM.cache_kv] method. This method computes and caches key and value vectors
for every token in `prompt`. Future calls to [`lm.next_token_logprobs`][hfppl.llms.CachedCausalLM.next_token_logprobs] and [`lm.next_token_logprobs_unbatched`][hfppl.llms.CachedCausalLM.next_token_logprobs_unbatched]
will automatically recognize when `prompt` is a prefix of the new query, and automatically
exploit incremental computation. Multiple prompts can be cached, and [`lm.clear_kv_cache()`][hfppl.llms.CachedCausalLM.clear_kv_cache] can
be used to clear the KV-cache without clearing the log probability cache.

Because [`lm.cache_kv`][hfppl.llms.CachedCausalLM.cache_kv] is not a batched call,
it is not well-suited to caching
different strings for different particles.
Because [`lm.cache_kv`][hfppl.llms.CachedCausalLM.cache_kv] is not a batched call,
it is not well-suited to caching
different strings for different particles.
Rather, it is best used in the `__init__` method of a model--or even
outside of a model--on fixed prompt strings that every particle will share.
outside of a model--on fixed prompt strings that every particle will share.
8 changes: 4 additions & 4 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ class MyModel(Model):
# A stateful context object for the LLM, initialized with the prompt
self.context = LMContext(lm, prompt)
self.eos_token = lm.tokenizer.eos_token_id

# The forbidden letter
self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab)
if forbidden_letter in v)

# The step method is used to perform a single 'step' of generation.
# This might be a single token, a single phrase, or any other division.
# Here, we generate one token at a time.
async def step(self):
# Condition on the next token *not* being a forbidden token.
await self.observe(self.context.mask_dist(self.forbidden_tokens), False)

# Sample the next token from the LLM -- automatically extends `self.context`.
token = await self.sample(self.context.next_token())

Expand Down Expand Up @@ -98,4 +98,4 @@ for particle in particles:

## Learning more

For more intuition on language model probabilistic programming, see [our paper](https://arxiv.org/abs/2306.03081), or the rest of this documentation.
For more intuition on language model probabilistic programming, see [our paper](https://arxiv.org/abs/2306.03081), or the rest of this documentation.
6 changes: 3 additions & 3 deletions docs/immutability.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Immutability

When a particle is promising, the sequential Monte Carlo algorithm may _clone_ it, by calling `copy.deepcopy`.
When a particle is promising, the sequential Monte Carlo algorithm may _clone_ it, by calling `copy.deepcopy`.

Depending on your model, this may be more or less expensive.
Depending on your model, this may be more or less expensive.

To make it faster, override the `immutable_properties(self)` method of your Model class, to return a `set[str]` of property names that are guaranteed not to change during `step`. For all properties in this set, LLaMPPL will use shared memory across particles, and avoid copying when cloning particles.
To make it faster, override the `immutable_properties(self)` method of your Model class, to return a `set[str]` of property names that are guaranteed not to change during `step`. For all properties in this set, LLaMPPL will use shared memory across particles, and avoid copying when cloning particles.
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Home

[LLaMPPL](https://github.com/probcomp/hfppl) is a research prototype for language model probabilistic programming: specifying language generation tasks by writing probabilistic programs that combine calls to LLMs, symbolic program logic, and probabilistic conditioning. To solve these tasks, LLaMPPL uses a specialized sequential Monte Carlo inference algorithm.
[LLaMPPL](https://github.com/probcomp/hfppl) is a research prototype for language model probabilistic programming: specifying language generation tasks by writing probabilistic programs that combine calls to LLMs, symbolic program logic, and probabilistic conditioning. To solve these tasks, LLaMPPL uses a specialized sequential Monte Carlo inference algorithm.

This technique, SMC steering, is described in our workshop abstract, [Sequential Monte Carlo Steering of Large Language Models using Probabilistic Programs](https://arxiv.org/abs/2306.03081).
This technique, SMC steering, is described in our workshop abstract, [Sequential Monte Carlo Steering of Large Language Models using Probabilistic Programs](https://arxiv.org/abs/2306.03081).
2 changes: 1 addition & 1 deletion docs/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ If your LLaMPPL model is running slowly, consider exploiting the following featu

- [Auto-Batching](batching.md) — to run multiple particles concurrently, with batched LLM calls
- [Caching](caching.md) - to cache key and value vectors for long prompts
- [Immutability hinting](immutability.md) - to significantly speed up the bookkeeping performed by SMC inference
- [Immutability hinting](immutability.md) - to significantly speed up the bookkeeping performed by SMC inference
6 changes: 3 additions & 3 deletions docs/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Alternatively, you can initialize an [`LMContext`][hfppl.distributions.lmcontext

## Create custom token distributions with `TokenCategorical`

You may also create a custom distribution over the vocabulary of a language model using the [`TokenCategorical`][hfppl.distributions.tokencategorical.TokenCategorical] distribution. It is parameterized by a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] instance, and an array of logits equal in length to the language model's vocabulary size.
This distribution is particularly useful as a proposal distribution; for example, a model might `sample` with `dist` set
You may also create a custom distribution over the vocabulary of a language model using the [`TokenCategorical`][hfppl.distributions.tokencategorical.TokenCategorical] distribution. It is parameterized by a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] instance, and an array of logits equal in length to the language model's vocabulary size.
This distribution is particularly useful as a proposal distribution; for example, a model might `sample` with `dist` set
to the LM's next token distribution, but with `proposal` set to a modified distribution that uses a heuristic to upweight
'good' tokens and downweight 'bad' ones.
'good' tokens and downweight 'bad' ones.
5 changes: 2 additions & 3 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Return a string that summarizes the particle's current state.
To run the interface, change to the `html` directory and run `python -m http.server`. This will start serving
the files in the `html` directory at localhost:8000. (If you are SSH-ing onto a remote machine, you may need
port forwarding. Visual Studio Code automatically handles this for some ports, including 8000.)
Then, when calling [`smc_standard`](hfppl.inference.smc_standard), set `visualization_dir`
to the path to the `html` directory. A JSON record of the run will automatically be saved
Then, when calling [`smc_standard`](hfppl.inference.smc_standard), set `visualization_dir`
to the path to the `html` directory. A JSON record of the run will automatically be saved
to that directory, and a URL will be printed to the console (`http://localhost:8000/smc.html?path=$json_file`).

4 changes: 1 addition & 3 deletions examples/grammar_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
"""

import asyncio
import os
from typing import List

from synchromesh.completion_engine import LarkCompletionEngine
from synchromesh.synchromesh import StreamingCSD
Expand Down Expand Up @@ -126,7 +124,7 @@ async def run_generation(
verbose: bool = False,
):
LLM = CachedCausalLM.from_pretrained(args.model)
if LLM.backend == 'hf':
if LLM.backend == "hf":
LLM.batch_size = args.batch_size
model = GrammarConstrainedSMC(
lm=LLM,
Expand Down
26 changes: 15 additions & 11 deletions examples/haiku.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os

import nltk

Expand All @@ -19,7 +18,6 @@


def count_syllables(word, unknown_word_syllables=100):

# Use the dictionary to get the list of possible phonetic representations for the word
phonetic_transcriptions = CMUDICT.get(word.strip().lower(), [])

Expand All @@ -34,6 +32,7 @@ def count_syllables(word, unknown_word_syllables=100):

return syllable_count


# Example poems for the prompt.
# Authors:
# - Amy Lowell
Expand Down Expand Up @@ -65,9 +64,9 @@ def count_syllables(word, unknown_word_syllables=100):
this deep in fall,
still not a butterfly."""


# LLaMPPL model
class Haiku(Model):

def __init__(self, LLM, prompt, syllable_pattern=[5, 7, 5]):
super().__init__()
self.context = LMContext(LLM, prompt)
Expand All @@ -84,7 +83,6 @@ async def step(self):

# Loop to sample words until this line is over
while syllables_remaining > 0:

# Sample a word
word, punctuation = await self.call(sample_word(self.context))

Expand Down Expand Up @@ -116,13 +114,16 @@ def string_for_serialization(self):
)
return s.replace("\n", "/")

async def run_example(LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=20, ess_threshold=0.5):

async def run_example(
LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=20, ess_threshold=0.5
):
# Construct prompt
prompt = f"""{EXAMPLE_POEMS}

5. "{poem_title}"
"""

# Cache the key value vectors for the prompt
LLM.cache_kv(LLM.tokenizer.encode(prompt))

Expand All @@ -136,6 +137,7 @@ async def run_example(LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=2

return particles


def main():
# Load the language model.
# Mistral is an open model; to use a model with restricted access, like LLaMA 3,
Expand All @@ -144,22 +146,24 @@ def main():
# LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

# Set batch size if using HuggingFace backend
if LLM.backend == 'hf':
if LLM.backend == "hf":
LLM.batch_size = 40

# Get poem title from user
poem_title = input("Enter a title for your Haiku: ")

syllables_per_line = [5, 7, 5] # [5, 3, 5] for a Lune
syllables_per_line = [5, 7, 5] # [5, 3, 5] for a Lune

# Run the example
particles = asyncio.run(run_example(LLM, poem_title, syllable_pattern=syllables_per_line))
particles = asyncio.run(
run_example(LLM, poem_title, syllable_pattern=syllables_per_line)
)

print("--------")
for i, particle in enumerate(particles):
print(f"\nPoem {i} (weight {particle.weight}):")
print(f"{particle.context}")


if __name__ == "__main__":
main()

Loading