Skip to content

Commit 69467f2

Browse files
committed
Add the implementation of SeFa and the interface.
1 parent 057f3a8 commit 69467f2

14 files changed

+4981
-2
lines changed

README.md

+32-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,38 @@ In this repository, we propose a *closed-form* approach, termed as **SeFa**, for
3030
| Orientation | Vertical Position | Shape
3131
| ![image](./docs/assets/stylegan_car_orientation.gif) | ![image](./docs/assets/stylegan_car_vertical_position.gif) | ![image](./docs/assets/stylegan_car_shape.gif)
3232

33+
## Semantic Discovery
34+
35+
It is very simple to interpret a particular model with
36+
37+
```bash
38+
MODEL_NAME=stylegan_animeface512
39+
LAYER_IDX=0-1
40+
NUM_SAMPLES=5
41+
NUM_SEMANTICS=5
42+
python sefa.py ${MODEL_NAME} \
43+
-L ${LAYER_IDX} \
44+
-N ${NUM_SAMPLES} \
45+
-K ${NUM_SEMANTICS}
46+
```
47+
48+
After the program finishes, there will be two visualization pages in the directory `results`.
49+
50+
**NOTE:** The pre-trained models are borrowed from the [genforce](https://github.com/genforce/genforce) repository.
51+
52+
## Interface
53+
54+
We also provide an interface for interactive editing based on [StreamLit](https://www.streamlit.io/). This interface can be locally launched with
55+
56+
```bash
57+
pip install streamlit
58+
CUDA_VISIBLE_DEVICES=0 streamlit run interface.py
59+
```
60+
61+
After the interface is launched, users can play with it via a browser.
62+
63+
**NOTE:** We have prepared some latent codes in the directory `latent_codes` to ensure the synthesis quality, which is completely determined by the pre-trained generator. Users can simply skip these prepared codes by clicking the `Random` button.
64+
3365
## BibTeX
3466

3567
```bibtex
@@ -40,5 +72,3 @@ In this repository, we propose a *closed-form* approach, termed as **SeFa**, for
4072
year = {2020}
4173
}
4274
```
43-
44-
## Code Coming Soon

SessionState.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Adds pre-session state to StreamLit.
2+
3+
This file is borrowed from
4+
https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
5+
"""
6+
7+
# pylint: disable=protected-access
8+
9+
try:
10+
import streamlit.ReportThread as ReportThread
11+
from streamlit.server.Server import Server
12+
except ModuleNotFoundError:
13+
# Streamlit >= 0.65.0
14+
import streamlit.report_thread as ReportThread
15+
from streamlit.server.server import Server
16+
17+
18+
class SessionState(object):
19+
"""Hack to add per-session state to Streamlit.
20+
21+
Usage
22+
-----
23+
24+
>>> import SessionState
25+
>>>
26+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
27+
>>> session_state.user_name
28+
''
29+
>>> session_state.user_name = 'Mary'
30+
>>> session_state.favorite_color
31+
'black'
32+
33+
Since you set user_name above, next time your script runs this will be the
34+
result:
35+
>>> session_state = get(user_name='', favorite_color='black')
36+
>>> session_state.user_name
37+
'Mary'
38+
39+
"""
40+
41+
def __init__(self, **kwargs):
42+
"""A new SessionState object.
43+
44+
Parameters
45+
----------
46+
**kwargs : any
47+
Default values for the session state.
48+
49+
Example
50+
-------
51+
>>> session_state = SessionState(user_name='', favorite_color='black')
52+
>>> session_state.user_name = 'Mary'
53+
''
54+
>>> session_state.favorite_color
55+
'black'
56+
57+
"""
58+
for key, val in kwargs.items():
59+
setattr(self, key, val)
60+
61+
62+
def get(**kwargs):
63+
"""Gets a SessionState object for the current session.
64+
65+
Creates a new object if necessary.
66+
67+
Parameters
68+
----------
69+
**kwargs : any
70+
Default values you want to add to the session state, if we're creating a
71+
new one.
72+
73+
Example
74+
-------
75+
>>> session_state = get(user_name='', favorite_color='black')
76+
>>> session_state.user_name
77+
''
78+
>>> session_state.user_name = 'Mary'
79+
>>> session_state.favorite_color
80+
'black'
81+
82+
Since you set user_name above, next time your script runs this will be the
83+
result:
84+
>>> session_state = get(user_name='', favorite_color='black')
85+
>>> session_state.user_name
86+
'Mary'
87+
88+
"""
89+
# Hack to get the session object from Streamlit.
90+
91+
ctx = ReportThread.get_report_ctx()
92+
93+
this_session = None
94+
95+
current_server = Server.get_current()
96+
if hasattr(current_server, '_session_infos'):
97+
# Streamlit < 0.56
98+
session_infos = Server.get_current()._session_infos.values()
99+
else:
100+
session_infos = Server.get_current()._session_info_by_id.values()
101+
102+
for session_info in session_infos:
103+
s = session_info.session
104+
if (
105+
# Streamlit < 0.54.0
106+
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
107+
or
108+
# Streamlit >= 0.54.0
109+
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
110+
or
111+
# Streamlit >= 0.65.2
112+
(not hasattr(s, '_main_dg') and
113+
s._uploaded_file_mgr == ctx.uploaded_file_mgr)
114+
):
115+
this_session = s
116+
117+
if this_session is None:
118+
raise RuntimeError(
119+
"Oh noes. Couldn't get your Streamlit Session object. "
120+
'Are you doing something fancy with threads?')
121+
122+
# Got the session object! Now let's attach some state into it.
123+
124+
if not hasattr(this_session, '_custom_session_state'):
125+
this_session._custom_session_state = SessionState(**kwargs)
126+
127+
return this_session._custom_session_state
128+
129+
# pylint: enable=protected-access

interface.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# python 3.7
2+
"""Demo."""
3+
4+
import numpy as np
5+
import torch
6+
import streamlit as st
7+
import SessionState
8+
9+
from models import parse_gan_type
10+
from utils import to_tensor
11+
from utils import postprocess
12+
from utils import load_generator
13+
from utils import factorize_weight
14+
15+
16+
@st.cache(allow_output_mutation=True, show_spinner=False)
17+
def get_model(model_name):
18+
"""Gets model by name."""
19+
return load_generator(model_name)
20+
21+
22+
@st.cache(allow_output_mutation=True, show_spinner=False)
23+
def factorize_model(model, layer_idx):
24+
"""Factorizes semantics from target layers of the given model."""
25+
return factorize_weight(model, layer_idx)
26+
27+
28+
def sample(model, gan_type, num=1):
29+
"""Samples latent codes."""
30+
codes = torch.randn(num, model.z_space_dim).cuda()
31+
if gan_type == 'pggan':
32+
codes = model.layer0.pixel_norm(codes)
33+
elif gan_type == 'stylegan':
34+
codes = model.mapping(codes)['w']
35+
codes = model.truncation(codes,
36+
trunc_psi=0.7,
37+
trunc_layers=8)
38+
elif gan_type == 'stylegan2':
39+
codes = model.mapping(codes)['w']
40+
codes = model.truncation(codes,
41+
trunc_psi=0.5,
42+
trunc_layers=18)
43+
codes = codes.detach().cpu().numpy()
44+
return codes
45+
46+
47+
@st.cache(allow_output_mutation=True, show_spinner=False)
48+
def synthesize(model, gan_type, code):
49+
"""Synthesizes an image with the give code."""
50+
if gan_type == 'pggan':
51+
image = model(to_tensor(code))['image']
52+
elif gan_type in ['stylegan', 'stylegan2']:
53+
image = model.synthesis(to_tensor(code))['image']
54+
image = postprocess(image)[0]
55+
return image
56+
57+
58+
def main():
59+
"""Main function (loop for StreamLit)."""
60+
st.title('Closed-Form Factorization of Latent Semantics in GANs')
61+
st.sidebar.title('Options')
62+
reset = st.sidebar.button('Reset')
63+
64+
model_name = st.sidebar.selectbox(
65+
'Model to Interpret',
66+
['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat512',
67+
'pggan_celebahq1024'])
68+
69+
model = get_model(model_name)
70+
gan_type = parse_gan_type(model)
71+
layer_idx = st.sidebar.selectbox(
72+
'Layers to Interpret',
73+
['all', '0-1', '2-5', '6-13'])
74+
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
75+
76+
num_semantics = st.sidebar.number_input(
77+
'Number of semantics', value=10, min_value=0, max_value=None, step=1)
78+
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
79+
if gan_type == 'pggan':
80+
max_step = 5.0
81+
elif gan_type == 'stylegan':
82+
max_step = 2.0
83+
elif gan_type == 'stylegan2':
84+
max_step = 15.0
85+
for sem_idx in steps:
86+
eigen_value = eigen_values[sem_idx]
87+
steps[sem_idx] = st.sidebar.slider(
88+
f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
89+
value=0.0,
90+
min_value=-max_step,
91+
max_value=max_step,
92+
step=0.04 * max_step if not reset else 0.0)
93+
94+
image_placeholder = st.empty()
95+
button_placeholder = st.empty()
96+
97+
try:
98+
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
99+
except FileNotFoundError:
100+
base_codes = sample(model, gan_type)
101+
102+
state = SessionState.get(model_name=model_name,
103+
code_idx=0,
104+
codes=base_codes[0:1])
105+
if state.model_name != model_name:
106+
state.model_name = model_name
107+
state.code_idx = 0
108+
state.codes = base_codes[0:1]
109+
110+
if button_placeholder.button('Random', key=0):
111+
state.code_idx += 1
112+
if state.code_idx < base_codes.shape[0]:
113+
state.codes = base_codes[state.code_idx][np.newaxis]
114+
else:
115+
state.codes = sample(model, gan_type)
116+
117+
code = state.codes.copy()
118+
for sem_idx, step in steps.items():
119+
if gan_type == 'pggan':
120+
code += boundaries[sem_idx:sem_idx + 1] * step
121+
elif gan_type in ['stylegan', 'stylegan2']:
122+
code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
123+
image = synthesize(model, gan_type, code)
124+
image_placeholder.image(image / 255.0)
125+
126+
127+
if __name__ == '__main__':
128+
main()

0 commit comments

Comments
 (0)