Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .github/workflows/check_installation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Test Installation

on:
push:
branches:
- master
- main
pull_request:
branches:
- '**'
workflow_dispatch:

concurrency:
# older builds for the same pull request number or branch should be cancelled
cancel-in-progress: true
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

jobs:
test-installation:
name: Test Boto Dependency
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.12

- name: Test default installation (should include boto)
shell: bash
run: |
python -m venv test_default_env
source test_default_env/bin/activate

python -m pip install .
pip freeze | grep boto || exit 1 # boto3/botocore should be installed by default

# Deactivate and clean up
deactivate
rm -rf test_default_env

- name: Test installation with SNOWFLAKE_NO_BOTO=1 (should exclude boto)
shell: bash
run: |
python -m venv test_no_boto_env
source test_no_boto_env/bin/activate

SNOWFLAKE_NO_BOTO=1 python -m pip install .

# Check that boto3 and botocore are NOT installed
pip freeze | grep boto && exit 1 # boto3 and botocore should be not installed

# Deactivate and clean up
deactivate
rm -rf test_no_boto_env
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ repos:
src/snowflake/connector/vendored/.*
)$
args: [--show-fixes]
- id: check-optional-imports
name: Check for direct imports of modules which might be unavailable
entry: python ci/pre-commit/check_optional_imports.py
language: system
files: ^src/snowflake/connector/.*\.py$
exclude: src/snowflake/connector/options.py
args: [--show-fixes]
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
Expand Down
157 changes: 157 additions & 0 deletions ci/pre-commit/check_optional_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
Pre-commit hook to ensure optional dependencies are always imported from .options module.
This ensures that the connector can operate in environments where these optional libraries are not available.
"""
import argparse
import ast
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List

CHECKED_MODULES = [
"boto3",
"botocore",
"aioboto3",
"aiobotocore",
"pandas",
"pyarrow",
"keyring",
]


@dataclass(frozen=True)
class ImportViolation:
"""Pretty prints a violation import restrictions."""

filename: str
line: int
col: int
message: str

def __str__(self):
return f"{self.filename}:{self.line}:{self.col}: {self.message}"


class ImportChecker(ast.NodeVisitor):
"""Checks that optional imports are only imported from .options module."""

def __init__(self, filename: str):
self.filename = filename
self.violations: List[ImportViolation] = []

def visit_If(self, node: ast.If):
# Always visit the condition, but ignore imports inside "if TYPE_CHECKING:" blocks
if getattr(node.test, "id", None) == "TYPE_CHECKING":
# Skip the body and orelse for TYPE_CHECKING blocks
pass
else:
self.generic_visit(node)

def visit_Import(self, node: ast.Import):
"""Check import statements."""
for alias in node.names:
self._check_import(alias.name, node.lineno, node.col_offset)
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom):
"""Check from...import statements."""
if node.module:
# Check if importing from a checked module directly
for module in CHECKED_MODULES:
if node.module.startswith(module):
self.violations.append(
ImportViolation(
self.filename,
node.lineno,
node.col_offset,
f"Import from '{node.module}' is not allowed. Use 'from .options import {module}' instead",
)
)

# Check if importing checked modules from .options (this is allowed)
if node.module == ".options":
# This is the correct way to import these modules
pass
self.generic_visit(node)

def _check_import(self, module_name: str, line: int, col: int):
"""Check if a module import is for checked modules and not from .options."""
for module in CHECKED_MODULES:
if module_name.startswith(module):
self.violations.append(
ImportViolation(
self.filename,
line,
col,
f"Direct import of '{module_name}' is not allowed. Use 'from .options import {module}' instead",
)
)
break


def check_file(filename: str) -> List[ImportViolation]:
"""Check a file for optional import violations."""
try:
tree = ast.parse(Path(filename).read_text())
except SyntaxError:
# gracefully handle syntax errors
return []
checker = ImportChecker(filename)
checker.visit(tree)
return checker.violations


def main():
"""Main function for pre-commit hook."""
parser = argparse.ArgumentParser(
description="Check that optional imports are only imported from .options module"
)
parser.add_argument("filenames", nargs="*", help="Filenames to check")
parser.add_argument(
"--show-fixes", action="store_true", help="Show suggested fixes"
)
args = parser.parse_args()

all_violations = []
for filename in args.filenames:
if not filename.endswith(".py"):
continue
all_violations.extend(check_file(filename))

# Show violations
if all_violations:
print("Optional import violations found:")
print()

for violation in all_violations:
print(f" {violation}")

if args.show_fixes:
print()
print("How to fix:")
print(" - Import optional modules only from .options module")
print(" - Example:")
print(" # CORRECT:")
print(" from .options import boto3, botocore, installed_boto")
print(" if installed_boto:")
print(" SigV4Auth = botocore.auth.SigV4Auth")
print()
print(" # INCORRECT:")
print(" import boto3")
print(" from botocore.auth import SigV4Auth")
print()
print(
" - This ensures the connector works in environments where optional libraries are not installed"
)

print()
print(f"Found {len(all_violations)} violation(s)")
return 1

return 0


if __name__ == "__main__":
sys.exit(main())
8 changes: 6 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ project_urls =
python_requires = >=3.9
packages = find_namespace:
install_requires =
# [boto] extension is added by default unless SNOWFLAKE_NO_BOTO variable is set
# check setup.py
asn1crypto>0.24.0,<2.0.0
boto3>=1.24
botocore>=1.24
cffi>=1.9,<2.0.0
cryptography>=3.1.0
pyOpenSSL>=22.0.0,<25.0.0
Expand Down Expand Up @@ -79,6 +79,9 @@ console_scripts =
snowflake-dump-certs = snowflake.connector.tool.dump_certs:main

[options.extras_require]
boto =
boto3>=1.24
botocore>=1.24
development =
Cython
coverage
Expand All @@ -100,4 +103,5 @@ secure-local-storage =
keyring>=23.1.0,<26.0.0
aio =
aiohttp>=3.12.14
aioboto =
aioboto3>=15.0.0
30 changes: 26 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

from setuptools import Extension, setup
from setuptools.command.egg_info import egg_info

CONNECTOR_SRC_DIR = os.path.join("src", "snowflake", "connector")
NANOARROW_SRC_DIR = os.path.join(CONNECTOR_SRC_DIR, "nanoarrow_cpp", "ArrowIterator")
Expand Down Expand Up @@ -38,9 +39,14 @@
extensions = None
cmd_class = {}

SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS = os.environ.get(
"SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS", "false"
).lower() in ("y", "yes", "t", "true", "1", "on")
_POSITIVE_VALUES = ("y", "yes", "t", "true", "1", "on")
SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS = (
os.environ.get("SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS", "false").lower()
in _POSITIVE_VALUES
)
SNOWFLAKE_NO_BOTO = (
os.environ.get("SNOWFLAKE_NO_BOTO", "false").lower() in _POSITIVE_VALUES
)

try:
from Cython.Build import cythonize
Expand Down Expand Up @@ -88,7 +94,7 @@ def build_extension(self, ext):
ext.sources += [
os.path.join(
NANOARROW_ARROW_ITERATOR_SRC_DIR,
*((file,) if isinstance(file, str) else file)
*((file,) if isinstance(file, str) else file),
)
for file in {
"ArrayConverter.cpp",
Expand Down Expand Up @@ -174,6 +180,22 @@ def new__compile(obj, src: str, ext, cc_args, extra_postargs, pp_opts):

cmd_class = {"build_ext": MyBuildExt}


class SetDefaultInstallationExtras(egg_info):
"""Adds AWS extra unless SNOWFLAKE_NO_BOTO is specified."""

def finalize_options(self):
super().finalize_options()

# if not explicitly excluded, add boto dependencies to install_requires
if not SNOWFLAKE_NO_BOTO:
boto_extras = self.distribution.extras_require.get("boto", [])
self.distribution.install_requires += boto_extras


# Update command classes
cmd_class["egg_info"] = SetDefaultInstallationExtras

setup(
version=version,
ext_modules=extensions,
Expand Down
26 changes: 18 additions & 8 deletions src/snowflake/connector/aio/_wif_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import os
from base64 import b64encode

import aioboto3
from aiobotocore.utils import AioInstanceMetadataRegionFetcher
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from snowflake.connector.options import (
aioboto3,
aiobotocore,
botocore,
installed_aioboto,
)

from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
from ..errors import ProgrammingError
from ..errors import MissingDependencyError, ProgrammingError
from ..wif_util import (
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE,
SNOWFLAKE_AUDIENCE,
Expand All @@ -31,7 +33,9 @@ async def get_aws_region() -> str:
if "AWS_REGION" in os.environ: # Lambda
region = os.environ["AWS_REGION"]
else: # EC2
region = await AioInstanceMetadataRegionFetcher().retrieve_region()
region = (
await aiobotocore.utils.AioInstanceMetadataRegionFetcher().retrieve_region()
)

if not region:
raise ProgrammingError(
Expand All @@ -46,6 +50,12 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:

If the application isn't running on AWS or no credentials were found, raises an error.
"""
if not installed_aioboto:
raise MissingDependencyError(
msg="AWS Workload Identity Federation can't be used because aioboto3 or aiobotocore optional dependency is not installed. Try installing missing dependencies.",
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
)

session = aioboto3.Session()
aws_creds = await session.get_credentials()
if not aws_creds:
Expand All @@ -57,7 +67,7 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:
region = await get_aws_region()
partition = session.get_partition_for_region(region)
sts_hostname = get_aws_sts_hostname(region, partition)
request = AWSRequest(
request = botocore.awsrequest.AWSRequest(
method="POST",
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
headers={
Expand All @@ -66,7 +76,7 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:
},
)

SigV4Auth(aws_creds, "sts", region).add_auth(request)
botocore.auth.SigV4Auth(aws_creds, "sts", region).add_auth(request)

assertion_dict = {
"url": request.url,
Expand Down
Loading
Loading