-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathutils.py
77 lines (61 loc) · 2.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from prettytable import PrettyTable
import torch
import os
import random
import numpy as np
from tqdm import tqdm
import deepwave
import torch.nn as nn
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from scipy.io import loadmat
from tools import (SaveResults, rock_properties,
load_checkpoint, awgn)
from PyFWI.rock_physics import pcs2dv_gassmann
from pyfwi_tools import model_resizing
from typing import List, Tuple, Optional
PATH = os.path.abspath(os.path.join(os.path.abspath(__file__), ".."))
def earth_model(name, smooth=0, device="cpu"):
if name == "marmousi_30":
vp = torch.load(PATH + "/data_model/marmousi_30.bin")
elif name == "marmousi_bl":
vp = torch.load(PATH + "/data_model/marmousi_bl.bin")
vp0 = torch.tensor(gaussian_filter(vp, sigma=smooth))
return vp.to(device=device), vp0.to(device=device)
def count_parameters(model):
'''
Function to count parameters of a network
'''
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
def data_normalization(data):
'''
Normalize before squeezing for batch and along t-direction.
It means, we measure the max of data for each receivers
'''
data_max, _ = data.max(dim=1, keepdim=True)
return data / (data_max.abs() + 1e-10)
def seed_everything(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def save_checkpoint(model, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
# "optimizer": optimizer.state_dict(),
# "inpa": inpa
}
torch.save(checkpoint, filename)