



class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# 1. Initialize file path or list of file names.
def __getitem__(self, index):
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0


class MNIST(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set if download:
self.download() if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None:
img = self.transform(img) if self.target_transform is not None:
target = self.target_transform(target) return img, target def __len__(self):
if self.train:
return 60000
return 10000


