12
12
from boto .exception import BotoServerError
13
13
from botocore .exceptions import ClientError
14
14
from boto .pyami .config import Config
15
- import botocore
16
- from typing import Tuple
15
+ from typing import Tuple , TYPE_CHECKING , Iterable , Any , Optional
16
+
17
+ if TYPE_CHECKING :
18
+ import mypy_boto3_rds
17
19
18
20
19
21
def fetch_aws_secret_key (access_key_id ) -> Tuple [str , str ]:
@@ -119,16 +121,33 @@ def connect_vpc(region, access_key_id):
119
121
return conn
120
122
121
123
124
+ def connect_rds_boto3 (region , access_key_id ) -> "mypy_boto3_rds.RDSClient" :
125
+ assert region
126
+ (access_key_id , secret_access_key ) = fetch_aws_secret_key (access_key_id )
127
+ client = boto3 .session .Session ().client (
128
+ "rds" ,
129
+ region_name = region ,
130
+ aws_access_key_id = access_key_id ,
131
+ aws_secret_access_key = secret_access_key ,
132
+ )
133
+ return client
134
+
135
+
122
136
def get_access_key_id ():
123
137
return os .environ .get ("EC2_ACCESS_KEY" ) or os .environ .get ("AWS_ACCESS_KEY_ID" )
124
138
125
139
126
- def retry (f , error_codes = [], logger = None ):
140
+ def retry (
141
+ f , error_codes : Optional [Iterable [Any ]] = None , logger = None , num_retries : int = 7
142
+ ):
127
143
"""
128
144
Retry function f up to 7 times. If error_codes argument is empty list, retry on all EC2 response errors,
129
145
otherwise, only on the specified error codes.
130
146
"""
131
147
148
+ if error_codes is None :
149
+ error_codes = []
150
+
132
151
def handle_exception (e ):
133
152
if hasattr (e , "error_code" ):
134
153
err_code = e .error_code
@@ -137,8 +156,12 @@ def handle_exception(e):
137
156
err_code = e .response ["Error" ]["Code" ]
138
157
err_msg = e .response ["Error" ]["Message" ]
139
158
140
- if i == num_retries or (error_codes != [] and err_code not in error_codes ):
141
- raise e
159
+ if err_code == "RequestLimitExceeded" :
160
+ return False
161
+
162
+ if error_codes and err_code not in error_codes :
163
+ return True
164
+
142
165
if logger is not None :
143
166
logger .log (
144
167
"got (possibly transient) EC2 error code '{0}': {1}. retrying..." .format (
@@ -147,8 +170,9 @@ def handle_exception(e):
147
170
)
148
171
149
172
def handle_boto3_exception (e ):
150
- if i == num_retries :
151
- raise e
173
+ if error_codes and getattr (e , "response" , {}).get ("code" ) not in error_codes :
174
+ return True
175
+
152
176
if logger is not None :
153
177
if hasattr (e , "response" ):
154
178
logger .log (
@@ -157,29 +181,23 @@ def handle_boto3_exception(e):
157
181
)
158
182
)
159
183
184
+ def should_abort (e ):
185
+ if isinstance (e , (SQSError , EC2ResponseError , BotoServerError )):
186
+ return handle_exception (e )
187
+ elif isinstance (e , ClientError ):
188
+ return handle_boto3_exception (e )
189
+
160
190
i = 0
161
- num_retries = 7
162
191
while i <= num_retries :
163
192
i += 1
164
193
next_sleep = 5 + random .random () * (2 ** i )
165
194
166
195
try :
167
196
return f ()
168
- except EC2ResponseError as e :
169
- handle_exception (e )
170
- except SQSError as e :
171
- handle_exception (e )
172
- except ClientError as e :
173
- handle_boto3_exception (e )
174
- except BotoServerError as e :
175
- if e .error_code == "RequestLimitExceeded" :
176
- num_retries += 1
177
- else :
178
- handle_exception (e )
179
- except botocore .exceptions .ClientError as e :
180
- handle_exception (e )
181
197
except Exception as e :
182
- raise e
198
+ if num_retries == i or should_abort (e ):
199
+ raise e
200
+ num_retries += 1
183
201
184
202
time .sleep (next_sleep )
185
203
0 commit comments