Skip to content

Commit b783a65

Browse files
authored
feat: adds util function to get available first factors (#610)
- Adds util function to get available first factors - Cleans up supertokens __init__ file to reduce redundancy - Adds test to ensure FactorIds class and method are in sync ref: supertokens/supertokens-node#1021
1 parent 08825fa commit b783a65

File tree

7 files changed

+109
-37
lines changed

7 files changed

+109
-37
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
## [unreleased]
1010

1111
## [0.31.0] - 2025-08-21
12+
- Adds util function to get available first factors
13+
1214
### Adds plugins support
1315
- Adds an `experimental` property (`SuperTokensExperimentalConfig`) to the `SuperTokensConfig`
1416
- Plugins can be configured under using the `plugins` property in the `experimental` config

supertokens_python/__init__.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from typing import Any, Dict, List, Optional
15+
from typing import List, Optional
1616

1717
from typing_extensions import Literal
1818

19-
from supertokens_python.framework.request import BaseRequest
2019
from supertokens_python.recipe_module import RecipeModule
2120
from supertokens_python.types import RecipeUserId
2221

@@ -30,6 +29,8 @@
3029
SupertokensExperimentalConfig,
3130
SupertokensInputConfig,
3231
SupertokensPublicConfig,
32+
get_request_from_user_context,
33+
is_recipe_initialized,
3334
)
3435

3536
# Some Pydantic models need a rebuild to resolve ForwardRefs
@@ -69,19 +70,10 @@ def get_all_cors_headers() -> List[str]:
6970
return Supertokens.get_instance().get_all_cors_headers()
7071

7172

72-
def get_request_from_user_context(
73-
user_context: Optional[Dict[str, Any]],
74-
) -> Optional[BaseRequest]:
75-
return Supertokens.get_instance().get_request_from_user_context(user_context)
76-
77-
7873
def convert_to_recipe_user_id(user_id: str) -> RecipeUserId:
7974
return RecipeUserId(user_id)
8075

8176

82-
is_recipe_initialized = Supertokens.is_recipe_initialized
83-
84-
8577
__all__ = [
8678
"AppInfo",
8779
"InputAppInfo",

supertokens_python/asyncio/__init__.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Any, Dict, List, Optional, Union
1515

1616
from supertokens_python import Supertokens
17+
from supertokens_python.exceptions import BadInputError
1718
from supertokens_python.interfaces import (
1819
CreateUserIdMappingOkResult,
1920
DeleteUserIdMappingOkResult,
@@ -26,8 +27,9 @@
2627
)
2728
from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult
2829
from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe
30+
from supertokens_python.recipe.session.interfaces import SessionContainer
2931
from supertokens_python.types import User
30-
from supertokens_python.types.base import AccountInfoInput
32+
from supertokens_python.types.base import AccountInfoInput, UserContext
3133

3234

3335
async def get_users_oldest_first(
@@ -172,3 +174,31 @@ async def list_users_by_account_info(
172174
do_union_of_account_info,
173175
user_context,
174176
)
177+
178+
179+
async def get_available_first_factors(
180+
tenant_id: str,
181+
session: Optional[SessionContainer],
182+
user_context: Optional[UserContext],
183+
):
184+
from supertokens_python.auth_utils import (
185+
filter_out_invalid_first_factors_or_throw_if_all_are_invalid,
186+
)
187+
from supertokens_python.recipe.multifactorauth.types import FactorIds
188+
189+
available_first_factors: List[str] = []
190+
191+
try:
192+
available_first_factors = (
193+
await filter_out_invalid_first_factors_or_throw_if_all_are_invalid(
194+
factor_ids=FactorIds.get_all_factors(),
195+
tenant_id=tenant_id,
196+
has_session=session is not None,
197+
user_context=user_context if user_context is not None else {},
198+
)
199+
)
200+
except BadInputError:
201+
# All the factors were invalid, so we let it pass through and return the empty list
202+
pass
203+
204+
return available_first_factors

supertokens_python/recipe/multifactorauth/types.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,27 @@ class NormalisedMultiFactorAuthConfig(
6161

6262

6363
class FactorIds:
64-
EMAILPASSWORD: Literal["emailpassword"] = "emailpassword"
65-
OTP_EMAIL: Literal["otp-email"] = "otp-email"
66-
OTP_PHONE: Literal["otp-phone"] = "otp-phone"
67-
LINK_EMAIL: Literal["link-email"] = "link-email"
68-
LINK_PHONE: Literal["link-phone"] = "link-phone"
69-
THIRDPARTY: Literal["thirdparty"] = "thirdparty"
70-
TOTP: Literal["totp"] = "totp"
71-
WEBAUTHN: Literal["webauthn"] = "webauthn"
64+
EMAILPASSWORD = "emailpassword"
65+
OTP_EMAIL = "otp-email"
66+
OTP_PHONE = "otp-phone"
67+
LINK_EMAIL = "link-email"
68+
LINK_PHONE = "link-phone"
69+
THIRDPARTY = "thirdparty"
70+
TOTP = "totp"
71+
WEBAUTHN = "webauthn"
72+
73+
@staticmethod
74+
def get_all_factors():
75+
return [
76+
FactorIds.EMAILPASSWORD,
77+
FactorIds.OTP_EMAIL,
78+
FactorIds.OTP_PHONE,
79+
FactorIds.LINK_EMAIL,
80+
FactorIds.LINK_PHONE,
81+
FactorIds.THIRDPARTY,
82+
FactorIds.TOTP,
83+
FactorIds.WEBAUTHN,
84+
]
7285

7386

7487
class FactorIdsAndType:

supertokens_python/supertokens.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SuperTokensPlugin,
4343
SuperTokensPublicPlugin,
4444
)
45+
from supertokens_python.types.base import UserContext
4546
from supertokens_python.types.response import CamelCaseBaseModel
4647

4748
from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT
@@ -181,13 +182,13 @@ def __init__(
181182
self.mode = mode
182183

183184
def get_top_level_website_domain(
184-
self, request: Optional[BaseRequest], user_context: Dict[str, Any]
185+
self, request: Optional[BaseRequest], user_context: UserContext
185186
) -> str:
186187
return get_top_level_domain_for_same_site_resolution(
187188
self.get_origin(request, user_context).get_as_string_dangerous()
188189
)
189190

190-
def get_origin(self, request: Optional[BaseRequest], user_context: Dict[str, Any]):
191+
def get_origin(self, request: Optional[BaseRequest], user_context: UserContext):
191192
origin = self.__origin
192193
if origin is None:
193194
origin = self.__website_domain
@@ -211,7 +212,7 @@ def defaultImpl(o: Any):
211212

212213

213214
def manage_session_post_response(
214-
session: SessionContainer, response: BaseResponse, user_context: Dict[str, Any]
215+
session: SessionContainer, response: BaseResponse, user_context: UserContext
215216
):
216217
# Something similar happens in handle_error of session/recipe.py
217218
for mutator in session.response_mutators:
@@ -577,7 +578,7 @@ async def get_user_count(
577578
self,
578579
include_recipe_ids: Union[None, List[str]],
579580
tenant_id: Optional[str] = None,
580-
user_context: Optional[Dict[str, Any]] = None,
581+
user_context: Optional[UserContext] = None,
581582
) -> int:
582583
querier = Querier.get_instance(None)
583584
include_recipe_ids_str = None
@@ -601,7 +602,7 @@ async def create_user_id_mapping(
601602
external_user_id: str,
602603
external_user_id_info: Optional[str],
603604
force: Optional[bool],
604-
user_context: Optional[Dict[str, Any]],
605+
user_context: Optional[UserContext],
605606
) -> Union[
606607
CreateUserIdMappingOkResult,
607608
UnknownSupertokensUserIDError,
@@ -641,7 +642,7 @@ async def get_user_id_mapping(
641642
self,
642643
user_id: str,
643644
user_id_type: Optional[UserIDTypes],
644-
user_context: Optional[Dict[str, Any]],
645+
user_context: Optional[UserContext],
645646
) -> Union[GetUserIdMappingOkResult, UnknownMappingError]:
646647
querier = Querier.get_instance(None)
647648

@@ -676,7 +677,7 @@ async def delete_user_id_mapping(
676677
user_id: str,
677678
user_id_type: Optional[UserIDTypes],
678679
force: Optional[bool],
679-
user_context: Optional[Dict[str, Any]],
680+
user_context: Optional[UserContext],
680681
) -> DeleteUserIdMappingOkResult:
681682
querier = Querier.get_instance(None)
682683

@@ -708,7 +709,7 @@ async def update_or_delete_user_id_mapping_info(
708709
user_id: str,
709710
user_id_type: Optional[UserIDTypes],
710711
external_user_id_info: Optional[str],
711-
user_context: Optional[Dict[str, Any]],
712+
user_context: Optional[UserContext],
712713
) -> Union[UpdateOrDeleteUserIdMappingInfoOkResult, UnknownMappingError]:
713714
querier = Querier.get_instance(None)
714715

@@ -734,7 +735,7 @@ async def update_or_delete_user_id_mapping_info(
734735
raise_general_exception("Please upgrade the SuperTokens core to >= 3.15.0")
735736

736737
async def middleware(
737-
self, request: BaseRequest, response: BaseResponse, user_context: Dict[str, Any]
738+
self, request: BaseRequest, response: BaseResponse, user_context: UserContext
738739
) -> Union[BaseResponse, None]:
739740
from supertokens_python.recipe.session.recipe import SessionRecipe
740741

@@ -907,7 +908,7 @@ async def handle_supertokens_error(
907908
request: BaseRequest,
908909
err: Exception,
909910
response: BaseResponse,
910-
user_context: Dict[str, Any],
911+
user_context: UserContext,
911912
) -> Optional[BaseResponse]:
912913
log_debug_message("errorHandler: Started")
913914
log_debug_message(
@@ -935,7 +936,7 @@ async def handle_supertokens_error(
935936

936937
def get_request_from_user_context(
937938
self,
938-
user_context: Optional[Dict[str, Any]] = None,
939+
user_context: Optional[UserContext] = None,
939940
) -> Optional[BaseRequest]:
940941
if user_context is None:
941942
return None
@@ -948,20 +949,22 @@ def get_request_from_user_context(
948949

949950
return user_context.get("_default", {}).get("request")
950951

951-
@staticmethod
952-
def is_recipe_initialized(recipe_id: str) -> bool:
952+
def is_recipe_initialized(self, recipe_id: str) -> bool:
953953
"""
954954
Check if a recipe is initialized.
955955
:param recipe_id: The ID of the recipe to check.
956956
:return: Whether the recipe is initialized.
957957
"""
958958
return any(
959-
recipe.get_recipe_id() == recipe_id
960-
for recipe in Supertokens.get_instance().recipe_modules
959+
recipe.get_recipe_id() == recipe_id for recipe in self.recipe_modules
961960
)
962961

963962

964963
def get_request_from_user_context(
965-
user_context: Optional[Dict[str, Any]],
964+
user_context: Optional[UserContext],
966965
) -> Optional[BaseRequest]:
967966
return Supertokens.get_instance().get_request_from_user_context(user_context)
967+
968+
969+
def is_recipe_initialized(recipe_id: str) -> bool:
970+
return Supertokens.get_instance().is_recipe_initialized(recipe_id)

supertokens_python/syncio/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
UserIdMappingAlreadyExistsError,
2626
UserIDTypes,
2727
)
28+
from supertokens_python.recipe.session.interfaces import SessionContainer
2829
from supertokens_python.types import User
29-
from supertokens_python.types.base import AccountInfoInput
30+
from supertokens_python.types.base import AccountInfoInput, UserContext
3031

3132

3233
def get_users_oldest_first(
@@ -178,3 +179,19 @@ def list_users_by_account_info(
178179
tenant_id, account_info, do_union_of_account_info, user_context
179180
)
180181
)
182+
183+
184+
def get_available_first_factors(
185+
tenant_id: str,
186+
session: Optional[SessionContainer],
187+
user_context: Optional[UserContext],
188+
):
189+
from supertokens_python.asyncio import (
190+
get_available_first_factors as async_get_available_first_factors,
191+
)
192+
193+
return sync(
194+
async_get_available_first_factors(
195+
tenant_id=tenant_id, session=session, user_context=user_context
196+
)
197+
)

tests/test_mfa.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from supertokens_python.recipe.multifactorauth.types import FactorIds
2+
3+
4+
def test_get_all_factors():
5+
"""Test that FactorIds.get_all_factors returns all factors defined in FactorIds class."""
6+
factors_from_dict: list[str] = []
7+
for k, v in FactorIds.__dict__.items():
8+
if (
9+
(not k.startswith("__") or not k.endswith("__"))
10+
and not k.startswith("<")
11+
and isinstance(v, str)
12+
):
13+
factors_from_dict.append(v)
14+
15+
assert factors_from_dict == FactorIds.get_all_factors()

0 commit comments

Comments
 (0)