Skip to content

Commit d5914e4

Browse files
committed
For #28441: refactored to clean up duplication in flag checks
1 parent e092d83 commit d5914e4

File tree

2 files changed

+51
-54
lines changed

2 files changed

+51
-54
lines changed

shotgun_api3/shotgun.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(self, host, meta):
134134
self.is_dev = False
135135

136136
self.version = tuple(self.version[:3])
137-
self._ensure_json_supported()
137+
self.ensure_json_supported()
138138

139139

140140
def _ensure_support(self, feature, raise_hell=True):
@@ -157,23 +157,23 @@ def _ensure_support(self, feature, raise_hell=True):
157157
return True
158158

159159

160-
def _ensure_json_supported(self):
160+
def ensure_json_supported(self):
161161
"""Wrapper for ensure_support"""
162-
self._ensure_support({
162+
return self._ensure_support({
163163
'version': (2, 4, 0),
164164
'label': 'JSON API'
165165
})
166166

167167
def ensure_include_archived_projects(self):
168168
"""Wrapper for ensure_support"""
169-
self._ensure_support({
169+
return self._ensure_support({
170170
'version': (5, 3, 14),
171171
'label': 'include_archived_projects parameter'
172172
})
173173

174174
def ensure_include_template_projects(self):
175175
"""Wrapper for ensure_support"""
176-
self._ensure_support({
176+
return self._ensure_support({
177177
'version': (6, 0, 0),
178178
'label': 'include_template_projects parameter'
179179
})
@@ -525,7 +525,7 @@ def find_one(self, entity_type, filters, fields=None, order=None,
525525
:param include_template_projects: Optional, flag to include entities
526526
belonging to template projects. Default: False
527527
528-
:returns: Result
528+
:returns: dict of requested entity's fields, or None if not found.
529529
"""
530530

531531
results = self.find(entity_type, filters, fields, order,
@@ -588,22 +588,13 @@ def find(self, entity_type, filters, fields=None, order=None,
588588
raise ShotgunError("Deprecated: Use of filter_operator for find()"
589589
" is not valid any more. See the documentation on find()")
590590

591-
if not include_archived_projects:
592-
# This defaults to True on the server (no argument is sent)
593-
# So we only need to check the server version if it is False
594-
self.server_caps.ensure_include_archived_projects()
595-
596-
if include_template_projects:
597-
# This defaults to False on the server (no argument is sent)
598-
# So we only need to check the server version if it is True
599-
self.server_caps.ensure_include_template_projects()
600-
601-
602591
params = self._construct_read_parameters(entity_type,
603592
fields,
604593
filters,
605594
retired_only,
606-
order,
595+
order)
596+
597+
params = self._construct_flag_parameters(params,
607598
include_archived_projects,
608599
include_template_projects)
609600

@@ -642,31 +633,24 @@ def find(self, entity_type, filters, fields=None, order=None,
642633
return self._parse_records(records)
643634

644635

645-
646636
def _construct_read_parameters(self,
647637
entity_type,
648638
fields,
649639
filters,
650640
retired_only,
651-
order,
652-
include_archived_projects,
653-
include_template_projects):
654-
params = {}
655-
params["type"] = entity_type
656-
params["return_fields"] = fields or ["id"]
657-
params["filters"] = filters
658-
params["return_only"] = (retired_only and 'retired') or "active"
659-
params["return_paging_info"] = True
660-
params["paging"] = { "entities_per_page": self.config.records_per_page,
661-
"current_page": 1 }
662-
663-
if include_archived_projects is False:
664-
# Defaults to True on the server, so only pass it if it's False
665-
params["include_archived_projects"] = False
641+
order):
666642

667-
if include_template_projects is True:
668-
# Defaults to False on the server, so only pass it if it's True
669-
params["include_template_projects"] = True
643+
params = {
644+
"type": entity_type,
645+
"return_fields": fields or ["id"],
646+
"filters": filters,
647+
"return_only": (retired_only and 'retired') or "active",
648+
"return_paging_info": True,
649+
"paging": {
650+
"entities_per_page": self.config.records_per_page,
651+
"current_page": 1
652+
}
653+
}
670654

671655
if order:
672656
sort_list = []
@@ -680,6 +664,29 @@ def _construct_read_parameters(self,
680664
'direction' : sort['direction']
681665
})
682666
params['sorts'] = sort_list
667+
668+
return params
669+
670+
671+
def _construct_flag_parameters(self,
672+
params,
673+
include_archived_projects,
674+
include_template_projects):
675+
676+
if not include_archived_projects:
677+
# This defaults to True on the server (no argument is sent)
678+
# So we only need to check the server version if it's False
679+
self.server_caps.ensure_include_archived_projects()
680+
# Only pass it if it's False
681+
params["include_archived_projects"] = False
682+
683+
if include_template_projects:
684+
# This defaults to False on the server (no argument is sent)
685+
# So we only need to check the server version if it's True
686+
self.server_caps.ensure_include_template_projects()
687+
# Only pass it if it's True
688+
params["include_template_projects"] = True
689+
683690
return params
684691

685692

@@ -715,19 +722,9 @@ def summarize(self,
715722
"summaries": summary_fields,
716723
"filters": filters}
717724

718-
if not include_archived_projects:
719-
# This defaults to True on the server (no argument is sent)
720-
# So we only need to check the server version if it is False
721-
self.server_caps.ensure_include_archived_projects()
722-
# Only pass it if it's False
723-
params["include_archived_projects"] = False
724-
725-
if include_template_projects:
726-
# This defaults to False on the server (no argument is sent)
727-
# So we only need to check the server version if it is True
728-
self.server_caps.ensure_include_template_projects()
729-
# Only pass it if it's True
730-
params["include_template_projects"] = True
725+
params = self._construct_flag_parameters(params,
726+
include_archived_projects,
727+
include_template_projects)
731728

732729
if grouping != None:
733730
params['grouping'] = grouping

tests/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,18 @@ def test_server_version_json(self):
7474
sc = ServerCapabilities("foo", {"version" : (2,4,0)})
7575

7676
sc.version = (2,3,99)
77-
self.assertRaises(api.ShotgunError, sc._ensure_json_supported)
77+
self.assertRaises(api.ShotgunError, sc.ensure_json_supported)
7878
self.assertRaises(api.ShotgunError, ServerCapabilities, "foo",
7979
{"version" : (2,2,0)})
8080

8181
sc.version = (0,0,0)
82-
self.assertRaises(api.ShotgunError, sc._ensure_json_supported)
82+
self.assertRaises(api.ShotgunError, sc.ensure_json_supported)
8383

8484
sc.version = (2,4,0)
85-
sc._ensure_json_supported()
85+
sc.ensure_json_supported()
8686

8787
sc.version = (2,5,0)
88-
sc._ensure_json_supported()
88+
sc.ensure_json_supported()
8989

9090

9191
def test_session_uuid(self):

0 commit comments

Comments
 (0)