diff --git a/beginner_source/basics/transforms_tutorial.py b/beginner_source/basics/transforms_tutorial.py index 33076958bf..03ebbeee73 100644 --- a/beginner_source/basics/transforms_tutorial.py +++ b/beginner_source/basics/transforms_tutorial.py @@ -23,42 +23,47 @@ The FashionMNIST features are in PIL Image format, and the labels are integers. For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors. -To make these transformations, we use ``ToTensor`` and ``Lambda``. +To make these transformations, we use the ``torchvision.transforms.v2`` API along with ``torch.nn.functional.one_hot``. """ import torch +import torch.nn.functional as F from torchvision import datasets -from torchvision.transforms import ToTensor, Lambda +from torchvision.transforms import v2 ds = datasets.FashionMNIST( root="data", train=True, download=True, - transform=ToTensor(), - target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) + transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), + target_transform=v2.Lambda( + lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() + ), ) ################################################# -# ToTensor() +# ToImage() and ToDtype() # ------------------------------- # -# `ToTensor `_ -# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales -# the image's pixel intensity values in the range [0., 1.] +# The ``torchvision.transforms.v2`` API replaces the legacy ``ToTensor`` transform with a two-step pipeline. +# `v2.ToImage `_ +# converts a PIL image or NumPy ``ndarray`` into a ``torchvision.tv_tensors.Image`` tensor, and +# `v2.ToDtype `_ +# with ``scale=True`` casts it to ``float32`` and scales the pixel intensity values to the range [0., 1.]. # ############################################## # Lambda Transforms # ------------------------------- # -# Lambda transforms apply any user-defined lambda function. Here, we define a function -# to turn the integer into a one-hot encoded tensor. -# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls -# `scatter_ `_ which assigns a -# ``value=1`` on the index as given by the label ``y``. +# Lambda transforms apply any user-defined lambda function. Here, we use +# `torch.nn.functional.one_hot `_ +# to turn the integer label into a one-hot encoded tensor of size 10 (the number of labels in our dataset), +# then cast it to ``float`` to match the expected dtype. -target_transform = Lambda(lambda y: torch.zeros( - 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) +target_transform = v2.Lambda( + lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() +) ###################################################################### # -------------- @@ -67,4 +72,5 @@ ################################################################# # Further Reading # ~~~~~~~~~~~~~~~~~ -# - `torchvision.transforms API `_ +# - `Getting started with transforms v2 `_ +# - `torchvision.transforms.v2 API `_