from __future__ import annotations
import multiprocessing as mp
from multiprocessing import Queue
import numpy as np
import torch
from typing import List, Dict, Any, Callable, Optional
from abc import ABC, abstractmethod
[docs]
class SimWorker(mp.Process):
"""
Worker process that manages a batch of environment instances.
Handles batched stepping for parallel rollouts.
"""
def __init__(
self,
worker_id: int,
env_factory: Callable,
num_envs: int,
command_queue: Queue,
result_queue: Queue,
seed: Optional[int] = None,
):
super().__init__(daemon=True)
self.worker_id = worker_id
self.env_factory = env_factory
self.num_envs = num_envs
self.command_queue = command_queue
self.result_queue = result_queue
self.seed = seed
self.envs: List[Any] = []
self.dones = [False] * num_envs # Track done status per env
self.last_obs: List[Any] = [] # Store last obs for done envs
self.running = True
[docs]
def run(self):
"""Main worker loop."""
# Initialize environments
self.envs = []
self.last_obs = []
for i in range(self.num_envs):
env_seed = (
self.seed + self.worker_id * self.num_envs + i
if self.seed is not None
else None
)
env = self.env_factory()
if env_seed is not None and hasattr(env, "seed"):
env.seed(env_seed)
self.envs.append(env)
# Initial obs will be set in reset
self.last_obs.append(None)
while self.running:
try:
command = self.command_queue.get(timeout=1.0)
if command is None: # Shutdown signal
break
cmd_type, data = command
if cmd_type == "step":
actions = data
results = self._step_batch(actions)
self.result_queue.put(("step_result", results))
elif cmd_type == "reset":
results = self._reset_batch()
self.result_queue.put(("reset_result", results))
elif cmd_type == "render":
results = self._render_batch()
self.result_queue.put(("render_result", results))
elif cmd_type == "close":
self._close_batch()
self.result_queue.put(("close_result", None))
break
except Exception:
continue # Timeout or error, keep running
self._close_batch()
def _step_batch(self, actions: List[np.ndarray]) -> List[Dict[str, Any]]:
"""Step all environments in batch."""
results = []
for i, action in enumerate(actions):
if self.dones[i]:
# Env is done, return last obs with done=True
results.append(
{"obs": self.last_obs[i], "reward": 0.0, "done": True, "info": {}}
)
else:
obs, reward, done, info = self.envs[i].step(action)
self.dones[i] = bool(done)
if done:
# Store last obs for future steps
self.last_obs[i] = obs
results.append(
{"obs": obs, "reward": reward, "done": done, "info": info}
)
return results
def _reset_batch(self) -> List[Dict[str, Any]]:
"""Reset all environments."""
results = []
for i, env in enumerate(self.envs):
obs = env.reset()
self.dones[i] = False
self.last_obs[i] = obs
results.append({"obs": obs})
return results
def _render_batch(self) -> List[np.ndarray]:
"""Render all environments."""
results = []
for env in self.envs:
frame = env.render()
results.append(frame)
return results
def _close_batch(self):
"""Close all environments."""
for env in self.envs:
if hasattr(env, "close"):
env.close()
[docs]
class VectorizedEnv(ABC):
"""
Abstract base class for vectorized environments.
Manages multiple worker processes for parallel simulation.
"""
def __init__(
self,
env_factory: Callable,
num_workers: int = 2,
envs_per_worker: int = 4,
seed: Optional[int] = None,
):
self.env_factory = env_factory
self.num_workers = num_workers
self.envs_per_worker = envs_per_worker
self.total_envs = num_workers * envs_per_worker
self.seed = seed
# Create communication queues
self.command_queues: List[Queue] = [Queue() for _ in range(num_workers)]
self.result_queues: List[Queue] = [Queue() for _ in range(num_workers)]
# Start workers
self.workers = []
for i in range(num_workers):
worker = SimWorker(
worker_id=i,
env_factory=env_factory,
num_envs=envs_per_worker,
command_queue=self.command_queues[i],
result_queue=self.result_queues[i],
seed=seed,
)
worker.start()
self.workers.append(worker)
# Cache observation and action spaces from a dummy env
dummy_env = env_factory()
self.observation_space = dummy_env.observation_space
self.action_space = dummy_env.action_space
if hasattr(dummy_env, "close"):
dummy_env.close()
[docs]
@abstractmethod
def step_batch(self, actions: torch.Tensor) -> Dict[str, Any]:
"""Step all environments with batched actions."""
pass
[docs]
@abstractmethod
def reset_batch(self) -> Dict[str, Any]:
"""Reset all environments."""
pass
[docs]
def render_batch(self) -> List[np.ndarray]:
"""Render all environments."""
for q in self.command_queues:
q.put(("render", None))
results = []
for q in self.result_queues:
cmd, data = q.get()
results.extend(data)
return results
[docs]
def close(self):
"""Shutdown all workers."""
for q in self.command_queues:
q.put(("close", None))
for worker in self.workers:
worker.join(timeout=5.0)
for q in self.command_queues + self.result_queues:
q.close()
[docs]
class TorchVectorizedEnv(VectorizedEnv):
"""
TorchWM-compatible vectorized environment.
Returns batched tensors suitable for PyTorch training.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._wait_for_workers()
def _wait_for_workers(self):
"""Ensure workers are ready by sending a dummy command."""
# This is a simple way to sync; could be improved
pass
[docs]
def step_batch(self, actions: torch.Tensor) -> Dict[str, Any]:
"""
Step all environments with batched actions.
Args:
actions: Tensor of shape (total_envs, action_dim)
Returns:
Dict with 'obs', 'reward', 'done', 'info' tensors
"""
assert actions.shape[0] == self.total_envs
# Split actions by worker
action_chunks = torch.chunk(actions, self.num_workers, dim=0)
# Send commands
for i, chunk in enumerate(action_chunks):
self.command_queues[i].put(("step", chunk.numpy()))
# Collect results
all_obs = []
all_rewards = []
all_dones = []
all_infos = []
for q in self.result_queues:
cmd, results = q.get()
for result in results:
all_obs.append(result["obs"])
all_rewards.append(result["reward"])
all_dones.append(result["done"])
all_infos.append(result["info"])
# Stack into tensors (assuming dict obs with 'image' key for simplicity)
# This assumes all obs are dicts with 'image' key of shape (C, H, W)
obs_images = [obs["image"] for obs in all_obs]
obs_tensor = torch.stack(
[torch.from_numpy(img).float() / 255.0 for img in obs_images]
)
reward_tensor = torch.tensor(all_rewards, dtype=torch.float32)
done_tensor = torch.tensor(all_dones, dtype=torch.bool)
# Return a mapping where obs is nested dict and info remains a list
return {
"obs": {"image": obs_tensor},
"reward": reward_tensor,
"done": done_tensor,
"info": all_infos, # Keep as list for now
}
[docs]
def reset_batch(self) -> Dict[str, Any]:
"""Reset all environments and return initial observations."""
for q in self.command_queues:
q.put(("reset", None))
all_obs = []
for q in self.result_queues:
cmd, results = q.get()
for result in results:
all_obs.append(result["obs"])
obs_images = [obs["image"] for obs in all_obs]
obs_tensor = torch.stack(
[torch.from_numpy(img).float() / 255.0 for img in obs_images]
)
return {"obs": {"image": obs_tensor}}