Skip to content

Commit

Permalink
update examples and tests type_guards
Browse files Browse the repository at this point in the history
  • Loading branch information
RandallPittmanOrSt committed Jul 30, 2024
1 parent d42667d commit 3b1de4b
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 20 deletions.
2 changes: 2 additions & 0 deletions examples/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import netCDF4
from timeit import Timer
import os, sys
import type_guards

# create an n1dim by n2dim by n3dim random array.
n1dim = 30
Expand All @@ -15,6 +16,7 @@
array = uniform(size=(n1dim,n2dim,n3dim,n4dim))

def write_netcdf(filename,zlib=False,least_significant_digit=None,format='NETCDF4'):
assert type_guards.valid_format(format)
file = netCDF4.Dataset(filename,'w',format=format)
file.createDimension('n1', n1dim)
file.createDimension('n2', n2dim)
Expand Down
2 changes: 2 additions & 0 deletions examples/bench_compress4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import netCDF4
from timeit import Timer
import os, sys
import type_guards

# use real data.
URL="http://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/ncep.reanalysis/pressure/hgt.1990.nc"
Expand All @@ -24,6 +25,7 @@ def write_netcdf(filename,nsd,quantize_mode='BitGroom'):
file.createDimension('n1', None)
file.createDimension('n3', n3dim)
file.createDimension('n4', n4dim)
assert type_guards.valid_quantize_mode(quantize_mode)
foo = file.createVariable('data',\
'f4',('n1','n3','n4'),\
zlib=True,shuffle=True,\
Expand Down
2 changes: 2 additions & 0 deletions examples/bench_diskless.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import netCDF4
from timeit import Timer
import os, sys
import type_guards

# create an n1dim by n2dim by n3dim random array.
n1dim = 30
Expand All @@ -15,6 +16,7 @@
array = uniform(size=(n1dim,n2dim,n3dim,n4dim))

def write_netcdf(filename,zlib=False,least_significant_digit=None,format='NETCDF4',closeit=False):
assert type_guards.valid_format(format)
file = netCDF4.Dataset(filename,'w',format=format,diskless=True,persist=True)
file.createDimension('n1', n1dim)
file.createDimension('n2', n2dim)
Expand Down
105 changes: 105 additions & 0 deletions examples/type_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""type_guards.py - Helpers for static and runtime type-checking of initialization arguments
for Dataset and Variable"""

from typing import TYPE_CHECKING, Any, Literal

from typing_extensions import TypeGuard

if TYPE_CHECKING:
# in stubs only
from netCDF4 import (
AccessMode,
CalendarType,
CompressionLevel,
CompressionType,
EndianType,
QuantizeMode,
)
from netCDF4 import Format as NCFormat
else:
AccessMode = Any
CalendarType = Any
CompressionLevel = Any
DiskFormat = Any
EndianType = Any
CompressionType = Any
NCFormat = Any
QuantizeMode = Any


def valid_access_mode(mode) -> TypeGuard[AccessMode]:
"""Check for a valid `mode` argument for opening a Dataset"""
return mode in {"r", "w", "r+", "a", "x", "rs", "ws", "r+s", "as"}


def valid_calendar(calendar) -> TypeGuard[CalendarType]:
"""Check for a valid `calendar` argument for cftime functions"""
return calendar in {
"standard",
"gregorian",
"proleptic_gregorian",
"noleap",
"365_day",
"360_day",
"julian",
"all_leap",
"366_day",
}


def valid_complevel(complevel) -> TypeGuard[CompressionLevel]:
"""Check for a valid `complevel` argument for creating a Variable"""
return isinstance(complevel, int) and 0 <= complevel <= 9


def valid_compression(compression) -> TypeGuard[CompressionType]:
"""Check for a valid `compression` argument for creating a Variable"""
return compression in {
"zlib",
"szip",
"zstd",
"bzip2",
"blosc_lz",
"blosc_lz4",
"blosc_lz4hc",
"blosc_zlib",
"blosc_zstd",
}


def valid_format(format) -> TypeGuard[NCFormat]:
"""Check for a valid `format` argument for opening a Dataset"""
return format in {
"NETCDF4",
"NETCDF4_CLASSIC",
"NETCDF3_CLASSIC",
"NETCDF3_64BIT_OFFSET",
"NETCDF3_64BIT_DATA",
}


def valid_endian(endian) -> TypeGuard[EndianType]:
"""Check for a valid `endian` argument for creating a Variable"""
return endian in {"native", "big", "little"}


def valid_bloscshuffle(blosc_shuffle) -> TypeGuard[Literal[0, 1, 2]]:
"""Check for a valid `blosc_shuffle` argument for creating a Variable"""
return blosc_shuffle in {0, 1, 2}


def valid_quantize_mode(quantize_mode) -> TypeGuard[QuantizeMode]:
"""Check for a valid `quantize_mode` argument for creating a Variable"""
return quantize_mode in {"BitGroom", "BitRound", "GranularBitRound"}


def valid_szip_coding(szip_coding) -> TypeGuard[Literal["nn", "ec"]]:
"""Check for a valid `szip_coding` argument for creating a Variable"""
return szip_coding in {"nn", "ec"}


def valid_szip_pixels_per_block(
szip_pixels_per_block,
) -> TypeGuard[Literal[4, 8, 16, 32]]:
"""Check for a valid `szip_pixels_per_block` argument for creating a Variable"""
return szip_pixels_per_block in {4, 8, 16, 32}
8 changes: 4 additions & 4 deletions test/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from netCDF4.utils import _quantize
from numpy.testing import assert_almost_equal
import os, tempfile, unittest
import type_guards as n4t
import type_guards

ndim = 100000
ndim2 = 100
Expand All @@ -16,7 +16,7 @@

def write_netcdf(filename,zlib,least_significant_digit,data,dtype='f8',shuffle=False,contiguous=False,\
chunksizes=None,complevel=6,fletcher32=False):
assert n4t.valid_complevel(complevel) or complevel is None
assert type_guards.valid_complevel(complevel) or complevel is None
file = Dataset(filename,'w')
file.createDimension('n', ndim)
foo = file.createVariable('data',\
Expand All @@ -32,7 +32,7 @@ def write_netcdf(filename,zlib,least_significant_digit,data,dtype='f8',shuffle=F
#compression=''
#compression=0
#compression='gzip' # should fail
assert n4t.valid_compression(compression) or compression is None
assert type_guards.valid_compression(compression) or compression is None
foo2 = file.createVariable('data2',\
dtype,('n'),compression=compression,least_significant_digit=least_significant_digit,\
shuffle=shuffle,contiguous=contiguous,complevel=complevel,fletcher32=fletcher32,chunksizes=chunksizes)
Expand All @@ -49,7 +49,7 @@ def write_netcdf2(filename,zlib,least_significant_digit,data,dtype='f8',shuffle=
file = Dataset(filename,'w')
file.createDimension('n', ndim)
file.createDimension('n2', ndim2)
assert n4t.valid_complevel(complevel) or complevel is None
assert type_guards.valid_complevel(complevel) or complevel is None
foo = file.createVariable('data2',\
dtype,('n','n2'),zlib=zlib,least_significant_digit=least_significant_digit,\
shuffle=shuffle,contiguous=contiguous,complevel=complevel,fletcher32=fletcher32,chunksizes=chunksizes)
Expand Down
6 changes: 3 additions & 3 deletions test/test_compression_blosc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from numpy.random.mtrand import uniform
from netCDF4 import Dataset
import type_guards as n4t
import type_guards
from numpy.testing import assert_almost_equal
import os, tempfile, unittest, sys
from filter_availability import no_plugins, has_blosc_filter
import type_guards as n4t
import type_guards


ndim = 100000
Expand All @@ -14,7 +14,7 @@
datarr = uniform(size=(ndim,))

def write_netcdf(filename,dtype='f8',blosc_shuffle=1,complevel=6):
assert (n4t.valid_complevel(complevel) or complevel is None) and n4t.valid_bloscshuffle(blosc_shuffle)
assert (type_guards.valid_complevel(complevel) or complevel is None) and type_guards.valid_bloscshuffle(blosc_shuffle)
nc = Dataset(filename,'w')
nc.createDimension('n', ndim)
foo = nc.createVariable('data',\
Expand Down
4 changes: 2 additions & 2 deletions test/test_compression_bzip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from numpy.testing import assert_almost_equal
import os, tempfile, unittest, sys
from filter_availability import no_plugins, has_bzip2_filter
import type_guards as n4t
import type_guards

ndim = 100000
filename1 = tempfile.NamedTemporaryFile(suffix='.nc', delete=False).name
filename2 = tempfile.NamedTemporaryFile(suffix='.nc', delete=False).name
array = uniform(size=(ndim,))

def write_netcdf(filename,dtype='f8',complevel=6):
assert n4t.valid_complevel(complevel) or complevel is None
assert type_guards.valid_complevel(complevel) or complevel is None
nc = Dataset(filename,'w')
nc.createDimension('n', ndim)
foo = nc.createVariable('data',\
Expand Down
4 changes: 2 additions & 2 deletions test/test_compression_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.testing import assert_almost_equal
import numpy as np
import os, tempfile, unittest
import type_guards as n4t
import type_guards

ndim = 100000
nfiles = 7
Expand All @@ -17,7 +17,7 @@ def write_netcdf(filename,zlib,significant_digits,data,dtype='f8',shuffle=False,
complevel=6,quantize_mode="BitGroom"):
file = Dataset(filename,'w')
file.createDimension('n', ndim)
assert (n4t.valid_complevel(complevel) or complevel is None) and n4t.valid_quantize_mode(quantize_mode)
assert (type_guards.valid_complevel(complevel) or complevel is None) and type_guards.valid_quantize_mode(quantize_mode)
foo = file.createVariable('data',\
dtype,('n'),zlib=zlib,significant_digits=significant_digits,\
shuffle=shuffle,complevel=complevel,quantize_mode=quantize_mode)
Expand Down
4 changes: 2 additions & 2 deletions test/test_compression_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.testing import assert_almost_equal
import os, tempfile, unittest, sys
from filter_availability import no_plugins, has_zstd_filter
import type_guards as n4t
import type_guards

ndim = 100000
filename1 = tempfile.NamedTemporaryFile(suffix='.nc', delete=False).name
Expand All @@ -13,7 +13,7 @@
def write_netcdf(filename,dtype='f8',complevel=6):
nc = Dataset(filename,'w')
nc.createDimension('n', ndim)
assert n4t.valid_complevel(complevel) or complevel is None
assert type_guards.valid_complevel(complevel) or complevel is None
foo = nc.createVariable('data',\
dtype,('n'),compression='zstd',complevel=complevel)
foo[:] = array
Expand Down
4 changes: 2 additions & 2 deletions test/test_endian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import unittest, os, tempfile
from numpy.testing import assert_array_equal, assert_array_almost_equal
import type_guards as n4t
import type_guards

data = np.arange(12,dtype='f4').reshape(3,4)
FILE_NAME = tempfile.NamedTemporaryFile(suffix='.nc', delete=False).name
Expand Down Expand Up @@ -74,7 +74,7 @@ def issue310(file):
endian='little'
else:
raise ValueError('cannot determine native endianness')
assert n4t.valid_endian(endian) # mypy fails to narrow endian on its own
assert type_guards.valid_endian(endian) # mypy fails to narrow endian on its own
var_big_endian = nc.createVariable(\
'obs_big_endian', '>f8', ('obs', ),\
endian=endian,fill_value=fval)
Expand Down
4 changes: 2 additions & 2 deletions test/test_stringarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
import os
from numpy.testing import assert_array_equal, assert_array_almost_equal
import type_guards as n4t
import type_guards

def generateString(length, alphabet=string.ascii_letters + string.digits + string.punctuation):
return(''.join([random.choice(alphabet) for i in range(length)]))
Expand All @@ -25,7 +25,7 @@ class StringArrayTestCase(unittest.TestCase):

def setUp(self):
self.file = FILE_NAME
assert n4t.valid_ncformat(FILE_FORMAT)
assert type_guards.valid_format(FILE_FORMAT)
nc = Dataset(FILE_NAME,'w',format=FILE_FORMAT)
nc.createDimension('n1',None)
nc.createDimension('n2',n2)
Expand Down
4 changes: 2 additions & 2 deletions test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.testing import assert_array_equal, assert_array_almost_equal
from numpy.random.mtrand import uniform
import netCDF4
import type_guards as n4t
import type_guards

# test primitive data types.

Expand All @@ -29,7 +29,7 @@ def setUp(self):
f.createDimension('n1', None)
f.createDimension('n2', n2dim)
for typ in datatypes:
assert n4t.valid_complevel(complevel) or complevel is None
assert type_guards.valid_complevel(complevel) or complevel is None
foo = f.createVariable('data_'+typ, typ, ('n1','n2',),zlib=zlib,complevel=complevel,shuffle=shuffle,least_significant_digit=least_significant_digit,fill_value=FillValue)
#foo._FillValue = FillValue
# test writing of _FillValue attribute for diff types
Expand Down
2 changes: 1 addition & 1 deletion test/type_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def valid_compression(compression) -> TypeGuard[CompressionType]:
}


def valid_ncformat(format) -> TypeGuard[NCFormat]:
def valid_format(format) -> TypeGuard[NCFormat]:
"""Check for a valid `format` argument for opening a Dataset"""
return format in {
"NETCDF4",
Expand Down

0 comments on commit 3b1de4b

Please sign in to comment.