Skip to content

Commit 3cb15f7

Browse files
committed
support the MNIST-series datasets
1 parent f7badf8 commit 3cb15f7

File tree

13 files changed

+2147
-207
lines changed

13 files changed

+2147
-207
lines changed

LICENSE.txt

+203-203
Large diffs are not rendered by default.

python/jtorch/__init__.py

+19
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def inner(*args, **kw):
5050
from . import autograd
5151
from .autograd import *
5252

53+
def ndimension(self):
54+
return self.ndim
55+
Var.ndimension = ndimension
56+
57+
5358
Tensor = Var
5459
tensor = wrapper(array)
5560

@@ -76,3 +81,17 @@ def forward(self, *args, **kw):
7681
import jtorch.nn
7782
from jtorch.nn import Module, Parameter
7883
import jtorch.optim
84+
85+
from jtorch.utils.dtype import Dtype, get_string_dtype
86+
87+
def frombuffer(buffer: bytearray,
88+
*,
89+
dtype: Dtype,
90+
count: int = -1,
91+
offset: int = 0,
92+
requires_grad: bool = True) -> Tensor:
93+
dtype = get_string_dtype(dtype)
94+
tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
95+
if requires_grad and tensor.dtype.is_float():
96+
tensor.requires_grad = True
97+
return tensor

python/jtorch/tutorial/quickstart.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,40 @@
1919
train=False,
2020
download=True,
2121
transform=ToTensor(),
22-
)
22+
)
23+
24+
batch_size = 64
25+
26+
# Create data loaders.
27+
train_dataloader = DataLoader(training_data, batch_size=batch_size)
28+
test_dataloader = DataLoader(test_data, batch_size=batch_size)
29+
30+
for X, y in test_dataloader:
31+
print(f"Shape of X [N, C, H, W]: {X.shape}")
32+
print(f"Shape of y: {y.shape} {y.dtype}")
33+
break
34+
35+
# Get cpu or gpu device for training.
36+
device = "cuda" if torch.cuda.is_available() else "cpu"
37+
print(f"Using {device} device")
38+
39+
# Define model
40+
class NeuralNetwork(nn.Module):
41+
def __init__(self):
42+
super(NeuralNetwork, self).__init__()
43+
self.flatten = nn.Flatten()
44+
self.linear_relu_stack = nn.Sequential(
45+
nn.Linear(28*28, 512),
46+
nn.ReLU(),
47+
nn.Linear(512, 512),
48+
nn.ReLU(),
49+
nn.Linear(512, 10)
50+
)
51+
52+
def forward(self, x):
53+
x = self.flatten(x)
54+
logits = self.linear_relu_stack(x)
55+
return logits
56+
57+
model = NeuralNetwork().to(device)
58+
print(model)
File renamed without changes.

python/jtorch/utils/data.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,56 @@
1-
class DataLoader:
2-
pass
1+
from audioop import mul
2+
import jittor as jt
3+
from jittor.dataset import Dataset as JDataset
4+
5+
from typing import Any, Callable, Iterable, Optional, Sequence, Union
6+
7+
8+
class Dataset:
9+
def __getitem__(self, index):
10+
raise NotImplementedError
11+
12+
13+
class DataLoader(JDataset):
14+
def __init__(self, dataset: Dataset,
15+
batch_size: Optional[int] = 1,
16+
shuffle: Optional[bool] = None,
17+
sampler = None,
18+
batch_sampler = None,
19+
num_workers: int = 0,
20+
collate_fn = None,
21+
pin_memory: bool = False,
22+
drop_last: bool = False,
23+
timeout: float = 0,
24+
worker_init_fn = None,
25+
multiprocessing_context=None,
26+
generator=None,
27+
*, prefetch_factor: int = 2,
28+
persistent_workers: bool = False,
29+
pin_memory_device: str = "") -> None:
30+
super().__init__(batch_size=batch_size,
31+
shuffle=shuffle,
32+
num_workers=num_workers,
33+
drop_last=drop_last)
34+
35+
unsupported_kwargs = {
36+
"sampler": sampler,
37+
"batch_sampler": batch_sampler,
38+
"pin_memory": pin_memory,
39+
"timeout": timeout,
40+
"worker_init_fn": worker_init_fn,
41+
"multiprocessing_context": multiprocessing_context,
42+
"generator": generator,
43+
"persistent_workers": persistent_workers,
44+
"pin_memory_device": pin_memory_device
45+
}
46+
for kwarg, value in unsupported_kwargs.items():
47+
if value:
48+
jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}")
49+
50+
self.collate_fn = collate_fn
51+
52+
def collate_batch(self, batch):
53+
if self.collate_fn is not None:
54+
return self.collate_fn(batch)
55+
else:
56+
return super().collate_batch(batch)

python/jtorch/utils/dtype.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Callable, Union
2+
Dtype = Union[Callable, str]
3+
4+
def get_string_dtype(dtype):
5+
if callable(dtype):
6+
dtype = dtype.__name__
7+
if not isinstance(dtype, str):
8+
raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.")
9+
return dtype
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import importlib.machinery
2+
import os
3+
4+
5+
def _download_file_from_remote_location(fpath: str, url: str) -> None:
6+
pass
7+
8+
9+
def _is_remote_location_available() -> bool:
10+
return False
11+
12+
13+
def _get_extension_path(lib_name):
14+
15+
lib_dir = os.path.dirname(__file__)
16+
if os.name == "nt":
17+
# Register the main torchvision library location on the default DLL path
18+
import ctypes
19+
import sys
20+
21+
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
22+
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
23+
prev_error_mode = kernel32.SetErrorMode(0x0001)
24+
25+
if with_load_library_flags:
26+
kernel32.AddDllDirectory.restype = ctypes.c_void_p
27+
28+
if sys.version_info >= (3, 8):
29+
os.add_dll_directory(lib_dir)
30+
elif with_load_library_flags:
31+
res = kernel32.AddDllDirectory(lib_dir)
32+
if res is None:
33+
err = ctypes.WinError(ctypes.get_last_error())
34+
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
35+
raise err
36+
37+
kernel32.SetErrorMode(prev_error_mode)
38+
39+
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
40+
41+
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
42+
ext_specs = extfinder.find_spec(lib_name)
43+
if ext_specs is None:
44+
raise ImportError
45+
46+
return ext_specs.origin
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
2+
3+
__all__ = (
4+
"EMNIST",
5+
"FashionMNIST",
6+
"QMNIST",
7+
"MNIST",
8+
"KMNIST",
9+
)

0 commit comments

Comments
 (0)