Skip to content

Commit 9acca61

Browse files
committed
For #28441: refactored to clean up duplication in flag checks
1 parent ba83707 commit 9acca61

File tree

2 files changed

+55
-54
lines changed

2 files changed

+55
-54
lines changed

shotgun_api3/shotgun.py

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(self, host, meta):
134134
self.is_dev = False
135135

136136
self.version = tuple(self.version[:3])
137-
self._ensure_json_supported()
137+
self.ensure_json_supported()
138138

139139

140140
def _ensure_support(self, feature):
@@ -151,25 +151,28 @@ def _ensure_support(self, feature):
151151
"%s requires server version %s or higher, "\
152152
"server is %s" % (feature['label'], _version_str(feature['version']), _version_str(self.version))
153153
)
154+
return False
155+
else:
156+
return True
154157

155158

156-
def _ensure_json_supported(self):
159+
def ensure_json_supported(self):
157160
"""Wrapper for ensure_support"""
158-
self._ensure_support({
161+
return self._ensure_support({
159162
'version': (2, 4, 0),
160163
'label': 'JSON API'
161164
})
162165

163166
def ensure_include_archived_projects(self):
164167
"""Wrapper for ensure_support"""
165-
self._ensure_support({
168+
return self._ensure_support({
166169
'version': (5, 3, 14),
167170
'label': 'include_archived_projects parameter'
168171
})
169172

170173
def ensure_include_template_projects(self):
171174
"""Wrapper for ensure_support"""
172-
self._ensure_support({
175+
return self._ensure_support({
173176
'version': (6, 0, 0),
174177
'label': 'include_template_projects parameter'
175178
})
@@ -514,7 +517,7 @@ def find_one(self, entity_type, filters, fields=None, order=None,
514517
:param include_template_projects: Optional, flag to include entities
515518
belonging to template projects. Default: False
516519
517-
:returns: Result
520+
:returns: dict of requested entity's fields, or None if not found.
518521
"""
519522

520523
results = self.find(entity_type, filters, fields, order,
@@ -577,22 +580,13 @@ def find(self, entity_type, filters, fields=None, order=None,
577580
raise ShotgunError("Deprecated: Use of filter_operator for find()"
578581
" is not valid any more. See the documentation on find()")
579582

580-
if not include_archived_projects:
581-
# This defaults to True on the server (no argument is sent)
582-
# So we only need to check the server version if it is False
583-
self.server_caps.ensure_include_archived_projects()
584-
585-
if include_template_projects:
586-
# This defaults to False on the server (no argument is sent)
587-
# So we only need to check the server version if it is True
588-
self.server_caps.ensure_include_template_projects()
589-
590-
591583
params = self._construct_read_parameters(entity_type,
592584
fields,
593585
filters,
594586
retired_only,
595-
order,
587+
order)
588+
589+
params = self._construct_flag_parameters(params,
596590
include_archived_projects,
597591
include_template_projects)
598592

@@ -631,31 +625,24 @@ def find(self, entity_type, filters, fields=None, order=None,
631625
return self._parse_records(records)
632626

633627

634-
635628
def _construct_read_parameters(self,
636629
entity_type,
637630
fields,
638631
filters,
639632
retired_only,
640-
order,
641-
include_archived_projects,
642-
include_template_projects):
643-
params = {}
644-
params["type"] = entity_type
645-
params["return_fields"] = fields or ["id"]
646-
params["filters"] = filters
647-
params["return_only"] = (retired_only and 'retired') or "active"
648-
params["return_paging_info"] = True
649-
params["paging"] = { "entities_per_page": self.config.records_per_page,
650-
"current_page": 1 }
651-
652-
if include_archived_projects is False:
653-
# Defaults to True on the server, so only pass it if it's False
654-
params["include_archived_projects"] = False
633+
order):
655634

656-
if include_template_projects is True:
657-
# Defaults to False on the server, so only pass it if it's True
658-
params["include_template_projects"] = True
635+
params = {
636+
"type": entity_type,
637+
"return_fields": fields or ["id"],
638+
"filters": filters,
639+
"return_only": (retired_only and 'retired') or "active",
640+
"return_paging_info": True,
641+
"paging": {
642+
"entities_per_page": self.config.records_per_page,
643+
"current_page": 1
644+
}
645+
}
659646

660647
if order:
661648
sort_list = []
@@ -669,8 +656,32 @@ def _construct_read_parameters(self,
669656
'direction' : sort['direction']
670657
})
671658
params['sorts'] = sort_list
659+
660+
return params
661+
662+
663+
def _construct_flag_parameters(self,
664+
params,
665+
include_archived_projects,
666+
include_template_projects):
667+
668+
if not include_archived_projects:
669+
# This defaults to True on the server (no argument is sent)
670+
# So we only need to check the server version if it's False
671+
self.server_caps.ensure_include_archived_projects()
672+
# Only pass it if it's False
673+
params["include_archived_projects"] = False
674+
675+
if include_template_projects:
676+
# This defaults to False on the server (no argument is sent)
677+
# So we only need to check the server version if it's True
678+
self.server_caps.ensure_include_template_projects()
679+
# Only pass it if it's True
680+
params["include_template_projects"] = True
681+
672682
return params
673683

684+
674685
def summarize(self,
675686
entity_type,
676687
filters,
@@ -695,19 +706,9 @@ def summarize(self,
695706
"summaries": summary_fields,
696707
"filters": filters}
697708

698-
if not include_archived_projects:
699-
# This defaults to True on the server (no argument is sent)
700-
# So we only need to check the server version if it is False
701-
self.server_caps.ensure_include_archived_projects()
702-
# Only pass it if it's False
703-
params["include_archived_projects"] = False
704-
705-
if include_template_projects:
706-
# This defaults to False on the server (no argument is sent)
707-
# So we only need to check the server version if it is True
708-
self.server_caps.ensure_include_template_projects()
709-
# Only pass it if it's True
710-
params["include_template_projects"] = True
709+
params = self._construct_flag_parameters(params,
710+
include_archived_projects,
711+
include_template_projects)
711712

712713
if grouping != None:
713714
params['grouping'] = grouping

tests/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,18 @@ def test_server_version_json(self):
7474
sc = ServerCapabilities("foo", {"version" : (2,4,0)})
7575

7676
sc.version = (2,3,99)
77-
self.assertRaises(api.ShotgunError, sc._ensure_json_supported)
77+
self.assertRaises(api.ShotgunError, sc.ensure_json_supported)
7878
self.assertRaises(api.ShotgunError, ServerCapabilities, "foo",
7979
{"version" : (2,2,0)})
8080

8181
sc.version = (0,0,0)
82-
self.assertRaises(api.ShotgunError, sc._ensure_json_supported)
82+
self.assertRaises(api.ShotgunError, sc.ensure_json_supported)
8383

8484
sc.version = (2,4,0)
85-
sc._ensure_json_supported()
85+
sc.ensure_json_supported()
8686

8787
sc.version = (2,5,0)
88-
sc._ensure_json_supported()
88+
sc.ensure_json_supported()
8989

9090

9191
def test_session_uuid(self):

0 commit comments

Comments
 (0)