Skip to content

Commit fe8a14f

Browse files
authored
Merge pull request #5 from oracle/dev/v1.0.0rc1
v1.0.0
2 parents 6866491 + c7d736a commit fe8a14f

File tree

9 files changed

+178
-19
lines changed

9 files changed

+178
-19
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ Run
1313
python3 -m pip install select_ai
1414
```
1515

16+
## Documentation
17+
18+
See [Select AI for Python documentation][documentation]
19+
1620
## Samples
1721

1822
Examples can be found in the [/samples][samples] directory
@@ -81,6 +85,7 @@ Released under the Universal Permissive License v1.0 as shown at
8185
<https://oss.oracle.com/licenses/upl/>.
8286

8387
[contributing]: https://github.com/oracle/python-select-ai/blob/main/CONTRIBUTING.md
88+
[documentation]: https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/
8489
[ghdiscussions]: https://github.com/oracle/python-select-ai/discussions
8590
[ghissues]: https://github.com/oracle/python-select-ai/issues
8691
[samples]: https://github.com/oracle/python-select-ai/tree/main/samples

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ keywords = [
2323
license = " UPL-1.0"
2424
license-files = ["LICENSE.txt"]
2525
classifiers = [
26-
"Development Status :: 4 - Beta",
26+
"Development Status :: 5 - Production/Stable",
2727
"Intended Audience :: Developers",
2828
"Natural Language :: English",
2929
"Operating System :: OS Independent",
@@ -34,7 +34,9 @@ classifiers = [
3434
"Programming Language :: Python :: 3.12",
3535
"Programming Language :: Python :: 3.13",
3636
"Programming Language :: Python :: Implementation :: CPython",
37-
"Topic :: Database"
37+
"Topic :: Database",
38+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
39+
"Topic :: Software Development :: Libraries :: Python Modules"
3840
]
3941
dependencies = [
4042
"oracledb",
@@ -45,6 +47,7 @@ dependencies = [
4547
Homepage = "https://github.com/oracle/python-select-ai"
4648
Repository = "https://github.com/oracle/python-select-ai"
4749
Issues = "https://github.com/oracle/python-select-ai/issues"
50+
Documentation = "https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/"
4851

4952
[tool.setuptools.packages.find]
5053
where = ["src"]

src/select_ai/_validations.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
8+
import inspect
9+
from collections.abc import Mapping, Sequence, Set
10+
from functools import wraps
11+
from typing import Any, get_args, get_origin, get_type_hints
12+
13+
NoneType = type(None)
14+
15+
16+
def _match(value, annot) -> bool:
17+
"""Recursively validate value against a typing annotation."""
18+
if annot is Any:
19+
return True
20+
21+
origin = get_origin(annot)
22+
args = get_args(annot)
23+
24+
# Handle Annotated[T, ...] → treat as T
25+
if origin is getattr(__import__("typing"), "Annotated", None):
26+
annot = args[0]
27+
origin = get_origin(annot)
28+
args = get_args(annot)
29+
30+
# Optional[T] is Union[T, NoneType]
31+
if origin is getattr(__import__("typing"), "Union", None):
32+
return any(_match(value, a) for a in args)
33+
34+
# Literal[…]
35+
if origin is getattr(__import__("typing"), "Literal", None):
36+
return any(value == lit for lit in args)
37+
38+
# Tuple cases
39+
if origin is tuple:
40+
if not isinstance(value, tuple):
41+
return False
42+
if len(args) == 2 and args[1] is Ellipsis:
43+
# tuple[T, ...]
44+
return all(_match(v, args[0]) for v in value)
45+
if len(args) != len(value):
46+
return False
47+
return all(_match(v, a) for v, a in zip(value, args))
48+
49+
# Mappings (dict-like)
50+
if origin in (dict, Mapping):
51+
if not isinstance(value, Mapping):
52+
return False
53+
k_annot, v_annot = args if args else (Any, Any)
54+
return all(
55+
_match(k, k_annot) and _match(v, v_annot) for k, v in value.items()
56+
)
57+
58+
# Sequences (list, Sequence) – but not str/bytes
59+
if origin in (list, Sequence):
60+
if isinstance(value, (str, bytes)):
61+
return False
62+
if not isinstance(value, Sequence):
63+
return False
64+
elem_annot = args[0] if args else Any
65+
return all(_match(v, elem_annot) for v in value)
66+
67+
# Sets
68+
if origin in (set, frozenset, Set):
69+
if not isinstance(value, (set, frozenset)):
70+
return False
71+
elem_annot = args[0] if args else Any
72+
return all(_match(v, elem_annot) for v in value)
73+
74+
# Fall back to normal isinstance for non-typing classes
75+
if isinstance(annot, type):
76+
return isinstance(value, annot)
77+
78+
# If annot is a typing alias like 'list' without args
79+
if origin is not None:
80+
# Treat bare containers as accepting anything inside
81+
return isinstance(value, origin)
82+
83+
# Unknown/unsupported typing form: accept conservatively
84+
return True
85+
86+
87+
def enforce_types(func):
88+
# Resolve ForwardRefs using function globals (handles "User" as a string, etc.)
89+
hints = get_type_hints(
90+
func, globalns=func.__globals__, include_extras=True
91+
)
92+
sig = inspect.signature(func)
93+
94+
def _check(bound):
95+
for name, val in bound.arguments.items():
96+
if name in hints:
97+
annot = hints[name]
98+
if not _match(val, annot):
99+
raise TypeError(
100+
f"Argument '{name}' failed type check: expected {annot!r}, "
101+
f"got {type(val).__name__} -> {val!r}"
102+
)
103+
104+
if inspect.iscoroutinefunction(func):
105+
106+
@wraps(func)
107+
async def aw(*args, **kwargs):
108+
bound = sig.bind(*args, **kwargs)
109+
bound.apply_defaults()
110+
_check(bound)
111+
return await func(*args, **kwargs)
112+
113+
return aw
114+
else:
115+
116+
@wraps(func)
117+
def w(*args, **kwargs):
118+
bound = sig.bind(*args, **kwargs)
119+
bound.apply_defaults()
120+
_check(bound)
121+
return func(*args, **kwargs)
122+
123+
return w

src/select_ai/async_profile.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,15 @@ async def generate(
344344
keyword_parameters=parameters,
345345
)
346346
if data is not None:
347-
return await data.read()
348-
return None
347+
result = await data.read()
348+
else:
349+
result = None
350+
if action == Action.RUNSQL and result:
351+
return pandas.DataFrame(json.loads(result))
352+
elif action == Action.RUNSQL:
353+
return pandas.DataFrame()
354+
else:
355+
return result
349356

350357
async def chat(self, prompt, params: Mapping = None) -> str:
351358
"""Asynchronously chat with the LLM
@@ -411,8 +418,7 @@ async def run_sql(
411418
:param params: Parameters to include in the LLM request
412419
:return: pandas.DataFrame
413420
"""
414-
data = await self.generate(prompt, action=Action.RUNSQL, params=params)
415-
return pandas.DataFrame(json.loads(data))
421+
return await self.generate(prompt, action=Action.RUNSQL, params=params)
416422

417423
async def show_sql(self, prompt, params: Mapping = None):
418424
"""Show the generated SQL

src/select_ai/base_profile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class ProfileAttributes(SelectAIDataClass):
7373
vector_index_name: Optional[str] = None
7474

7575
def __post_init__(self):
76+
super().__post_init__()
7677
if self.provider and not isinstance(self.provider, Provider):
7778
raise ValueError(
7879
f"'provider' must be an object of " f"type select_ai.Provider"

src/select_ai/profile.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
from contextlib import contextmanager
1010
from dataclasses import replace as dataclass_replace
11-
from typing import Iterator, Mapping, Optional, Union
11+
from typing import Generator, Iterator, Mapping, Optional, Union
1212

1313
import oracledb
1414
import pandas
@@ -258,7 +258,9 @@ def _from_db(cls, profile_name: str) -> "Profile":
258258
raise ProfileNotFoundError(profile_name=profile_name)
259259

260260
@classmethod
261-
def list(cls, profile_name_pattern: str = ".*") -> Iterator["Profile"]:
261+
def list(
262+
cls, profile_name_pattern: str = ".*"
263+
) -> Generator["Profile", None, None]:
262264
"""List AI Profiles saved in the database.
263265
264266
:param str profile_name_pattern: Regular expressions can be used
@@ -314,8 +316,15 @@ def generate(
314316
keyword_parameters=parameters,
315317
)
316318
if data is not None:
317-
return data.read()
318-
return None
319+
result = data.read()
320+
else:
321+
result = None
322+
if action == Action.RUNSQL and result:
323+
return pandas.DataFrame(json.loads(result))
324+
elif action == Action.RUNSQL:
325+
return pandas.DataFrame()
326+
else:
327+
return result
319328

320329
def chat(self, prompt: str, params: Mapping = None) -> str:
321330
"""Chat with the LLM
@@ -375,10 +384,7 @@ def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame:
375384
:param params: Parameters to include in the LLM request
376385
:return: pandas.DataFrame
377386
"""
378-
data = json.loads(
379-
self.generate(prompt, action=Action.RUNSQL, params=params)
380-
)
381-
return pandas.DataFrame(data)
387+
return self.generate(prompt, action=Action.RUNSQL, params=params)
382388

383389
def show_sql(self, prompt: str, params: Mapping = None) -> str:
384390
"""Show the generated SQL

src/select_ai/provider.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import List, Optional, Union
1010

1111
from select_ai._abc import SelectAIDataClass
12+
from select_ai._validations import enforce_types
1213

1314
from .db import async_cursor, cursor
1415
from .sql import (
@@ -194,6 +195,7 @@ class AnthropicProvider(Provider):
194195
provider_endpoint = "api.anthropic.com"
195196

196197

198+
@enforce_types
197199
async def async_enable_provider(
198200
users: Union[str, List[str]], provider_endpoint: str = None
199201
):
@@ -210,7 +212,7 @@ async def async_enable_provider(
210212

211213
async with async_cursor() as cr:
212214
for user in users:
213-
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user))
215+
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
214216
if provider_endpoint:
215217
await cr.execute(
216218
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
@@ -219,6 +221,7 @@ async def async_enable_provider(
219221
)
220222

221223

224+
@enforce_types
222225
async def async_disable_provider(
223226
users: Union[str, List[str]], provider_endpoint: str = None
224227
):
@@ -234,7 +237,7 @@ async def async_disable_provider(
234237

235238
async with async_cursor() as cr:
236239
for user in users:
237-
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user))
240+
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
238241
if provider_endpoint:
239242
await cr.execute(
240243
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
@@ -243,6 +246,7 @@ async def async_disable_provider(
243246
)
244247

245248

249+
@enforce_types
246250
def enable_provider(
247251
users: Union[str, List[str]], provider_endpoint: str = None
248252
):
@@ -256,7 +260,7 @@ def enable_provider(
256260

257261
with cursor() as cr:
258262
for user in users:
259-
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user))
263+
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
260264
if provider_endpoint:
261265
cr.execute(
262266
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
@@ -265,6 +269,7 @@ def enable_provider(
265269
)
266270

267271

272+
@enforce_types
268273
def disable_provider(
269274
users: Union[str, List[str]], provider_endpoint: str = None
270275
):
@@ -279,7 +284,7 @@ def disable_provider(
279284

280285
with cursor() as cr:
281286
for user in users:
282-
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user))
287+
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
283288
if provider_endpoint:
284289
cr.execute(
285290
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,

src/select_ai/vector_index.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def __init__(
119119
attributes: Optional[VectorIndexAttributes] = None,
120120
):
121121
"""Initialize a Vector Index"""
122+
if attributes and not isinstance(attributes, VectorIndexAttributes):
123+
raise TypeError(
124+
"'attributes' must be an object of type "
125+
"select_ai.VectorIndexAttributes"
126+
)
127+
if profile and not isinstance(profile, BaseProfile):
128+
raise TypeError(
129+
"'profile' must be an object of type "
130+
"select_ai.Profile or select_ai.AsyncProfile"
131+
)
122132
self.profile = profile
123133
self.index_name = index_name
124134
self.attributes = attributes

src/select_ai/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# http://oss.oracle.com/licenses/upl.
66
# -----------------------------------------------------------------------------
77

8-
__version__ = "1.0.0b1"
8+
__version__ = "1.0.0"

0 commit comments

Comments
 (0)