Skip to content

Commit

Permalink
Update search_markets method to apply the total parameter to all type…
Browse files Browse the repository at this point in the history
…s, add tests (#901)

* Update search_markets method to apply the total parameter to all types, fixes #534

* Add integration tests for searching multiple types in multiple markets

* Update search_markets method to apply the total parameter to all types, add tests

---------

Co-authored-by: Stéphane Bruckert <[email protected]>
  • Loading branch information
rngolam and stephanebruckert authored Mar 15, 2023
1 parent f2d23e2 commit fe438c0
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 13 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Added optional `encoder_cls` argument to `CacheFileHandler`, which overwrite default encoder for token before writing to disk

### Added
- Added optional `encoder_cls` argument to `CacheFileHandler`, which overwrite default encoder for token before writing to disk
- Integration tests for searching multiple types in multiple markets (non-user endpoints)

### Fixed
- Fixed the regex for matching playlist URIs with the format spotify:user:USERNAME:playlist:PLAYLISTID.
- `search_markets` now factors the counts of all types in the `total` rather than just the first type ([#534](https://github.com/spotipy-dev/spotipy/issues/534))

### Removed

Expand Down
31 changes: 20 additions & 11 deletions spotipy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from spotipy.exceptions import SpotifyException

from collections import defaultdict

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -594,12 +596,12 @@ def search_markets(self, q, limit=10, offset=0, type="track", markets=None, tota
official documentation https://developer.spotify.com/documentation/web-api/reference/search/) # noqa
- limit - the number of items to return (min = 1, default = 10, max = 50). If a search is to be done on multiple
markets, then this limit is applied to each market. (e.g. search US, CA, MX each with a limit of 10).
If multiple types are specified, this applies to each type.
- offset - the index of the first item to return
- type - the types of items to return. One or more of 'artist', 'album',
'track', 'playlist', 'show', or 'episode'. If multiple types are desired, pass in a comma separated string.
- markets - A list of ISO 3166-1 alpha-2 country codes. Search all country markets by default.
- total - the total number of results to return if multiple markets are supplied in the search.
If multiple types are specified, this only applies to the first type.
- total - the total number of results to return across multiple markets and types.
"""
warnings.warn(
"Searching multiple markets is an experimental feature. "
Expand Down Expand Up @@ -2005,22 +2007,29 @@ def _search_multiple_markets(self, q, limit, offset, type, markets, total):
UserWarning,
)

results = {}
first_type = type.split(",")[0] + 's'
results = defaultdict(dict)
item_types = [item_type + "s" for item_type in type.split(",")]
count = 0

for country in markets:
result = self._get(
"search", q=q, limit=limit, offset=offset, type=type, market=country
)
results[country] = result
for item_type in item_types:
results[country][item_type] = result[item_type]

# Truncate the items list to the current limit
if len(results[country][item_type]['items']) > limit:
results[country][item_type]['items'] = \
results[country][item_type]['items'][:limit]

count += len(results[country][item_type]['items'])
if total and limit > total - count:
# when approaching `total` results, adjust `limit` to not request more
# items than needed
limit = total - count

count += len(result[first_type]['items'])
if total and count >= total:
break
if total and limit > total - count:
# when approaching `total` results, adjust `limit` to not request more
# items than needed
limit = total - count
return results

return results
81 changes: 81 additions & 0 deletions tests/integration/non_user_endpoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,87 @@ def test_artist_search_with_multiple_markets(self):
total_limited_results += len(results_limited[country]['artists']['items'])
self.assertTrue(total_limited_results <= total)

def test_multiple_types_search_with_multiple_markets(self):
total = 14

countries_list = ['GB', 'US', 'AU']
countries_tuple = ('GB', 'US', 'AU')

results_multiple = self.spotify.search_markets(q='abba', type='artist,track',
markets=countries_list)
results_all = self.spotify.search_markets(q='abba', type='artist,track')
results_tuple = self.spotify.search_markets(q='abba', type='artist,track',
markets=countries_tuple)
results_limited = self.spotify.search_markets(q='abba', limit=3, type='artist,track',
markets=countries_list, total=total)

# Asserts 'artists' property is present in all responses
self.assertTrue(
all('artists' in results_multiple[country] for country in results_multiple))
self.assertTrue(all('artists' in results_all[country] for country in results_all))
self.assertTrue(all('artists' in results_tuple[country] for country in results_tuple))
self.assertTrue(all('artists' in results_limited[country] for country in results_limited))

# Asserts 'tracks' property is present in all responses
self.assertTrue(
all('tracks' in results_multiple[country] for country in results_multiple))
self.assertTrue(all('tracks' in results_all[country] for country in results_all))
self.assertTrue(all('tracks' in results_tuple[country] for country in results_tuple))
self.assertTrue(all('tracks' in results_limited[country] for country in results_limited))

# Asserts 'artists' list is nonempty in unlimited searches
self.assertTrue(
all(len(results_multiple[country]['artists']['items']) > 0 for country in
results_multiple))
self.assertTrue(all(len(results_all[country]['artists']
['items']) > 0 for country in results_all))
self.assertTrue(
all(len(results_tuple[country]['artists']['items']) > 0 for country in results_tuple))

# Asserts 'tracks' list is nonempty in unlimited searches
self.assertTrue(
all(len(results_multiple[country]['tracks']['items']) > 0 for country in
results_multiple))
self.assertTrue(all(len(results_all[country]['tracks']
['items']) > 0 for country in results_all))
self.assertTrue(all(len(results_tuple[country]['tracks']
['items']) > 0 for country in results_tuple))

# Asserts artist name is the first artist result in all searches
self.assertTrue(all(results_multiple[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_multiple))
self.assertTrue(all(results_all[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_all))
self.assertTrue(all(results_tuple[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_tuple))
self.assertTrue(all(results_limited[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_limited))

# Asserts track name is present in responses from specified markets
self.assertTrue(all('Dancing Queen' in
[item['name'] for item in results_multiple[country]['tracks']['items']]
for country in results_multiple))
self.assertTrue(all('Dancing Queen' in
[item['name'] for item in results_tuple[country]['tracks']['items']]
for country in results_tuple))

# Asserts expected number of items are returned based on the total
# 3 artists + 3 tracks = 6 items returned from first market
# 3 artists + 3 tracks = 6 items returned from second market
# 2 artists + 0 tracks = 2 items returned from third market
# 14 items returned total
self.assertEqual(len(results_limited['GB']['artists']['items']), 3)
self.assertEqual(len(results_limited['GB']['tracks']['items']), 3)
self.assertEqual(len(results_limited['US']['artists']['items']), 3)
self.assertEqual(len(results_limited['US']['tracks']['items']), 3)
self.assertEqual(len(results_limited['AU']['artists']['items']), 2)
self.assertEqual(len(results_limited['AU']['tracks']['items']), 0)

item_count = sum([len(market_result['artists']['items']) + len(market_result['tracks']
['items']) for market_result in results_limited.values()])

self.assertEqual(item_count, total)

def test_artist_albums(self):
results = self.spotify.artist_albums(self.weezer_urn)
self.assertTrue('items' in results)
Expand Down

0 comments on commit fe438c0

Please sign in to comment.