Skip to content

Commit

Permalink
black formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 21, 2024
1 parent d752230 commit 01a003a
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion egs/libritts/CODEC/encodec/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test():
for rep in range(4):
length: int = torch.randint(10, 2_000, (1,)).item()
bits: int = torch.randint(1, 16, (1,)).item()
tokens: List[int] = torch.randint(2 ** bits, (length,)).tolist()
tokens: List[int] = torch.randint(2**bits, (length,)).tolist()
rebuilt: List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
Expand Down
2 changes: 1 addition & 1 deletion egs/libritts/CODEC/encodec/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
super().__init__()
self.wav_to_specs = []
for i in range(5, 12):
s = 2 ** i
s = 2**i
self.wav_to_specs.append(
MelSpectrogram(
sample_rate=sampling_rate,
Expand Down
4 changes: 2 additions & 2 deletions egs/libritts/CODEC/encodec/modules/seanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
SEANetResnetBlock(
mult * n_filters,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
dilations=[dilation_base**j, 1],
norm=norm,
norm_params=norm_params,
activation=activation,
Expand Down Expand Up @@ -311,7 +311,7 @@ def __init__(
SEANetResnetBlock(
mult * n_filters // 2,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
dilations=[dilation_base**j, 1],
activation=activation,
activation_params=activation_params,
norm=norm,
Expand Down
16 changes: 8 additions & 8 deletions egs/libritts/CODEC/encodec/quantization/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def build_stable_quantized_cdf(
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2 ** total_range_bits
total_range = 2**total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
Expand All @@ -51,7 +51,7 @@ def build_stable_quantized_cdf(
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
if (
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
).any() or quantized_cdf[0] < min_range:
Expand Down Expand Up @@ -142,18 +142,18 @@ def push(self, symbol: int, quantized_cdf: Tensor):
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2 ** self.total_range_bits:
while self.delta < 2**self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1

range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(
math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
)
effective_high = int(
math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
)
assert self.low <= self.high
self.high = self.low + effective_high
Expand Down Expand Up @@ -238,7 +238,7 @@ def pull(self, quantized_cdf: Tensor) -> Optional[int]:
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2 ** self.total_range_bits:
while self.delta < 2**self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
Expand All @@ -255,10 +255,10 @@ def bin_search(low_idx: int, high_idx: int):
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(
math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
)
effective_high = int(
math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
)
low = effective_low + self.low
high = effective_high + self.low
Expand Down
2 changes: 1 addition & 1 deletion egs/libritts/CODEC/encodec/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):

for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs ** 2).sum(dim=-1)
dists = -(diffs**2).sum(dim=-1)

buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
Expand Down

0 comments on commit 01a003a

Please sign in to comment.