|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import torch
|
20 |
| -from param_bench.train.comms.pt import comms_utils |
21 |
| -from param_bench.train.comms.pt.comms_utils import ( |
| 20 | +from import comms_utils |
| 21 | +from comms_utils import ( |
22 | 22 | bootstrap_info_holder,
|
23 | 23 | commsArgs,
|
24 | 24 | commsParamsHolderBase,
|
25 | 25 | paramCommsBench,
|
26 | 26 | paramStreamGuard,
|
27 | 27 | paramToCommName,
|
28 | 28 | )
|
29 |
| -from param_bench.train.comms.pt.param_profile import paramProfile, paramTimer |
30 |
| -from param_bench.train.comms.pt.pytorch_backend_utils import supportedP2pOps |
| 29 | +from param_profile import paramProfile, paramTimer |
| 30 | +from pytorch_backend_utils import supportedP2pOps |
31 | 31 |
|
32 | 32 | try:
|
33 | 33 | from trainer_iteration_wrapper import setTrainingIteration # @manual
|
@@ -64,7 +64,7 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non
|
64 | 64 | if "://" in comms_file: # assume that "://" in directory path means remote store
|
65 | 65 | saveToLocal = False
|
66 | 66 | try:
|
67 |
| - from param_bench.train.comms.pt.fb.internals import ( |
| 67 | + from fb.internals import ( |
68 | 68 | writeRemoteTrace as writeFbRemoteTrace,
|
69 | 69 | )
|
70 | 70 |
|
@@ -1388,13 +1388,13 @@ def initBackend(
|
1388 | 1388 | """
|
1389 | 1389 | # init backend and corresponding function pointers
|
1390 | 1390 | if commsParams.nw_stack == "pytorch-dist":
|
1391 |
| - from param_bench.train.comms.pt.pytorch_dist_backend import ( |
| 1391 | + from pytorch_dist_backend import ( |
1392 | 1392 | PyTorchDistBackend,
|
1393 | 1393 | )
|
1394 | 1394 |
|
1395 | 1395 | self.backendFuncs = PyTorchDistBackend(bootstrap_info, commsParams)
|
1396 | 1396 | elif commsParams.nw_stack == "pytorch-xla-tpu":
|
1397 |
| - from param_bench.train.comms.pt.pytorch_tpu_backend import PyTorchTPUBackend |
| 1397 | + from pytorch_tpu_backend import PyTorchTPUBackend |
1398 | 1398 |
|
1399 | 1399 | self.backendFuncs = PyTorchTPUBackend(bootstrap_info, commsParams)
|
1400 | 1400 | else:
|
@@ -1526,7 +1526,7 @@ def readRawTrace(self, remotePath: str, rank: int) -> None:
|
1526 | 1526 | raw_comms_trace = comms_utils.commonUrlRead(remotePath=remotePath)
|
1527 | 1527 | else:
|
1528 | 1528 | try:
|
1529 |
| - from param_bench.train.comms.pt.fb.internals import ( |
| 1529 | + from fb.internals import ( |
1530 | 1530 | readRemoteTrace as readFbRemoteTrace,
|
1531 | 1531 | )
|
1532 | 1532 |
|
@@ -1589,7 +1589,7 @@ def readTrace(self, remotePath: str, rank: int) -> None:
|
1589 | 1589 |
|
1590 | 1590 | # Convert trace to comms trace.
|
1591 | 1591 | try:
|
1592 |
| - from param_bench.train.comms.pt import commsTraceParser |
| 1592 | + from import commsTraceParser |
1593 | 1593 | except ImportError:
|
1594 | 1594 | logger.info("FB internals not present, using base parser.")
|
1595 | 1595 | self.comms_trace = extractCommsInfo(self.comms_trace)
|
|
0 commit comments