Skip to content

Commit 97e6660

Browse files
0.9.18
1 parent f3834dd commit 97e6660

20 files changed

+232
-26
lines changed

torchstudio/analyzers/multiclass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def generate_report(self, size, dpi):
100100

101101
canvas = plt.get_current_fig_manager().canvas
102102
canvas.draw()
103-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
103+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
104104
plt.close()
105105
return img
106106

torchstudio/analyzers/multilabel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def generate_report(self, size, dpi):
8787

8888
canvas = plt.get_current_fig_manager().canvas
8989
canvas.draw()
90-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
90+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
9191
plt.close()
9292
return img
9393

torchstudio/analyzers/valuesdistribution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def generate_report(self, size, dpi):
105105

106106
canvas = plt.get_current_fig_manager().canvas
107107
canvas.draw()
108-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
108+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
109109
plt.close()
110110
return img
111111

torchstudio/metricsplot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def inverse(x):
130130

131131
canvas = plt.get_current_fig_manager().canvas
132132
canvas.draw()
133-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
133+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
134134
plt.close()
135135
return img
136136

torchstudio/modelbuild.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def level_trace(root):
238238
for tensor in output_tensors:
239239
metric.append("Accuracy")
240240

241-
tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([64,0,100,20]))
242-
tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step']))
241+
tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([64,0,100,30]))
242+
tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['AdamWScheduleFree','NoSchedule']))
243243

244244
if msg_type == 'Exit':
245245
break

torchstudio/modeltrain.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tqdm.auto import tqdm
1717
from collections.abc import Iterable
1818
import threading
19-
19+
import math
2020

2121
class CachedDataset(Dataset):
2222
def __init__(self, train=True, hash=None):
@@ -294,6 +294,8 @@ def deepcopy_cpu(value):
294294

295295
#training
296296
model.train()
297+
if hasattr(optimizer, 'train'):
298+
optimizer.train()
297299
train_loss = 0
298300
train_metrics = []
299301
for metric in metrics:
@@ -334,6 +336,8 @@ def deepcopy_cpu(value):
334336

335337
#validation
336338
model.eval()
339+
if hasattr(optimizer, 'eval'):
340+
optimizer.eval()
337341
valid_loss = 0
338342
valid_metrics = []
339343
for metric in metrics:
+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
import torch.optim
8+
import math
9+
10+
class AdamWScheduleFree(torch.optim.Optimizer):
11+
r"""
12+
Schedule-Free AdamW
13+
As the name suggests, no scheduler is needed with this optimizer.
14+
To add warmup, rather than using a learning rate schedule you can just
15+
set the warmup_steps parameter.
16+
17+
This optimizer requires that .train() and .val() be called before the
18+
beginning of training and evaluation respectively.
19+
20+
Arguments:
21+
params (iterable):
22+
Iterable of parameters to optimize or dicts defining
23+
parameter groups.
24+
lr (float):
25+
Learning rate parameter (default 0.0025)
26+
betas (Tuple[float, float], optional): coefficients used for computing
27+
running averages of gradient and its square (default: (0.9, 0.999)).
28+
eps (float):
29+
Term added to the denominator outside of the root operation to
30+
improve numerical stability. (default: 1e-8).
31+
weight_decay (float):
32+
Weight decay, i.e. a L2 penalty (default: 0).
33+
warmup_steps (int): Enables a linear learning rate warmup (default 0).
34+
r (float): Use polynomial weighting in the average
35+
with power r (default 0).
36+
weight_lr_power (float): During warmup, the weights in the average will
37+
be equal to lr raised to this power. Set to 0 for no weighting
38+
(default 2.0).
39+
"""
40+
def __init__(self,
41+
params,
42+
lr=0.0025,
43+
betas=(0.9, 0.999),
44+
eps=1e-8,
45+
weight_decay=0,
46+
warmup_steps=0,
47+
r=0.0,
48+
weight_lr_power=2.0,
49+
):
50+
51+
defaults = dict(lr=lr,
52+
betas=betas,
53+
eps=eps,
54+
r=r,
55+
k=0,
56+
warmup_steps=warmup_steps,
57+
train_mode = True,
58+
weight_sum=0.0,
59+
lr_max=-1.0,
60+
weight_lr_power=weight_lr_power,
61+
weight_decay=weight_decay)
62+
super().__init__(params, defaults)
63+
64+
def eval(self):
65+
for group in self.param_groups:
66+
train_mode = group['train_mode']
67+
beta1, _ = group['betas']
68+
if train_mode:
69+
for p in group['params']:
70+
state = self.state[p]
71+
if 'z' in state:
72+
# Set p.data to x
73+
p.data.lerp_(end=state['z'], weight=1-1/beta1)
74+
group['train_mode'] = False
75+
76+
def train(self):
77+
for group in self.param_groups:
78+
train_mode = group['train_mode']
79+
beta1, _ = group['betas']
80+
if not train_mode:
81+
for p in group['params']:
82+
state = self.state[p]
83+
if 'z' in state:
84+
# Set p.data to y
85+
p.data.lerp_(end=state['z'], weight=1-beta1)
86+
group['train_mode'] = True
87+
88+
def step(self, closure=None):
89+
"""Performs a single optimization step.
90+
91+
Arguments:
92+
closure (callable, optional): A closure that reevaluates the model
93+
and returns the loss.
94+
"""
95+
96+
loss = None
97+
if closure is not None:
98+
loss = closure()
99+
100+
for group in self.param_groups:
101+
eps = group['eps']
102+
beta1, beta2 = group['betas']
103+
decay = group['weight_decay']
104+
k = group['k']
105+
r = group['r']
106+
warmup_steps = group['warmup_steps']
107+
weight_lr_power = group['weight_lr_power']
108+
109+
if k < warmup_steps:
110+
sched = (k+1) / warmup_steps
111+
else:
112+
sched = 1.0
113+
114+
bias_correction2 = 1 - beta2 ** (k+1)
115+
lr = group['lr']*sched*math.sqrt(bias_correction2)
116+
117+
lr_max = group['lr_max'] = max(lr, group['lr_max'])
118+
119+
weight = ((k+1)**r) * (lr_max**weight_lr_power)
120+
weight_sum = group['weight_sum'] = group['weight_sum'] + weight
121+
122+
ckp1 = weight/weight_sum
123+
124+
if not group['train_mode']:
125+
raise Exception("Not in train mode!")
126+
127+
for p in group['params']:
128+
if p.grad is None:
129+
continue
130+
131+
y = p.data # Notation to match theory
132+
grad = p.grad.data
133+
134+
state = self.state[p]
135+
136+
if 'z' not in state:
137+
state['z'] = torch.clone(y)
138+
state['exp_avg_sq'] = torch.zeros_like(p.data)
139+
140+
z = state['z']
141+
exp_avg_sq = state['exp_avg_sq']
142+
143+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
144+
denom = exp_avg_sq.sqrt().add_(eps)
145+
146+
# Reuse grad buffer for memory efficiency
147+
grad_normalized = grad.div_(denom)
148+
149+
# Weight decay calculated at y
150+
if decay != 0:
151+
grad_normalized.add_(y, alpha=decay)
152+
153+
# These operations update y in-place,
154+
# without computing x explicitly.
155+
y.lerp_(end=z, weight=ckp1)
156+
y.add_(grad_normalized, alpha=lr*(beta1*(1-ckp1)-1))
157+
158+
# z step
159+
z.sub_(grad_normalized, alpha=lr)
160+
161+
group['k'] = k+1
162+
return loss

torchstudio/parametersplot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def plot_parameters(size, dpi,
134134

135135
canvas = plt.get_current_fig_manager().canvas
136136
canvas.draw()
137-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
137+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
138138
plt.close()
139139
return img
140140

torchstudio/pythonparse.py

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def filter_parent_objects(objects:List[Dict]) -> List[Dict]:
157157

158158
generated_class="""\
159159
import typing
160+
from typing import Any, Callable, List, Tuple, Union, Sequence, Optional
160161
import pathlib
161162
import torch
162163
import torch.nn as nn
@@ -172,6 +173,7 @@ def __init__({4}):
172173

173174
generated_function="""\
174175
import typing
176+
from typing import Any, Callable, List, Tuple, Union, Sequence, Optional
175177
import pathlib
176178
import torch
177179
import torch.nn as nn

torchstudio/renderers/bitmap.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,6 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
118118

119119
canvas = plt.get_current_fig_manager().canvas
120120
canvas.draw()
121-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
121+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
122122
plt.close()
123123
return img

torchstudio/renderers/boundingbox.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
122122

123123
canvas = plt.get_current_fig_manager().canvas
124124
canvas.draw()
125-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
125+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
126126
plt.close()
127127
return img
128128

torchstudio/renderers/labels.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
129129

130130
canvas = plt.get_current_fig_manager().canvas
131131
canvas.draw()
132-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
132+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
133133
plt.close()
134134
return img
135135

torchstudio/renderers/signal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
9292

9393
canvas = plt.get_current_fig_manager().canvas
9494
canvas.draw()
95-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
95+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
9696
plt.close()
9797
return img
9898

torchstudio/renderers/spectrogram.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
124124

125125
canvas = plt.get_current_fig_manager().canvas
126126
canvas.draw()
127-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
127+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
128128
plt.close()
129129
return img
130130

torchstudio/renderers/volume.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
122122

123123
canvas = plt.get_current_fig_manager().canvas
124124
canvas.draw()
125-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
125+
img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba())
126126
plt.close()
127127
return img
128128

torchstudio/schedulers/multistep.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,15 @@ class MultiStep(lr_scheduler.MultiStepLR):
1212
gamma (float): Multiplicative factor of learning rate decay.
1313
Default: 0.1.
1414
last_epoch (int): The index of last epoch. Default: -1.
15+
verbose (bool): If ``True``, prints a message to stdout for
16+
each update. Default: ``False``.
17+
18+
.. deprecated:: 2.2
19+
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
20+
learning rate.
1521
1622
Example:
23+
>>> # xdoctest: +SKIP
1724
>>> # Assuming optimizer uses lr = 0.05 for all groups
1825
>>> # lr = 0.05 if epoch < 30
1926
>>> # lr = 0.005 if 30 <= epoch < 80
@@ -22,6 +29,8 @@ class MultiStep(lr_scheduler.MultiStepLR):
2229
>>> for epoch in range(100):
2330
>>> train(...)
2431
>>> validate(...)
25-
>>> scheduler.step()"""
26-
def __init__(self, optimizer, milestones=[75, 100, 125], gamma=0.1, last_epoch=-1):
27-
super().__init__(optimizer, milestones, gamma, last_epoch, verbose=False)
32+
>>> scheduler.step()
33+
"""
34+
35+
def __init__(self, optimizer, milestones=[75, 100, 125], gamma=0.1, last_epoch=-1, verbose="deprecated"):
36+
super().__init__(optimizer, milestones, gamma, last_epoch, verbose)

torchstudio/schedulers/noschedule.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class NoSchedule():
2+
"""No Schedule
3+
"""
4+
def __init__(self,
5+
optimizer,
6+
last_epoch=-1):
7+
self.last_epoch=0 if last_epoch<0 else last_epoch
8+
9+
def step(self):
10+
self.last_epoch+=1
11+
12+

0 commit comments

Comments
 (0)