Source code for world_models.transforms.transforms

from logging import getLogger

from PIL import ImageFilter

import torch
import torchvision.transforms as transforms

logger = getLogger()


[docs] def make_transforms( crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): """Compose image augmentations and normalization for vision model training. Supports random crops, optional flip/color distortion/blur, and returns a `torchvision.transforms.Compose` pipeline. """ logger.info("making imagenet data transforms") def get_color_distortion(s=1.0): # s is the strength of color distortion. color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) rnd_gray = transforms.RandomGrayscale(p=0.2) color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) return color_distort transform_list = [] transform_list += [transforms.RandomResizedCrop(crop_size, scale=crop_scale)] if horizontal_flip: transform_list += [transforms.RandomHorizontalFlip()] if color_distortion: transform_list += [get_color_distortion(s=color_jitter)] if gaussian_blur: transform_list += [GaussianBlur(p=0.5)] transform_list += [transforms.ToTensor()] transform_list += [transforms.Normalize(normalization[0], normalization[1])] transform = transforms.Compose(transform_list) return transform
[docs] class GaussianBlur(object): """Probabilistic Gaussian blur augmentation for PIL images. Applies blur with random radius in a configurable range when sampled. """ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0): self.prob = p self.radius_min = radius_min self.radius_max = radius_max def __call__(self, img): if torch.bernoulli(torch.tensor(self.prob)) == 0: return img radius = self.radius_min + torch.rand(1).item() * ( self.radius_max - self.radius_min ) return img.filter(ImageFilter.GaussianBlur(radius=radius))