diff --git a/gptq.py b/gptq.py index 1fa90c4..9a609a2 100644 --- a/gptq.py +++ b/gptq.py @@ -53,8 +53,8 @@ def add_batch(self, inp, out): self.H *= self.nsamples / (self.nsamples + tmp) self.nsamples += tmp # inp = inp.float() - inp = math.sqrt(2 / self.nsamples) * inp.float() - # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + inp = math.sqrt(2*tmp / self.nsamples) * inp.float() + # self.H += 2*tmp / self.nsamples * inp.matmul(inp.t()) self.H += inp.matmul(inp.t()) def fasterquant(