|
8 | 8 | import re |
9 | 9 | import sys |
10 | 10 | import traceback |
| 11 | +import typing |
11 | 12 | import uuid |
12 | 13 | import warnings |
13 | 14 | import weakref |
|
20 | 21 | from logging import getLogger |
21 | 22 | from threading import Lock |
22 | 23 | from types import TracebackType |
23 | | -from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence |
| 24 | +from typing import ( |
| 25 | + Any, |
| 26 | + Callable, |
| 27 | + Generator, |
| 28 | + Iterable, |
| 29 | + Iterator, |
| 30 | + NamedTuple, |
| 31 | + Sequence, |
| 32 | + TypeVar, |
| 33 | +) |
24 | 34 | from uuid import UUID |
25 | 35 |
|
26 | 36 | from cryptography.hazmat.backends import default_backend |
|
76 | 86 | QueryStatus, |
77 | 87 | ) |
78 | 88 | from .converter import SnowflakeConverter |
79 | | -from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor |
| 89 | +from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase |
80 | 90 | from .description import ( |
81 | 91 | CLIENT_NAME, |
82 | 92 | CLIENT_VERSION, |
|
125 | 135 | from .util_text import construct_hostname, parse_account, split_statements |
126 | 136 | from .wif_util import AttestationProvider |
127 | 137 |
|
| 138 | +if sys.version_info >= (3, 13) or typing.TYPE_CHECKING: |
| 139 | + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor) |
| 140 | +else: |
| 141 | + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase) |
| 142 | + |
128 | 143 | DEFAULT_CLIENT_PREFETCH_THREADS = 4 |
129 | 144 | MAX_CLIENT_PREFETCH_THREADS = 10 |
130 | 145 | MAX_CLIENT_FETCH_THREADS = 1024 |
@@ -1055,9 +1070,7 @@ def rollback(self) -> None: |
1055 | 1070 | """Rolls back the current transaction.""" |
1056 | 1071 | self.cursor().execute("ROLLBACK") |
1057 | 1072 |
|
1058 | | - def cursor( |
1059 | | - self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor |
1060 | | - ) -> SnowflakeCursor: |
| 1073 | + def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls: |
1061 | 1074 | """Creates a cursor object. Each statement will be executed in a new cursor object.""" |
1062 | 1075 | logger.debug("cursor") |
1063 | 1076 | if not self.rest: |
|
0 commit comments