You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I try to modify the training script to use my own data set for training (the original version is using CIFAR10/MNIST). My data set is a folder with images in jpg format.
I've got a problem with the data module.
Here's the modified code in train.py.
def create_dset(config):
INV_SCALER = lambda x: x
SCALER = lambda x: x
NULL_VAL = None
PLOT_VAR_IDXS = None
PLOT_VAR_NAMES = None
PAD_VAL = None
if config.dset in ["mnist", "cifar"]:
if config.dset == "mnist":
config.target_points = 28 - config.context_points
datasetCls = stf.data.image_completion.MNISTDset
PLOT_VAR_IDXS = [18, 24]
PLOT_VAR_NAMES = ["18th row", "24th row"]
else:
config.target_points = 32 * 32 - config.context_points
datasetCls = stf.data.image_completion.CIFARDset
PLOT_VAR_IDXS = [0]
PLOT_VAR_NAMES = ["Reds"]
DATA_MODULE = stf.data.DataModule(
datasetCls=datasetCls,
dataset_kwargs={"context_points": config.context_points},
batch_size=config.batch_size,
workers=config.workers,
overfit=args.overfit,
)
return (
DATA_MODULE,
INV_SCALER,
SCALER,
NULL_VAL,
PLOT_VAR_IDXS,
PLOT_VAR_NAMES,
PAD_VAL,
)
# Try to use my own data set here
elif config.dset == "custom":
data_dir = "./spacetimeformer/mydata"
# Define data transformations
transform = transforms.Compose([
transforms.Resize((256, 256)), # Resize the images to a specific size
transforms.ToTensor(), # Convert images to tensors
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the pixel values
])
# Create the custom image dataset
dataset = CUSTOMDset(data_dir, context_points=config.context_points, transform=transform)
DATA_MODULE = stf.data.DataModule(
datasetCls=dataset,
dataset_kwargs={"context_points": config.context_points},
batch_size=config.batch_size,
workers=config.workers,
overfit=args.overfit,
)
# Rest of the values remain the same as in the "mnist" and "cifar" cases
PLOT_VAR_IDXS = [0]
PLOT_VAR_NAMES = ["Reds"]
PAD_VAL = None
return (
DATA_MODULE,
INV_SCALER,
SCALER,
NULL_VAL,
PLOT_VAR_IDXS,
PLOT_VAR_NAMES,
PAD_VAL,
)
I got error message when can the function later. ( data_module, inv_scaler, scaler, null_val, plot_var_idxs, plot_var_names, pad_val, ) = create_dset(args)
When I use it later in test_dataloader = data_module.test_dataloader().
It gives me error message like:
Traceback (most recent call last):
File "train_my.py", line 525, in <module>
main(args)
File "train_my.py", line 435, in main
test_dataloader = data_module.test_dataloader()
File "/home/spacetimeformer/spacetimeformer/data/datamodule.py", line 37, in test_dataloader
return self._make_dloader("test", shuffle=shuffle)
File "/home/spacetimeformer/spacetimeformer/data/datamodule.py", line 44, in _make_dloader
self.datasetCls(**self.dataset_kwargs, split=split),
TypeError: 'CUSTOMDset' object is not callable
I try to print the object and it gives me <spacetimeformer.data.datamodule.DataModule object at 0x7f59cd8f84c0>.
My environment is: Python=3.8, torch=2.0.1.
Any idea about the bug or how to train on custom data will be appreciated! I'm really new to pytorch lightning module and struggling to solve the problem.
The text was updated successfully, but these errors were encountered:
HiFei4869
changed the title
Training with
Training with custom dataset. Error: object is not callable
Aug 3, 2023
I haven't tried a custom dataset yet but I plan to do so soon. I have one question though, why are you not using the setup that's used in the repo? Mainly torch 1.11.0?
I try to modify the training script to use my own data set for training (the original version is using CIFAR10/MNIST). My data set is a folder with images in jpg format.
I've got a problem with the data module.
Here's the modified code in
train.py
.I got error message when can the function later.
( data_module, inv_scaler, scaler, null_val, plot_var_idxs, plot_var_names, pad_val, ) = create_dset(args)
When I use it later in
test_dataloader = data_module.test_dataloader()
.It gives me error message like:
I try to print the object and it gives me
<spacetimeformer.data.datamodule.DataModule object at 0x7f59cd8f84c0>
.My environment is: Python=3.8, torch=2.0.1.
Any idea about the bug or how to train on custom data will be appreciated! I'm really new to pytorch lightning module and struggling to solve the problem.
The text was updated successfully, but these errors were encountered: