diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f948375 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "yarn" + directory: "/" + schedule: + interval: "monthly" + groups: + all-updates: + patterns: + - "*" diff --git a/.github/workflows/build-test-publish.yml b/.github/workflows/build-test.yml similarity index 81% rename from .github/workflows/build-test-publish.yml rename to .github/workflows/build-test.yml index cbdaff4..1f4d176 100644 --- a/.github/workflows/build-test-publish.yml +++ b/.github/workflows/build-test.yml @@ -1,10 +1,12 @@ -name: Build, Test, Publish +name: Build and Test -on: [push] +on: + pull_request: + push: jobs: - test-publish: - name: Build, test, and publish package + test: + name: Build, test, and prepare docs runs-on: ubuntu-latest permissions: contents: read @@ -36,16 +38,15 @@ jobs: mkdir -p docs/media [ -d "build" ] && cp -r build docs/build [ -d "build" ] && cp -r build docs/media/build - [ -d "examples" ] && cp -r examples docs/examples - [ -d "examples" ] && cp -r examples docs/media/examples + [ -d "examples" ] && cp -rL examples docs/examples + [ -d "examples" ] && cp -rL examples docs/media/examples - uses: actions/upload-pages-artifact@v4 with: path: './docs' - - run: yarn pkg-pr-new publish - deploy-docs: + if: github.event_name == 'push' && github.ref == 'refs/heads/main' environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} @@ -53,7 +54,7 @@ jobs: permissions: pages: write id-token: write - needs: test-publish + needs: test steps: - uses: actions/deploy-pages@v5 id: deployment diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..8b95bbc --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,25 @@ +name: Publish + +on: + push: + branches: [main] + workflow_dispatch: + +jobs: + publish: + name: Publish package preview + runs-on: ubuntu-latest + permissions: + contents: read + timeout-minutes: 10 + steps: + - uses: actions/checkout@v6 + - run: corepack enable + - uses: actions/setup-node@v6 + with: + node-version: 24 + cache: yarn + + - run: yarn install --immutable + - run: yarn build + - run: yarn pkg-pr-new publish diff --git a/.github/workflows/pyodide-test.yml b/.github/workflows/pyodide-test.yml index 99cf865..c94fd46 100644 --- a/.github/workflows/pyodide-test.yml +++ b/.github/workflows/pyodide-test.yml @@ -1,6 +1,8 @@ name: Pyodide Test -on: [push] +on: + pull_request: + push: jobs: pyodide-test: diff --git a/.github/workflows/verify-generated-tests.yml b/.github/workflows/verify-generated-tests.yml index 459bb77..d883a5d 100644 --- a/.github/workflows/verify-generated-tests.yml +++ b/.github/workflows/verify-generated-tests.yml @@ -1,6 +1,8 @@ name: Verify Generated Tests -on: [push] +on: + pull_request: + push: jobs: verify-sync: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cf153fe..bf5d47c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,18 +1,38 @@ +## Development Commands + +```bash +yarn install # Install dependencies +yarn build # Build for browser and node (runs clean first) +yarn test # Run tests against src/ (uses mocha + tsx) +yarn test:build # Run tests against built library in build/ +yarn test:coverage # Run tests with coverage report +yarn test:watch # Watch mode for tests (alias: yarn watch) +yarn lint # Lint src/ with ESLint +yarn serve # Serve on localhost:8080 +yarn docs # Build TypeDoc documentation +yarn update-tests # Regenerate test/testcases.gen.js from scripts/generate_tests.py +``` + +To run a single test file: +```bash +yarn mocha --node-option conditions=torch-src test/tensor.test.js +``` + ## Codebase Structure - [`src`](src) - [`index.ts`](src/index.ts) is the entry point of the library. - [`tensor.ts`](src/tensor.ts) is the main tensor class. - - [`function`](function) contains all functions that tensors can perform. - - [`nn`](nn) contains all neural network modules (for everything under `torch.nn`). - - [`optim`](optim) contains all optimizers (for everything under `torch.optim`). - - [`creation`](creation) contains all tensor creation functions (all functions that create a tensor not from scratch, including `zeros`, `randn`). + - [`functions`](src/functions) contains all functions that tensors can perform. + - [`nn`](src/nn) contains all neural network modules (for everything under `torch.nn`). + - [`optim`](src/optim) contains all optimizers (for everything under `torch.optim`). + - [`creation`](src/creation) contains all tensor creation functions (all functions that create a tensor not from scratch, including `zeros`, `randn`). - [`examples`](examples) contains example usages of the library, including on node, on the browser, and using pyodide on the browser. - [`test`](test) contains the test cases of the library, including on node and on the browser. See [Testing](#testing). ### Development Scripts -You can use `yarn watch` to automatically test after each edit. +Use `yarn watch` (or `yarn test:watch`) to automatically re-run tests on each edit. ### Adding a new Function diff --git a/README.md b/README.md index 3d39c83..3b1e6a4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,18 @@ # torch machine-learning libraries for Source Academy +The primary objective of this project is to create a reimplementation of PyTorch in TypeScript, with an educational focus. This project is developed with Source Academy integration in mind. + +This project reimplements core parts of PyTorch while trying to keep the codebase simple, and the API as close to PyTorch as possible. + +Using Pyodide, we can run Python code in the browser. Using `pyodide_bridge.py` in a way similar to `examples/pyodide/` we can run PyTorch-like code in the browser. + +## Notable differences with PyTorch + +- This library exposes extra information for debuggers and visualizers to catch, as seen in `events` in [`src/util.ts`](src/util.ts). It is similar to hooks in PyTorch. +- This library does not differentiate between LongTensors and FloatTensors. It uses `number` for all tensor elements. +- This library does not currently support devices, such as GPUs. + ## Getting Started Install yarn: @@ -48,4 +60,10 @@ yarn serve ## Contributing -For detailed information on the codebase and tests, see [CONTRIBUTING.md](CONTRIBUTING.md). +Contributions are welcome. The short version: + +1. Run `yarn test` to verify everything passes. +2. Add tests for new ops or behaviour changes. +3. Follow the existing patterns — new ops go in [src/functions/ops.ts](src/functions/ops.ts). + +For full details on the codebase, how to add operations, and the testing setup, see [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/examples/basic_backpropagation.js b/examples/basic_backpropagation.js index 8274ac2..72a5a37 100644 --- a/examples/basic_backpropagation.js +++ b/examples/basic_backpropagation.js @@ -1,4 +1,4 @@ -import { Tensor } from '../build/node/torch.node.es.js'; +import { Tensor } from '../build/node/torch.node.es.mjs'; const x = new Tensor([2.0], { requires_grad: true }); const y = x.pow(new Tensor([2.0])); diff --git a/examples/pyodide/.gitignore b/examples/pyodide/.gitignore index 2aa8c99..cb95714 100644 --- a/examples/pyodide/.gitignore +++ b/examples/pyodide/.gitignore @@ -142,3 +142,6 @@ dist vite.config.js.timestamp-* vite.config.ts.timestamp-* .vite/ + +# Pyodide package cache +.pyodide-packages/ diff --git a/examples/pyodide/bridge.py b/examples/pyodide/bridge.py deleted file mode 100644 index 872c259..0000000 --- a/examples/pyodide/bridge.py +++ /dev/null @@ -1,805 +0,0 @@ -# bridge.py -# Provides a PyTorch-compatible Python API over js_torch (the TypeScript torch library). -# -# Before loading this file, set the following globals in Pyodide: -# js_torch - the torch module (window.torch from the UMD build) - -from pyodide.ffi import JsProxy, to_js - - -# --------------------------------------------------------------------------- -# Internal helpers -# --------------------------------------------------------------------------- - -def _wrap_result(result): - """ - Wrap a JS return value: - - JsProxy (JS object/Tensor) -> Python Tensor - - Python primitive (int, float, bool) -> return as-is - JS primitives are automatically converted to Python by Pyodide, - so they will NOT be JsProxy instances. - """ - if isinstance(result, JsProxy): - return Tensor(result) - return result - - -def _transform(obj): - """Convert Python objects to JS-compatible types before passing to JS.""" - if isinstance(obj, Tensor): - return obj._js - if isinstance(obj, (list, tuple)): - return to_js([_transform(item) for item in obj]) - return obj - - -def _transform_args(args): - return [_transform(a) for a in args] - - -# --------------------------------------------------------------------------- -# Tensor -# --------------------------------------------------------------------------- - -class Tensor: - """Python wrapper around a JS Tensor, mirroring the PyTorch Tensor API.""" - - # ------------------------------------------------------------------ - # Construction - # ------------------------------------------------------------------ - - def __new__(cls, data, requires_grad=False): - # Return None for missing tensors so e.g. `tensor.grad` returns None - # when there is no gradient — matching PyTorch behaviour. - # Pyodide may represent JS null as a special JsNull type (not JsProxy, not None). - if data is None or type(data).__name__ in ('JsNull', 'JsUndefined'): - return None - return super().__new__(cls) - - def __init__(self, data, requires_grad=False): - if isinstance(data, JsProxy): - self._js = data - else: - js_data = to_js(data) if isinstance(data, (list, tuple)) else data - self._js = js_torch.tensor(js_data, requires_grad) - - # ------------------------------------------------------------------ - # Representation - # ------------------------------------------------------------------ - - def __repr__(self): - extra = ", requires_grad=True" if self.requires_grad else "" - return f"tensor({self.tolist()}{extra})" - - # ------------------------------------------------------------------ - # Data access - # ------------------------------------------------------------------ - - def tolist(self): - """Return tensor data as a (nested) Python list, or a Python scalar for 0-d tensors.""" - result = self._js.toArray() - if isinstance(result, JsProxy): - return result.to_py() - return result # scalar - - def item(self): - return self._js.item() - - # ------------------------------------------------------------------ - # Properties - # ------------------------------------------------------------------ - - @property - def shape(self): - return tuple(self._js.shape.to_py()) - - @property - def data(self): - """Detached view of the tensor data (no gradient).""" - return self.detach() - - @property - def requires_grad(self): - return bool(self._js.requires_grad) - - @requires_grad.setter - def requires_grad(self, value): - self._js.requires_grad = value - - @property - def grad(self): - raw = self._js.grad - if raw is None or type(raw).__name__ in ('JsNull', 'JsUndefined'): - return None - return Tensor(raw) - - @grad.setter - def grad(self, value): - self._js.grad = value._js if isinstance(value, Tensor) else None - - @property - def T(self): - if len(self.shape) < 2: - return self - return Tensor(self._js.transpose(0, 1)) - - # ------------------------------------------------------------------ - # Grad utilities - # ------------------------------------------------------------------ - - def backward(self, gradient=None): - if gradient is None: - self._js.backward() - else: - self._js.backward(gradient._js) - - def detach(self): - return Tensor(self._js.detach()) - - def zero_(self): - self._js.zero_() - return self - - def retain_grad(self): - self._js.retain_grad() - - # ------------------------------------------------------------------ - # Shape utilities - # ------------------------------------------------------------------ - - def size(self, dim=None): - s = self.shape - return s if dim is None else s[dim] - - def dim(self): - return len(self.shape) - - def numel(self): - n = 1 - for s in self.shape: - n *= s - return n - - def reshape(self, *args): - shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) - return Tensor(self._js.reshape(to_js(shape))) - - def view(self, *args): - return self.reshape(*args) - - def squeeze(self, dim=None): - if dim is None: - new_shape = [s for s in self.shape if s != 1] - return Tensor(self._js.reshape(to_js(new_shape or [1]))) - return Tensor(self._js.squeeze(dim)) - - def unsqueeze(self, dim): - return Tensor(self._js.unsqueeze(dim)) - - def expand(self, *args): - shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) - return Tensor(self._js.expand(to_js(shape))) - - def transpose(self, dim0, dim1): - return Tensor(self._js.transpose(dim0, dim1)) - - def flatten(self, start_dim=0, end_dim=-1): - return Tensor(self._js.flatten(start_dim, end_dim)) - - # ------------------------------------------------------------------ - # Reductions — default (no dim) sums all elements, matching PyTorch - # ------------------------------------------------------------------ - - def sum(self, dim=None, keepdim=False): - return Tensor(self._js.sum() if dim is None else self._js.sum(dim, keepdim)) - - def mean(self, dim=None, keepdim=False): - return Tensor(self._js.mean() if dim is None else self._js.mean(dim, keepdim)) - - def max(self, dim=None, keepdim=False): - return Tensor(self._js.max() if dim is None else self._js.max(dim, keepdim)) - - def min(self, dim=None, keepdim=False): - return Tensor(self._js.min() if dim is None else self._js.min(dim, keepdim)) - - # ------------------------------------------------------------------ - # Arithmetic — explicit methods - # ------------------------------------------------------------------ - - def _to_js(self, other): - return other._js if isinstance(other, Tensor) else other - - def add(self, other): return Tensor(self._js.add(self._to_js(other))) - def sub(self, other): return Tensor(self._js.sub(self._to_js(other))) - def mul(self, other): return Tensor(self._js.mul(self._to_js(other))) - def div(self, other): return Tensor(self._js.div(self._to_js(other))) - def pow(self, other): return Tensor(self._js.pow(self._to_js(other))) - def matmul(self, other): return Tensor(self._js.matmul(self._to_js(other))) - - # ------------------------------------------------------------------ - # Arithmetic operators - # ------------------------------------------------------------------ - - def __add__(self, other): return self.add(other) - def __radd__(self, other): return self.add(other) # add is commutative - def __sub__(self, other): return self.sub(other) - def __rsub__(self, other): - o = other if isinstance(other, Tensor) else Tensor(other) - return o.sub(self) - def __mul__(self, other): return self.mul(other) - def __rmul__(self, other): return self.mul(other) # mul is commutative - def __truediv__(self, other): return self.div(other) - def __rtruediv__(self, other): - o = other if isinstance(other, Tensor) else Tensor(other) - return o.div(self) - def __pow__(self, other): return self.pow(other) - def __rpow__(self, other): - o = other if isinstance(other, Tensor) else Tensor(other) - return o.pow(self) - def __matmul__(self, other): return self.matmul(other) - def __neg__(self): return Tensor(self._js.neg()) - def __abs__(self): return Tensor(self._js.abs()) - - # ------------------------------------------------------------------ - # Unary operations - # ------------------------------------------------------------------ - - def neg(self): return Tensor(self._js.neg()) - def abs(self): return Tensor(self._js.abs()) - def log(self): return Tensor(self._js.log()) - def exp(self): return Tensor(self._js.exp()) - def sqrt(self): return Tensor(self._js.sqrt()) - def square(self): return Tensor(self._js.square()) - def sin(self): return Tensor(self._js.sin()) - def cos(self): return Tensor(self._js.cos()) - def tan(self): return Tensor(self._js.tan()) - def sigmoid(self): return Tensor(self._js.sigmoid()) - def relu(self): return Tensor(js_torch.nn.functional.relu(self._js)) - def softmax(self, dim): return Tensor(self._js.softmax(dim)) - def clamp(self, min, max): return Tensor(self._js.clamp(min, max)) - def sign(self): return Tensor(self._js.sign()) - def reciprocal(self): return Tensor(self._js.reciprocal()) - def nan_to_num(self): return Tensor(self._js.nan_to_num()) - - # ------------------------------------------------------------------ - # Comparison - # ------------------------------------------------------------------ - - def lt(self, other): return Tensor(self._js.lt(self._to_js(other))) - def gt(self, other): return Tensor(self._js.gt(self._to_js(other))) - def le(self, other): return Tensor(self._js.le(self._to_js(other))) - def ge(self, other): return Tensor(self._js.ge(self._to_js(other))) - def eq(self, other): return Tensor(self._js.eq(self._to_js(other))) - def ne(self, other): return Tensor(self._js.ne(self._to_js(other))) - - def allclose(self, other, rtol=1e-5, atol=1e-8, equal_nan=False): - return bool(js_torch.allclose(self._js, other._js, rtol, atol, equal_nan)) - - # ------------------------------------------------------------------ - # Type conversions - # ------------------------------------------------------------------ - - def __float__(self): return float(self.item()) - def __int__(self): return int(self.item()) - def __bool__(self): return bool(self.item()) - def __format__(self, fmt): return format(self.item(), fmt) - - # ------------------------------------------------------------------ - # Indexing - # ------------------------------------------------------------------ - - def __getitem__(self, key): - if isinstance(key, int): - return Tensor(self._js.index(key)) - if isinstance(key, tuple): - result = self._js - for k in key: - if isinstance(k, int): - result = result.index(k) - else: - raise NotImplementedError( - "Only integer indexing is supported in multi-dimensional indexing" - ) - return Tensor(result) - if isinstance(key, slice): - start, stop, step = key.indices(self.shape[0]) - data = [Tensor(self._js.index(i)).tolist() for i in range(start, stop, step)] - return Tensor(data) - raise TypeError(f"Invalid index type: {type(key).__name__}") - - # ------------------------------------------------------------------ - # Iteration and length - # ------------------------------------------------------------------ - - def __len__(self): - return self.shape[0] - - def __iter__(self): - data = self.tolist() - if not isinstance(data, list): - raise TypeError("iteration over a 0-d tensor") - for item in data: - yield Tensor(item) - - # ------------------------------------------------------------------ - # Catch-all: delegate unknown attribute accesses to the JS tensor. - # Returned JsProxy objects are wrapped in Tensor; primitives pass through. - # ------------------------------------------------------------------ - - def __getattr__(self, name): - if name.startswith('_'): - raise AttributeError(name) - def method(*args, **kwargs): - js_args = _transform_args(args) - return _wrap_result(self._js.__getattribute__(name)(*js_args)) - return method - - -# --------------------------------------------------------------------------- -# Typed tensor subclasses -# --------------------------------------------------------------------------- - -def _trunc_nested(data): - """Truncate all numbers in a nested list toward zero (for LongTensor).""" - if isinstance(data, (int, float)): - return int(data) # Python int() truncates toward zero - return [_trunc_nested(item) for item in data] - - -class FloatTensor(Tensor): - """ - A Tensor that stores floating-point values. - Equivalent to a regular Tensor; provided for PyTorch API compatibility. - """ - def __init__(self, data, requires_grad=False): - if isinstance(data, JsProxy): - super().__init__(data) - else: - super().__init__(data, requires_grad) - - -class LongTensor(Tensor): - """ - A Tensor whose values are truncated to integers (toward zero). - LongTensor([-1.7]) -> tensor([-1]), LongTensor([1.9]) -> tensor([1]). - """ - def __init__(self, data, requires_grad=False): - if isinstance(data, JsProxy): - super().__init__(data) - else: - truncated = _trunc_nested(data) if isinstance(data, (list, tuple)) else int(data) - super().__init__(truncated, requires_grad) - - -# --------------------------------------------------------------------------- -# no_grad context manager — actually disables grad in the JS engine -# --------------------------------------------------------------------------- - -class _NoGrad: - def __enter__(self): - self._prev = js_torch.enable_no_grad() - return self - - def __exit__(self, *args): - js_torch.disable_no_grad(self._prev) - - -# --------------------------------------------------------------------------- -# Parameter -# --------------------------------------------------------------------------- - -class Parameter(Tensor): - """A Tensor that is automatically registered as a parameter.""" - def __init__(self, data, requires_grad=True): - if isinstance(data, Tensor): - self._js = js_torch.nn.Parameter.new(data._js) - elif isinstance(data, JsProxy): - self._js = js_torch.nn.Parameter.new(data) - else: - self._js = js_torch.nn.Parameter.new(js_torch.tensor(data)) - if not requires_grad: - self._js.requires_grad = False - - -# --------------------------------------------------------------------------- -# Module — pure-Python base class for user-defined models -# --------------------------------------------------------------------------- - -class Module: - """ - Pure-Python nn.Module. Subclass this to build models using bridge Tensors. - Assign `Parameter` or `_NNModule` instances as attributes and they are - automatically tracked by `parameters()`. - """ - - def __init__(self): - object.__setattr__(self, '_parameters', {}) - object.__setattr__(self, '_modules', {}) - object.__setattr__(self, 'training', True) - - def __setattr__(self, name, value): - try: - params = object.__getattribute__(self, '_parameters') - modules = object.__getattribute__(self, '_modules') - except AttributeError: - object.__setattr__(self, name, value) - return - - if isinstance(value, Parameter): - params[name] = value - elif isinstance(value, (Module, _NNModule)): - modules[name] = value - object.__setattr__(self, name, value) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def forward(self, *args, **kwargs): - raise NotImplementedError - - def parameters(self): - params = list(object.__getattribute__(self, '_parameters').values()) - for mod in object.__getattribute__(self, '_modules').values(): - params.extend(mod.parameters()) - return params - - def named_parameters(self, prefix=''): - result = [] - for name, p in object.__getattribute__(self, '_parameters').items(): - full = f"{prefix}.{name}" if prefix else name - result.append((full, p)) - for mod_name, mod in object.__getattribute__(self, '_modules').items(): - full_mod = f"{prefix}.{mod_name}" if prefix else mod_name - result.extend(mod.named_parameters(full_mod)) - return result - - def train(self, mode=True): - object.__setattr__(self, 'training', mode) - for mod in object.__getattribute__(self, '_modules').values(): - mod.train(mode) - return self - - def eval(self): - return self.train(False) - - def zero_grad(self): - for p in self.parameters(): - p.grad = None - - -# --------------------------------------------------------------------------- -# _NNModule — wraps a JS nn.Module instance -# --------------------------------------------------------------------------- - -class _NNModule: - """Wraps a JS nn.Module returned by the nn factory functions.""" - - def __init__(self, js_module): - self._module = js_module - - def __call__(self, *args): - js_args = [a._js if isinstance(a, Tensor) else a for a in args] - return Tensor(self._module.call(*js_args)) - - def forward(self, *args): - js_args = [a._js if isinstance(a, Tensor) else a for a in args] - return Tensor(self._module.forward(*js_args)) - - def parameters(self): - return [Tensor(p) for p in self._module.parameters().to_py()] - - def named_parameters(self, prefix=''): - raw = self._module.named_parameters(prefix).to_py() - return [(pair[0], Tensor(pair[1])) for pair in raw] - - def train(self, mode=True): - self._module.train(mode) - return self - - def eval(self): - return self.train(False) - - def zero_grad(self): - for p in self.parameters(): - p.grad = None - - -# --------------------------------------------------------------------------- -# nn.functional -# --------------------------------------------------------------------------- - -class _NNFunctional: - def relu(self, input): - return Tensor(js_torch.nn.functional.relu(input._js)) - - def sigmoid(self, input): - return Tensor(js_torch.nn.functional.sigmoid(input._js)) - - def leaky_relu(self, input, negative_slope=0.01): - return Tensor(js_torch.nn.functional.leaky_relu(input._js, negative_slope)) - - def max_pool2d(self, input, kernel_size, stride=None, padding=0): - if stride is None: - return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size)) - return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size, stride, padding)) - - def nll_loss(self, input, target, reduction='mean'): - return Tensor(js_torch.nn.functional.nll_loss(input._js, target._js, reduction)) - - def __getattr__(self, name): - if name.startswith('_'): - raise AttributeError(name) - def fn(*args, **kwargs): - return _wrap_result(js_torch.nn.functional.__getattribute__(name)(*_transform_args(args))) - return fn - - -# --------------------------------------------------------------------------- -# nn.parameter namespace -# --------------------------------------------------------------------------- - -class _NNParameterNamespace: - def __init__(self): - self.Parameter = Parameter - - -# --------------------------------------------------------------------------- -# nn namespace -# --------------------------------------------------------------------------- - -class _NNNamespace: - def __init__(self): - self.functional = _NNFunctional() - self.parameter = _NNParameterNamespace() - self.Module = Module - self.Parameter = Parameter - - def Linear(self, in_features, out_features, bias=True): - return _NNModule(js_torch.nn.Linear.new(in_features, out_features, bias)) - - def ReLU(self): - return _NNModule(js_torch.nn.ReLU.new()) - - def Sigmoid(self): - return _NNModule(js_torch.nn.Sigmoid.new()) - - def Sequential(self, *modules): - js_mods = [m._module for m in modules] - return _NNModule(js_torch.nn.Sequential.new(*js_mods)) - - def MSELoss(self, reduction='mean'): - return _NNModule(js_torch.nn.MSELoss.new(reduction)) - - def L1Loss(self, reduction='mean'): - return _NNModule(js_torch.nn.L1Loss.new(reduction)) - - def BCELoss(self, weight=None, reduction='mean'): - js_weight = weight._js if isinstance(weight, Tensor) else None - return _NNModule(js_torch.nn.BCELoss.new(js_weight, reduction)) - - def CrossEntropyLoss(self, reduction='mean'): - return _NNModule(js_torch.nn.CrossEntropyLoss.new(reduction)) - - def Conv1d(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, bias=True): - return _NNModule(js_torch.nn.Conv1d.new( - in_channels, out_channels, kernel_size, - stride, padding, dilation, groups, bias - )) - - def Conv2d(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, bias=True): - return _NNModule(js_torch.nn.Conv2d.new( - in_channels, out_channels, kernel_size, - stride, padding, dilation, groups, bias - )) - - def Conv3d(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, bias=True): - return _NNModule(js_torch.nn.Conv3d.new( - in_channels, out_channels, kernel_size, - stride, padding, dilation, groups, bias - )) - - def LeakyReLU(self, negative_slope=0.01): - return _NNModule(js_torch.nn.LeakyReLU.new(negative_slope)) - - def MaxPool2d(self, kernel_size, stride=None, padding=0): - if stride is None: - return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size)) - return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size, stride, padding)) - - def Dropout(self, p=0.5): - return _NNModule(js_torch.nn.Dropout.new(p)) - - def Softmax(self, dim): - return _NNModule(js_torch.nn.Softmax.new(dim)) - - def Flatten(self, start_dim=1, end_dim=-1): - return _NNModule(js_torch.nn.Flatten.new(start_dim, end_dim)) - - def NLLLoss(self, reduction='mean'): - return _NNModule(js_torch.nn.NLLLoss.new(reduction)) - - -# --------------------------------------------------------------------------- -# optim wrappers -# --------------------------------------------------------------------------- - -class _Optimizer: - def __init__(self, js_optim): - self._optim = js_optim - - def step(self): - self._optim.step() - - def zero_grad(self): - self._optim.zero_grad() - - -class _OptimNamespace: - def SGD(self, params, lr=0.001, momentum=0.0, dampening=0.0, - weight_decay=0.0, nesterov=False, maximize=False): - js_params = to_js([p._js for p in params]) - return _Optimizer(js_torch.optim.SGD.new( - js_params, lr, momentum, dampening, weight_decay, nesterov, maximize - )) - - def Adam(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0.0, amsgrad=False, maximize=False): - js_params = to_js([p._js for p in params]) - js_betas = to_js(list(betas)) - return _Optimizer(js_torch.optim.Adam.new( - js_params, lr, js_betas, eps, weight_decay, amsgrad, maximize - )) - - def Adagrad(self, params, lr=0.01, lr_decay=0, weight_decay=0, eps=1e-10): - js_params = to_js([p._js for p in params]) - return _Optimizer(js_torch.optim.Adagrad.new( - js_params, lr, lr_decay, weight_decay, eps - )) - - -# --------------------------------------------------------------------------- -# torch namespace -# --------------------------------------------------------------------------- - -class _Torch: - def __init__(self): - self.nn = _NNNamespace() - self.optim = _OptimNamespace() - self.no_grad = _NoGrad - self.Tensor = Tensor - self.FloatTensor = FloatTensor - self.LongTensor = LongTensor - - @property - def tensor(self): - return Tensor - - # --- creation functions --- - - def _shape_from_args(self, args): - return list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) - - def zeros(self, *args, **kwargs): - return Tensor(js_torch.zeros(to_js(self._shape_from_args(args)))) - - def ones(self, *args, **kwargs): - return Tensor(js_torch.ones(to_js(self._shape_from_args(args)))) - - def zeros_like(self, input): - return Tensor(js_torch.zeros_like(input._js)) - - def ones_like(self, input): - return Tensor(js_torch.ones_like(input._js)) - - def randn(self, *args, **kwargs): - return Tensor(js_torch.randn(to_js(self._shape_from_args(args)))) - - def rand(self, *args, **kwargs): - return Tensor(js_torch.rand(to_js(self._shape_from_args(args)))) - - def arange(self, start, end=None, step=1): - if end is None: - end = start - start = 0 - return Tensor(js_torch.arange(start, end, step)) - - def linspace(self, start, end, steps): - return Tensor(js_torch.linspace(start, end, steps)) - - def empty(self, *args, **kwargs): - return Tensor(js_torch.empty(to_js(self._shape_from_args(args)))) - - def empty_like(self, input): - return Tensor(js_torch.empty_like(input._js)) - - def full(self, shape, fill_value): - return Tensor(js_torch.full(to_js(list(shape)), fill_value)) - - def full_like(self, input, fill_value): - return Tensor(js_torch.full_like(input._js, fill_value)) - - def rand_like(self, input): - return Tensor(js_torch.rand_like(input._js)) - - def randn_like(self, input): - return Tensor(js_torch.randn_like(input._js)) - - def randint_like(self, input, low, high): - return Tensor(js_torch.randint_like(input._js, low, high)) - - # --- utility functions --- - - def is_tensor(self, obj): - return isinstance(obj, Tensor) - - def is_nonzero(self, input): - if input.numel() != 1: - raise RuntimeError( - "Boolean value of Tensor with more than one element is ambiguous" - ) - return bool(input.item() != 0) - - def numel(self, input): - return input.numel() - - # --- functional wrappers --- - - def sum(self, input, dim=None, keepdim=False): - return input.sum(dim, keepdim) - - def mean(self, input, dim=None, keepdim=False): - return input.mean(dim, keepdim) - - def sigmoid(self, input): - return input.sigmoid() - - def relu(self, input): - return input.relu() - - def softmax(self, input, dim): - return input.softmax(dim) - - def clamp(self, input, min, max): - return input.clamp(min, max) - - def clip(self, input, min, max): - return self.clamp(input, min, max) - - def flatten(self, input, start_dim=0, end_dim=-1): - return input.flatten(start_dim, end_dim) - - def allclose(self, a, b, rtol=1e-5, atol=1e-8, equal_nan=False): - return a.allclose(b, rtol, atol, equal_nan) - - def is_grad_enabled(self): - return bool(js_torch.is_grad_enabled()) - - def cat(self, tensors, dim=0): - if isinstance(tensors, Tensor): - tensors = [tensors] - return Tensor(js_torch.cat(to_js([t._js for t in tensors]), dim)) - - def concatenate(self, tensors, dim=0): - return self.cat(tensors, dim) - - def concat(self, tensors, dim=0): - return self.cat(tensors, dim) - - def stack(self, tensors, dim=0): - return Tensor(js_torch.stack(to_js([t._js for t in tensors]), dim)) - - def Size(self, shape): - return list(shape) - - def __getattr__(self, name): - if name.startswith('_'): - raise AttributeError(name) - def fn(*args, **kwargs): - return _wrap_result(js_torch.__getattribute__(name)(*_transform_args(args))) - return fn - - -torch = _Torch() diff --git a/examples/pyodide/bridge.py b/examples/pyodide/bridge.py new file mode 120000 index 0000000..deb96f4 --- /dev/null +++ b/examples/pyodide/bridge.py @@ -0,0 +1 @@ +../../pyodide_bridge.py \ No newline at end of file diff --git a/examples/pyodide/index.html b/examples/pyodide/index.html index f3ea57b..a0feac7 100644 --- a/examples/pyodide/index.html +++ b/examples/pyodide/index.html @@ -19,6 +19,7 @@

Torch + Pyodide

+ @@ -54,6 +55,8 @@

Output:

throw new Error("torch.js failed to load. Run `yarn build` first."); } + await pyodide.loadPackage('numpy'); + // Make torch available to Python pyodide.globals.set("js_torch", window.torch); diff --git a/examples/pyodide/main.js b/examples/pyodide/main.js index ab17d47..9bb81c5 100644 --- a/examples/pyodide/main.js +++ b/examples/pyodide/main.js @@ -1,5 +1,5 @@ import { loadPyodide } from 'pyodide'; -import * as torch from 'torch'; +import * as torch from '@sourceacademy/torch'; import { readFileSync, writeFileSync, mkdirSync, readdirSync, existsSync } from 'fs'; import { execSync } from 'child_process'; import path from 'path'; @@ -11,8 +11,12 @@ function getPyFiles() { return readdirSync(PY_DIR).filter(f => f.endsWith('.py')).sort(); } +const PYODIDE_PACKAGE_CACHE_DIR = './.pyodide-packages'; + async function setupPyodide() { - const pyodide = await loadPyodide(); + if (!existsSync(PYODIDE_PACKAGE_CACHE_DIR)) mkdirSync(PYODIDE_PACKAGE_CACHE_DIR); + const pyodide = await loadPyodide({ packageCacheDir: PYODIDE_PACKAGE_CACHE_DIR }); + await pyodide.loadPackage('numpy'); pyodide.globals.set('js_torch', torch); pyodide.globals.set('_js_is_null', (x) => x == null); pyodide.runPython(readFileSync('./bridge.py', 'utf8')); diff --git a/examples/pyodide/package.json b/examples/pyodide/package.json index c7b84a7..af748cc 100644 --- a/examples/pyodide/package.json +++ b/examples/pyodide/package.json @@ -8,7 +8,7 @@ "test-pyodide": "node main.js test" }, "dependencies": { - "pyodide": "^0.29.3", - "torch": "portal:../../" + "@sourceacademy/torch": "portal:../../", + "pyodide": "^0.29.3" } } diff --git a/examples/pyodide/py/numpy_interop.py b/examples/pyodide/py/numpy_interop.py new file mode 100644 index 0000000..2fe7ac0 --- /dev/null +++ b/examples/pyodide/py/numpy_interop.py @@ -0,0 +1,60 @@ +# numpy_interop.py — Verify torch.from_numpy() and Tensor.numpy() interop. + +import numpy as np + +print("=== from_numpy: 1-D array ===") +arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) +t = torch.from_numpy(arr) +print("numpy array:", arr.tolist()) +print("tensor:", t) +print("> shape correct:", t.shape == (3,)) +print("> values correct:", t.allclose(torch.tensor([1.0, 2.0, 3.0]))) + +print("\n=== from_numpy: 2-D array ===") +arr2d = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) +t2d = torch.from_numpy(arr2d) +print("numpy array:", arr2d.tolist()) +print("tensor:", t2d) +print("> shape correct:", t2d.shape == (2, 2)) +print("> values correct:", t2d.allclose(torch.tensor([[1.0, 2.0], [3.0, 4.0]]))) + +print("\n=== numpy(): tensor -> ndarray ===") +t3 = torch.tensor([10.0, 20.0, 30.0]) +a3 = t3.numpy() +print("tensor:", t3) +print("numpy type:", type(a3).__name__) +print("numpy values:", a3.tolist()) +print("> type is ndarray:", isinstance(a3, np.ndarray)) +print("> values correct:", a3.tolist() == [10.0, 20.0, 30.0]) + +print("\n=== numpy(): 2-D tensor ===") +t4 = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) +a4 = t4.numpy() +print("numpy shape:", list(a4.shape)) +print("numpy values:", a4.tolist()) +print("> shape correct:", list(a4.shape) == [2, 2]) +print("> values correct:", a4.tolist() == [[1.0, 2.0], [3.0, 4.0]]) + +print("\n=== numpy(): requires grad error ===") +t5 = torch.tensor([1.0, 2.0], requires_grad=True) +try: + t5.numpy() + print("> ERROR: should have raised RuntimeError") +except RuntimeError as e: + print("Got expected RuntimeError:", str(e)[:50]) + print("> requires_grad error raised correctly:", True) + +print("\n=== numpy() after detach ===") +t6 = torch.tensor([5.0, 6.0], requires_grad=True) +a6 = t6.detach().numpy() +print("values:", a6.tolist()) +print("> detach().numpy() works:", a6.tolist() == [5.0, 6.0]) + +print("\n=== round-trip: numpy -> tensor -> numpy ===") +original = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) +roundtrip = torch.from_numpy(original).numpy() +print("original:", original.tolist()) +print("round-trip:", roundtrip.tolist()) +print("> round-trip correct:", original.tolist() == roundtrip.tolist()) + +print("\nAll numpy_interop checks passed.") diff --git a/examples/pyodide/yarn.lock b/examples/pyodide/yarn.lock index 0c93a3a..d053097 100644 --- a/examples/pyodide/yarn.lock +++ b/examples/pyodide/yarn.lock @@ -5,6 +5,12 @@ __metadata: version: 8 cacheKey: 10c0 +"@sourceacademy/torch@portal:../../::locator=example-pyodide%40workspace%3A.": + version: 0.0.0-use.local + resolution: "@sourceacademy/torch@portal:../../::locator=example-pyodide%40workspace%3A." + languageName: node + linkType: soft + "@types/emscripten@npm:^1.41.4": version: 1.41.5 resolution: "@types/emscripten@npm:1.41.5" @@ -16,8 +22,8 @@ __metadata: version: 0.0.0-use.local resolution: "example-pyodide@workspace:." dependencies: + "@sourceacademy/torch": "portal:../../" pyodide: "npm:^0.29.3" - torch: "portal:../../" languageName: unknown linkType: soft @@ -31,12 +37,6 @@ __metadata: languageName: node linkType: hard -"torch@portal:../../::locator=example-pyodide%40workspace%3A.": - version: 0.0.0-use.local - resolution: "torch@portal:../../::locator=example-pyodide%40workspace%3A." - languageName: node - linkType: soft - "ws@npm:^8.5.0": version: 8.19.0 resolution: "ws@npm:8.19.0" diff --git a/package.json b/package.json index 469ce72..7d12bd8 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "type": "module", - "name": "torch", + "name": "@sourceacademy/torch", "packageManager": "yarn@4.10.3", "version": "0.1.0", "description": "machine-learning libraries for Source Academy", @@ -12,7 +12,7 @@ "files": [ "build", "src", - "examples" + "pyodide_bridge.py" ], "main": "./build/node/torch.node.cjs", "module": "./build/node/torch.node.es.mjs", diff --git a/pyodide_bridge.py b/pyodide_bridge.py new file mode 100644 index 0000000..64a11c3 --- /dev/null +++ b/pyodide_bridge.py @@ -0,0 +1,828 @@ +# bridge.py +# Provides a PyTorch-compatible Python API over js_torch (the TypeScript torch library). +# +# Before loading this file, set the following globals in Pyodide: +# js_torch - the torch module (window.torch from the UMD build) + +from pyodide.ffi import JsProxy, to_js + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _wrap_result(result): + """ + Wrap a JS return value: + - JsProxy (JS object/Tensor) -> Python Tensor + - Python primitive (int, float, bool) -> return as-is + JS primitives are automatically converted to Python by Pyodide, + so they will NOT be JsProxy instances. + """ + if isinstance(result, JsProxy): + return Tensor(result) + return result + + +def _transform(obj): + """Convert Python objects to JS-compatible types before passing to JS.""" + if isinstance(obj, Tensor): + return obj._js + if isinstance(obj, (list, tuple)): + return to_js([_transform(item) for item in obj]) + return obj + + +def _transform_args(args): + return [_transform(a) for a in args] + + +# --------------------------------------------------------------------------- +# Tensor +# --------------------------------------------------------------------------- + +class Tensor: + """Python wrapper around a JS Tensor, mirroring the PyTorch Tensor API.""" + + # ------------------------------------------------------------------ + # Construction + # ------------------------------------------------------------------ + + def __new__(cls, data, requires_grad=False): + # Return None for missing tensors so e.g. `tensor.grad` returns None + # when there is no gradient — matching PyTorch behaviour. + # Pyodide may represent JS null as a special JsNull type (not JsProxy, not None). + if data is None or type(data).__name__ in ('JsNull', 'JsUndefined'): + return None + return super().__new__(cls) + + def __init__(self, data, requires_grad=False): + if isinstance(data, JsProxy): + self._js = data + else: + js_data = to_js(data) if isinstance(data, (list, tuple)) else data + self._js = js_torch.tensor(js_data, requires_grad) + + # ------------------------------------------------------------------ + # Representation + # ------------------------------------------------------------------ + + def __repr__(self): + extra = ", requires_grad=True" if self.requires_grad else "" + return f"tensor({self.tolist()}{extra})" + + # ------------------------------------------------------------------ + # Data access + # ------------------------------------------------------------------ + + def tolist(self): + """Return tensor data as a (nested) Python list, or a Python scalar for 0-d tensors.""" + result = self._js.toArray() + if isinstance(result, JsProxy): + return result.to_py() + return result # scalar + + def item(self): + return self._js.item() + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def shape(self): + return tuple(self._js.shape.to_py()) + + @property + def data(self): + """Detached view of the tensor data (no gradient).""" + return self.detach() + + @property + def requires_grad(self): + return bool(self._js.requires_grad) + + @requires_grad.setter + def requires_grad(self, value): + self._js.requires_grad = value + + @property + def grad(self): + raw = self._js.grad + if raw is None or type(raw).__name__ in ('JsNull', 'JsUndefined'): + return None + return Tensor(raw) + + @grad.setter + def grad(self, value): + self._js.grad = value._js if isinstance(value, Tensor) else None + + @property + def T(self): + if len(self.shape) < 2: + return self + return Tensor(self._js.transpose(0, 1)) + + # ------------------------------------------------------------------ + # Grad utilities + # ------------------------------------------------------------------ + + def backward(self, gradient=None): + if gradient is None: + self._js.backward() + else: + self._js.backward(gradient._js) + + def detach(self): + return Tensor(self._js.detach()) + + def zero_(self): + self._js.zero_() + return self + + def retain_grad(self): + self._js.retain_grad() + + # ------------------------------------------------------------------ + # Shape utilities + # ------------------------------------------------------------------ + + def size(self, dim=None): + s = self.shape + return s if dim is None else s[dim] + + def dim(self): + return len(self.shape) + + def numel(self): + n = 1 + for s in self.shape: + n *= s + return n + + def reshape(self, *args): + shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) + return Tensor(self._js.reshape(to_js(shape))) + + def view(self, *args): + return self.reshape(*args) + + def squeeze(self, dim=None): + if dim is None: + new_shape = [s for s in self.shape if s != 1] + return Tensor(self._js.reshape(to_js(new_shape))) + return Tensor(self._js.squeeze(dim)) + + def unsqueeze(self, dim): + return Tensor(self._js.unsqueeze(dim)) + + def expand(self, *args): + shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) + return Tensor(self._js.expand(to_js(shape))) + + def transpose(self, dim0, dim1): + return Tensor(self._js.transpose(dim0, dim1)) + + def flatten(self, start_dim=0, end_dim=-1): + return Tensor(self._js.flatten(start_dim, end_dim)) + + # ------------------------------------------------------------------ + # Reductions — default (no dim) sums all elements, matching PyTorch + # ------------------------------------------------------------------ + + def sum(self, dim=None, keepdim=False): + return Tensor(self._js.sum() if dim is None else self._js.sum(dim, keepdim)) + + def mean(self, dim=None, keepdim=False): + return Tensor(self._js.mean() if dim is None else self._js.mean(dim, keepdim)) + + def max(self, dim=None, keepdim=False): + return Tensor(self._js.max() if dim is None else self._js.max(dim, keepdim)) + + def min(self, dim=None, keepdim=False): + return Tensor(self._js.min() if dim is None else self._js.min(dim, keepdim)) + + # ------------------------------------------------------------------ + # Arithmetic — explicit methods + # ------------------------------------------------------------------ + + def _to_js(self, other): + return other._js if isinstance(other, Tensor) else other + + def add(self, other): return Tensor(self._js.add(self._to_js(other))) + def sub(self, other): return Tensor(self._js.sub(self._to_js(other))) + def mul(self, other): return Tensor(self._js.mul(self._to_js(other))) + def div(self, other): return Tensor(self._js.div(self._to_js(other))) + def pow(self, other): return Tensor(self._js.pow(self._to_js(other))) + def matmul(self, other): return Tensor(self._js.matmul(self._to_js(other))) + + # ------------------------------------------------------------------ + # Arithmetic operators + # ------------------------------------------------------------------ + + def __add__(self, other): return self.add(other) + def __radd__(self, other): return self.add(other) # add is commutative + def __sub__(self, other): return self.sub(other) + def __rsub__(self, other): + o = other if isinstance(other, Tensor) else Tensor(other) + return o.sub(self) + def __mul__(self, other): return self.mul(other) + def __rmul__(self, other): return self.mul(other) # mul is commutative + def __truediv__(self, other): return self.div(other) + def __rtruediv__(self, other): + o = other if isinstance(other, Tensor) else Tensor(other) + return o.div(self) + def __pow__(self, other): return self.pow(other) + def __rpow__(self, other): + o = other if isinstance(other, Tensor) else Tensor(other) + return o.pow(self) + def __matmul__(self, other): return self.matmul(other) + def __neg__(self): return Tensor(self._js.neg()) + def __abs__(self): return Tensor(self._js.abs()) + + # ------------------------------------------------------------------ + # Unary operations + # ------------------------------------------------------------------ + + def neg(self): return Tensor(self._js.neg()) + def abs(self): return Tensor(self._js.abs()) + def log(self): return Tensor(self._js.log()) + def exp(self): return Tensor(self._js.exp()) + def sqrt(self): return Tensor(self._js.sqrt()) + def square(self): return Tensor(self._js.square()) + def sin(self): return Tensor(self._js.sin()) + def cos(self): return Tensor(self._js.cos()) + def tan(self): return Tensor(self._js.tan()) + def sigmoid(self): return Tensor(self._js.sigmoid()) + def relu(self): return Tensor(js_torch.nn.functional.relu(self._js)) + def softmax(self, dim): return Tensor(self._js.softmax(dim)) + def clamp(self, min, max): return Tensor(self._js.clamp(min, max)) + def sign(self): return Tensor(self._js.sign()) + def reciprocal(self): return Tensor(self._js.reciprocal()) + def nan_to_num(self): return Tensor(self._js.nan_to_num()) + + # ------------------------------------------------------------------ + # Comparison + # ------------------------------------------------------------------ + + def lt(self, other): return Tensor(self._js.lt(self._to_js(other))) + def gt(self, other): return Tensor(self._js.gt(self._to_js(other))) + def le(self, other): return Tensor(self._js.le(self._to_js(other))) + def ge(self, other): return Tensor(self._js.ge(self._to_js(other))) + def eq(self, other): return Tensor(self._js.eq(self._to_js(other))) + def ne(self, other): return Tensor(self._js.ne(self._to_js(other))) + + def allclose(self, other, rtol=1e-5, atol=1e-8, equal_nan=False): + return bool(js_torch.allclose(self._js, other._js, rtol, atol, equal_nan)) + + # ------------------------------------------------------------------ + # NumPy interop + # ------------------------------------------------------------------ + + def numpy(self): + import numpy as np + if self.requires_grad: + raise RuntimeError( + "Can't call numpy() on Tensor that requires grad. " + "Use tensor.detach().numpy() instead." + ) + return np.array(self.tolist()) + + # ------------------------------------------------------------------ + # Type conversions + # ------------------------------------------------------------------ + + def __float__(self): return float(self.item()) + def __int__(self): return int(self.item()) + def __bool__(self): return bool(self.item()) + def __format__(self, fmt): return format(self.item(), fmt) + + # ------------------------------------------------------------------ + # Indexing + # ------------------------------------------------------------------ + + def __getitem__(self, key): + if isinstance(key, int): + return Tensor(self._js.index(key)) + if isinstance(key, tuple): + result = self._js + for k in key: + if isinstance(k, int): + result = result.index(k) + else: + raise NotImplementedError( + "Only integer indexing is supported in multi-dimensional indexing" + ) + return Tensor(result) + if isinstance(key, slice): + raise NotImplementedError( + "Slice indexing is not implemented; converting through Python lists " + "would return a detached copy and break tensor/autograd semantics" + ) + raise TypeError(f"Invalid index type: {type(key).__name__}") + + # ------------------------------------------------------------------ + # Iteration and length + # ------------------------------------------------------------------ + + def __len__(self): + if self.dim() == 0: + raise TypeError("len() of a 0-d tensor") + return self.shape[0] + + def __iter__(self): + if self.dim() == 0: + raise TypeError("iteration over a 0-d tensor") + for i in range(self.shape[0]): + yield self[i] + + # ------------------------------------------------------------------ + # Catch-all: delegate unknown attribute accesses to the JS tensor. + # Returned JsProxy objects are wrapped in Tensor; primitives pass through. + # ------------------------------------------------------------------ + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + def method(*args, **kwargs): + if kwargs: + raise TypeError( + f"{name}() does not support keyword arguments in this bridge; " + f"got unexpected keyword argument(s): {', '.join(sorted(kwargs.keys()))}" + ) + js_args = _transform_args(args) + return _wrap_result(self._js.__getattribute__(name)(*js_args)) + return method + + +# --------------------------------------------------------------------------- +# Typed tensor subclasses +# --------------------------------------------------------------------------- + +def _trunc_nested(data): + """Truncate all numbers in a nested list toward zero (for LongTensor).""" + if isinstance(data, (int, float)): + return int(data) # Python int() truncates toward zero + return [_trunc_nested(item) for item in data] + + +class FloatTensor(Tensor): + """ + A Tensor that stores floating-point values. + Equivalent to a regular Tensor; provided for PyTorch API compatibility. + """ + def __init__(self, data, requires_grad=False): + if isinstance(data, JsProxy): + super().__init__(data) + else: + super().__init__(data, requires_grad) + + +class LongTensor(Tensor): + """ + A Tensor whose values are truncated to integers (toward zero). + LongTensor([-1.7]) -> tensor([-1]), LongTensor([1.9]) -> tensor([1]). + """ + def __init__(self, data, requires_grad=False): + if isinstance(data, JsProxy): + super().__init__(data) + else: + truncated = _trunc_nested(data) if isinstance(data, (list, tuple)) else int(data) + super().__init__(truncated, requires_grad) + + +# --------------------------------------------------------------------------- +# no_grad context manager — actually disables grad in the JS engine +# --------------------------------------------------------------------------- + +class _NoGrad: + def __enter__(self): + self._prev = js_torch.enable_no_grad() + return self + + def __exit__(self, *args): + js_torch.disable_no_grad(self._prev) + + +# --------------------------------------------------------------------------- +# Parameter +# --------------------------------------------------------------------------- + +class Parameter(Tensor): + """A Tensor that is automatically registered as a parameter.""" + def __init__(self, data, requires_grad=True): + if isinstance(data, Tensor): + self._js = js_torch.nn.Parameter.new(data._js) + elif isinstance(data, JsProxy): + self._js = js_torch.nn.Parameter.new(data) + else: + self._js = js_torch.nn.Parameter.new(js_torch.tensor(data)) + if not requires_grad: + self._js.requires_grad = False + + +# --------------------------------------------------------------------------- +# Module — pure-Python base class for user-defined models +# --------------------------------------------------------------------------- + +class Module: + """ + Pure-Python nn.Module. Subclass this to build models using bridge Tensors. + Assign `Parameter` or `_NNModule` instances as attributes and they are + automatically tracked by `parameters()`. + """ + + def __init__(self): + object.__setattr__(self, '_parameters', {}) + object.__setattr__(self, '_modules', {}) + object.__setattr__(self, 'training', True) + + def __setattr__(self, name, value): + try: + params = object.__getattribute__(self, '_parameters') + modules = object.__getattribute__(self, '_modules') + except AttributeError: + object.__setattr__(self, name, value) + return + + if isinstance(value, Parameter): + params[name] = value + elif isinstance(value, (Module, _NNModule)): + modules[name] = value + object.__setattr__(self, name, value) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def parameters(self): + params = list(object.__getattribute__(self, '_parameters').values()) + for mod in object.__getattribute__(self, '_modules').values(): + params.extend(mod.parameters()) + return params + + def named_parameters(self, prefix=''): + result = [] + for name, p in object.__getattribute__(self, '_parameters').items(): + full = f"{prefix}.{name}" if prefix else name + result.append((full, p)) + for mod_name, mod in object.__getattribute__(self, '_modules').items(): + full_mod = f"{prefix}.{mod_name}" if prefix else mod_name + result.extend(mod.named_parameters(full_mod)) + return result + + def train(self, mode=True): + object.__setattr__(self, 'training', mode) + for mod in object.__getattribute__(self, '_modules').values(): + mod.train(mode) + return self + + def eval(self): + return self.train(False) + + def zero_grad(self): + for p in self.parameters(): + p.grad = None + + +# --------------------------------------------------------------------------- +# _NNModule — wraps a JS nn.Module instance +# --------------------------------------------------------------------------- + +class _NNModule: + """Wraps a JS nn.Module returned by the nn factory functions.""" + + def __init__(self, js_module): + self._module = js_module + + def __call__(self, *args): + js_args = [a._js if isinstance(a, Tensor) else a for a in args] + return Tensor(self._module.call(*js_args)) + + def forward(self, *args): + js_args = [a._js if isinstance(a, Tensor) else a for a in args] + return Tensor(self._module.forward(*js_args)) + + def parameters(self): + return [Tensor(p) for p in self._module.parameters().to_py()] + + def named_parameters(self, prefix=''): + raw = self._module.named_parameters(prefix).to_py() + return [(pair[0], Tensor(pair[1])) for pair in raw] + + def train(self, mode=True): + self._module.train(mode) + return self + + def eval(self): + return self.train(False) + + def zero_grad(self): + for p in self.parameters(): + p.grad = None + + +# --------------------------------------------------------------------------- +# nn.functional +# --------------------------------------------------------------------------- + +class _NNFunctional: + def relu(self, input): + return Tensor(js_torch.nn.functional.relu(input._js)) + + def sigmoid(self, input): + return Tensor(js_torch.nn.functional.sigmoid(input._js)) + + def leaky_relu(self, input, negative_slope=0.01): + return Tensor(js_torch.nn.functional.leaky_relu(input._js, negative_slope)) + + def max_pool2d(self, input, kernel_size, stride=None, padding=0): + if stride is None: + return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size)) + return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size, stride, padding)) + + def nll_loss(self, input, target, reduction='mean'): + return Tensor(js_torch.nn.functional.nll_loss(input._js, target._js, reduction)) + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + def fn(*args, **kwargs): + return _wrap_result(js_torch.nn.functional.__getattribute__(name)(*_transform_args(args))) + return fn + + +# --------------------------------------------------------------------------- +# nn.parameter namespace +# --------------------------------------------------------------------------- + +class _NNParameterNamespace: + def __init__(self): + self.Parameter = Parameter + + +# --------------------------------------------------------------------------- +# nn namespace +# --------------------------------------------------------------------------- + +class _NNNamespace: + def __init__(self): + self.functional = _NNFunctional() + self.parameter = _NNParameterNamespace() + self.Module = Module + self.Parameter = Parameter + + def Linear(self, in_features, out_features, bias=True): + return _NNModule(js_torch.nn.Linear.new(in_features, out_features, bias)) + + def ReLU(self): + return _NNModule(js_torch.nn.ReLU.new()) + + def Sigmoid(self): + return _NNModule(js_torch.nn.Sigmoid.new()) + + def Sequential(self, *modules): + js_mods = [m._module for m in modules] + return _NNModule(js_torch.nn.Sequential.new(*js_mods)) + + def MSELoss(self, reduction='mean'): + return _NNModule(js_torch.nn.MSELoss.new(reduction)) + + def L1Loss(self, reduction='mean'): + return _NNModule(js_torch.nn.L1Loss.new(reduction)) + + def BCELoss(self, weight=None, reduction='mean'): + js_weight = weight._js if isinstance(weight, Tensor) else None + return _NNModule(js_torch.nn.BCELoss.new(js_weight, reduction)) + + def CrossEntropyLoss(self, reduction='mean'): + return _NNModule(js_torch.nn.CrossEntropyLoss.new(reduction)) + + def Conv1d(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + return _NNModule(js_torch.nn.Conv1d.new( + in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias + )) + + def Conv2d(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + return _NNModule(js_torch.nn.Conv2d.new( + in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias + )) + + def Conv3d(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + return _NNModule(js_torch.nn.Conv3d.new( + in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias + )) + + def LeakyReLU(self, negative_slope=0.01): + return _NNModule(js_torch.nn.LeakyReLU.new(negative_slope)) + + def MaxPool2d(self, kernel_size, stride=None, padding=0): + if stride is None: + return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size)) + return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size, stride, padding)) + + def Dropout(self, p=0.5): + return _NNModule(js_torch.nn.Dropout.new(p)) + + def Softmax(self, dim): + return _NNModule(js_torch.nn.Softmax.new(dim)) + + def Flatten(self, start_dim=1, end_dim=-1): + return _NNModule(js_torch.nn.Flatten.new(start_dim, end_dim)) + + def NLLLoss(self, reduction='mean'): + return _NNModule(js_torch.nn.NLLLoss.new(reduction)) + + +# --------------------------------------------------------------------------- +# optim wrappers +# --------------------------------------------------------------------------- + +class _Optimizer: + def __init__(self, js_optim): + self._optim = js_optim + + def step(self): + self._optim.step() + + def zero_grad(self): + self._optim.zero_grad() + + +class _OptimNamespace: + def SGD(self, params, lr=0.001, momentum=0.0, dampening=0.0, + weight_decay=0.0, nesterov=False, maximize=False): + js_params = to_js([p._js for p in params]) + return _Optimizer(js_torch.optim.SGD.new( + js_params, lr, momentum, dampening, weight_decay, nesterov, maximize + )) + + def Adam(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0.0, amsgrad=False, maximize=False): + js_params = to_js([p._js for p in params]) + js_betas = to_js(list(betas)) + return _Optimizer(js_torch.optim.Adam.new( + js_params, lr, js_betas, eps, weight_decay, amsgrad, maximize + )) + + def Adagrad(self, params, lr=0.01, lr_decay=0, weight_decay=0, eps=1e-10): + js_params = to_js([p._js for p in params]) + return _Optimizer(js_torch.optim.Adagrad.new( + js_params, lr, lr_decay, weight_decay, eps + )) + + +# --------------------------------------------------------------------------- +# torch namespace +# --------------------------------------------------------------------------- + +class _Torch: + def __init__(self): + self.nn = _NNNamespace() + self.optim = _OptimNamespace() + self.no_grad = _NoGrad + self.Tensor = Tensor + self.FloatTensor = FloatTensor + self.LongTensor = LongTensor + + @property + def tensor(self): + return Tensor + + # --- creation functions --- + + def _shape_from_args(self, args): + return list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args) + + def zeros(self, *args, **kwargs): + return Tensor(js_torch.zeros(to_js(self._shape_from_args(args)))) + + def ones(self, *args, **kwargs): + return Tensor(js_torch.ones(to_js(self._shape_from_args(args)))) + + def zeros_like(self, input): + return Tensor(js_torch.zeros_like(input._js)) + + def ones_like(self, input): + return Tensor(js_torch.ones_like(input._js)) + + def randn(self, *args, **kwargs): + return Tensor(js_torch.randn(to_js(self._shape_from_args(args)))) + + def rand(self, *args, **kwargs): + return Tensor(js_torch.rand(to_js(self._shape_from_args(args)))) + + def arange(self, start, end=None, step=1): + if end is None: + end = start + start = 0 + return Tensor(js_torch.arange(start, end, step)) + + def linspace(self, start, end, steps): + return Tensor(js_torch.linspace(start, end, steps)) + + def empty(self, *args, **kwargs): + return Tensor(js_torch.empty(to_js(self._shape_from_args(args)))) + + def empty_like(self, input): + return Tensor(js_torch.empty_like(input._js)) + + def full(self, shape, fill_value): + return Tensor(js_torch.full(to_js(list(shape)), fill_value)) + + def full_like(self, input, fill_value): + return Tensor(js_torch.full_like(input._js, fill_value)) + + def rand_like(self, input): + return Tensor(js_torch.rand_like(input._js)) + + def randn_like(self, input): + return Tensor(js_torch.randn_like(input._js)) + + def randint_like(self, input, low, high): + return Tensor(js_torch.randint_like(input._js, low, high)) + + # --- utility functions --- + + def is_tensor(self, obj): + return isinstance(obj, Tensor) + + def from_numpy(self, array): + return Tensor(array.tolist()) + + def is_nonzero(self, input): + if input.numel() != 1: + raise RuntimeError( + "Boolean value of Tensor with more than one element is ambiguous" + ) + return bool(input.item() != 0) + + def numel(self, input): + return input.numel() + + # --- functional wrappers --- + + def sum(self, input, dim=None, keepdim=False): + return input.sum(dim, keepdim) + + def mean(self, input, dim=None, keepdim=False): + return input.mean(dim, keepdim) + + def sigmoid(self, input): + return input.sigmoid() + + def relu(self, input): + return input.relu() + + def softmax(self, input, dim): + return input.softmax(dim) + + def clamp(self, input, min, max): + return input.clamp(min, max) + + def clip(self, input, min, max): + return self.clamp(input, min, max) + + def flatten(self, input, start_dim=0, end_dim=-1): + return input.flatten(start_dim, end_dim) + + def allclose(self, a, b, rtol=1e-5, atol=1e-8, equal_nan=False): + return a.allclose(b, rtol, atol, equal_nan) + + def is_grad_enabled(self): + return bool(js_torch.is_grad_enabled()) + + def cat(self, tensors, dim=0): + if isinstance(tensors, Tensor): + tensors = [tensors] + return Tensor(js_torch.cat(to_js([t._js for t in tensors]), dim)) + + def concatenate(self, tensors, dim=0): + return self.cat(tensors, dim) + + def concat(self, tensors, dim=0): + return self.cat(tensors, dim) + + def stack(self, tensors, dim=0): + return Tensor(js_torch.stack(to_js([t._js for t in tensors]), dim)) + + def Size(self, shape): + return list(shape) + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + def fn(*args, **kwargs): + return _wrap_result(js_torch.__getattribute__(name)(*_transform_args(args))) + return fn + + +torch = _Torch() diff --git a/test/backward.test.js b/test/backward.test.js index d66171b..a195c32 100644 --- a/test/backward.test.js +++ b/test/backward.test.js @@ -1,5 +1,5 @@ import { assert } from 'chai'; -import { Tensor } from 'torch'; +import { Tensor } from '@sourceacademy/torch'; describe('Autograd', () => { it('y = x**2, dy/dx at x=2 is 4', () => { diff --git a/test/broadcast.test.js b/test/broadcast.test.js index 524a2a8..104c4f7 100644 --- a/test/broadcast.test.js +++ b/test/broadcast.test.js @@ -1,5 +1,5 @@ import { assert } from 'chai'; -import { Tensor, add, __left_index__, __right_index__ } from 'torch'; +import { Tensor, add, __left_index__, __right_index__ } from '@sourceacademy/torch'; describe('Broadcast Index', () => { it('Broadcast indices should be correct in 2d array', () => { diff --git a/test/creation.test.ts b/test/creation.test.ts index 3641f9f..fc2f55a 100644 --- a/test/creation.test.ts +++ b/test/creation.test.ts @@ -1,6 +1,6 @@ import { assert } from 'chai'; -import * as torch from 'torch'; -import { Tensor } from 'torch'; +import * as torch from '@sourceacademy/torch'; +import { Tensor } from '@sourceacademy/torch'; describe('Creation Functions', () => { describe('tensor', () => { diff --git a/test/custom_operations.test.js b/test/custom_operations.test.js index 1277394..9deccb0 100644 --- a/test/custom_operations.test.js +++ b/test/custom_operations.test.js @@ -1,5 +1,5 @@ import { assert } from 'chai'; -import { Tensor } from 'torch'; +import { Tensor } from '@sourceacademy/torch'; describe('Custom Operations', () => { diff --git a/test/event_listener.test.js b/test/event_listener.test.js index 4a2aa2a..647308a 100644 --- a/test/event_listener.test.js +++ b/test/event_listener.test.js @@ -1,5 +1,5 @@ import { assert } from 'chai'; -import * as torch from 'torch'; +import * as torch from '@sourceacademy/torch'; describe('Event Bus', () => { const a = new torch.Tensor([1, 2, 3], { requires_grad: true }); diff --git a/test/export.test.ts b/test/export.test.ts index 167aba4..d04ad5c 100644 --- a/test/export.test.ts +++ b/test/export.test.ts @@ -3,7 +3,7 @@ import { _atenMap } from '../src/export'; import { _getAllOperationNames } from '../src/functions/registry'; // Import torch to trigger operation registration side effects -import 'torch'; +import '@sourceacademy/torch'; describe('Export', () => { // List from https://docs.pytorch.org/docs/2.10/user_guide/torch_compiler/torch.compiler_ir.html diff --git a/test/functional.test.js b/test/functional.test.js index d209194..4a5f8c3 100644 --- a/test/functional.test.js +++ b/test/functional.test.js @@ -1,6 +1,6 @@ import { assert } from 'chai'; -import * as torch from 'torch'; -import { Tensor } from 'torch'; +import * as torch from '@sourceacademy/torch'; +import { Tensor } from '@sourceacademy/torch'; describe('Functional', () => { describe('Addition', () => { diff --git a/test/generated.test.js b/test/generated.test.js index 7d880d9..5514dcb 100644 --- a/test/generated.test.js +++ b/test/generated.test.js @@ -1,5 +1,5 @@ -import * as torch from 'torch'; -import { Tensor } from 'torch'; +import * as torch from '@sourceacademy/torch'; +import { Tensor } from '@sourceacademy/torch'; import { assert } from 'chai'; import { testData } from './testcases.gen.js'; diff --git a/test/grad_mode.test.js b/test/grad_mode.test.js index e6ca5e2..05de450 100644 --- a/test/grad_mode.test.js +++ b/test/grad_mode.test.js @@ -1,5 +1,5 @@ import { assert } from 'chai'; -import { Tensor, no_grad, enable_no_grad, disable_no_grad, is_grad_enabled } from 'torch'; +import { Tensor, no_grad, enable_no_grad, disable_no_grad, is_grad_enabled } from '@sourceacademy/torch'; describe('Grad Mode', () => { afterEach(() => { diff --git a/test/index.html b/test/index.html index 29e8f24..b05ca6d 100644 --- a/test/index.html +++ b/test/index.html @@ -20,7 +20,7 @@