TorchWM
Documentation
Getting Started
Package Overview
API Reference
TorchWM
Index
Index
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
W
A
ActionDecoder (class in world_models.vision.dreamer_decoder)
add() (world_models.memory.dreamer_memory.ReplayBuffer method)
add_exploration() (world_models.vision.dreamer_decoder.ActionDecoder method)
AllGather (class in world_models.utils.jepa_utils)
AllReduce (class in world_models.utils.jepa_utils)
AllReduceSum (class in world_models.utils.jepa_utils)
append() (world_models.memory.planet_memory.Episode method)
(world_models.memory.planet_memory.Memory method)
atanh() (world_models.vision.dreamer_decoder.TanhBijector method)
AverageMeter (class in world_models.utils.jepa_utils)
B
backward() (world_models.utils.jepa_utils.AllGather static method)
(world_models.utils.jepa_utils.AllReduce static method)
(world_models.utils.jepa_utils.AllReduceSum static method)
BATCH (world_models.configs.DiTConfig attribute)
BETA_END (world_models.configs.DiTConfig attribute)
BETA_START (world_models.configs.DiTConfig attribute)
C
CHANNELS (world_models.configs.DiTConfig attribute)
classes (world_models.datasets.imagenet1k.ImageNetSubset property)
compute_return() (in module world_models.utils.dreamer_utils)
ConvDecoder (class in world_models.vision.dreamer_decoder)
ConvEncoder (class in world_models.vision.dreamer_encoder)
copy_imgnt_locally() (in module world_models.datasets.imagenet1k)
CosineWDSchedule (class in world_models.utils.jepa_utils)
CSVLogger (class in world_models.utils.jepa_utils)
D
d (world_models.controller.rssm_policy.RSSMPolicy attribute)
DATASET (world_models.configs.DiTConfig attribute)
DDPM (class in world_models.models.diffusion.DDPM)
DenseDecoder (class in world_models.vision.dreamer_decoder)
DEPTH (world_models.configs.DiTConfig attribute)
detach_state() (world_models.models.dreamer_rssm.RSSM method)
device (world_models.controller.rssm_policy.RSSMPolicy attribute)
DiT (class in world_models.models.diffusion.DiT)
DiTConfig (class in world_models.configs)
DreamerConfig (class in world_models.configs)
DROP (world_models.configs.DiTConfig attribute)
dump_scalars_to_pickle() (world_models.utils.dreamer_utils.Logger method)
E
EMA (world_models.configs.DiTConfig attribute)
EMA_DECAY (world_models.configs.DiTConfig attribute)
entropy() (world_models.vision.dreamer_decoder.SampleDist method)
Episode (class in world_models.memory.planet_memory)
EPOCHS (world_models.configs.DiTConfig attribute)
F
filter_dataset_() (world_models.datasets.imagenet1k.ImageNetSubset method)
flush() (world_models.utils.dreamer_utils.Logger method)
forward() (world_models.models.diffusion.DiT.DiT method)
(world_models.models.diffusion.DiT.PatchEmbed method)
(world_models.models.diffusion.DiT.PatchUnEmbed method)
(world_models.models.diffusion.DiT.TransformerBlock method)
(world_models.utils.jepa_utils.AllGather static method)
(world_models.utils.jepa_utils.AllReduce static method)
(world_models.utils.jepa_utils.AllReduceSum static method)
(world_models.vision.dreamer_decoder.ActionDecoder method)
(world_models.vision.dreamer_decoder.ConvDecoder method)
(world_models.vision.dreamer_decoder.DenseDecoder method)
(world_models.vision.dreamer_encoder.ConvEncoder method)
FreezeParameters (class in world_models.utils.dreamer_utils)
G
GaussianBlur (class in world_models.transforms.transforms)
get_dist() (world_models.models.dreamer_rssm.RSSM method)
get_dit_config() (in module world_models.configs)
get_parameters() (in module world_models.utils.dreamer_utils)
gpu_timer() (in module world_models.utils.jepa_utils)
grad_logger() (in module world_models.utils.jepa_utils)
H
H (world_models.controller.rssm_policy.RSSMPolicy attribute)
HEADS (world_models.configs.DiTConfig attribute)
I
ImageNet (class in world_models.datasets.imagenet1k)
ImageNetSubset (class in world_models.datasets.imagenet1k)
imagine_rollout() (world_models.models.dreamer_rssm.RSSM method)
imagine_step() (world_models.models.dreamer_rssm.RSSM method)
IMG_SIZE (world_models.configs.DiTConfig attribute)
init_distributed() (in module world_models.utils.jepa_utils)
init_state() (world_models.models.dreamer_rssm.RSSM method)
J
JEPAConfig (class in world_models.configs)
K
K (world_models.controller.rssm_policy.RSSMPolicy attribute)
L
latent_size (world_models.controller.rssm_policy.RSSMPolicy attribute)
log() (world_models.utils.jepa_utils.CSVLogger method)
log_abs_det_jacobian() (world_models.vision.dreamer_decoder.TanhBijector method)
log_scalar() (world_models.utils.dreamer_utils.Logger method)
log_scalars() (world_models.utils.dreamer_utils.Logger method)
log_videos() (world_models.utils.dreamer_utils.Logger method)
Logger (class in world_models.utils.dreamer_utils)
LR (world_models.configs.DiTConfig attribute)
M
make_cifar10() (in module world_models.datasets.cifar10)
make_imagefolder() (in module world_models.datasets.imagenet1k)
make_imagenet1k() (in module world_models.datasets.imagenet1k)
make_transforms() (in module world_models.transforms.transforms)
MaskCollator (class in world_models.masks.multiblock)
(class in world_models.masks.random)
mean() (world_models.vision.dreamer_decoder.SampleDist method)
Memory (class in world_models.memory.planet_memory)
mode() (world_models.vision.dreamer_decoder.SampleDist method)
module
world_models.configs
world_models.controller.rssm_policy
world_models.datasets.cifar10
world_models.datasets.imagenet1k
world_models.masks.multiblock
world_models.masks.random
world_models.memory.dreamer_memory
world_models.memory.planet_memory
world_models.models
world_models.models.diffusion.DDPM
world_models.models.diffusion.DiT
world_models.models.dreamer_rssm
world_models.transforms.transforms
world_models.utils.dreamer_utils
world_models.utils.jepa_utils
world_models.vision.dreamer_decoder
world_models.vision.dreamer_encoder
N
N (world_models.controller.rssm_policy.RSSMPolicy attribute)
name (world_models.vision.dreamer_decoder.SampleDist property)
O
observe_rollout() (world_models.models.dreamer_rssm.RSSM method)
observe_step() (world_models.models.dreamer_rssm.RSSM method)
P
p_sample() (world_models.models.diffusion.DDPM.DDPM method)
PATCH (world_models.configs.DiTConfig attribute)
PatchEmbed (class in world_models.models.diffusion.DiT)
PatchUnEmbed (class in world_models.models.diffusion.DiT)
poll() (world_models.controller.rssm_policy.RSSMPolicy method)
Q
q_sample() (world_models.models.diffusion.DDPM.DDPM method)
R
repeat_interleave_batch() (in module world_models.utils.jepa_utils)
ReplayBuffer (class in world_models.memory.dreamer_memory)
reset() (world_models.controller.rssm_policy.RSSMPolicy method)
(world_models.utils.jepa_utils.AverageMeter method)
ROOT_PATH (world_models.configs.DiTConfig attribute)
RSSM (class in world_models.models.dreamer_rssm)
rssm (world_models.controller.rssm_policy.RSSMPolicy attribute)
RSSMPolicy (class in world_models.controller.rssm_policy)
S
sample() (world_models.memory.dreamer_memory.ReplayBuffer method)
(world_models.memory.planet_memory.Memory method)
(world_models.models.diffusion.DDPM.DDPM method)
(world_models.vision.dreamer_decoder.SampleDist method)
SampleDist (class in world_models.vision.dreamer_decoder)
seq_to_batch() (world_models.models.dreamer_rssm.RSSM method)
sign (world_models.vision.dreamer_decoder.TanhBijector property)
sinusoidal_time_embedding() (in module world_models.models.diffusion.DiT)
size (world_models.memory.planet_memory.Episode property)
(world_models.memory.planet_memory.Memory property)
stack_states() (world_models.models.dreamer_rssm.RSSM method)
state_size (world_models.controller.rssm_policy.RSSMPolicy attribute)
step() (world_models.masks.multiblock.MaskCollator method)
(world_models.masks.random.MaskCollator method)
(world_models.utils.jepa_utils.CosineWDSchedule method)
(world_models.utils.jepa_utils.WarmupCosineSchedule method)
T
T (world_models.controller.rssm_policy.RSSMPolicy attribute)
TanhBijector (class in world_models.vision.dreamer_decoder)
terminate() (world_models.memory.planet_memory.Episode method)
TIMESTEPS (world_models.configs.DiTConfig attribute)
to_dict() (world_models.configs.JEPAConfig method)
train() (world_models.models.diffusion.DiT.DiT class method)
TransformerBlock (class in world_models.models.diffusion.DiT)
trunc_normal_() (in module world_models.utils.jepa_utils)
U
update() (world_models.utils.jepa_utils.AverageMeter method)
W
WarmupCosineSchedule (class in world_models.utils.jepa_utils)
WIDTH (world_models.configs.DiTConfig attribute)
WORKDIR (world_models.configs.DiTConfig attribute)
world_models.configs
module
world_models.controller.rssm_policy
module
world_models.datasets.cifar10
module
world_models.datasets.imagenet1k
module
world_models.masks.multiblock
module
world_models.masks.random
module
world_models.memory.dreamer_memory
module
world_models.memory.planet_memory
module
world_models.models
module
world_models.models.diffusion.DDPM
module
world_models.models.diffusion.DiT
module
world_models.models.dreamer_rssm
module
world_models.transforms.transforms
module
world_models.utils.dreamer_utils
module
world_models.utils.jepa_utils
module
world_models.vision.dreamer_decoder
module
world_models.vision.dreamer_encoder
module