-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zeyi Huang
authored and
Zeyi Huang
committed
Apr 11, 2022
1 parent
d723f53
commit 5d950fe
Showing
258 changed files
with
37,654 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Code of Conduct | ||
|
||
Facebook has adopted a Code of Conduct that we expect project participants to adhere to. | ||
Please read the [full text](https://code.fb.com/codeofconduct/) | ||
so that you can understand what actions will and will not be tolerated. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Contributing to `DomainBed` | ||
We want to make contributing to this project as easy and transparent as | ||
possible. | ||
|
||
## Pull Requests | ||
We actively welcome your pull requests. | ||
|
||
1. Fork the repo and create your branch from `master`. | ||
2. If you've added code that should be tested, add tests. | ||
3. If you've changed APIs, update the documentation. | ||
4. Ensure the test suite passes. | ||
5. Make sure your code lints. | ||
6. If you haven't already, complete the Contributor License Agreement ("CLA"). | ||
|
||
## Contributor License Agreement ("CLA") | ||
In order to accept your pull request, we need you to submit a CLA. You only need | ||
to do this once to work on any of Facebook's open source projects. | ||
|
||
Complete your CLA here: <https://code.facebook.com/cla> | ||
|
||
## Issues | ||
We use GitHub issues to track public bugs. Please ensure your description is | ||
clear and has sufficient instructions to be able to reproduce the issue. | ||
|
||
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe | ||
disclosure of security bugs. In those cases, please go through the process | ||
outlined on that page and do not file a public issue. | ||
|
||
## License | ||
By contributing to `DomainBed`, you agree that your contributions | ||
will be licensed under the LICENSE file in the root directory of this source | ||
tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import torch | ||
import torchvision | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms | ||
from PIL import Image | ||
import warnings | ||
import bisect | ||
from os.path import join, dirname, exists | ||
from random import random, sample | ||
|
||
dataset = {} | ||
dataset["PACS"] = ["art_painting", "cartoon", "photo", "sketch"] | ||
|
||
available_datasets = dataset["PACS"] | ||
|
||
class ConcatDataset(Dataset): | ||
""" | ||
Dataset to concatenate multiple datasets. | ||
Purpose: useful to assemble different existing datasets, possibly | ||
large-scale datasets as the concatenation operation is done in an | ||
on-the-fly manner. | ||
Arguments: | ||
datasets (sequence): List of datasets to be concatenated | ||
""" | ||
|
||
@staticmethod | ||
def cumsum(sequence): | ||
r, s = [], 0 | ||
for e in sequence: | ||
l = len(e) | ||
r.append(l + s) | ||
s += l | ||
return r | ||
|
||
def __init__(self, datasets): | ||
super(ConcatDataset, self).__init__() | ||
assert len(datasets) > 0, 'datasets should not be an empty iterable' | ||
self.datasets = list(datasets) | ||
self.cumulative_sizes = self.cumsum(self.datasets) | ||
|
||
def __len__(self): | ||
return self.cumulative_sizes[-1] | ||
|
||
def __getitem__(self, idx): | ||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | ||
if dataset_idx == 0: | ||
sample_idx = idx | ||
else: | ||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | ||
return self.datasets[dataset_idx][sample_idx], dataset_idx | ||
|
||
@property | ||
def cummulative_sizes(self): | ||
warnings.warn("cummulative_sizes attribute is renamed to " | ||
"cumulative_sizes", DeprecationWarning, stacklevel=2) | ||
return self.cumulative_sizes | ||
|
||
class MyDataset(Dataset): | ||
def __init__(self, names, labels, img_transformer=None, data_dir='./'): | ||
self.names = names | ||
self.labels = labels | ||
self.data_dir = data_dir | ||
|
||
self._image_transformer = img_transformer | ||
|
||
def get_image(self, index): | ||
framename = self.data_dir + '/' + self.names[index] | ||
img = Image.open(framename).convert('RGB') | ||
return self._image_transformer(img) | ||
|
||
def __getitem__(self, index): | ||
|
||
img = self.get_image(index) | ||
|
||
return img, int(self.labels[index]) | ||
|
||
def __len__(self): | ||
return len(self.names) | ||
|
||
def get_random_subset(names, labels, percent): | ||
""" | ||
:param names: list of names | ||
:param labels: list of labels | ||
:param percent: 0 < float < 1 | ||
:return: | ||
""" | ||
samples = len(names) | ||
amount = int(samples * percent) | ||
random_index = sample(range(samples), amount) | ||
name_val = [names[k] for k in random_index] | ||
name_train = [v for k, v in enumerate(names) if k not in random_index] | ||
labels_val = [labels[k] for k in random_index] | ||
labels_train = [v for k, v in enumerate(labels) if k not in random_index] | ||
return name_train, name_val, labels_train, labels_val | ||
|
||
def _dataset_info(txt_labels, num_classes=10000): | ||
with open(txt_labels, 'r') as f: | ||
images_list = f.readlines() | ||
|
||
file_names = [] | ||
labels = [] | ||
for row in images_list: | ||
row = row.split(' ') | ||
if int(row[1]) >= num_classes: | ||
continue | ||
file_names.append(row[0]) | ||
labels.append(int(row[1])) | ||
|
||
return file_names, labels | ||
|
||
def get_split_dataset_info(txt_list, val_percentage): | ||
names, labels = _dataset_info(txt_list) | ||
return get_random_subset(names, labels, val_percentage) | ||
|
||
def get_train_dataloader(args): | ||
|
||
dataset_list = args.source | ||
assert isinstance(dataset_list, list) | ||
val_datasets = [] | ||
img_transformer = get_train_transformers(args) | ||
img_num_per_domain = [] | ||
train_loader_list = [] | ||
for dname in dataset_list: | ||
name_train, labels_train = _dataset_info(join(args.datalist_dir, args.dataset, '%s_train_kfold.txt' % dname)) | ||
name_val, labels_val = _dataset_info(join(args.datalist_dir, args.dataset, '%s_crossval_kfold.txt' % dname)) | ||
train_dataset = MyDataset(name_train, labels_train, img_transformer=img_transformer, data_dir=args.data_dir) | ||
val_datasets.append(MyDataset(name_val, labels_val, img_transformer=get_val_transformer(args), data_dir=args.data_dir)) | ||
|
||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=12, pin_memory=True, drop_last=False) | ||
img_num_per_domain.append(len(name_train)) | ||
train_loader_list.append(train_loader) | ||
|
||
val_dataset = ConcatDataset(val_datasets) | ||
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=12, pin_memory=True, drop_last=False) | ||
return train_loader_list, val_loader, img_num_per_domain | ||
|
||
def get_val_dataloader(args): | ||
|
||
if isinstance(args.target, list): | ||
img_tr = get_val_transformer(args) | ||
val_datasets = [] | ||
for dname in args.target: | ||
names, labels = _dataset_info(join(args.datalist_dir, args.dataset, '%s_test.txt' % dname)) | ||
val_datasets.append(MyDataset(names, labels, img_transformer=img_tr, data_dir=args.data_dir)) | ||
|
||
dataset = ConcatDataset(val_datasets) | ||
|
||
else: | ||
names, labels = _dataset_info(join(args.datalist_dir, args.dataset, '%s_test.txt' % args.target)) | ||
img_tr = get_val_transformer(args) | ||
val_dataset = MyDataset(names, labels, img_transformer=img_tr, data_dir=args.data_dir) | ||
|
||
dataset = ConcatDataset([val_dataset]) | ||
|
||
loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=12, pin_memory=True, drop_last=False) | ||
return loader | ||
|
||
def get_train_transformers(args): | ||
img_tr = [transforms.RandomResizedCrop(int(args.image_size), (args.min_scale, args.max_scale))] | ||
if args.flip > 0.0: | ||
img_tr.append(transforms.RandomHorizontalFlip(args.flip)) | ||
if args.jitter > 0.0: | ||
img_tr.append(transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, hue=min(0.5, args.jitter))) | ||
|
||
img_tr = img_tr + [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] | ||
|
||
return transforms.Compose(img_tr) | ||
|
||
def get_val_transformer(args): | ||
img_tr = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] | ||
return transforms.Compose(img_tr) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Function | ||
|
||
class GradReverse(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x, lambd, reverse=True): | ||
ctx.lambd = lambd | ||
ctx.reverse=reverse | ||
return x.view_as(x) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
if ctx.reverse: | ||
return (grad_output * -ctx.lambd), None, None | ||
else: | ||
return (grad_output * ctx.lambd), None, None | ||
|
||
def grad_reverse(x, lambd=1.0, reverse=True): | ||
return GradReverse.apply(x, lambd, reverse) | ||
|
||
class DisNet(nn.Module): | ||
def __init__(self, in_channels, num_domains, layers=[1024, 256]): | ||
super(DisNet, self).__init__() | ||
self.domain_classifier = nn.ModuleList() | ||
|
||
# self.domain_classifier.append(nn.Linear(in_channels, layers[0])) | ||
# for i in range(1, len(layers)): | ||
# self.domain_classifier.append(nn.Sequential( | ||
# nn.ReLU(inplace=True), | ||
# nn.Linear(layers[i-1], layers[i]))) | ||
# self.domain_classifier.append(nn.ReLU(inplace=True)) | ||
# self.domain_classifier.append(nn.Dropout()) | ||
# self.domain_classifier.append(nn.Linear(layers[-1], num_domains)) | ||
# self.domain_classifier = nn.Sequential(*self.domain_classifier) | ||
self.domain_classifier = nn.Linear(in_channels,num_domains) | ||
for m in self.modules(): | ||
if isinstance(m, nn.Linear): | ||
nn.init.xavier_uniform_(m.weight, .1) | ||
nn.init.constant_(m.bias, 0.) | ||
|
||
self.lambda_ = 0.0 | ||
|
||
def set_lambda(self, lambda_): | ||
self.lambda_ = lambda_ | ||
|
||
def forward(self, x): | ||
x = grad_reverse(x, self.lambda_) | ||
return self.domain_classifier(x) | ||
|
||
def get_params(self, lr): | ||
return [{"params": self.domain_classifier.parameters(), "lr": lr}] | ||
|
||
class ClsNet(nn.Module): | ||
def __init__(self, in_channels, num_domains, num_classes, reverse=True, layers=[1024, 256]): | ||
super(ClsNet, self).__init__() | ||
self.classifier_list = nn.ModuleList() | ||
for _ in range(num_domains): | ||
# class_list = nn.ModuleList() | ||
# class_list.append(nn.Linear(in_channels, layers[0])) | ||
# for i in range(1, len(layers)): | ||
# class_list.append(nn.Sequential( | ||
# nn.ReLU(inplace=True), | ||
# nn.Linear(layers[i-1], layers[i]) | ||
# )) | ||
# class_list.append(nn.ReLU(inplace=True)) | ||
# class_list.append(nn.Dropout()) | ||
# class_list.append(nn.Linear(layers[-1], num_classes)) | ||
# self.classifier_list.append(nn.Sequential(*class_list)) | ||
self.classifier_list.append(nn.Linear(in_channels,num_classes)) | ||
for m in self.classifier_list.modules(): | ||
if isinstance(m, nn.Linear): | ||
nn.init.xavier_uniform_(m.weight, .1) | ||
nn.init.constant_(m.bias, 0.) | ||
|
||
self.lambda_ = 0 | ||
self.reverse = reverse | ||
|
||
def set_lambda(self, lambda_): | ||
self.lambda_ = lambda_ | ||
|
||
def forward(self, x): | ||
output = [] | ||
for c, x_ in zip(self.classifier_list, x): | ||
if len(x_) == 0: | ||
output.append(None) | ||
else: | ||
x_ = grad_reverse(x_, self.lambda_, self.reverse) | ||
output.append(c(x_)) | ||
|
||
return output | ||
|
||
def get_params(self, lr): | ||
return [{"params": self.classifier_list.parameters(), "lr": lr}] | ||
|
||
def aux_Models(in_channels, num_domains, num_classes, layers_dis=[], layers_cls=[]): | ||
|
||
dis_model = DisNet(in_channels, num_domains, layers_dis) | ||
c_model = ClsNet(in_channels, num_domains, num_classes, reverse=False, layers=layers_cls) | ||
cp_model = ClsNet(in_channels, num_domains, num_classes, reverse=True, layers=layers_cls) | ||
|
||
return dis_model, c_model, cp_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from domainbed.ERDG.models import resnet | ||
|
||
nets_map = { | ||
'resnet18': resnet.resnet18, | ||
'resnet50': resnet.resnet50, | ||
} | ||
|
||
|
||
def get_network(name): | ||
if name not in nets_map: | ||
raise ValueError('Name of network unknown %s' % name) | ||
|
||
def get_network_fn(**kwargs): | ||
return nets_map[name](**kwargs) | ||
|
||
return get_network_fn | ||
|
||
print(get_network('resnet18')) |
Oops, something went wrong.