-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Hey! Thanks for releasing your work!
I was wondering if you looked at k-means for 1D data. As I understand it you can find the globally optimal centroids, so I thought it might be interesting.
Ran some tests with Llama 3.1 8B Instruct (models are on huggingface):
Model | Bits | ArcC | ArcE | STEMMMLU | HumanMMLU | SocialMMLU | OtherMMLU | Avg |
---|---|---|---|---|---|---|---|---|
float16 | 16 | 51.62 | 81.86 | 58.61 | 64.42 | 75.88 | 74.31 | 67.78 |
bfloat16 | 16 | 51.79 | 81.86 | 58.64 | 64.25 | 76.86 | 74.18 | 67.93 |
sklearn | 4.05 | 51.02 | 81.52 | 56.58 | 60.31 | 75.88 | 72.96 | 66.37 |
kmeans1d | 4.05 | 52.04 | 82.36 | 57.24 | 62.23 | 75.33 | 73.09 | 67.04 |
sklearn | 3.02 | 46.67 | 78.74 | 52.39 | 54.79 | 71.01 | 70.09 | 62.28 |
kmeans1d | 3.02 | 42.49 | 74.53 | 53.82 | 55.70 | 69.48 | 68.13 | 60.69 |
Curious what you make of it.
The change is minimal:
pip install git+https://github.com/smpanaro/kmeans1d@master
(credit to apple/coremltools)
and apply this diff:
diff
diff --git a/lean_quantizer.py b/lean_quantizer.py
index a860f64..250beb1 100644
--- a/lean_quantizer.py
+++ b/lean_quantizer.py
@@ -6,6 +6,7 @@ import numpy as np
from sklearn.cluster import KMeans
from multiprocessing import Pool
from tqdm import tqdm
+import kmeans1d
import torch
import torch.nn as nn
@@ -14,13 +15,18 @@ import transformers
from quant import *
-DEBUG = False
+DEBUG = False
+USE_KMEANS1D = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def kmeans_fit(row_data):
weights_np, sample_weight, n_cluster, random_seed = row_data
+ if USE_KMEANS1D:
+ _, centroids = kmeans1d.cluster(weights_np, n_cluster, weights=sample_weight)
+ return np.array(centroids, dtype=np.float32)
+
kmeans = KMeans(
n_clusters=n_cluster,
init=np.linspace(weights_np.min(), weights_np.max(), num=n_cluster)[:, None] if n_cluster <= 8 else 'k-means++',
Metadata
Metadata
Assignees
Labels
No labels