Source code for world_models.losses.convae_loss
"""Loss functions for World Models training.
This module provides loss functions for training VAE and other world model components.
"""
import torch.nn.functional as F
import torch
[docs]
def conv_vae_loss_fn(
reconst: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logsigma: torch.Tensor
) -> torch.Tensor:
"""Compute the ConvVAE loss function.
The loss combines:
1. Reconstruction loss (MSE) between input and reconstructed images
2. KL divergence between learned latent distribution and prior (standard normal)
The total loss is: BCE + KLD
Args:
reconst: Reconstructed images from the VAE decoder.
x: Original input images.
mu: Mean of the latent distribution.
logsigma: Log variance of the latent distribution.
Returns:
Scalar tensor containing the total VAE loss.
Example:
>>> recon_x, mu, logsigma = vae(images)
>>> loss = conv_vae_loss_fn(recon_x, images, mu, logsigma)
>>> loss.backward()
"""
bce = F.mse_loss(reconst, x, size_average=False)
kld = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
return bce + kld