diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000..f948375
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,10 @@
+version: 2
+updates:
+ - package-ecosystem: "yarn"
+ directory: "/"
+ schedule:
+ interval: "monthly"
+ groups:
+ all-updates:
+ patterns:
+ - "*"
diff --git a/.github/workflows/build-test-publish.yml b/.github/workflows/build-test.yml
similarity index 81%
rename from .github/workflows/build-test-publish.yml
rename to .github/workflows/build-test.yml
index cbdaff4..1f4d176 100644
--- a/.github/workflows/build-test-publish.yml
+++ b/.github/workflows/build-test.yml
@@ -1,10 +1,12 @@
-name: Build, Test, Publish
+name: Build and Test
-on: [push]
+on:
+ pull_request:
+ push:
jobs:
- test-publish:
- name: Build, test, and publish package
+ test:
+ name: Build, test, and prepare docs
runs-on: ubuntu-latest
permissions:
contents: read
@@ -36,16 +38,15 @@ jobs:
mkdir -p docs/media
[ -d "build" ] && cp -r build docs/build
[ -d "build" ] && cp -r build docs/media/build
- [ -d "examples" ] && cp -r examples docs/examples
- [ -d "examples" ] && cp -r examples docs/media/examples
+ [ -d "examples" ] && cp -rL examples docs/examples
+ [ -d "examples" ] && cp -rL examples docs/media/examples
- uses: actions/upload-pages-artifact@v4
with:
path: './docs'
- - run: yarn pkg-pr-new publish
-
deploy-docs:
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
@@ -53,7 +54,7 @@ jobs:
permissions:
pages: write
id-token: write
- needs: test-publish
+ needs: test
steps:
- uses: actions/deploy-pages@v5
id: deployment
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000..8b95bbc
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,25 @@
+name: Publish
+
+on:
+ push:
+ branches: [main]
+ workflow_dispatch:
+
+jobs:
+ publish:
+ name: Publish package preview
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ timeout-minutes: 10
+ steps:
+ - uses: actions/checkout@v6
+ - run: corepack enable
+ - uses: actions/setup-node@v6
+ with:
+ node-version: 24
+ cache: yarn
+
+ - run: yarn install --immutable
+ - run: yarn build
+ - run: yarn pkg-pr-new publish
diff --git a/.github/workflows/pyodide-test.yml b/.github/workflows/pyodide-test.yml
index 99cf865..c94fd46 100644
--- a/.github/workflows/pyodide-test.yml
+++ b/.github/workflows/pyodide-test.yml
@@ -1,6 +1,8 @@
name: Pyodide Test
-on: [push]
+on:
+ pull_request:
+ push:
jobs:
pyodide-test:
diff --git a/.github/workflows/verify-generated-tests.yml b/.github/workflows/verify-generated-tests.yml
index 459bb77..d883a5d 100644
--- a/.github/workflows/verify-generated-tests.yml
+++ b/.github/workflows/verify-generated-tests.yml
@@ -1,6 +1,8 @@
name: Verify Generated Tests
-on: [push]
+on:
+ pull_request:
+ push:
jobs:
verify-sync:
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index cf153fe..bf5d47c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,18 +1,38 @@
+## Development Commands
+
+```bash
+yarn install # Install dependencies
+yarn build # Build for browser and node (runs clean first)
+yarn test # Run tests against src/ (uses mocha + tsx)
+yarn test:build # Run tests against built library in build/
+yarn test:coverage # Run tests with coverage report
+yarn test:watch # Watch mode for tests (alias: yarn watch)
+yarn lint # Lint src/ with ESLint
+yarn serve # Serve on localhost:8080
+yarn docs # Build TypeDoc documentation
+yarn update-tests # Regenerate test/testcases.gen.js from scripts/generate_tests.py
+```
+
+To run a single test file:
+```bash
+yarn mocha --node-option conditions=torch-src test/tensor.test.js
+```
+
## Codebase Structure
- [`src`](src)
- [`index.ts`](src/index.ts) is the entry point of the library.
- [`tensor.ts`](src/tensor.ts) is the main tensor class.
- - [`function`](function) contains all functions that tensors can perform.
- - [`nn`](nn) contains all neural network modules (for everything under `torch.nn`).
- - [`optim`](optim) contains all optimizers (for everything under `torch.optim`).
- - [`creation`](creation) contains all tensor creation functions (all functions that create a tensor not from scratch, including `zeros`, `randn`).
+ - [`functions`](src/functions) contains all functions that tensors can perform.
+ - [`nn`](src/nn) contains all neural network modules (for everything under `torch.nn`).
+ - [`optim`](src/optim) contains all optimizers (for everything under `torch.optim`).
+ - [`creation`](src/creation) contains all tensor creation functions (all functions that create a tensor not from scratch, including `zeros`, `randn`).
- [`examples`](examples) contains example usages of the library, including on node, on the browser, and using pyodide on the browser.
- [`test`](test) contains the test cases of the library, including on node and on the browser. See [Testing](#testing).
### Development Scripts
-You can use `yarn watch` to automatically test after each edit.
+Use `yarn watch` (or `yarn test:watch`) to automatically re-run tests on each edit.
### Adding a new Function
diff --git a/README.md b/README.md
index 3d39c83..3b1e6a4 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,18 @@
# torch
machine-learning libraries for Source Academy
+The primary objective of this project is to create a reimplementation of PyTorch in TypeScript, with an educational focus. This project is developed with Source Academy integration in mind.
+
+This project reimplements core parts of PyTorch while trying to keep the codebase simple, and the API as close to PyTorch as possible.
+
+Using Pyodide, we can run Python code in the browser. Using `pyodide_bridge.py` in a way similar to `examples/pyodide/` we can run PyTorch-like code in the browser.
+
+## Notable differences with PyTorch
+
+- This library exposes extra information for debuggers and visualizers to catch, as seen in `events` in [`src/util.ts`](src/util.ts). It is similar to hooks in PyTorch.
+- This library does not differentiate between LongTensors and FloatTensors. It uses `number` for all tensor elements.
+- This library does not currently support devices, such as GPUs.
+
## Getting Started
Install yarn:
@@ -48,4 +60,10 @@ yarn serve
## Contributing
-For detailed information on the codebase and tests, see [CONTRIBUTING.md](CONTRIBUTING.md).
+Contributions are welcome. The short version:
+
+1. Run `yarn test` to verify everything passes.
+2. Add tests for new ops or behaviour changes.
+3. Follow the existing patterns — new ops go in [src/functions/ops.ts](src/functions/ops.ts).
+
+For full details on the codebase, how to add operations, and the testing setup, see [CONTRIBUTING.md](CONTRIBUTING.md).
diff --git a/examples/basic_backpropagation.js b/examples/basic_backpropagation.js
index 8274ac2..72a5a37 100644
--- a/examples/basic_backpropagation.js
+++ b/examples/basic_backpropagation.js
@@ -1,4 +1,4 @@
-import { Tensor } from '../build/node/torch.node.es.js';
+import { Tensor } from '../build/node/torch.node.es.mjs';
const x = new Tensor([2.0], { requires_grad: true });
const y = x.pow(new Tensor([2.0]));
diff --git a/examples/pyodide/.gitignore b/examples/pyodide/.gitignore
index 2aa8c99..cb95714 100644
--- a/examples/pyodide/.gitignore
+++ b/examples/pyodide/.gitignore
@@ -142,3 +142,6 @@ dist
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
.vite/
+
+# Pyodide package cache
+.pyodide-packages/
diff --git a/examples/pyodide/bridge.py b/examples/pyodide/bridge.py
deleted file mode 100644
index 872c259..0000000
--- a/examples/pyodide/bridge.py
+++ /dev/null
@@ -1,805 +0,0 @@
-# bridge.py
-# Provides a PyTorch-compatible Python API over js_torch (the TypeScript torch library).
-#
-# Before loading this file, set the following globals in Pyodide:
-# js_torch - the torch module (window.torch from the UMD build)
-
-from pyodide.ffi import JsProxy, to_js
-
-
-# ---------------------------------------------------------------------------
-# Internal helpers
-# ---------------------------------------------------------------------------
-
-def _wrap_result(result):
- """
- Wrap a JS return value:
- - JsProxy (JS object/Tensor) -> Python Tensor
- - Python primitive (int, float, bool) -> return as-is
- JS primitives are automatically converted to Python by Pyodide,
- so they will NOT be JsProxy instances.
- """
- if isinstance(result, JsProxy):
- return Tensor(result)
- return result
-
-
-def _transform(obj):
- """Convert Python objects to JS-compatible types before passing to JS."""
- if isinstance(obj, Tensor):
- return obj._js
- if isinstance(obj, (list, tuple)):
- return to_js([_transform(item) for item in obj])
- return obj
-
-
-def _transform_args(args):
- return [_transform(a) for a in args]
-
-
-# ---------------------------------------------------------------------------
-# Tensor
-# ---------------------------------------------------------------------------
-
-class Tensor:
- """Python wrapper around a JS Tensor, mirroring the PyTorch Tensor API."""
-
- # ------------------------------------------------------------------
- # Construction
- # ------------------------------------------------------------------
-
- def __new__(cls, data, requires_grad=False):
- # Return None for missing tensors so e.g. `tensor.grad` returns None
- # when there is no gradient — matching PyTorch behaviour.
- # Pyodide may represent JS null as a special JsNull type (not JsProxy, not None).
- if data is None or type(data).__name__ in ('JsNull', 'JsUndefined'):
- return None
- return super().__new__(cls)
-
- def __init__(self, data, requires_grad=False):
- if isinstance(data, JsProxy):
- self._js = data
- else:
- js_data = to_js(data) if isinstance(data, (list, tuple)) else data
- self._js = js_torch.tensor(js_data, requires_grad)
-
- # ------------------------------------------------------------------
- # Representation
- # ------------------------------------------------------------------
-
- def __repr__(self):
- extra = ", requires_grad=True" if self.requires_grad else ""
- return f"tensor({self.tolist()}{extra})"
-
- # ------------------------------------------------------------------
- # Data access
- # ------------------------------------------------------------------
-
- def tolist(self):
- """Return tensor data as a (nested) Python list, or a Python scalar for 0-d tensors."""
- result = self._js.toArray()
- if isinstance(result, JsProxy):
- return result.to_py()
- return result # scalar
-
- def item(self):
- return self._js.item()
-
- # ------------------------------------------------------------------
- # Properties
- # ------------------------------------------------------------------
-
- @property
- def shape(self):
- return tuple(self._js.shape.to_py())
-
- @property
- def data(self):
- """Detached view of the tensor data (no gradient)."""
- return self.detach()
-
- @property
- def requires_grad(self):
- return bool(self._js.requires_grad)
-
- @requires_grad.setter
- def requires_grad(self, value):
- self._js.requires_grad = value
-
- @property
- def grad(self):
- raw = self._js.grad
- if raw is None or type(raw).__name__ in ('JsNull', 'JsUndefined'):
- return None
- return Tensor(raw)
-
- @grad.setter
- def grad(self, value):
- self._js.grad = value._js if isinstance(value, Tensor) else None
-
- @property
- def T(self):
- if len(self.shape) < 2:
- return self
- return Tensor(self._js.transpose(0, 1))
-
- # ------------------------------------------------------------------
- # Grad utilities
- # ------------------------------------------------------------------
-
- def backward(self, gradient=None):
- if gradient is None:
- self._js.backward()
- else:
- self._js.backward(gradient._js)
-
- def detach(self):
- return Tensor(self._js.detach())
-
- def zero_(self):
- self._js.zero_()
- return self
-
- def retain_grad(self):
- self._js.retain_grad()
-
- # ------------------------------------------------------------------
- # Shape utilities
- # ------------------------------------------------------------------
-
- def size(self, dim=None):
- s = self.shape
- return s if dim is None else s[dim]
-
- def dim(self):
- return len(self.shape)
-
- def numel(self):
- n = 1
- for s in self.shape:
- n *= s
- return n
-
- def reshape(self, *args):
- shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args)
- return Tensor(self._js.reshape(to_js(shape)))
-
- def view(self, *args):
- return self.reshape(*args)
-
- def squeeze(self, dim=None):
- if dim is None:
- new_shape = [s for s in self.shape if s != 1]
- return Tensor(self._js.reshape(to_js(new_shape or [1])))
- return Tensor(self._js.squeeze(dim))
-
- def unsqueeze(self, dim):
- return Tensor(self._js.unsqueeze(dim))
-
- def expand(self, *args):
- shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args)
- return Tensor(self._js.expand(to_js(shape)))
-
- def transpose(self, dim0, dim1):
- return Tensor(self._js.transpose(dim0, dim1))
-
- def flatten(self, start_dim=0, end_dim=-1):
- return Tensor(self._js.flatten(start_dim, end_dim))
-
- # ------------------------------------------------------------------
- # Reductions — default (no dim) sums all elements, matching PyTorch
- # ------------------------------------------------------------------
-
- def sum(self, dim=None, keepdim=False):
- return Tensor(self._js.sum() if dim is None else self._js.sum(dim, keepdim))
-
- def mean(self, dim=None, keepdim=False):
- return Tensor(self._js.mean() if dim is None else self._js.mean(dim, keepdim))
-
- def max(self, dim=None, keepdim=False):
- return Tensor(self._js.max() if dim is None else self._js.max(dim, keepdim))
-
- def min(self, dim=None, keepdim=False):
- return Tensor(self._js.min() if dim is None else self._js.min(dim, keepdim))
-
- # ------------------------------------------------------------------
- # Arithmetic — explicit methods
- # ------------------------------------------------------------------
-
- def _to_js(self, other):
- return other._js if isinstance(other, Tensor) else other
-
- def add(self, other): return Tensor(self._js.add(self._to_js(other)))
- def sub(self, other): return Tensor(self._js.sub(self._to_js(other)))
- def mul(self, other): return Tensor(self._js.mul(self._to_js(other)))
- def div(self, other): return Tensor(self._js.div(self._to_js(other)))
- def pow(self, other): return Tensor(self._js.pow(self._to_js(other)))
- def matmul(self, other): return Tensor(self._js.matmul(self._to_js(other)))
-
- # ------------------------------------------------------------------
- # Arithmetic operators
- # ------------------------------------------------------------------
-
- def __add__(self, other): return self.add(other)
- def __radd__(self, other): return self.add(other) # add is commutative
- def __sub__(self, other): return self.sub(other)
- def __rsub__(self, other):
- o = other if isinstance(other, Tensor) else Tensor(other)
- return o.sub(self)
- def __mul__(self, other): return self.mul(other)
- def __rmul__(self, other): return self.mul(other) # mul is commutative
- def __truediv__(self, other): return self.div(other)
- def __rtruediv__(self, other):
- o = other if isinstance(other, Tensor) else Tensor(other)
- return o.div(self)
- def __pow__(self, other): return self.pow(other)
- def __rpow__(self, other):
- o = other if isinstance(other, Tensor) else Tensor(other)
- return o.pow(self)
- def __matmul__(self, other): return self.matmul(other)
- def __neg__(self): return Tensor(self._js.neg())
- def __abs__(self): return Tensor(self._js.abs())
-
- # ------------------------------------------------------------------
- # Unary operations
- # ------------------------------------------------------------------
-
- def neg(self): return Tensor(self._js.neg())
- def abs(self): return Tensor(self._js.abs())
- def log(self): return Tensor(self._js.log())
- def exp(self): return Tensor(self._js.exp())
- def sqrt(self): return Tensor(self._js.sqrt())
- def square(self): return Tensor(self._js.square())
- def sin(self): return Tensor(self._js.sin())
- def cos(self): return Tensor(self._js.cos())
- def tan(self): return Tensor(self._js.tan())
- def sigmoid(self): return Tensor(self._js.sigmoid())
- def relu(self): return Tensor(js_torch.nn.functional.relu(self._js))
- def softmax(self, dim): return Tensor(self._js.softmax(dim))
- def clamp(self, min, max): return Tensor(self._js.clamp(min, max))
- def sign(self): return Tensor(self._js.sign())
- def reciprocal(self): return Tensor(self._js.reciprocal())
- def nan_to_num(self): return Tensor(self._js.nan_to_num())
-
- # ------------------------------------------------------------------
- # Comparison
- # ------------------------------------------------------------------
-
- def lt(self, other): return Tensor(self._js.lt(self._to_js(other)))
- def gt(self, other): return Tensor(self._js.gt(self._to_js(other)))
- def le(self, other): return Tensor(self._js.le(self._to_js(other)))
- def ge(self, other): return Tensor(self._js.ge(self._to_js(other)))
- def eq(self, other): return Tensor(self._js.eq(self._to_js(other)))
- def ne(self, other): return Tensor(self._js.ne(self._to_js(other)))
-
- def allclose(self, other, rtol=1e-5, atol=1e-8, equal_nan=False):
- return bool(js_torch.allclose(self._js, other._js, rtol, atol, equal_nan))
-
- # ------------------------------------------------------------------
- # Type conversions
- # ------------------------------------------------------------------
-
- def __float__(self): return float(self.item())
- def __int__(self): return int(self.item())
- def __bool__(self): return bool(self.item())
- def __format__(self, fmt): return format(self.item(), fmt)
-
- # ------------------------------------------------------------------
- # Indexing
- # ------------------------------------------------------------------
-
- def __getitem__(self, key):
- if isinstance(key, int):
- return Tensor(self._js.index(key))
- if isinstance(key, tuple):
- result = self._js
- for k in key:
- if isinstance(k, int):
- result = result.index(k)
- else:
- raise NotImplementedError(
- "Only integer indexing is supported in multi-dimensional indexing"
- )
- return Tensor(result)
- if isinstance(key, slice):
- start, stop, step = key.indices(self.shape[0])
- data = [Tensor(self._js.index(i)).tolist() for i in range(start, stop, step)]
- return Tensor(data)
- raise TypeError(f"Invalid index type: {type(key).__name__}")
-
- # ------------------------------------------------------------------
- # Iteration and length
- # ------------------------------------------------------------------
-
- def __len__(self):
- return self.shape[0]
-
- def __iter__(self):
- data = self.tolist()
- if not isinstance(data, list):
- raise TypeError("iteration over a 0-d tensor")
- for item in data:
- yield Tensor(item)
-
- # ------------------------------------------------------------------
- # Catch-all: delegate unknown attribute accesses to the JS tensor.
- # Returned JsProxy objects are wrapped in Tensor; primitives pass through.
- # ------------------------------------------------------------------
-
- def __getattr__(self, name):
- if name.startswith('_'):
- raise AttributeError(name)
- def method(*args, **kwargs):
- js_args = _transform_args(args)
- return _wrap_result(self._js.__getattribute__(name)(*js_args))
- return method
-
-
-# ---------------------------------------------------------------------------
-# Typed tensor subclasses
-# ---------------------------------------------------------------------------
-
-def _trunc_nested(data):
- """Truncate all numbers in a nested list toward zero (for LongTensor)."""
- if isinstance(data, (int, float)):
- return int(data) # Python int() truncates toward zero
- return [_trunc_nested(item) for item in data]
-
-
-class FloatTensor(Tensor):
- """
- A Tensor that stores floating-point values.
- Equivalent to a regular Tensor; provided for PyTorch API compatibility.
- """
- def __init__(self, data, requires_grad=False):
- if isinstance(data, JsProxy):
- super().__init__(data)
- else:
- super().__init__(data, requires_grad)
-
-
-class LongTensor(Tensor):
- """
- A Tensor whose values are truncated to integers (toward zero).
- LongTensor([-1.7]) -> tensor([-1]), LongTensor([1.9]) -> tensor([1]).
- """
- def __init__(self, data, requires_grad=False):
- if isinstance(data, JsProxy):
- super().__init__(data)
- else:
- truncated = _trunc_nested(data) if isinstance(data, (list, tuple)) else int(data)
- super().__init__(truncated, requires_grad)
-
-
-# ---------------------------------------------------------------------------
-# no_grad context manager — actually disables grad in the JS engine
-# ---------------------------------------------------------------------------
-
-class _NoGrad:
- def __enter__(self):
- self._prev = js_torch.enable_no_grad()
- return self
-
- def __exit__(self, *args):
- js_torch.disable_no_grad(self._prev)
-
-
-# ---------------------------------------------------------------------------
-# Parameter
-# ---------------------------------------------------------------------------
-
-class Parameter(Tensor):
- """A Tensor that is automatically registered as a parameter."""
- def __init__(self, data, requires_grad=True):
- if isinstance(data, Tensor):
- self._js = js_torch.nn.Parameter.new(data._js)
- elif isinstance(data, JsProxy):
- self._js = js_torch.nn.Parameter.new(data)
- else:
- self._js = js_torch.nn.Parameter.new(js_torch.tensor(data))
- if not requires_grad:
- self._js.requires_grad = False
-
-
-# ---------------------------------------------------------------------------
-# Module — pure-Python base class for user-defined models
-# ---------------------------------------------------------------------------
-
-class Module:
- """
- Pure-Python nn.Module. Subclass this to build models using bridge Tensors.
- Assign `Parameter` or `_NNModule` instances as attributes and they are
- automatically tracked by `parameters()`.
- """
-
- def __init__(self):
- object.__setattr__(self, '_parameters', {})
- object.__setattr__(self, '_modules', {})
- object.__setattr__(self, 'training', True)
-
- def __setattr__(self, name, value):
- try:
- params = object.__getattribute__(self, '_parameters')
- modules = object.__getattribute__(self, '_modules')
- except AttributeError:
- object.__setattr__(self, name, value)
- return
-
- if isinstance(value, Parameter):
- params[name] = value
- elif isinstance(value, (Module, _NNModule)):
- modules[name] = value
- object.__setattr__(self, name, value)
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def forward(self, *args, **kwargs):
- raise NotImplementedError
-
- def parameters(self):
- params = list(object.__getattribute__(self, '_parameters').values())
- for mod in object.__getattribute__(self, '_modules').values():
- params.extend(mod.parameters())
- return params
-
- def named_parameters(self, prefix=''):
- result = []
- for name, p in object.__getattribute__(self, '_parameters').items():
- full = f"{prefix}.{name}" if prefix else name
- result.append((full, p))
- for mod_name, mod in object.__getattribute__(self, '_modules').items():
- full_mod = f"{prefix}.{mod_name}" if prefix else mod_name
- result.extend(mod.named_parameters(full_mod))
- return result
-
- def train(self, mode=True):
- object.__setattr__(self, 'training', mode)
- for mod in object.__getattribute__(self, '_modules').values():
- mod.train(mode)
- return self
-
- def eval(self):
- return self.train(False)
-
- def zero_grad(self):
- for p in self.parameters():
- p.grad = None
-
-
-# ---------------------------------------------------------------------------
-# _NNModule — wraps a JS nn.Module instance
-# ---------------------------------------------------------------------------
-
-class _NNModule:
- """Wraps a JS nn.Module returned by the nn factory functions."""
-
- def __init__(self, js_module):
- self._module = js_module
-
- def __call__(self, *args):
- js_args = [a._js if isinstance(a, Tensor) else a for a in args]
- return Tensor(self._module.call(*js_args))
-
- def forward(self, *args):
- js_args = [a._js if isinstance(a, Tensor) else a for a in args]
- return Tensor(self._module.forward(*js_args))
-
- def parameters(self):
- return [Tensor(p) for p in self._module.parameters().to_py()]
-
- def named_parameters(self, prefix=''):
- raw = self._module.named_parameters(prefix).to_py()
- return [(pair[0], Tensor(pair[1])) for pair in raw]
-
- def train(self, mode=True):
- self._module.train(mode)
- return self
-
- def eval(self):
- return self.train(False)
-
- def zero_grad(self):
- for p in self.parameters():
- p.grad = None
-
-
-# ---------------------------------------------------------------------------
-# nn.functional
-# ---------------------------------------------------------------------------
-
-class _NNFunctional:
- def relu(self, input):
- return Tensor(js_torch.nn.functional.relu(input._js))
-
- def sigmoid(self, input):
- return Tensor(js_torch.nn.functional.sigmoid(input._js))
-
- def leaky_relu(self, input, negative_slope=0.01):
- return Tensor(js_torch.nn.functional.leaky_relu(input._js, negative_slope))
-
- def max_pool2d(self, input, kernel_size, stride=None, padding=0):
- if stride is None:
- return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size))
- return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size, stride, padding))
-
- def nll_loss(self, input, target, reduction='mean'):
- return Tensor(js_torch.nn.functional.nll_loss(input._js, target._js, reduction))
-
- def __getattr__(self, name):
- if name.startswith('_'):
- raise AttributeError(name)
- def fn(*args, **kwargs):
- return _wrap_result(js_torch.nn.functional.__getattribute__(name)(*_transform_args(args)))
- return fn
-
-
-# ---------------------------------------------------------------------------
-# nn.parameter namespace
-# ---------------------------------------------------------------------------
-
-class _NNParameterNamespace:
- def __init__(self):
- self.Parameter = Parameter
-
-
-# ---------------------------------------------------------------------------
-# nn namespace
-# ---------------------------------------------------------------------------
-
-class _NNNamespace:
- def __init__(self):
- self.functional = _NNFunctional()
- self.parameter = _NNParameterNamespace()
- self.Module = Module
- self.Parameter = Parameter
-
- def Linear(self, in_features, out_features, bias=True):
- return _NNModule(js_torch.nn.Linear.new(in_features, out_features, bias))
-
- def ReLU(self):
- return _NNModule(js_torch.nn.ReLU.new())
-
- def Sigmoid(self):
- return _NNModule(js_torch.nn.Sigmoid.new())
-
- def Sequential(self, *modules):
- js_mods = [m._module for m in modules]
- return _NNModule(js_torch.nn.Sequential.new(*js_mods))
-
- def MSELoss(self, reduction='mean'):
- return _NNModule(js_torch.nn.MSELoss.new(reduction))
-
- def L1Loss(self, reduction='mean'):
- return _NNModule(js_torch.nn.L1Loss.new(reduction))
-
- def BCELoss(self, weight=None, reduction='mean'):
- js_weight = weight._js if isinstance(weight, Tensor) else None
- return _NNModule(js_torch.nn.BCELoss.new(js_weight, reduction))
-
- def CrossEntropyLoss(self, reduction='mean'):
- return _NNModule(js_torch.nn.CrossEntropyLoss.new(reduction))
-
- def Conv1d(self, in_channels, out_channels, kernel_size,
- stride=1, padding=0, dilation=1, groups=1, bias=True):
- return _NNModule(js_torch.nn.Conv1d.new(
- in_channels, out_channels, kernel_size,
- stride, padding, dilation, groups, bias
- ))
-
- def Conv2d(self, in_channels, out_channels, kernel_size,
- stride=1, padding=0, dilation=1, groups=1, bias=True):
- return _NNModule(js_torch.nn.Conv2d.new(
- in_channels, out_channels, kernel_size,
- stride, padding, dilation, groups, bias
- ))
-
- def Conv3d(self, in_channels, out_channels, kernel_size,
- stride=1, padding=0, dilation=1, groups=1, bias=True):
- return _NNModule(js_torch.nn.Conv3d.new(
- in_channels, out_channels, kernel_size,
- stride, padding, dilation, groups, bias
- ))
-
- def LeakyReLU(self, negative_slope=0.01):
- return _NNModule(js_torch.nn.LeakyReLU.new(negative_slope))
-
- def MaxPool2d(self, kernel_size, stride=None, padding=0):
- if stride is None:
- return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size))
- return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size, stride, padding))
-
- def Dropout(self, p=0.5):
- return _NNModule(js_torch.nn.Dropout.new(p))
-
- def Softmax(self, dim):
- return _NNModule(js_torch.nn.Softmax.new(dim))
-
- def Flatten(self, start_dim=1, end_dim=-1):
- return _NNModule(js_torch.nn.Flatten.new(start_dim, end_dim))
-
- def NLLLoss(self, reduction='mean'):
- return _NNModule(js_torch.nn.NLLLoss.new(reduction))
-
-
-# ---------------------------------------------------------------------------
-# optim wrappers
-# ---------------------------------------------------------------------------
-
-class _Optimizer:
- def __init__(self, js_optim):
- self._optim = js_optim
-
- def step(self):
- self._optim.step()
-
- def zero_grad(self):
- self._optim.zero_grad()
-
-
-class _OptimNamespace:
- def SGD(self, params, lr=0.001, momentum=0.0, dampening=0.0,
- weight_decay=0.0, nesterov=False, maximize=False):
- js_params = to_js([p._js for p in params])
- return _Optimizer(js_torch.optim.SGD.new(
- js_params, lr, momentum, dampening, weight_decay, nesterov, maximize
- ))
-
- def Adam(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0.0, amsgrad=False, maximize=False):
- js_params = to_js([p._js for p in params])
- js_betas = to_js(list(betas))
- return _Optimizer(js_torch.optim.Adam.new(
- js_params, lr, js_betas, eps, weight_decay, amsgrad, maximize
- ))
-
- def Adagrad(self, params, lr=0.01, lr_decay=0, weight_decay=0, eps=1e-10):
- js_params = to_js([p._js for p in params])
- return _Optimizer(js_torch.optim.Adagrad.new(
- js_params, lr, lr_decay, weight_decay, eps
- ))
-
-
-# ---------------------------------------------------------------------------
-# torch namespace
-# ---------------------------------------------------------------------------
-
-class _Torch:
- def __init__(self):
- self.nn = _NNNamespace()
- self.optim = _OptimNamespace()
- self.no_grad = _NoGrad
- self.Tensor = Tensor
- self.FloatTensor = FloatTensor
- self.LongTensor = LongTensor
-
- @property
- def tensor(self):
- return Tensor
-
- # --- creation functions ---
-
- def _shape_from_args(self, args):
- return list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args)
-
- def zeros(self, *args, **kwargs):
- return Tensor(js_torch.zeros(to_js(self._shape_from_args(args))))
-
- def ones(self, *args, **kwargs):
- return Tensor(js_torch.ones(to_js(self._shape_from_args(args))))
-
- def zeros_like(self, input):
- return Tensor(js_torch.zeros_like(input._js))
-
- def ones_like(self, input):
- return Tensor(js_torch.ones_like(input._js))
-
- def randn(self, *args, **kwargs):
- return Tensor(js_torch.randn(to_js(self._shape_from_args(args))))
-
- def rand(self, *args, **kwargs):
- return Tensor(js_torch.rand(to_js(self._shape_from_args(args))))
-
- def arange(self, start, end=None, step=1):
- if end is None:
- end = start
- start = 0
- return Tensor(js_torch.arange(start, end, step))
-
- def linspace(self, start, end, steps):
- return Tensor(js_torch.linspace(start, end, steps))
-
- def empty(self, *args, **kwargs):
- return Tensor(js_torch.empty(to_js(self._shape_from_args(args))))
-
- def empty_like(self, input):
- return Tensor(js_torch.empty_like(input._js))
-
- def full(self, shape, fill_value):
- return Tensor(js_torch.full(to_js(list(shape)), fill_value))
-
- def full_like(self, input, fill_value):
- return Tensor(js_torch.full_like(input._js, fill_value))
-
- def rand_like(self, input):
- return Tensor(js_torch.rand_like(input._js))
-
- def randn_like(self, input):
- return Tensor(js_torch.randn_like(input._js))
-
- def randint_like(self, input, low, high):
- return Tensor(js_torch.randint_like(input._js, low, high))
-
- # --- utility functions ---
-
- def is_tensor(self, obj):
- return isinstance(obj, Tensor)
-
- def is_nonzero(self, input):
- if input.numel() != 1:
- raise RuntimeError(
- "Boolean value of Tensor with more than one element is ambiguous"
- )
- return bool(input.item() != 0)
-
- def numel(self, input):
- return input.numel()
-
- # --- functional wrappers ---
-
- def sum(self, input, dim=None, keepdim=False):
- return input.sum(dim, keepdim)
-
- def mean(self, input, dim=None, keepdim=False):
- return input.mean(dim, keepdim)
-
- def sigmoid(self, input):
- return input.sigmoid()
-
- def relu(self, input):
- return input.relu()
-
- def softmax(self, input, dim):
- return input.softmax(dim)
-
- def clamp(self, input, min, max):
- return input.clamp(min, max)
-
- def clip(self, input, min, max):
- return self.clamp(input, min, max)
-
- def flatten(self, input, start_dim=0, end_dim=-1):
- return input.flatten(start_dim, end_dim)
-
- def allclose(self, a, b, rtol=1e-5, atol=1e-8, equal_nan=False):
- return a.allclose(b, rtol, atol, equal_nan)
-
- def is_grad_enabled(self):
- return bool(js_torch.is_grad_enabled())
-
- def cat(self, tensors, dim=0):
- if isinstance(tensors, Tensor):
- tensors = [tensors]
- return Tensor(js_torch.cat(to_js([t._js for t in tensors]), dim))
-
- def concatenate(self, tensors, dim=0):
- return self.cat(tensors, dim)
-
- def concat(self, tensors, dim=0):
- return self.cat(tensors, dim)
-
- def stack(self, tensors, dim=0):
- return Tensor(js_torch.stack(to_js([t._js for t in tensors]), dim))
-
- def Size(self, shape):
- return list(shape)
-
- def __getattr__(self, name):
- if name.startswith('_'):
- raise AttributeError(name)
- def fn(*args, **kwargs):
- return _wrap_result(js_torch.__getattribute__(name)(*_transform_args(args)))
- return fn
-
-
-torch = _Torch()
diff --git a/examples/pyodide/bridge.py b/examples/pyodide/bridge.py
new file mode 120000
index 0000000..deb96f4
--- /dev/null
+++ b/examples/pyodide/bridge.py
@@ -0,0 +1 @@
+../../pyodide_bridge.py
\ No newline at end of file
diff --git a/examples/pyodide/index.html b/examples/pyodide/index.html
index f3ea57b..a0feac7 100644
--- a/examples/pyodide/index.html
+++ b/examples/pyodide/index.html
@@ -19,6 +19,7 @@
Torch + Pyodide
+
@@ -54,6 +55,8 @@ Output:
throw new Error("torch.js failed to load. Run `yarn build` first.");
}
+ await pyodide.loadPackage('numpy');
+
// Make torch available to Python
pyodide.globals.set("js_torch", window.torch);
diff --git a/examples/pyodide/main.js b/examples/pyodide/main.js
index ab17d47..9bb81c5 100644
--- a/examples/pyodide/main.js
+++ b/examples/pyodide/main.js
@@ -1,5 +1,5 @@
import { loadPyodide } from 'pyodide';
-import * as torch from 'torch';
+import * as torch from '@sourceacademy/torch';
import { readFileSync, writeFileSync, mkdirSync, readdirSync, existsSync } from 'fs';
import { execSync } from 'child_process';
import path from 'path';
@@ -11,8 +11,12 @@ function getPyFiles() {
return readdirSync(PY_DIR).filter(f => f.endsWith('.py')).sort();
}
+const PYODIDE_PACKAGE_CACHE_DIR = './.pyodide-packages';
+
async function setupPyodide() {
- const pyodide = await loadPyodide();
+ if (!existsSync(PYODIDE_PACKAGE_CACHE_DIR)) mkdirSync(PYODIDE_PACKAGE_CACHE_DIR);
+ const pyodide = await loadPyodide({ packageCacheDir: PYODIDE_PACKAGE_CACHE_DIR });
+ await pyodide.loadPackage('numpy');
pyodide.globals.set('js_torch', torch);
pyodide.globals.set('_js_is_null', (x) => x == null);
pyodide.runPython(readFileSync('./bridge.py', 'utf8'));
diff --git a/examples/pyodide/package.json b/examples/pyodide/package.json
index c7b84a7..af748cc 100644
--- a/examples/pyodide/package.json
+++ b/examples/pyodide/package.json
@@ -8,7 +8,7 @@
"test-pyodide": "node main.js test"
},
"dependencies": {
- "pyodide": "^0.29.3",
- "torch": "portal:../../"
+ "@sourceacademy/torch": "portal:../../",
+ "pyodide": "^0.29.3"
}
}
diff --git a/examples/pyodide/py/numpy_interop.py b/examples/pyodide/py/numpy_interop.py
new file mode 100644
index 0000000..2fe7ac0
--- /dev/null
+++ b/examples/pyodide/py/numpy_interop.py
@@ -0,0 +1,60 @@
+# numpy_interop.py — Verify torch.from_numpy() and Tensor.numpy() interop.
+
+import numpy as np
+
+print("=== from_numpy: 1-D array ===")
+arr = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+t = torch.from_numpy(arr)
+print("numpy array:", arr.tolist())
+print("tensor:", t)
+print("> shape correct:", t.shape == (3,))
+print("> values correct:", t.allclose(torch.tensor([1.0, 2.0, 3.0])))
+
+print("\n=== from_numpy: 2-D array ===")
+arr2d = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
+t2d = torch.from_numpy(arr2d)
+print("numpy array:", arr2d.tolist())
+print("tensor:", t2d)
+print("> shape correct:", t2d.shape == (2, 2))
+print("> values correct:", t2d.allclose(torch.tensor([[1.0, 2.0], [3.0, 4.0]])))
+
+print("\n=== numpy(): tensor -> ndarray ===")
+t3 = torch.tensor([10.0, 20.0, 30.0])
+a3 = t3.numpy()
+print("tensor:", t3)
+print("numpy type:", type(a3).__name__)
+print("numpy values:", a3.tolist())
+print("> type is ndarray:", isinstance(a3, np.ndarray))
+print("> values correct:", a3.tolist() == [10.0, 20.0, 30.0])
+
+print("\n=== numpy(): 2-D tensor ===")
+t4 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
+a4 = t4.numpy()
+print("numpy shape:", list(a4.shape))
+print("numpy values:", a4.tolist())
+print("> shape correct:", list(a4.shape) == [2, 2])
+print("> values correct:", a4.tolist() == [[1.0, 2.0], [3.0, 4.0]])
+
+print("\n=== numpy(): requires grad error ===")
+t5 = torch.tensor([1.0, 2.0], requires_grad=True)
+try:
+ t5.numpy()
+ print("> ERROR: should have raised RuntimeError")
+except RuntimeError as e:
+ print("Got expected RuntimeError:", str(e)[:50])
+ print("> requires_grad error raised correctly:", True)
+
+print("\n=== numpy() after detach ===")
+t6 = torch.tensor([5.0, 6.0], requires_grad=True)
+a6 = t6.detach().numpy()
+print("values:", a6.tolist())
+print("> detach().numpy() works:", a6.tolist() == [5.0, 6.0])
+
+print("\n=== round-trip: numpy -> tensor -> numpy ===")
+original = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
+roundtrip = torch.from_numpy(original).numpy()
+print("original:", original.tolist())
+print("round-trip:", roundtrip.tolist())
+print("> round-trip correct:", original.tolist() == roundtrip.tolist())
+
+print("\nAll numpy_interop checks passed.")
diff --git a/examples/pyodide/yarn.lock b/examples/pyodide/yarn.lock
index 0c93a3a..d053097 100644
--- a/examples/pyodide/yarn.lock
+++ b/examples/pyodide/yarn.lock
@@ -5,6 +5,12 @@ __metadata:
version: 8
cacheKey: 10c0
+"@sourceacademy/torch@portal:../../::locator=example-pyodide%40workspace%3A.":
+ version: 0.0.0-use.local
+ resolution: "@sourceacademy/torch@portal:../../::locator=example-pyodide%40workspace%3A."
+ languageName: node
+ linkType: soft
+
"@types/emscripten@npm:^1.41.4":
version: 1.41.5
resolution: "@types/emscripten@npm:1.41.5"
@@ -16,8 +22,8 @@ __metadata:
version: 0.0.0-use.local
resolution: "example-pyodide@workspace:."
dependencies:
+ "@sourceacademy/torch": "portal:../../"
pyodide: "npm:^0.29.3"
- torch: "portal:../../"
languageName: unknown
linkType: soft
@@ -31,12 +37,6 @@ __metadata:
languageName: node
linkType: hard
-"torch@portal:../../::locator=example-pyodide%40workspace%3A.":
- version: 0.0.0-use.local
- resolution: "torch@portal:../../::locator=example-pyodide%40workspace%3A."
- languageName: node
- linkType: soft
-
"ws@npm:^8.5.0":
version: 8.19.0
resolution: "ws@npm:8.19.0"
diff --git a/package.json b/package.json
index 469ce72..7d12bd8 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"type": "module",
- "name": "torch",
+ "name": "@sourceacademy/torch",
"packageManager": "yarn@4.10.3",
"version": "0.1.0",
"description": "machine-learning libraries for Source Academy",
@@ -12,7 +12,7 @@
"files": [
"build",
"src",
- "examples"
+ "pyodide_bridge.py"
],
"main": "./build/node/torch.node.cjs",
"module": "./build/node/torch.node.es.mjs",
diff --git a/pyodide_bridge.py b/pyodide_bridge.py
new file mode 100644
index 0000000..64a11c3
--- /dev/null
+++ b/pyodide_bridge.py
@@ -0,0 +1,828 @@
+# bridge.py
+# Provides a PyTorch-compatible Python API over js_torch (the TypeScript torch library).
+#
+# Before loading this file, set the following globals in Pyodide:
+# js_torch - the torch module (window.torch from the UMD build)
+
+from pyodide.ffi import JsProxy, to_js
+
+
+# ---------------------------------------------------------------------------
+# Internal helpers
+# ---------------------------------------------------------------------------
+
+def _wrap_result(result):
+ """
+ Wrap a JS return value:
+ - JsProxy (JS object/Tensor) -> Python Tensor
+ - Python primitive (int, float, bool) -> return as-is
+ JS primitives are automatically converted to Python by Pyodide,
+ so they will NOT be JsProxy instances.
+ """
+ if isinstance(result, JsProxy):
+ return Tensor(result)
+ return result
+
+
+def _transform(obj):
+ """Convert Python objects to JS-compatible types before passing to JS."""
+ if isinstance(obj, Tensor):
+ return obj._js
+ if isinstance(obj, (list, tuple)):
+ return to_js([_transform(item) for item in obj])
+ return obj
+
+
+def _transform_args(args):
+ return [_transform(a) for a in args]
+
+
+# ---------------------------------------------------------------------------
+# Tensor
+# ---------------------------------------------------------------------------
+
+class Tensor:
+ """Python wrapper around a JS Tensor, mirroring the PyTorch Tensor API."""
+
+ # ------------------------------------------------------------------
+ # Construction
+ # ------------------------------------------------------------------
+
+ def __new__(cls, data, requires_grad=False):
+ # Return None for missing tensors so e.g. `tensor.grad` returns None
+ # when there is no gradient — matching PyTorch behaviour.
+ # Pyodide may represent JS null as a special JsNull type (not JsProxy, not None).
+ if data is None or type(data).__name__ in ('JsNull', 'JsUndefined'):
+ return None
+ return super().__new__(cls)
+
+ def __init__(self, data, requires_grad=False):
+ if isinstance(data, JsProxy):
+ self._js = data
+ else:
+ js_data = to_js(data) if isinstance(data, (list, tuple)) else data
+ self._js = js_torch.tensor(js_data, requires_grad)
+
+ # ------------------------------------------------------------------
+ # Representation
+ # ------------------------------------------------------------------
+
+ def __repr__(self):
+ extra = ", requires_grad=True" if self.requires_grad else ""
+ return f"tensor({self.tolist()}{extra})"
+
+ # ------------------------------------------------------------------
+ # Data access
+ # ------------------------------------------------------------------
+
+ def tolist(self):
+ """Return tensor data as a (nested) Python list, or a Python scalar for 0-d tensors."""
+ result = self._js.toArray()
+ if isinstance(result, JsProxy):
+ return result.to_py()
+ return result # scalar
+
+ def item(self):
+ return self._js.item()
+
+ # ------------------------------------------------------------------
+ # Properties
+ # ------------------------------------------------------------------
+
+ @property
+ def shape(self):
+ return tuple(self._js.shape.to_py())
+
+ @property
+ def data(self):
+ """Detached view of the tensor data (no gradient)."""
+ return self.detach()
+
+ @property
+ def requires_grad(self):
+ return bool(self._js.requires_grad)
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ self._js.requires_grad = value
+
+ @property
+ def grad(self):
+ raw = self._js.grad
+ if raw is None or type(raw).__name__ in ('JsNull', 'JsUndefined'):
+ return None
+ return Tensor(raw)
+
+ @grad.setter
+ def grad(self, value):
+ self._js.grad = value._js if isinstance(value, Tensor) else None
+
+ @property
+ def T(self):
+ if len(self.shape) < 2:
+ return self
+ return Tensor(self._js.transpose(0, 1))
+
+ # ------------------------------------------------------------------
+ # Grad utilities
+ # ------------------------------------------------------------------
+
+ def backward(self, gradient=None):
+ if gradient is None:
+ self._js.backward()
+ else:
+ self._js.backward(gradient._js)
+
+ def detach(self):
+ return Tensor(self._js.detach())
+
+ def zero_(self):
+ self._js.zero_()
+ return self
+
+ def retain_grad(self):
+ self._js.retain_grad()
+
+ # ------------------------------------------------------------------
+ # Shape utilities
+ # ------------------------------------------------------------------
+
+ def size(self, dim=None):
+ s = self.shape
+ return s if dim is None else s[dim]
+
+ def dim(self):
+ return len(self.shape)
+
+ def numel(self):
+ n = 1
+ for s in self.shape:
+ n *= s
+ return n
+
+ def reshape(self, *args):
+ shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args)
+ return Tensor(self._js.reshape(to_js(shape)))
+
+ def view(self, *args):
+ return self.reshape(*args)
+
+ def squeeze(self, dim=None):
+ if dim is None:
+ new_shape = [s for s in self.shape if s != 1]
+ return Tensor(self._js.reshape(to_js(new_shape)))
+ return Tensor(self._js.squeeze(dim))
+
+ def unsqueeze(self, dim):
+ return Tensor(self._js.unsqueeze(dim))
+
+ def expand(self, *args):
+ shape = list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args)
+ return Tensor(self._js.expand(to_js(shape)))
+
+ def transpose(self, dim0, dim1):
+ return Tensor(self._js.transpose(dim0, dim1))
+
+ def flatten(self, start_dim=0, end_dim=-1):
+ return Tensor(self._js.flatten(start_dim, end_dim))
+
+ # ------------------------------------------------------------------
+ # Reductions — default (no dim) sums all elements, matching PyTorch
+ # ------------------------------------------------------------------
+
+ def sum(self, dim=None, keepdim=False):
+ return Tensor(self._js.sum() if dim is None else self._js.sum(dim, keepdim))
+
+ def mean(self, dim=None, keepdim=False):
+ return Tensor(self._js.mean() if dim is None else self._js.mean(dim, keepdim))
+
+ def max(self, dim=None, keepdim=False):
+ return Tensor(self._js.max() if dim is None else self._js.max(dim, keepdim))
+
+ def min(self, dim=None, keepdim=False):
+ return Tensor(self._js.min() if dim is None else self._js.min(dim, keepdim))
+
+ # ------------------------------------------------------------------
+ # Arithmetic — explicit methods
+ # ------------------------------------------------------------------
+
+ def _to_js(self, other):
+ return other._js if isinstance(other, Tensor) else other
+
+ def add(self, other): return Tensor(self._js.add(self._to_js(other)))
+ def sub(self, other): return Tensor(self._js.sub(self._to_js(other)))
+ def mul(self, other): return Tensor(self._js.mul(self._to_js(other)))
+ def div(self, other): return Tensor(self._js.div(self._to_js(other)))
+ def pow(self, other): return Tensor(self._js.pow(self._to_js(other)))
+ def matmul(self, other): return Tensor(self._js.matmul(self._to_js(other)))
+
+ # ------------------------------------------------------------------
+ # Arithmetic operators
+ # ------------------------------------------------------------------
+
+ def __add__(self, other): return self.add(other)
+ def __radd__(self, other): return self.add(other) # add is commutative
+ def __sub__(self, other): return self.sub(other)
+ def __rsub__(self, other):
+ o = other if isinstance(other, Tensor) else Tensor(other)
+ return o.sub(self)
+ def __mul__(self, other): return self.mul(other)
+ def __rmul__(self, other): return self.mul(other) # mul is commutative
+ def __truediv__(self, other): return self.div(other)
+ def __rtruediv__(self, other):
+ o = other if isinstance(other, Tensor) else Tensor(other)
+ return o.div(self)
+ def __pow__(self, other): return self.pow(other)
+ def __rpow__(self, other):
+ o = other if isinstance(other, Tensor) else Tensor(other)
+ return o.pow(self)
+ def __matmul__(self, other): return self.matmul(other)
+ def __neg__(self): return Tensor(self._js.neg())
+ def __abs__(self): return Tensor(self._js.abs())
+
+ # ------------------------------------------------------------------
+ # Unary operations
+ # ------------------------------------------------------------------
+
+ def neg(self): return Tensor(self._js.neg())
+ def abs(self): return Tensor(self._js.abs())
+ def log(self): return Tensor(self._js.log())
+ def exp(self): return Tensor(self._js.exp())
+ def sqrt(self): return Tensor(self._js.sqrt())
+ def square(self): return Tensor(self._js.square())
+ def sin(self): return Tensor(self._js.sin())
+ def cos(self): return Tensor(self._js.cos())
+ def tan(self): return Tensor(self._js.tan())
+ def sigmoid(self): return Tensor(self._js.sigmoid())
+ def relu(self): return Tensor(js_torch.nn.functional.relu(self._js))
+ def softmax(self, dim): return Tensor(self._js.softmax(dim))
+ def clamp(self, min, max): return Tensor(self._js.clamp(min, max))
+ def sign(self): return Tensor(self._js.sign())
+ def reciprocal(self): return Tensor(self._js.reciprocal())
+ def nan_to_num(self): return Tensor(self._js.nan_to_num())
+
+ # ------------------------------------------------------------------
+ # Comparison
+ # ------------------------------------------------------------------
+
+ def lt(self, other): return Tensor(self._js.lt(self._to_js(other)))
+ def gt(self, other): return Tensor(self._js.gt(self._to_js(other)))
+ def le(self, other): return Tensor(self._js.le(self._to_js(other)))
+ def ge(self, other): return Tensor(self._js.ge(self._to_js(other)))
+ def eq(self, other): return Tensor(self._js.eq(self._to_js(other)))
+ def ne(self, other): return Tensor(self._js.ne(self._to_js(other)))
+
+ def allclose(self, other, rtol=1e-5, atol=1e-8, equal_nan=False):
+ return bool(js_torch.allclose(self._js, other._js, rtol, atol, equal_nan))
+
+ # ------------------------------------------------------------------
+ # NumPy interop
+ # ------------------------------------------------------------------
+
+ def numpy(self):
+ import numpy as np
+ if self.requires_grad:
+ raise RuntimeError(
+ "Can't call numpy() on Tensor that requires grad. "
+ "Use tensor.detach().numpy() instead."
+ )
+ return np.array(self.tolist())
+
+ # ------------------------------------------------------------------
+ # Type conversions
+ # ------------------------------------------------------------------
+
+ def __float__(self): return float(self.item())
+ def __int__(self): return int(self.item())
+ def __bool__(self): return bool(self.item())
+ def __format__(self, fmt): return format(self.item(), fmt)
+
+ # ------------------------------------------------------------------
+ # Indexing
+ # ------------------------------------------------------------------
+
+ def __getitem__(self, key):
+ if isinstance(key, int):
+ return Tensor(self._js.index(key))
+ if isinstance(key, tuple):
+ result = self._js
+ for k in key:
+ if isinstance(k, int):
+ result = result.index(k)
+ else:
+ raise NotImplementedError(
+ "Only integer indexing is supported in multi-dimensional indexing"
+ )
+ return Tensor(result)
+ if isinstance(key, slice):
+ raise NotImplementedError(
+ "Slice indexing is not implemented; converting through Python lists "
+ "would return a detached copy and break tensor/autograd semantics"
+ )
+ raise TypeError(f"Invalid index type: {type(key).__name__}")
+
+ # ------------------------------------------------------------------
+ # Iteration and length
+ # ------------------------------------------------------------------
+
+ def __len__(self):
+ if self.dim() == 0:
+ raise TypeError("len() of a 0-d tensor")
+ return self.shape[0]
+
+ def __iter__(self):
+ if self.dim() == 0:
+ raise TypeError("iteration over a 0-d tensor")
+ for i in range(self.shape[0]):
+ yield self[i]
+
+ # ------------------------------------------------------------------
+ # Catch-all: delegate unknown attribute accesses to the JS tensor.
+ # Returned JsProxy objects are wrapped in Tensor; primitives pass through.
+ # ------------------------------------------------------------------
+
+ def __getattr__(self, name):
+ if name.startswith('_'):
+ raise AttributeError(name)
+ def method(*args, **kwargs):
+ if kwargs:
+ raise TypeError(
+ f"{name}() does not support keyword arguments in this bridge; "
+ f"got unexpected keyword argument(s): {', '.join(sorted(kwargs.keys()))}"
+ )
+ js_args = _transform_args(args)
+ return _wrap_result(self._js.__getattribute__(name)(*js_args))
+ return method
+
+
+# ---------------------------------------------------------------------------
+# Typed tensor subclasses
+# ---------------------------------------------------------------------------
+
+def _trunc_nested(data):
+ """Truncate all numbers in a nested list toward zero (for LongTensor)."""
+ if isinstance(data, (int, float)):
+ return int(data) # Python int() truncates toward zero
+ return [_trunc_nested(item) for item in data]
+
+
+class FloatTensor(Tensor):
+ """
+ A Tensor that stores floating-point values.
+ Equivalent to a regular Tensor; provided for PyTorch API compatibility.
+ """
+ def __init__(self, data, requires_grad=False):
+ if isinstance(data, JsProxy):
+ super().__init__(data)
+ else:
+ super().__init__(data, requires_grad)
+
+
+class LongTensor(Tensor):
+ """
+ A Tensor whose values are truncated to integers (toward zero).
+ LongTensor([-1.7]) -> tensor([-1]), LongTensor([1.9]) -> tensor([1]).
+ """
+ def __init__(self, data, requires_grad=False):
+ if isinstance(data, JsProxy):
+ super().__init__(data)
+ else:
+ truncated = _trunc_nested(data) if isinstance(data, (list, tuple)) else int(data)
+ super().__init__(truncated, requires_grad)
+
+
+# ---------------------------------------------------------------------------
+# no_grad context manager — actually disables grad in the JS engine
+# ---------------------------------------------------------------------------
+
+class _NoGrad:
+ def __enter__(self):
+ self._prev = js_torch.enable_no_grad()
+ return self
+
+ def __exit__(self, *args):
+ js_torch.disable_no_grad(self._prev)
+
+
+# ---------------------------------------------------------------------------
+# Parameter
+# ---------------------------------------------------------------------------
+
+class Parameter(Tensor):
+ """A Tensor that is automatically registered as a parameter."""
+ def __init__(self, data, requires_grad=True):
+ if isinstance(data, Tensor):
+ self._js = js_torch.nn.Parameter.new(data._js)
+ elif isinstance(data, JsProxy):
+ self._js = js_torch.nn.Parameter.new(data)
+ else:
+ self._js = js_torch.nn.Parameter.new(js_torch.tensor(data))
+ if not requires_grad:
+ self._js.requires_grad = False
+
+
+# ---------------------------------------------------------------------------
+# Module — pure-Python base class for user-defined models
+# ---------------------------------------------------------------------------
+
+class Module:
+ """
+ Pure-Python nn.Module. Subclass this to build models using bridge Tensors.
+ Assign `Parameter` or `_NNModule` instances as attributes and they are
+ automatically tracked by `parameters()`.
+ """
+
+ def __init__(self):
+ object.__setattr__(self, '_parameters', {})
+ object.__setattr__(self, '_modules', {})
+ object.__setattr__(self, 'training', True)
+
+ def __setattr__(self, name, value):
+ try:
+ params = object.__getattribute__(self, '_parameters')
+ modules = object.__getattribute__(self, '_modules')
+ except AttributeError:
+ object.__setattr__(self, name, value)
+ return
+
+ if isinstance(value, Parameter):
+ params[name] = value
+ elif isinstance(value, (Module, _NNModule)):
+ modules[name] = value
+ object.__setattr__(self, name, value)
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def parameters(self):
+ params = list(object.__getattribute__(self, '_parameters').values())
+ for mod in object.__getattribute__(self, '_modules').values():
+ params.extend(mod.parameters())
+ return params
+
+ def named_parameters(self, prefix=''):
+ result = []
+ for name, p in object.__getattribute__(self, '_parameters').items():
+ full = f"{prefix}.{name}" if prefix else name
+ result.append((full, p))
+ for mod_name, mod in object.__getattribute__(self, '_modules').items():
+ full_mod = f"{prefix}.{mod_name}" if prefix else mod_name
+ result.extend(mod.named_parameters(full_mod))
+ return result
+
+ def train(self, mode=True):
+ object.__setattr__(self, 'training', mode)
+ for mod in object.__getattribute__(self, '_modules').values():
+ mod.train(mode)
+ return self
+
+ def eval(self):
+ return self.train(False)
+
+ def zero_grad(self):
+ for p in self.parameters():
+ p.grad = None
+
+
+# ---------------------------------------------------------------------------
+# _NNModule — wraps a JS nn.Module instance
+# ---------------------------------------------------------------------------
+
+class _NNModule:
+ """Wraps a JS nn.Module returned by the nn factory functions."""
+
+ def __init__(self, js_module):
+ self._module = js_module
+
+ def __call__(self, *args):
+ js_args = [a._js if isinstance(a, Tensor) else a for a in args]
+ return Tensor(self._module.call(*js_args))
+
+ def forward(self, *args):
+ js_args = [a._js if isinstance(a, Tensor) else a for a in args]
+ return Tensor(self._module.forward(*js_args))
+
+ def parameters(self):
+ return [Tensor(p) for p in self._module.parameters().to_py()]
+
+ def named_parameters(self, prefix=''):
+ raw = self._module.named_parameters(prefix).to_py()
+ return [(pair[0], Tensor(pair[1])) for pair in raw]
+
+ def train(self, mode=True):
+ self._module.train(mode)
+ return self
+
+ def eval(self):
+ return self.train(False)
+
+ def zero_grad(self):
+ for p in self.parameters():
+ p.grad = None
+
+
+# ---------------------------------------------------------------------------
+# nn.functional
+# ---------------------------------------------------------------------------
+
+class _NNFunctional:
+ def relu(self, input):
+ return Tensor(js_torch.nn.functional.relu(input._js))
+
+ def sigmoid(self, input):
+ return Tensor(js_torch.nn.functional.sigmoid(input._js))
+
+ def leaky_relu(self, input, negative_slope=0.01):
+ return Tensor(js_torch.nn.functional.leaky_relu(input._js, negative_slope))
+
+ def max_pool2d(self, input, kernel_size, stride=None, padding=0):
+ if stride is None:
+ return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size))
+ return Tensor(js_torch.nn.functional.max_pool2d(input._js, kernel_size, stride, padding))
+
+ def nll_loss(self, input, target, reduction='mean'):
+ return Tensor(js_torch.nn.functional.nll_loss(input._js, target._js, reduction))
+
+ def __getattr__(self, name):
+ if name.startswith('_'):
+ raise AttributeError(name)
+ def fn(*args, **kwargs):
+ return _wrap_result(js_torch.nn.functional.__getattribute__(name)(*_transform_args(args)))
+ return fn
+
+
+# ---------------------------------------------------------------------------
+# nn.parameter namespace
+# ---------------------------------------------------------------------------
+
+class _NNParameterNamespace:
+ def __init__(self):
+ self.Parameter = Parameter
+
+
+# ---------------------------------------------------------------------------
+# nn namespace
+# ---------------------------------------------------------------------------
+
+class _NNNamespace:
+ def __init__(self):
+ self.functional = _NNFunctional()
+ self.parameter = _NNParameterNamespace()
+ self.Module = Module
+ self.Parameter = Parameter
+
+ def Linear(self, in_features, out_features, bias=True):
+ return _NNModule(js_torch.nn.Linear.new(in_features, out_features, bias))
+
+ def ReLU(self):
+ return _NNModule(js_torch.nn.ReLU.new())
+
+ def Sigmoid(self):
+ return _NNModule(js_torch.nn.Sigmoid.new())
+
+ def Sequential(self, *modules):
+ js_mods = [m._module for m in modules]
+ return _NNModule(js_torch.nn.Sequential.new(*js_mods))
+
+ def MSELoss(self, reduction='mean'):
+ return _NNModule(js_torch.nn.MSELoss.new(reduction))
+
+ def L1Loss(self, reduction='mean'):
+ return _NNModule(js_torch.nn.L1Loss.new(reduction))
+
+ def BCELoss(self, weight=None, reduction='mean'):
+ js_weight = weight._js if isinstance(weight, Tensor) else None
+ return _NNModule(js_torch.nn.BCELoss.new(js_weight, reduction))
+
+ def CrossEntropyLoss(self, reduction='mean'):
+ return _NNModule(js_torch.nn.CrossEntropyLoss.new(reduction))
+
+ def Conv1d(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, bias=True):
+ return _NNModule(js_torch.nn.Conv1d.new(
+ in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias
+ ))
+
+ def Conv2d(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, bias=True):
+ return _NNModule(js_torch.nn.Conv2d.new(
+ in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias
+ ))
+
+ def Conv3d(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, bias=True):
+ return _NNModule(js_torch.nn.Conv3d.new(
+ in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias
+ ))
+
+ def LeakyReLU(self, negative_slope=0.01):
+ return _NNModule(js_torch.nn.LeakyReLU.new(negative_slope))
+
+ def MaxPool2d(self, kernel_size, stride=None, padding=0):
+ if stride is None:
+ return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size))
+ return _NNModule(js_torch.nn.MaxPool2d.new(kernel_size, stride, padding))
+
+ def Dropout(self, p=0.5):
+ return _NNModule(js_torch.nn.Dropout.new(p))
+
+ def Softmax(self, dim):
+ return _NNModule(js_torch.nn.Softmax.new(dim))
+
+ def Flatten(self, start_dim=1, end_dim=-1):
+ return _NNModule(js_torch.nn.Flatten.new(start_dim, end_dim))
+
+ def NLLLoss(self, reduction='mean'):
+ return _NNModule(js_torch.nn.NLLLoss.new(reduction))
+
+
+# ---------------------------------------------------------------------------
+# optim wrappers
+# ---------------------------------------------------------------------------
+
+class _Optimizer:
+ def __init__(self, js_optim):
+ self._optim = js_optim
+
+ def step(self):
+ self._optim.step()
+
+ def zero_grad(self):
+ self._optim.zero_grad()
+
+
+class _OptimNamespace:
+ def SGD(self, params, lr=0.001, momentum=0.0, dampening=0.0,
+ weight_decay=0.0, nesterov=False, maximize=False):
+ js_params = to_js([p._js for p in params])
+ return _Optimizer(js_torch.optim.SGD.new(
+ js_params, lr, momentum, dampening, weight_decay, nesterov, maximize
+ ))
+
+ def Adam(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0.0, amsgrad=False, maximize=False):
+ js_params = to_js([p._js for p in params])
+ js_betas = to_js(list(betas))
+ return _Optimizer(js_torch.optim.Adam.new(
+ js_params, lr, js_betas, eps, weight_decay, amsgrad, maximize
+ ))
+
+ def Adagrad(self, params, lr=0.01, lr_decay=0, weight_decay=0, eps=1e-10):
+ js_params = to_js([p._js for p in params])
+ return _Optimizer(js_torch.optim.Adagrad.new(
+ js_params, lr, lr_decay, weight_decay, eps
+ ))
+
+
+# ---------------------------------------------------------------------------
+# torch namespace
+# ---------------------------------------------------------------------------
+
+class _Torch:
+ def __init__(self):
+ self.nn = _NNNamespace()
+ self.optim = _OptimNamespace()
+ self.no_grad = _NoGrad
+ self.Tensor = Tensor
+ self.FloatTensor = FloatTensor
+ self.LongTensor = LongTensor
+
+ @property
+ def tensor(self):
+ return Tensor
+
+ # --- creation functions ---
+
+ def _shape_from_args(self, args):
+ return list(args[0]) if len(args) == 1 and isinstance(args[0], (list, tuple)) else list(args)
+
+ def zeros(self, *args, **kwargs):
+ return Tensor(js_torch.zeros(to_js(self._shape_from_args(args))))
+
+ def ones(self, *args, **kwargs):
+ return Tensor(js_torch.ones(to_js(self._shape_from_args(args))))
+
+ def zeros_like(self, input):
+ return Tensor(js_torch.zeros_like(input._js))
+
+ def ones_like(self, input):
+ return Tensor(js_torch.ones_like(input._js))
+
+ def randn(self, *args, **kwargs):
+ return Tensor(js_torch.randn(to_js(self._shape_from_args(args))))
+
+ def rand(self, *args, **kwargs):
+ return Tensor(js_torch.rand(to_js(self._shape_from_args(args))))
+
+ def arange(self, start, end=None, step=1):
+ if end is None:
+ end = start
+ start = 0
+ return Tensor(js_torch.arange(start, end, step))
+
+ def linspace(self, start, end, steps):
+ return Tensor(js_torch.linspace(start, end, steps))
+
+ def empty(self, *args, **kwargs):
+ return Tensor(js_torch.empty(to_js(self._shape_from_args(args))))
+
+ def empty_like(self, input):
+ return Tensor(js_torch.empty_like(input._js))
+
+ def full(self, shape, fill_value):
+ return Tensor(js_torch.full(to_js(list(shape)), fill_value))
+
+ def full_like(self, input, fill_value):
+ return Tensor(js_torch.full_like(input._js, fill_value))
+
+ def rand_like(self, input):
+ return Tensor(js_torch.rand_like(input._js))
+
+ def randn_like(self, input):
+ return Tensor(js_torch.randn_like(input._js))
+
+ def randint_like(self, input, low, high):
+ return Tensor(js_torch.randint_like(input._js, low, high))
+
+ # --- utility functions ---
+
+ def is_tensor(self, obj):
+ return isinstance(obj, Tensor)
+
+ def from_numpy(self, array):
+ return Tensor(array.tolist())
+
+ def is_nonzero(self, input):
+ if input.numel() != 1:
+ raise RuntimeError(
+ "Boolean value of Tensor with more than one element is ambiguous"
+ )
+ return bool(input.item() != 0)
+
+ def numel(self, input):
+ return input.numel()
+
+ # --- functional wrappers ---
+
+ def sum(self, input, dim=None, keepdim=False):
+ return input.sum(dim, keepdim)
+
+ def mean(self, input, dim=None, keepdim=False):
+ return input.mean(dim, keepdim)
+
+ def sigmoid(self, input):
+ return input.sigmoid()
+
+ def relu(self, input):
+ return input.relu()
+
+ def softmax(self, input, dim):
+ return input.softmax(dim)
+
+ def clamp(self, input, min, max):
+ return input.clamp(min, max)
+
+ def clip(self, input, min, max):
+ return self.clamp(input, min, max)
+
+ def flatten(self, input, start_dim=0, end_dim=-1):
+ return input.flatten(start_dim, end_dim)
+
+ def allclose(self, a, b, rtol=1e-5, atol=1e-8, equal_nan=False):
+ return a.allclose(b, rtol, atol, equal_nan)
+
+ def is_grad_enabled(self):
+ return bool(js_torch.is_grad_enabled())
+
+ def cat(self, tensors, dim=0):
+ if isinstance(tensors, Tensor):
+ tensors = [tensors]
+ return Tensor(js_torch.cat(to_js([t._js for t in tensors]), dim))
+
+ def concatenate(self, tensors, dim=0):
+ return self.cat(tensors, dim)
+
+ def concat(self, tensors, dim=0):
+ return self.cat(tensors, dim)
+
+ def stack(self, tensors, dim=0):
+ return Tensor(js_torch.stack(to_js([t._js for t in tensors]), dim))
+
+ def Size(self, shape):
+ return list(shape)
+
+ def __getattr__(self, name):
+ if name.startswith('_'):
+ raise AttributeError(name)
+ def fn(*args, **kwargs):
+ return _wrap_result(js_torch.__getattribute__(name)(*_transform_args(args)))
+ return fn
+
+
+torch = _Torch()
diff --git a/test/backward.test.js b/test/backward.test.js
index d66171b..a195c32 100644
--- a/test/backward.test.js
+++ b/test/backward.test.js
@@ -1,5 +1,5 @@
import { assert } from 'chai';
-import { Tensor } from 'torch';
+import { Tensor } from '@sourceacademy/torch';
describe('Autograd', () => {
it('y = x**2, dy/dx at x=2 is 4', () => {
diff --git a/test/broadcast.test.js b/test/broadcast.test.js
index 524a2a8..104c4f7 100644
--- a/test/broadcast.test.js
+++ b/test/broadcast.test.js
@@ -1,5 +1,5 @@
import { assert } from 'chai';
-import { Tensor, add, __left_index__, __right_index__ } from 'torch';
+import { Tensor, add, __left_index__, __right_index__ } from '@sourceacademy/torch';
describe('Broadcast Index', () => {
it('Broadcast indices should be correct in 2d array', () => {
diff --git a/test/creation.test.ts b/test/creation.test.ts
index 3641f9f..fc2f55a 100644
--- a/test/creation.test.ts
+++ b/test/creation.test.ts
@@ -1,6 +1,6 @@
import { assert } from 'chai';
-import * as torch from 'torch';
-import { Tensor } from 'torch';
+import * as torch from '@sourceacademy/torch';
+import { Tensor } from '@sourceacademy/torch';
describe('Creation Functions', () => {
describe('tensor', () => {
diff --git a/test/custom_operations.test.js b/test/custom_operations.test.js
index 1277394..9deccb0 100644
--- a/test/custom_operations.test.js
+++ b/test/custom_operations.test.js
@@ -1,5 +1,5 @@
import { assert } from 'chai';
-import { Tensor } from 'torch';
+import { Tensor } from '@sourceacademy/torch';
describe('Custom Operations', () => {
diff --git a/test/event_listener.test.js b/test/event_listener.test.js
index 4a2aa2a..647308a 100644
--- a/test/event_listener.test.js
+++ b/test/event_listener.test.js
@@ -1,5 +1,5 @@
import { assert } from 'chai';
-import * as torch from 'torch';
+import * as torch from '@sourceacademy/torch';
describe('Event Bus', () => {
const a = new torch.Tensor([1, 2, 3], { requires_grad: true });
diff --git a/test/export.test.ts b/test/export.test.ts
index 167aba4..d04ad5c 100644
--- a/test/export.test.ts
+++ b/test/export.test.ts
@@ -3,7 +3,7 @@ import { _atenMap } from '../src/export';
import { _getAllOperationNames } from '../src/functions/registry';
// Import torch to trigger operation registration side effects
-import 'torch';
+import '@sourceacademy/torch';
describe('Export', () => {
// List from https://docs.pytorch.org/docs/2.10/user_guide/torch_compiler/torch.compiler_ir.html
diff --git a/test/functional.test.js b/test/functional.test.js
index d209194..4a5f8c3 100644
--- a/test/functional.test.js
+++ b/test/functional.test.js
@@ -1,6 +1,6 @@
import { assert } from 'chai';
-import * as torch from 'torch';
-import { Tensor } from 'torch';
+import * as torch from '@sourceacademy/torch';
+import { Tensor } from '@sourceacademy/torch';
describe('Functional', () => {
describe('Addition', () => {
diff --git a/test/generated.test.js b/test/generated.test.js
index 7d880d9..5514dcb 100644
--- a/test/generated.test.js
+++ b/test/generated.test.js
@@ -1,5 +1,5 @@
-import * as torch from 'torch';
-import { Tensor } from 'torch';
+import * as torch from '@sourceacademy/torch';
+import { Tensor } from '@sourceacademy/torch';
import { assert } from 'chai';
import { testData } from './testcases.gen.js';
diff --git a/test/grad_mode.test.js b/test/grad_mode.test.js
index e6ca5e2..05de450 100644
--- a/test/grad_mode.test.js
+++ b/test/grad_mode.test.js
@@ -1,5 +1,5 @@
import { assert } from 'chai';
-import { Tensor, no_grad, enable_no_grad, disable_no_grad, is_grad_enabled } from 'torch';
+import { Tensor, no_grad, enable_no_grad, disable_no_grad, is_grad_enabled } from '@sourceacademy/torch';
describe('Grad Mode', () => {
afterEach(() => {
diff --git a/test/index.html b/test/index.html
index 29e8f24..b05ca6d 100644
--- a/test/index.html
+++ b/test/index.html
@@ -20,7 +20,7 @@