Skip to content

Commit

Permalink
hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Apr 8, 2024
1 parent 2059dca commit 0575a8f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
12 changes: 12 additions & 0 deletions jaxonloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def __len__(self) -> int:
def __getitem__(self, idx):
return self.data[idx]

def split(self, ratio: float) -> tuple["SingleArrayDataset", "SingleArrayDataset"]:
split = int(len(self.data) * ratio)
return SingleArrayDataset(self.data[:split]), SingleArrayDataset(
self.data[split:]
)


class DataTargetDataset(JaxonDataset):
def __init__(self, data: NDArray, target: NDArray):
Expand All @@ -43,3 +49,9 @@ def __len__(self) -> int:

def __getitem__(self, idx):
return self.data[idx], self.target[idx]

def split(self, ratio: float) -> tuple["DataTargetDataset", "DataTargetDataset"]:
split = int(len(self.data) * ratio)
return DataTargetDataset(
self.data[:split], self.target[:split]
), DataTargetDataset(self.data[split:], self.target[split:])
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

@nox.session
def tests(session):
session.install("uv")
session.install("pip")
session.run("uv", "pip", "install", ".[dev]")
session.run("uv", "pip", "install", "-e", ".")
Expand Down

0 comments on commit 0575a8f

Please sign in to comment.