Skip to content

Commit

Permalink
first version
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi Huang authored and Zeyi Huang committed Apr 11, 2022
1 parent d723f53 commit 5d950fe
Show file tree
Hide file tree
Showing 258 changed files with 37,654 additions and 0 deletions.
5 changes: 5 additions & 0 deletions DomainBed/CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Code of Conduct

Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.fb.com/codeofconduct/)
so that you can understand what actions will and will not be tolerated.
32 changes: 32 additions & 0 deletions DomainBed/CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Contributing to `DomainBed`
We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
By contributing to `DomainBed`, you agree that your contributions
will be licensed under the LICENSE file in the root directory of this source
tree.
173 changes: 173 additions & 0 deletions DomainBed/ERDG/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import warnings
import bisect
from os.path import join, dirname, exists
from random import random, sample

dataset = {}
dataset["PACS"] = ["art_painting", "cartoon", "photo", "sketch"]

available_datasets = dataset["PACS"]

class ConcatDataset(Dataset):
"""
Dataset to concatenate multiple datasets.
Purpose: useful to assemble different existing datasets, possibly
large-scale datasets as the concatenation operation is done in an
on-the-fly manner.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], dataset_idx

@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes

class MyDataset(Dataset):
def __init__(self, names, labels, img_transformer=None, data_dir='./'):
self.names = names
self.labels = labels
self.data_dir = data_dir

self._image_transformer = img_transformer

def get_image(self, index):
framename = self.data_dir + '/' + self.names[index]
img = Image.open(framename).convert('RGB')
return self._image_transformer(img)

def __getitem__(self, index):

img = self.get_image(index)

return img, int(self.labels[index])

def __len__(self):
return len(self.names)

def get_random_subset(names, labels, percent):
"""
:param names: list of names
:param labels: list of labels
:param percent: 0 < float < 1
:return:
"""
samples = len(names)
amount = int(samples * percent)
random_index = sample(range(samples), amount)
name_val = [names[k] for k in random_index]
name_train = [v for k, v in enumerate(names) if k not in random_index]
labels_val = [labels[k] for k in random_index]
labels_train = [v for k, v in enumerate(labels) if k not in random_index]
return name_train, name_val, labels_train, labels_val

def _dataset_info(txt_labels, num_classes=10000):
with open(txt_labels, 'r') as f:
images_list = f.readlines()

file_names = []
labels = []
for row in images_list:
row = row.split(' ')
if int(row[1]) >= num_classes:
continue
file_names.append(row[0])
labels.append(int(row[1]))

return file_names, labels

def get_split_dataset_info(txt_list, val_percentage):
names, labels = _dataset_info(txt_list)
return get_random_subset(names, labels, val_percentage)

def get_train_dataloader(args):

dataset_list = args.source
assert isinstance(dataset_list, list)
val_datasets = []
img_transformer = get_train_transformers(args)
img_num_per_domain = []
train_loader_list = []
for dname in dataset_list:
name_train, labels_train = _dataset_info(join(args.datalist_dir, args.dataset, '%s_train_kfold.txt' % dname))
name_val, labels_val = _dataset_info(join(args.datalist_dir, args.dataset, '%s_crossval_kfold.txt' % dname))
train_dataset = MyDataset(name_train, labels_train, img_transformer=img_transformer, data_dir=args.data_dir)
val_datasets.append(MyDataset(name_val, labels_val, img_transformer=get_val_transformer(args), data_dir=args.data_dir))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=12, pin_memory=True, drop_last=False)
img_num_per_domain.append(len(name_train))
train_loader_list.append(train_loader)

val_dataset = ConcatDataset(val_datasets)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=12, pin_memory=True, drop_last=False)
return train_loader_list, val_loader, img_num_per_domain

def get_val_dataloader(args):

if isinstance(args.target, list):
img_tr = get_val_transformer(args)
val_datasets = []
for dname in args.target:
names, labels = _dataset_info(join(args.datalist_dir, args.dataset, '%s_test.txt' % dname))
val_datasets.append(MyDataset(names, labels, img_transformer=img_tr, data_dir=args.data_dir))

dataset = ConcatDataset(val_datasets)

else:
names, labels = _dataset_info(join(args.datalist_dir, args.dataset, '%s_test.txt' % args.target))
img_tr = get_val_transformer(args)
val_dataset = MyDataset(names, labels, img_transformer=img_tr, data_dir=args.data_dir)

dataset = ConcatDataset([val_dataset])

loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=12, pin_memory=True, drop_last=False)
return loader

def get_train_transformers(args):
img_tr = [transforms.RandomResizedCrop(int(args.image_size), (args.min_scale, args.max_scale))]
if args.flip > 0.0:
img_tr.append(transforms.RandomHorizontalFlip(args.flip))
if args.jitter > 0.0:
img_tr.append(transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, hue=min(0.5, args.jitter)))

img_tr = img_tr + [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

return transforms.Compose(img_tr)

def get_val_transformer(args):
img_tr = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
return transforms.Compose(img_tr)
Empty file.
102 changes: 102 additions & 0 deletions DomainBed/ERDG/models/aux_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
import torch.nn as nn
from torch.autograd import Function

class GradReverse(torch.autograd.Function):
@staticmethod
def forward(ctx, x, lambd, reverse=True):
ctx.lambd = lambd
ctx.reverse=reverse
return x.view_as(x)

@staticmethod
def backward(ctx, grad_output):
if ctx.reverse:
return (grad_output * -ctx.lambd), None, None
else:
return (grad_output * ctx.lambd), None, None

def grad_reverse(x, lambd=1.0, reverse=True):
return GradReverse.apply(x, lambd, reverse)

class DisNet(nn.Module):
def __init__(self, in_channels, num_domains, layers=[1024, 256]):
super(DisNet, self).__init__()
self.domain_classifier = nn.ModuleList()

# self.domain_classifier.append(nn.Linear(in_channels, layers[0]))
# for i in range(1, len(layers)):
# self.domain_classifier.append(nn.Sequential(
# nn.ReLU(inplace=True),
# nn.Linear(layers[i-1], layers[i])))
# self.domain_classifier.append(nn.ReLU(inplace=True))
# self.domain_classifier.append(nn.Dropout())
# self.domain_classifier.append(nn.Linear(layers[-1], num_domains))
# self.domain_classifier = nn.Sequential(*self.domain_classifier)
self.domain_classifier = nn.Linear(in_channels,num_domains)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, .1)
nn.init.constant_(m.bias, 0.)

self.lambda_ = 0.0

def set_lambda(self, lambda_):
self.lambda_ = lambda_

def forward(self, x):
x = grad_reverse(x, self.lambda_)
return self.domain_classifier(x)

def get_params(self, lr):
return [{"params": self.domain_classifier.parameters(), "lr": lr}]

class ClsNet(nn.Module):
def __init__(self, in_channels, num_domains, num_classes, reverse=True, layers=[1024, 256]):
super(ClsNet, self).__init__()
self.classifier_list = nn.ModuleList()
for _ in range(num_domains):
# class_list = nn.ModuleList()
# class_list.append(nn.Linear(in_channels, layers[0]))
# for i in range(1, len(layers)):
# class_list.append(nn.Sequential(
# nn.ReLU(inplace=True),
# nn.Linear(layers[i-1], layers[i])
# ))
# class_list.append(nn.ReLU(inplace=True))
# class_list.append(nn.Dropout())
# class_list.append(nn.Linear(layers[-1], num_classes))
# self.classifier_list.append(nn.Sequential(*class_list))
self.classifier_list.append(nn.Linear(in_channels,num_classes))
for m in self.classifier_list.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, .1)
nn.init.constant_(m.bias, 0.)

self.lambda_ = 0
self.reverse = reverse

def set_lambda(self, lambda_):
self.lambda_ = lambda_

def forward(self, x):
output = []
for c, x_ in zip(self.classifier_list, x):
if len(x_) == 0:
output.append(None)
else:
x_ = grad_reverse(x_, self.lambda_, self.reverse)
output.append(c(x_))

return output

def get_params(self, lr):
return [{"params": self.classifier_list.parameters(), "lr": lr}]

def aux_Models(in_channels, num_domains, num_classes, layers_dis=[], layers_cls=[]):

dis_model = DisNet(in_channels, num_domains, layers_dis)
c_model = ClsNet(in_channels, num_domains, num_classes, reverse=False, layers=layers_cls)
cp_model = ClsNet(in_channels, num_domains, num_classes, reverse=True, layers=layers_cls)

return dis_model, c_model, cp_model
18 changes: 18 additions & 0 deletions DomainBed/ERDG/models/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from domainbed.ERDG.models import resnet

nets_map = {
'resnet18': resnet.resnet18,
'resnet50': resnet.resnet50,
}


def get_network(name):
if name not in nets_map:
raise ValueError('Name of network unknown %s' % name)

def get_network_fn(**kwargs):
return nets_map[name](**kwargs)

return get_network_fn

print(get_network('resnet18'))
Loading

0 comments on commit 5d950fe

Please sign in to comment.