diff --git a/alf/ext/fused_linear_act.py b/alf/ext/fused_linear_act.py index 34187e9db..87537b309 100644 --- a/alf/ext/fused_linear_act.py +++ b/alf/ext/fused_linear_act.py @@ -29,10 +29,7 @@ class StaticState: workspace = {} workspace_size = 1024 * 1024 * 8 - bias_g = { - idx: torch.tensor([], dtype=torch.float16).cuda(idx) - for idx in range(torch.cuda.device_count()) - } + bias_g = {} @classmethod def get(cls, name: str, device: torch.device) -> Any: @@ -41,6 +38,7 @@ def get(cls, name: str, device: torch.device) -> Any: cls.workspace[idx] = torch.empty((cls.workspace_size, ), dtype=torch.uint8, device=device).cuda(idx) + cls.bias_g[idx] = torch.tensor([], dtype=torch.float16).cuda(idx) if name == "bias": return cls.bias_g[idx] if name == "workspace":