Source code for world_models.transforms.image
from typing import Any
from logging import getLogger
from PIL import ImageFilter
import torch
import torchvision.transforms as transforms
logger = getLogger()
[docs]
def make_transforms(
crop_size: int = 224,
crop_scale: tuple[float, float] = (0.3, 1.0),
color_jitter: float = 1.0,
horizontal_flip: bool = False,
color_distortion: bool = False,
gaussian_blur: bool = False,
normalization: tuple[tuple[float, ...], tuple[float, ...]] = (
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225),
),
) -> Any:
"""Compose image augmentations and normalization for vision model training.
Supports random crops, optional flip/color distortion/blur, and returns a
`torchvision.transforms.Compose` pipeline.
"""
logger.info("making imagenet data transforms")
def get_color_distortion(s: float = 1.0) -> Any:
# s is the strength of color distortion.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
return color_distort
transform_list = []
transform_list += [transforms.RandomResizedCrop(crop_size, scale=crop_scale)]
if horizontal_flip:
transform_list += [transforms.RandomHorizontalFlip()]
if color_distortion:
transform_list += [get_color_distortion(s=color_jitter)]
if gaussian_blur:
transform_list += [GaussianBlur(p=0.5)]
transform_list += [transforms.ToTensor()]
transform_list += [transforms.Normalize(normalization[0], normalization[1])]
transform = transforms.Compose(transform_list)
return transform
[docs]
class GaussianBlur(object):
"""Probabilistic Gaussian blur augmentation for PIL images.
Applies blur with random radius in a configurable range when sampled.
"""
def __init__(
self, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0
) -> None:
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img: Any) -> Any:
if torch.bernoulli(torch.tensor(self.prob)) == 0:
return img
radius = self.radius_min + torch.rand(1).item() * (
self.radius_max - self.radius_min
)
return img.filter(ImageFilter.GaussianBlur(radius=radius))