File tree Expand file tree Collapse file tree 2 files changed +12
-5
lines changed
tensorflow_privacy/privacy/dp_query Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -268,10 +268,11 @@ def derive_metrics(self, global_state):
268268
269269def _zeros_like (arg ):
270270 """A `zeros_like` function that also works for `tf.TensorSpec`s."""
271- try :
272- arg = tf .convert_to_tensor (value = arg )
273- except TypeError :
274- pass
271+ if not isinstance (arg , tf .TensorSpec ):
272+ try :
273+ arg = tf .convert_to_tensor (value = arg )
274+ except TypeError :
275+ pass
275276 return tf .zeros (arg .shape , arg .dtype )
276277
277278
Original file line number Diff line number Diff line change 1818
1919import distutils
2020import math
21- from typing import Optional
21+ from typing import Any , Optional
2222
2323import attr
2424import dp_accounting
@@ -136,6 +136,12 @@ def initial_global_state(self):
136136 arity = self ._arity ,
137137 inner_query_state = self ._inner_query .initial_global_state ())
138138
139+ def initial_sample_state (self , template : Optional [Any ] = None ):
140+ """Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
141+ unprocessed_sample_state = super ().initial_sample_state (template )
142+ sample_params = self .derive_sample_params (self .initial_global_state ())
143+ return self .preprocess_record (sample_params , unprocessed_sample_state )
144+
139145 def derive_sample_params (self , global_state ):
140146 """Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
141147 return (global_state .arity ,
You can’t perform that action at this time.
0 commit comments