Skip to content

Commit 882310f

Browse files
authored
Feat fedavg tf1.15 (#952)
1 parent 423273b commit 882310f

19 files changed

+1152
-2
lines changed

Makefile

+8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ protobuf:
2323
--grpc_python_out=. \
2424
protocols/fedlearner/channel/*.proto
2525

26+
python -m grpc_tools.protoc -I. \
27+
--python_out=. \
28+
fedlearner/fedavg/cluster/cluster.proto
29+
python -m grpc_tools.protoc -I. \
30+
--python_out=. \
31+
--grpc_python_out=. \
32+
fedlearner/fedavg/training_service.proto
33+
2634
lint:
2735
pylint --rcfile ci/pylintrc fedlearner example
2836

deploy/scripts/trainer/run_fedavg.sh

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
set -ex
3+
4+
source /app/deploy/scripts/hdfs_common.sh
5+
source /app/deploy/scripts/env_to_args.sh
6+
7+
if [[ -n "${CODE_KEY}" ]]; then
8+
pull_code ${CODE_KEY} $PWD
9+
fi
10+
11+
if [[ $ROLE == "leader" ]]; then
12+
export FL_LEADER_ADDRESS="0.0.0.0:50051"
13+
elif [[ -n $PEER_ADDR ]]; then
14+
export FL_LEADER_ADDRESS=$PEER_ADDR
15+
fi
16+
17+
python $ROLE.py

deploy/scripts/wait4pair_wrapper.sh

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ while [[ true ]]; do
1212
export PEER_ADDR=`cat ${pair}`
1313
break
1414
else
15+
echo "still waiting for peer addr"
1516
sleep 1
1617
fi
1718
done

example/fedavg/mnist/follower.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
from fedlearner.fedavg import train_from_keras_model
3+
from .model import create_model, x_train, y_train, x_test, y_test
4+
5+
fed_leader_address = os.getenv("FL_LEADER_ADDRESS", "0.0.0.0:6870")
6+
fl_name = "follower"
7+
fl_cluster = {
8+
"leader": {
9+
"name": "leader",
10+
"address": fed_leader_address
11+
},
12+
"followers": [{
13+
"name": "follower"
14+
}]
15+
}
16+
17+
model = create_model()
18+
x = x_train[len(x_train) // 2:]
19+
y = y_train[len(y_train) // 2:]
20+
train_from_keras_model(model,
21+
x,
22+
y,
23+
batch_size=30,
24+
epochs=1,
25+
fl_name=fl_name,
26+
fl_cluster=fl_cluster,
27+
steps_per_sync=10)
28+
29+
model.evaluate(x_test, y_test)

example/fedavg/mnist/leader.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
from fedlearner.fedavg import train_from_keras_model
3+
from .model import create_model, x_train, y_train, x_test, y_test
4+
5+
fed_leader_address = os.getenv("FL_LEADER_ADDRESS", "0.0.0.0:6870")
6+
fl_name = "leader"
7+
fl_cluster = {
8+
"leader": {
9+
"name": "leader",
10+
"address": fed_leader_address
11+
},
12+
"followers": [{
13+
"name": "follower"
14+
}]
15+
}
16+
17+
model = create_model()
18+
x = x_train[:len(x_train) // 2]
19+
y = y_train[:len(y_train) // 2]
20+
train_from_keras_model(model,
21+
x,
22+
y,
23+
batch_size=30,
24+
epochs=1,
25+
fl_name=fl_name,
26+
fl_cluster=fl_cluster,
27+
steps_per_sync=10)
28+
29+
model.evaluate(x_test, y_test)

example/fedavg/mnist/model.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
4+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
5+
x_train = x_train.reshape(x_train.shape[0], -1).astype(np.float32) / 255.0
6+
y_train = y_train.astype(np.int32)
7+
8+
x_test = x_test.reshape(x_test.shape[0], -1).astype(np.float32) / 255.0
9+
y_test = y_test.astype(np.int32)
10+
11+
12+
def create_model():
13+
model = tf.keras.Sequential([
14+
tf.keras.layers.Dense(200, activation='relu', input_shape=(784, )),
15+
tf.keras.layers.Dense(200, activation='relu'),
16+
tf.keras.layers.Dense(10, activation='softmax'),
17+
])
18+
model.compile(optimizer=tf.keras.optimizers.SGD(0.01),
19+
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
20+
metrics=['acc'])
21+
return model

fedlearner/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from fedlearner import data_join
2020
from fedlearner import proxy
2121
from fedlearner import trainer
22+
from fedlearner import fedavg

fedlearner/common/grpc_utils.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import time
2+
import collections
3+
import grpc
4+
from . import fl_logging as logging
5+
6+
7+
class LocalServicerContext(grpc.ServicerContext):
8+
9+
def invocation_metadata(self):
10+
return ()
11+
12+
def peer(self):
13+
return "local"
14+
15+
def peer_identities(self):
16+
return None
17+
18+
def peer_identity_key(self):
19+
return None
20+
21+
def auth_context(self):
22+
return dict()
23+
24+
def set_compression(self, compression):
25+
return grpc.Compression.NoCompression
26+
27+
def send_initial_metadata(self, initial_metadata):
28+
pass
29+
30+
def set_trailing_metadata(self, trailing_metadata):
31+
pass
32+
33+
def abort(self, code, details):
34+
pass
35+
36+
def abort_with_status(self, status):
37+
pass
38+
39+
def set_code(self, code):
40+
pass
41+
42+
def set_details(self, details):
43+
pass
44+
45+
def disable_next_message_compression(self):
46+
pass
47+
48+
def is_active(self):
49+
return True
50+
51+
def time_remaining(self):
52+
return None
53+
54+
def cancel(self):
55+
pass
56+
57+
def add_callback(self, callback):
58+
pass
59+
60+
61+
def call_with_retry(call, max_retry_times=None, retry_interval=1):
62+
retry_times = 0
63+
while True:
64+
try:
65+
retry_times += 1
66+
return call()
67+
except grpc.RpcError as e:
68+
if max_retry_times is None or retry_times < max_retry_times:
69+
logging.warning(
70+
"grpc call error, status: %s"
71+
", details: %s, wait %ds for retry", e.code(), e.details(),
72+
retry_interval)
73+
time.sleep(retry_interval)
74+
else:
75+
raise e
76+
77+
78+
#def remote_insecure_channel(address, options=None, compression=None):
79+
# EGRESS_URL = os.getenv('EGRESS_URL', None)
80+
# EGRESS_HOST = os.environ.get('EGRESS_HOST', None)
81+
# EGRESS_DOMAIN = os.environ.get('EGRESS_DOMAIN', None)
82+
# if not EGRESS_URL:
83+
# return grpc.insecure_channel(address, options, compression)
84+
#
85+
# options = list(options) if options else list()
86+
# default_authority = EGRESS_HOST or address
87+
# options.append(('grpc.default_authority', default_authority))
88+
# channel = grpc.insecure_channel(EGRESS_URL, options, compression)
89+
#
90+
# if EGRESS_DOMAIN:
91+
# address = address + '.' + EGRESS_DOMAIN
92+
# channel = grpc.intercept_channel(
93+
# channel, add_metadata_interceptor({'x-host': address}))
94+
#
95+
# return channel
96+
#
97+
#
98+
#def add_metadata_interceptor(headers):
99+
# if not isinstance(headers, dict):
100+
# raise TypeError("headers must be a dict")
101+
# headers = list(headers.items())
102+
#
103+
# def add_metadata_fn(client_call_details, request_iterator,
104+
# request_streaming, response_streaming):
105+
# metadata = list(client_call_details.metadata or [])
106+
# metadata.extend(headers)
107+
# client_call_details = _ClientCallDetails(
108+
# client_call_details.method, client_call_details.timeout, metadata,
109+
# client_call_details.credentials)
110+
# return client_call_details, request_iterator, None
111+
#
112+
# return _GenericClientInterceptor(add_metadata_fn)
113+
114+
115+
class _ClientCallDetails(
116+
collections.namedtuple(
117+
'_ClientCallDetails',
118+
('method', 'timeout', 'metadata', 'credentials')),
119+
grpc.ClientCallDetails):
120+
pass
121+
122+
123+
class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor,
124+
grpc.UnaryStreamClientInterceptor,
125+
grpc.StreamUnaryClientInterceptor,
126+
grpc.StreamStreamClientInterceptor):
127+
128+
def __init__(self, interceptor_function):
129+
self._fn = interceptor_function
130+
131+
def intercept_unary_unary(self, continuation, client_call_details,
132+
request):
133+
new_details, new_request_iterator, postprocess = self._fn(
134+
client_call_details, iter((request, )), False, False)
135+
response = continuation(new_details, next(new_request_iterator))
136+
return postprocess(response) if postprocess else response
137+
138+
def intercept_unary_stream(self, continuation, client_call_details,
139+
request):
140+
new_details, new_request_iterator, postprocess = self._fn(
141+
client_call_details, iter((request, )), False, True)
142+
response_it = continuation(new_details, next(new_request_iterator))
143+
return postprocess(response_it) if postprocess else response_it
144+
145+
def intercept_stream_unary(self, continuation, client_call_details,
146+
request_iterator):
147+
new_details, new_request_iterator, postprocess = self._fn(
148+
client_call_details, request_iterator, True, False)
149+
response = continuation(new_details, new_request_iterator)
150+
return postprocess(response) if postprocess else response
151+
152+
def intercept_stream_stream(self, continuation, client_call_details,
153+
request_iterator):
154+
new_details, new_request_iterator, postprocess = self._fn(
155+
client_call_details, request_iterator, True, True)
156+
response_it = continuation(new_details, new_request_iterator)
157+
return postprocess(response_it) if postprocess else response_it

fedlearner/fedavg/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
from .fedavg import train_from_keras_model

fedlearner/fedavg/_global_context.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2020 The FedLearner Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import threading
17+
from fedlearner.common import stats
18+
19+
class _GlobalContext:
20+
def __init__(self):
21+
self.job = os.getenv("FL_JOB") \
22+
or os.getenv("APPLICATION_ID") \
23+
or "unknow"
24+
self.task = os.getenv("FL_TASK") \
25+
or "unknow"
26+
self.task_index = os.getenv("FL_TASK_INDEX") \
27+
or os.getenv("INDEX") \
28+
or "0"
29+
self.task_index = int(self.task_index)
30+
31+
self._stats_client = None
32+
33+
self._lock = threading.Lock()
34+
35+
@property
36+
def stats_client(self):
37+
if self._stats_client:
38+
return self._stats_client
39+
40+
with self._lock:
41+
if not self._stats_client:
42+
self._stats_client = stats.with_tags({
43+
"job": self.job,
44+
"task": self.task,
45+
"task_index": self.task_index,
46+
})
47+
48+
return self._stats_client
49+
50+
global_context = _GlobalContext()

fedlearner/fedavg/cluster/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
from .cluster_pb2 import FLNodeDef, FLClusterDef
3+
from .cluster_spec import FLClusterSpec
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
package fedlearner.cluster;
4+
5+
message FLNodeDef {
6+
string name = 1;
7+
string address = 2;
8+
}
9+
10+
message FLClusterDef {
11+
FLNodeDef leader = 1;
12+
repeated FLNodeDef followers = 2;
13+
}

0 commit comments

Comments
 (0)