Skip to content

Commit 958b97a

Browse files
committed
fixed imports
1 parent 0ea9db0 commit 958b97a

14 files changed

+39
-39
lines changed

et_replay/comm/backend/pytorch_dist_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from et_replay.comm.param_profile import paramProfile
2121

2222
try:
23-
from param_bench.train.comms.pt.fb.internals import (
23+
from fb.internals import (
2424
all_to_all_internal,
2525
all_to_allv_internal,
2626
extend_distributed,

et_replay/comm/comms_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2626

2727
try:
28-
from param_bench.train.comms.pt.fb.internals import (
28+
from fb.internals import (
2929
fbInitProfiler,
3030
fbSampleProfiler,
3131
fbStartProfiler,

et_replay/tools/comm_replay.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non
6666
if "://" in comms_file: # assume that "://" in directory path means remote store
6767
saveToLocal = False
6868
try:
69-
from param_bench.train.comms.pt.fb.internals import (
69+
from fb.internals import (
7070
writeRemoteTrace as writeFbRemoteTrace,
7171
)
7272

@@ -1528,7 +1528,7 @@ def readRawTrace(self, remotePath: str, rank: int) -> None:
15281528
raw_comms_trace = comms_utils.commonUrlRead(remotePath=remotePath)
15291529
else:
15301530
try:
1531-
from param_bench.train.comms.pt.fb.internals import (
1531+
from fb.internals import (
15321532
readRemoteTrace as readFbRemoteTrace,
15331533
)
15341534

train/comms/pt/comms.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@
1616
# pytorch
1717
import torch
1818

19-
from param_bench.train.comms.pt import comms_utils
20-
from param_bench.train.comms.pt.comms_utils import (
19+
import comms_utils
20+
from comms_utils import (
2121
bootstrap_info_holder,
2222
commsParamsHolderBase,
2323
ensureTensorFlush,
2424
paramCommsBench,
2525
paramDeviceTimer,
2626
paramStreamGuard,
2727
)
28-
from param_bench.train.comms.pt.logger_utils import (
28+
from logger_utils import (
2929
benchType,
3030
commsCollPerfMetrics,
3131
commsPt2PtPerfMetrics,
3232
commsQuantCollPerfMetrics,
3333
customized_perf_loggers,
3434
)
3535

36-
from param_bench.train.comms.pt.pytorch_backend_utils import (
36+
from pytorch_backend_utils import (
3737
pt2ptPatterns,
3838
supportedC10dBackends,
3939
supportedCollectives,
@@ -1824,13 +1824,13 @@ def initBackend(
18241824
commsParams.nw_stack == "pytorch-dist"
18251825
and commsParams.backend in supportedC10dBackends
18261826
):
1827-
from param_bench.train.comms.pt.pytorch_dist_backend import (
1827+
from pytorch_dist_backend import (
18281828
PyTorchDistBackend,
18291829
)
18301830

18311831
backendObj = PyTorchDistBackend(bootstrap_info, commsParams)
18321832
elif commsParams.nw_stack == "pytorch-xla-tpu":
1833-
from param_bench.train.comms.pt.pytorch_tpu_backend import PyTorchTPUBackend
1833+
from pytorch_tpu_backend import PyTorchTPUBackend
18341834

18351835
backendObj = PyTorchTPUBackend(bootstrap_info, commsParams)
18361836
else:
@@ -1839,7 +1839,7 @@ def initBackend(
18391839
logging.warning(
18401840
f"Attempt loading customized backend {commsParams.backend} if registered. Note that this is not officially supported. Use it with caution and at your own risk."
18411841
)
1842-
from param_bench.train.comms.pt.pytorch_backend_utils import (
1842+
from pytorch_backend_utils import (
18431843
customized_backend,
18441844
)
18451845

train/comms/pt/commsTraceParser.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from et_replay import ExecutionTrace
99

10-
from param_bench.train.comms.pt import comms_utils
11-
from param_bench.train.comms.pt.comms_utils import commsArgs
12-
from param_bench.train.comms.pt.pytorch_backend_utils import supportedP2pOps
10+
from import comms_utils
11+
from comms_utils import commsArgs
12+
from pytorch_backend_utils import supportedP2pOps
1313

1414
tensorDtypeMap = {
1515
"Tensor(int)": "int",

train/comms/pt/commsTraceReplay.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
import numpy as np
1919
import torch
20-
from param_bench.train.comms.pt import comms_utils
21-
from param_bench.train.comms.pt.comms_utils import (
20+
from import comms_utils
21+
from comms_utils import (
2222
bootstrap_info_holder,
2323
commsArgs,
2424
commsParamsHolderBase,
2525
paramCommsBench,
2626
paramStreamGuard,
2727
paramToCommName,
2828
)
29-
from param_bench.train.comms.pt.param_profile import paramProfile, paramTimer
30-
from param_bench.train.comms.pt.pytorch_backend_utils import supportedP2pOps
29+
from param_profile import paramProfile, paramTimer
30+
from pytorch_backend_utils import supportedP2pOps
3131

3232
try:
3333
from trainer_iteration_wrapper import setTrainingIteration # @manual
@@ -64,7 +64,7 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non
6464
if "://" in comms_file: # assume that "://" in directory path means remote store
6565
saveToLocal = False
6666
try:
67-
from param_bench.train.comms.pt.fb.internals import (
67+
from fb.internals import (
6868
writeRemoteTrace as writeFbRemoteTrace,
6969
)
7070

@@ -1388,13 +1388,13 @@ def initBackend(
13881388
"""
13891389
# init backend and corresponding function pointers
13901390
if commsParams.nw_stack == "pytorch-dist":
1391-
from param_bench.train.comms.pt.pytorch_dist_backend import (
1391+
from pytorch_dist_backend import (
13921392
PyTorchDistBackend,
13931393
)
13941394

13951395
self.backendFuncs = PyTorchDistBackend(bootstrap_info, commsParams)
13961396
elif commsParams.nw_stack == "pytorch-xla-tpu":
1397-
from param_bench.train.comms.pt.pytorch_tpu_backend import PyTorchTPUBackend
1397+
from pytorch_tpu_backend import PyTorchTPUBackend
13981398

13991399
self.backendFuncs = PyTorchTPUBackend(bootstrap_info, commsParams)
14001400
else:
@@ -1526,7 +1526,7 @@ def readRawTrace(self, remotePath: str, rank: int) -> None:
15261526
raw_comms_trace = comms_utils.commonUrlRead(remotePath=remotePath)
15271527
else:
15281528
try:
1529-
from param_bench.train.comms.pt.fb.internals import (
1529+
from fb.internals import (
15301530
readRemoteTrace as readFbRemoteTrace,
15311531
)
15321532

@@ -1589,7 +1589,7 @@ def readTrace(self, remotePath: str, rank: int) -> None:
15891589

15901590
# Convert trace to comms trace.
15911591
try:
1592-
from param_bench.train.comms.pt import commsTraceParser
1592+
from import commsTraceParser
15931593
except ImportError:
15941594
logger.info("FB internals not present, using base parser.")
15951595
self.comms_trace = extractCommsInfo(self.comms_trace)

train/comms/pt/comms_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2424

2525
try:
26-
from param_bench.train.comms.pt.fb.internals import (
26+
from fb.internals import (
2727
fbInitProfiler,
2828
fbSampleProfiler,
2929
fbStartProfiler,
@@ -38,8 +38,8 @@
3838

3939
import numpy as np
4040
import torch
41-
from param_bench.train.comms.pt.param_profile import paramTimer
42-
from param_bench.train.comms.pt.pytorch_backend_utils import (
41+
from param_profile import paramTimer
42+
from pytorch_backend_utils import (
4343
backendFunctions,
4444
collectiveArgsHolder,
4545
customized_backend,

train/comms/pt/dlrm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn as nn
2424
from comms_utils import paramCommsBench
2525

26-
from param_bench.train.comms.pt import comms_utils
26+
from import comms_utils
2727
from pytorch_dist_backend import PyTorchDistBackend
2828
from torch.autograd import Function
2929

train/comms/pt/logger_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from enum import Enum
99
from typing import abstractmethod, Dict, Optional
1010

11-
from param_bench.train.comms.pt.pytorch_backend_utils import backendFunctions
11+
from pytorch_backend_utils import backendFunctions
1212

1313
logger = logging.getLogger(__name__)
1414

train/comms/pt/pytorch_backend_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from param_bench.train.comms.pt.param_profile import paramTimer
12+
from param_profile import paramTimer
1313

1414
from torch.distributed import ProcessGroup
1515

train/comms/pt/pytorch_dist_backend.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import torch
1414
import torch.distributed as dist
1515
import torch.nn as nn
16-
from param_bench.train.comms.pt.param_profile import paramProfile
17-
from param_bench.train.comms.pt.pytorch_backend_utils import (
16+
from param_profile import paramProfile
17+
from pytorch_backend_utils import (
1818
backendFunctions,
1919
collectiveArgsHolder,
2020
)
@@ -27,7 +27,7 @@
2727
has_triton = False
2828

2929
try:
30-
from param_bench.train.comms.pt.fb.internals import (
30+
from fb.internals import (
3131
all_to_all_internal,
3232
all_to_allv_internal,
3333
extend_distributed,

train/comms/pt/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def main():
5-
package_base = "param_bench.train.comms.pt"
5+
package_base = ""
66

77
# List the packages and their dir mapping:
88
# "install_destination_package_path": "source_dir_path"

train/comms/pt/tests/commsTraceReplay_tests.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
from comms_utils import commsArgs
77

8-
from param_bench.train.comms.pt.commsTraceReplay import commsTraceReplayBench
9-
from param_bench.train.comms.pt.tests.mocks.backend_mock import MockBackendFunction
10-
from param_bench.train.comms.pt.tests.test_utils import (
8+
from commsTraceReplay import commsTraceReplayBench
9+
from tests.mocks.backend_mock import MockBackendFunction
10+
from tests.test_utils import (
1111
commsParamsTest,
1212
createCommsArgs,
1313
testArgs,

train/comms/pt/tests/comms_utils_tests.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import torch
55

6-
from param_bench.train.comms.pt import comms_utils
7-
from param_bench.train.comms.pt.tests.mocks.backend_mock import MockBackendFunction
8-
from param_bench.train.comms.pt.tests.test_utils import (
6+
from import comms_utils
7+
from tests.mocks.backend_mock import MockBackendFunction
8+
from tests.test_utils import (
99
bootstrap_info_test,
1010
commsParamsTest,
1111
)

0 commit comments

Comments
 (0)