Source code for world_models.benchmarks.reporting

from __future__ import annotations

from typing import Dict, Any
import csv
import json


[docs] def export_csv(results: Dict[str, Any], path: str): seeds = results.get("seeds", {}) with open(path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["seed", "mean", "std", "episode_returns"]) for seed, data in seeds.items(): writer.writerow( [ seed, data.get("mean", ""), data.get("std", ""), json.dumps(data.get("episode_returns", [])), ] )
[docs] def export_markdown(results: Dict[str, Any], path: str): seeds = results.get("seeds", {}) lines = [] lines.append("| seed | mean | std | episode_returns |") lines.append("|---:|---:|---:|:---|") for seed, data in seeds.items(): # format episode returns compactly er = data.get("episode_returns", []) er_str = ", ".join([f"{v:.1f}" for v in er]) if er else "[]" mean = data.get("mean", 0.0) std = data.get("std", 0.0) lines.append(f"| {seed} | {mean:.3f} | {std:.3f} | {er_str} |") agg = results.get("aggregate", {}) lines.append("") lines.append("**Aggregate**") lines.append(f"Mean: {agg.get('mean', 0.0):.3f} ") lines.append(f"Median: {agg.get('median', 0.0):.3f} ") lines.append(f"IQM: {agg.get('iqm', 0.0):.3f} ") with open(path, "w") as f: f.write("\n".join(lines))
[docs] def export_latex( results: Dict[str, Any], path: str, caption: str = "Benchmark results" ): seeds = results.get("seeds", {}) agg = results.get("aggregate", {}) lines = [] lines.append("\\begin{table}[ht]") lines.append("\\centering") lines.append("\\begin{tabular}{lrrr}") lines.append("\\toprule") # header row (LaTeX uses \\\\ to end a row) lines.append("Seed & Mean & Std & Episode Returns " + "\\\\") lines.append("\\midrule") for seed, data in seeds.items(): er = data.get("episode_returns", []) er_str = ", ".join([f"{v:.1f}" for v in er]) if er else "--" mean = data.get("mean", 0.0) std = data.get("std", 0.0) lines.append(f"{seed} & {mean:.1f} & {std:.1f} & {er_str} " + "\\\\") lines.append("\\midrule") lines.append(f"Aggregate IQM & {agg.get('iqm', 0.0):.3f} & & " + "\\\\") lines.append("\\bottomrule") lines.append("\\end{tabular}") lines.append(f"\\caption{{{caption}}}") lines.append("\\end{table}") with open(path, "w") as f: f.write("\n".join(lines))