Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with custom dataset. Error: object is not callable #80

Open
HiFei4869 opened this issue Aug 3, 2023 · 2 comments
Open

Training with custom dataset. Error: object is not callable #80

HiFei4869 opened this issue Aug 3, 2023 · 2 comments

Comments

@HiFei4869
Copy link

HiFei4869 commented Aug 3, 2023

I try to modify the training script to use my own data set for training (the original version is using CIFAR10/MNIST). My data set is a folder with images in jpg format.
I've got a problem with the data module.
Here's the modified code in train.py.

def create_dset(config):
    INV_SCALER = lambda x: x
    SCALER = lambda x: x
    NULL_VAL = None
    PLOT_VAR_IDXS = None
    PLOT_VAR_NAMES = None
    PAD_VAL = None


    if config.dset in ["mnist", "cifar"]:
        if config.dset == "mnist":
            config.target_points = 28 - config.context_points
            datasetCls = stf.data.image_completion.MNISTDset
            PLOT_VAR_IDXS = [18, 24]
            PLOT_VAR_NAMES = ["18th row", "24th row"]
        else:
            config.target_points = 32 * 32 - config.context_points
            datasetCls = stf.data.image_completion.CIFARDset
            PLOT_VAR_IDXS = [0]
            PLOT_VAR_NAMES = ["Reds"]
        DATA_MODULE = stf.data.DataModule(
            datasetCls=datasetCls,
            dataset_kwargs={"context_points": config.context_points},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )

        return (
            DATA_MODULE,
            INV_SCALER,
            SCALER,
            NULL_VAL,
            PLOT_VAR_IDXS,
            PLOT_VAR_NAMES,
            PAD_VAL,
        )
    # Try to use my own data set here
    elif config.dset == "custom":
        data_dir = "./spacetimeformer/mydata"

        # Define data transformations
        transform = transforms.Compose([
            transforms.Resize((256, 256)),  # Resize the images to a specific size
            transforms.ToTensor(),          # Convert images to tensors
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the pixel values
        ])
        
        # Create the custom image dataset
        dataset = CUSTOMDset(data_dir, context_points=config.context_points, transform=transform)
        
        DATA_MODULE = stf.data.DataModule(
            datasetCls=dataset,
            dataset_kwargs={"context_points": config.context_points},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )

        # Rest of the values remain the same as in the "mnist" and "cifar" cases
        PLOT_VAR_IDXS = [0]
        PLOT_VAR_NAMES = ["Reds"]
        PAD_VAL = None
    
        return (
            DATA_MODULE,
            INV_SCALER,
            SCALER,
            NULL_VAL,
            PLOT_VAR_IDXS,
            PLOT_VAR_NAMES,
            PAD_VAL,
        )

I got error message when can the function later.
( data_module, inv_scaler, scaler, null_val, plot_var_idxs, plot_var_names, pad_val, ) = create_dset(args)
When I use it later in test_dataloader = data_module.test_dataloader().
It gives me error message like:

Traceback (most recent call last):
	File "train_my.py", line 525, in <module>
	  main(args)
	File "train_my.py", line 435, in main
	  test_dataloader = data_module.test_dataloader()
	File "/home/spacetimeformer/spacetimeformer/data/datamodule.py", line 37, in test_dataloader
	  return self._make_dloader("test", shuffle=shuffle)
	File "/home/spacetimeformer/spacetimeformer/data/datamodule.py", line 44, in _make_dloader
	  self.datasetCls(**self.dataset_kwargs, split=split),
TypeError: 'CUSTOMDset' object is not callable

I try to print the object and it gives me <spacetimeformer.data.datamodule.DataModule object at 0x7f59cd8f84c0>.

My environment is: Python=3.8, torch=2.0.1.

Any idea about the bug or how to train on custom data will be appreciated! I'm really new to pytorch lightning module and struggling to solve the problem.

@HiFei4869 HiFei4869 changed the title Training with Training with custom dataset. Error: object is not callable Aug 3, 2023
@nchesk
Copy link

nchesk commented Sep 12, 2023

@HiFaye4869 do you have any progress? I’m trying to use it on my own dataset as well

@szdrnja
Copy link

szdrnja commented Sep 13, 2023

I haven't tried a custom dataset yet but I plan to do so soon. I have one question though, why are you not using the setup that's used in the repo? Mainly torch 1.11.0?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants