Commit
·
df508c1
1
Parent(s):
341188c
PPO playing Acrobot-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c
Browse files- README.md +4 -3
- benchmark_publish.py +91 -0
- colab_requirements.txt +3 -1
- enjoy.py +13 -88
- huggingface_publish.py +177 -0
- lambda_labs/lambda_requirements.txt +3 -1
- poetry.lock +47 -1
- publish/markdown_format.py +210 -0
- pyproject.toml +2 -0
- replay.meta.json +1 -1
- runner/config.py +11 -8
- runner/evaluate.py +103 -0
- runner/running_utils.py +4 -4
- shared/callbacks/eval_callback.py +27 -20
- shared/policy/policy.py +3 -1
- shared/stats.py +3 -24
README.md
CHANGED
@@ -42,7 +42,7 @@ By default training goes to a rl-algo-impls project while benchmarks go to
|
|
42 |
rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
|
43 |
models and the model weights are uploaded to WandB.
|
44 |
|
45 |
-
Before doing
|
46 |
login`.
|
47 |
|
48 |
|
@@ -50,7 +50,7 @@ login`.
|
|
50 |
## Usage
|
51 |
/sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
|
52 |
|
53 |
-
Note: While the model state dictionary and hyperaparameters are saved, the
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
[5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
|
@@ -68,7 +68,8 @@ notebook.
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
-
commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
|
|
|
72 |
|
73 |
```
|
74 |
python train.py --algo ppo --env Acrobot-v1 --seed 4
|
|
|
42 |
rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
|
43 |
models and the model weights are uploaded to WandB.
|
44 |
|
45 |
+
Before doing anything below, you'll need to create a wandb account and run `wandb
|
46 |
login`.
|
47 |
|
48 |
|
|
|
50 |
## Usage
|
51 |
/sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
|
52 |
|
53 |
+
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
[5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
|
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
+
commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). While
|
72 |
+
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
75 |
python train.py --algo ppo --env Acrobot-v1 --seed 4
|
benchmark_publish.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import subprocess
|
3 |
+
import wandb
|
4 |
+
import wandb.apis.public
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
from multiprocessing.pool import ThreadPool
|
8 |
+
from typing import List, NamedTuple
|
9 |
+
|
10 |
+
|
11 |
+
class RunGroup(NamedTuple):
|
12 |
+
algo: str
|
13 |
+
env_id: str
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument(
|
19 |
+
"--wandb-project-name",
|
20 |
+
type=str,
|
21 |
+
default="rl-algo-impls-benchmarks",
|
22 |
+
help="WandB project name to load runs from",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--wandb-entity",
|
26 |
+
type=str,
|
27 |
+
default=None,
|
28 |
+
help="WandB team of project. None uses default entity",
|
29 |
+
)
|
30 |
+
parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
|
31 |
+
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
|
32 |
+
parser.add_argument(
|
33 |
+
"--envs", type=str, nargs="*", help="Optional filter down to these envs"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--huggingface-user",
|
37 |
+
type=str,
|
38 |
+
default=None,
|
39 |
+
help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--pool-size",
|
43 |
+
type=int,
|
44 |
+
default=3,
|
45 |
+
help="How many publish jobs can run in parallel",
|
46 |
+
)
|
47 |
+
parser.set_defaults(
|
48 |
+
wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
|
49 |
+
wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
|
50 |
+
envs=["CartPole-v1", "Acrobot-v1"],
|
51 |
+
)
|
52 |
+
args = parser.parse_args()
|
53 |
+
print(args)
|
54 |
+
|
55 |
+
api = wandb.Api()
|
56 |
+
all_runs = api.runs(
|
57 |
+
f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
|
58 |
+
)
|
59 |
+
|
60 |
+
required_tags = set(args.wandb_tags)
|
61 |
+
runs: List[wandb.apis.public.Run] = [
|
62 |
+
r
|
63 |
+
for r in all_runs
|
64 |
+
if required_tags.issubset(set(r.config.get("wandb_tags", [])))
|
65 |
+
]
|
66 |
+
|
67 |
+
runs_paths_by_group = defaultdict(list)
|
68 |
+
for r in runs:
|
69 |
+
algo = r.config["algo"]
|
70 |
+
env = r.config["env"]
|
71 |
+
if args.envs and env not in args.envs:
|
72 |
+
continue
|
73 |
+
run_group = RunGroup(algo, env)
|
74 |
+
runs_paths_by_group[run_group].append("/".join(r.path))
|
75 |
+
|
76 |
+
def run(run_paths: List[str]) -> None:
|
77 |
+
publish_args = ["python", "huggingface_publish.py"]
|
78 |
+
publish_args.append("--wandb-run-paths")
|
79 |
+
publish_args.extend(run_paths)
|
80 |
+
publish_args.append("--wandb-report-url")
|
81 |
+
publish_args.append(args.wandb_report_url)
|
82 |
+
if args.huggingface_user:
|
83 |
+
publish_args.append("--huggingface-user")
|
84 |
+
publish_args.append(args.huggingface_user)
|
85 |
+
subprocess.run(publish_args)
|
86 |
+
|
87 |
+
tp = ThreadPool(args.pool_size)
|
88 |
+
for run_paths in runs_paths_by_group.values():
|
89 |
+
tp.apply_async(run, (run_paths,))
|
90 |
+
tp.close()
|
91 |
+
tp.join()
|
colab_requirements.txt
CHANGED
@@ -4,4 +4,6 @@ gym[box2d] >= 0.21.0, < 0.22
|
|
4 |
pyglet == 1.5.27
|
5 |
wandb >= 0.13.9, < 0.14
|
6 |
pyvirtualdisplay == 3.0
|
7 |
-
pybullet >= 3.2.5, < 3.3
|
|
|
|
|
|
4 |
pyglet == 1.5.27
|
5 |
wandb >= 0.13.9, < 0.14
|
6 |
pyvirtualdisplay == 3.0
|
7 |
+
pybullet >= 3.2.5, < 3.3
|
8 |
+
tabulate >= 0.9.0, < 0.10
|
9 |
+
huggingface-hub >= 0.12.0, < 0.13
|
enjoy.py
CHANGED
@@ -3,103 +3,28 @@ import os
|
|
3 |
|
4 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
|
6 |
-
import
|
7 |
-
import
|
8 |
-
|
9 |
-
from dataclasses import dataclass
|
10 |
-
from typing import Optional
|
11 |
-
|
12 |
-
from runner.env import make_eval_env
|
13 |
-
from runner.config import Config, RunArgs
|
14 |
-
from runner.running_utils import (
|
15 |
-
base_parser,
|
16 |
-
load_hyperparams,
|
17 |
-
set_seeds,
|
18 |
-
get_device,
|
19 |
-
make_policy,
|
20 |
-
)
|
21 |
-
from shared.callbacks.eval_callback import evaluate
|
22 |
-
|
23 |
-
|
24 |
-
@dataclass
|
25 |
-
class EvalArgs(RunArgs):
|
26 |
-
render: bool = True
|
27 |
-
best: bool = True
|
28 |
-
n_envs: int = 1
|
29 |
-
n_episodes: int = 3
|
30 |
-
deterministic: Optional[bool] = None
|
31 |
-
wandb_run_path: Optional[str] = None
|
32 |
|
33 |
|
34 |
if __name__ == "__main__":
|
35 |
-
parser = base_parser()
|
36 |
parser.add_argument("--render", default=True, type=bool)
|
37 |
parser.add_argument("--best", default=True, type=bool)
|
38 |
parser.add_argument("--n_envs", default=1, type=int)
|
39 |
parser.add_argument("--n_episodes", default=3, type=int)
|
40 |
-
parser.add_argument("--deterministic", default=None, type=bool)
|
|
|
|
|
|
|
|
|
41 |
parser.add_argument("--wandb-run-path", default=None, type=str)
|
42 |
parser.set_defaults(
|
43 |
-
|
44 |
)
|
|
|
|
|
|
|
45 |
args = EvalArgs(**vars(parser.parse_args()))
|
46 |
|
47 |
-
|
48 |
-
import wandb
|
49 |
-
|
50 |
-
api = wandb.Api()
|
51 |
-
run = api.run(args.wandb_run_path)
|
52 |
-
hyperparams = run.config
|
53 |
-
|
54 |
-
args.algo = hyperparams["algo"]
|
55 |
-
args.env = hyperparams["env"]
|
56 |
-
args.use_deterministic_algorithms = hyperparams.get(
|
57 |
-
"use_deterministic_algorithms", True
|
58 |
-
)
|
59 |
-
|
60 |
-
config = Config(args, hyperparams, os.path.dirname(__file__))
|
61 |
-
model_path = config.model_dir_path(best=args.best, downloaded=True)
|
62 |
-
|
63 |
-
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
|
64 |
-
run.file(model_archive_name).download()
|
65 |
-
if os.path.isdir(model_path):
|
66 |
-
shutil.rmtree(model_path)
|
67 |
-
shutil.unpack_archive(model_archive_name, model_path)
|
68 |
-
os.remove(model_archive_name)
|
69 |
-
else:
|
70 |
-
hyperparams = load_hyperparams(args.algo, args.env, os.path.dirname(__file__))
|
71 |
-
|
72 |
-
config = Config(args, hyperparams, os.path.dirname(__file__))
|
73 |
-
model_path = config.model_dir_path(best=args.best)
|
74 |
-
|
75 |
-
print(args)
|
76 |
-
|
77 |
-
set_seeds(args.seed, args.use_deterministic_algorithms)
|
78 |
-
|
79 |
-
env = make_eval_env(
|
80 |
-
config,
|
81 |
-
override_n_envs=args.n_envs,
|
82 |
-
render=args.render,
|
83 |
-
normalize_load_path=model_path,
|
84 |
-
**config.env_hyperparams,
|
85 |
-
)
|
86 |
-
device = get_device(config.device, env)
|
87 |
-
policy = make_policy(
|
88 |
-
args.algo,
|
89 |
-
env,
|
90 |
-
device,
|
91 |
-
load_path=model_path,
|
92 |
-
**config.policy_hyperparams,
|
93 |
-
).eval()
|
94 |
-
|
95 |
-
if args.deterministic is None:
|
96 |
-
deterministic = config.eval_params.get("deterministic", True)
|
97 |
-
else:
|
98 |
-
deterministic = args.deterministic
|
99 |
-
evaluate(
|
100 |
-
env,
|
101 |
-
policy,
|
102 |
-
args.n_episodes,
|
103 |
-
render=args.render,
|
104 |
-
deterministic=deterministic,
|
105 |
-
)
|
|
|
3 |
|
4 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
|
6 |
+
from runner.evaluate import EvalArgs, evaluate_model
|
7 |
+
from runner.running_utils import base_parser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
if __name__ == "__main__":
|
11 |
+
parser = base_parser(multiple=False)
|
12 |
parser.add_argument("--render", default=True, type=bool)
|
13 |
parser.add_argument("--best", default=True, type=bool)
|
14 |
parser.add_argument("--n_envs", default=1, type=int)
|
15 |
parser.add_argument("--n_episodes", default=3, type=int)
|
16 |
+
parser.add_argument("--deterministic-eval", default=None, type=bool)
|
17 |
+
parser.add_argument(
|
18 |
+
"--no-print-returns", action="store_true", help="Limit printing"
|
19 |
+
)
|
20 |
+
# wandb-run-path overrides base RunArgs
|
21 |
parser.add_argument("--wandb-run-path", default=None, type=str)
|
22 |
parser.set_defaults(
|
23 |
+
algo=["ppo"],
|
24 |
)
|
25 |
+
args = parser.parse_args()
|
26 |
+
args.algo = args.algo[0]
|
27 |
+
args.env = args.env[0]
|
28 |
args = EvalArgs(**vars(parser.parse_args()))
|
29 |
|
30 |
+
evaluate_model(args, os.path.dirname(__file__))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
huggingface_publish.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import requests
|
7 |
+
import shutil
|
8 |
+
import subprocess
|
9 |
+
import tempfile
|
10 |
+
import wandb
|
11 |
+
import wandb.apis.public
|
12 |
+
|
13 |
+
from typing import List, Optional
|
14 |
+
|
15 |
+
from huggingface_hub.hf_api import HfApi, upload_folder
|
16 |
+
from huggingface_hub.repocard import metadata_save
|
17 |
+
from publish.markdown_format import EvalTableData, model_card_text
|
18 |
+
from runner.evaluate import EvalArgs, evaluate_model
|
19 |
+
from runner.env import make_eval_env
|
20 |
+
from shared.callbacks.eval_callback import evaluate
|
21 |
+
from wrappers.vec_episode_recorder import VecEpisodeRecorder
|
22 |
+
|
23 |
+
|
24 |
+
def publish(
|
25 |
+
wandb_run_paths: List[str],
|
26 |
+
wandb_report_url: str,
|
27 |
+
huggingface_user: Optional[str] = None,
|
28 |
+
huggingface_token: Optional[str] = None,
|
29 |
+
) -> None:
|
30 |
+
api = wandb.Api()
|
31 |
+
runs = [api.run(rp) for rp in wandb_run_paths]
|
32 |
+
algo = runs[0].config["algo"]
|
33 |
+
env = runs[0].config["env"]
|
34 |
+
evaluations = [
|
35 |
+
evaluate_model(
|
36 |
+
EvalArgs(
|
37 |
+
algo,
|
38 |
+
env,
|
39 |
+
seed=r.config.get("seed", None),
|
40 |
+
render=False,
|
41 |
+
best=True,
|
42 |
+
n_envs=None,
|
43 |
+
n_episodes=10,
|
44 |
+
no_print_returns=True,
|
45 |
+
wandb_run_path="/".join(r.path),
|
46 |
+
),
|
47 |
+
os.path.dirname(__file__),
|
48 |
+
)
|
49 |
+
for r in runs
|
50 |
+
]
|
51 |
+
run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
|
52 |
+
table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
|
53 |
+
best_eval = sorted(
|
54 |
+
table_data, key=lambda d: d.evaluation.stats.score, reverse=True
|
55 |
+
)[0]
|
56 |
+
|
57 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
58 |
+
_, (policy, stats, config) = best_eval
|
59 |
+
|
60 |
+
repo_name = config.model_name(include_seed=False)
|
61 |
+
repo_dir_path = os.path.join(tmpdirname, repo_name)
|
62 |
+
# Locally clone this repo to a temp directory
|
63 |
+
subprocess.run(["git", "clone", ".", repo_dir_path])
|
64 |
+
shutil.rmtree(os.path.join(repo_dir_path, ".git"))
|
65 |
+
model_path = config.model_dir_path(best=True, downloaded=True)
|
66 |
+
shutil.copytree(
|
67 |
+
model_path,
|
68 |
+
os.path.join(
|
69 |
+
repo_dir_path, "saved_models", config.model_dir_name(best=True)
|
70 |
+
),
|
71 |
+
)
|
72 |
+
|
73 |
+
github_url = "https://github.com/sgoodfriend/rl-algo-impls"
|
74 |
+
commit_hash = run_metadata.get("git", {}).get("commit", None)
|
75 |
+
card_text = model_card_text(
|
76 |
+
algo,
|
77 |
+
env,
|
78 |
+
github_url,
|
79 |
+
commit_hash,
|
80 |
+
wandb_report_url,
|
81 |
+
table_data,
|
82 |
+
best_eval,
|
83 |
+
)
|
84 |
+
readme_filepath = os.path.join(repo_dir_path, "README.md")
|
85 |
+
os.remove(readme_filepath)
|
86 |
+
with open(readme_filepath, "w") as f:
|
87 |
+
f.write(card_text)
|
88 |
+
|
89 |
+
metadata = {
|
90 |
+
"library_name": "rl-algo-impls",
|
91 |
+
"tags": [
|
92 |
+
env,
|
93 |
+
algo,
|
94 |
+
"deep-reinforcement-learning",
|
95 |
+
"reinforcement-learning",
|
96 |
+
],
|
97 |
+
"model-index": [
|
98 |
+
{
|
99 |
+
"name": algo,
|
100 |
+
"results": [
|
101 |
+
{
|
102 |
+
"metrics": [
|
103 |
+
{
|
104 |
+
"type": "mean_reward",
|
105 |
+
"value": str(stats.score),
|
106 |
+
"name": "mean_reward",
|
107 |
+
}
|
108 |
+
],
|
109 |
+
"task": {
|
110 |
+
"type": "reinforcement-learning",
|
111 |
+
"name": "reinforcement-learning",
|
112 |
+
},
|
113 |
+
"dataset": {
|
114 |
+
"name": env,
|
115 |
+
"type": env,
|
116 |
+
},
|
117 |
+
}
|
118 |
+
],
|
119 |
+
}
|
120 |
+
],
|
121 |
+
}
|
122 |
+
metadata_save(readme_filepath, metadata)
|
123 |
+
|
124 |
+
video_env = VecEpisodeRecorder(
|
125 |
+
make_eval_env(
|
126 |
+
config,
|
127 |
+
override_n_envs=1,
|
128 |
+
normalize_load_path=model_path,
|
129 |
+
**config.env_hyperparams,
|
130 |
+
),
|
131 |
+
os.path.join(repo_dir_path, "replay"),
|
132 |
+
max_video_length=3600,
|
133 |
+
)
|
134 |
+
evaluate(
|
135 |
+
video_env,
|
136 |
+
policy,
|
137 |
+
1,
|
138 |
+
deterministic=config.eval_params.get("deterministic", True),
|
139 |
+
)
|
140 |
+
|
141 |
+
api = HfApi()
|
142 |
+
huggingface_user = huggingface_user or api.whoami()["name"]
|
143 |
+
huggingface_repo = f"{huggingface_user}/{repo_name}"
|
144 |
+
api.create_repo(
|
145 |
+
token=huggingface_token,
|
146 |
+
repo_id=huggingface_repo,
|
147 |
+
private=True,
|
148 |
+
exist_ok=True,
|
149 |
+
)
|
150 |
+
repo_url = upload_folder(
|
151 |
+
repo_id=huggingface_repo,
|
152 |
+
folder_path=repo_dir_path,
|
153 |
+
path_in_repo="",
|
154 |
+
commit_message=f"{algo.upper()} playing {env} from {github_url}/tree/{commit_hash}",
|
155 |
+
token=huggingface_token,
|
156 |
+
)
|
157 |
+
print(f"Pushed model to the hub: {repo_url}")
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
parser = argparse.ArgumentParser()
|
162 |
+
parser.add_argument(
|
163 |
+
"--wandb-run-paths",
|
164 |
+
type=str,
|
165 |
+
nargs="+",
|
166 |
+
help="Run paths of the form entity/project/run_id",
|
167 |
+
)
|
168 |
+
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
|
169 |
+
parser.add_argument(
|
170 |
+
"--huggingface-user",
|
171 |
+
type=str,
|
172 |
+
help="Huggingface user or team to upload model cards",
|
173 |
+
default=None,
|
174 |
+
)
|
175 |
+
args = parser.parse_args()
|
176 |
+
print(args)
|
177 |
+
publish(**vars(args))
|
lambda_labs/lambda_requirements.txt
CHANGED
@@ -6,4 +6,6 @@ gym[box2d] >= 0.21.0, < 0.22
|
|
6 |
pyglet == 1.5.27
|
7 |
wandb >= 0.13.9, < 0.14
|
8 |
pyvirtualdisplay == 3.0
|
9 |
-
pybullet >= 3.2.5, < 3.3
|
|
|
|
|
|
6 |
pyglet == 1.5.27
|
7 |
wandb >= 0.13.9, < 0.14
|
8 |
pyvirtualdisplay == 3.0
|
9 |
+
pybullet >= 3.2.5, < 3.3
|
10 |
+
tabulate >= 0.9.0, < 0.10
|
11 |
+
huggingface-hub >= 0.12.0, < 0.13
|
poetry.lock
CHANGED
@@ -1217,6 +1217,37 @@ chardet = ["chardet (>=2.2)"]
|
|
1217 |
genshi = ["genshi"]
|
1218 |
lxml = ["lxml"]
|
1219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1220 |
[[package]]
|
1221 |
name = "idna"
|
1222 |
version = "3.4"
|
@@ -3687,6 +3718,21 @@ pure-eval = "*"
|
|
3687 |
[package.extras]
|
3688 |
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
3689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3690 |
[[package]]
|
3691 |
name = "tensorboard"
|
3692 |
version = "2.11.0"
|
@@ -4152,4 +4198,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
|
|
4152 |
[metadata]
|
4153 |
lock-version = "2.0"
|
4154 |
python-versions = "~3.10"
|
4155 |
-
content-hash = "
|
|
|
1217 |
genshi = ["genshi"]
|
1218 |
lxml = ["lxml"]
|
1219 |
|
1220 |
+
[[package]]
|
1221 |
+
name = "huggingface-hub"
|
1222 |
+
version = "0.12.0"
|
1223 |
+
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
1224 |
+
category = "main"
|
1225 |
+
optional = false
|
1226 |
+
python-versions = ">=3.7.0"
|
1227 |
+
files = [
|
1228 |
+
{file = "huggingface_hub-0.12.0-py3-none-any.whl", hash = "sha256:93809eabbfb2058a808bddf8b2a70f645de3f9df73ce87ddf5163d4c74b71c0c"},
|
1229 |
+
{file = "huggingface_hub-0.12.0.tar.gz", hash = "sha256:da82c9ec8f9d8f976ffd3fd8249d20bb35c2dd3145a9f7ca1106f0ebefd9afa0"},
|
1230 |
+
]
|
1231 |
+
|
1232 |
+
[package.dependencies]
|
1233 |
+
filelock = "*"
|
1234 |
+
packaging = ">=20.9"
|
1235 |
+
pyyaml = ">=5.1"
|
1236 |
+
requests = "*"
|
1237 |
+
tqdm = ">=4.42.1"
|
1238 |
+
typing-extensions = ">=3.7.4.3"
|
1239 |
+
|
1240 |
+
[package.extras]
|
1241 |
+
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (==22.3)", "flake8 (>=3.8.3)", "flake8-bugbear", "isort (>=5.5.4)", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
1242 |
+
cli = ["InquirerPy (==0.3.4)"]
|
1243 |
+
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (==22.3)", "flake8 (>=3.8.3)", "flake8-bugbear", "isort (>=5.5.4)", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
1244 |
+
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
1245 |
+
quality = ["black (==22.3)", "flake8 (>=3.8.3)", "flake8-bugbear", "isort (>=5.5.4)", "mypy (==0.982)"]
|
1246 |
+
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
1247 |
+
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "isort (>=5.5.4)", "jedi", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile"]
|
1248 |
+
torch = ["torch"]
|
1249 |
+
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
1250 |
+
|
1251 |
[[package]]
|
1252 |
name = "idna"
|
1253 |
version = "3.4"
|
|
|
3718 |
[package.extras]
|
3719 |
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
3720 |
|
3721 |
+
[[package]]
|
3722 |
+
name = "tabulate"
|
3723 |
+
version = "0.9.0"
|
3724 |
+
description = "Pretty-print tabular data"
|
3725 |
+
category = "main"
|
3726 |
+
optional = false
|
3727 |
+
python-versions = ">=3.7"
|
3728 |
+
files = [
|
3729 |
+
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
|
3730 |
+
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
|
3731 |
+
]
|
3732 |
+
|
3733 |
+
[package.extras]
|
3734 |
+
widechars = ["wcwidth"]
|
3735 |
+
|
3736 |
[[package]]
|
3737 |
name = "tensorboard"
|
3738 |
version = "2.11.0"
|
|
|
4198 |
[metadata]
|
4199 |
lock-version = "2.0"
|
4200 |
python-versions = "~3.10"
|
4201 |
+
content-hash = "89d4861857be881d3c6cb591d17fb98396b8c117b24a8d4ce4b6593ac8048670"
|
publish/markdown_format.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import wandb.apis.public
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
from dataclasses import dataclass, asdict
|
8 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
|
11 |
+
from runner.evaluate import Evaluation
|
12 |
+
|
13 |
+
EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class EvaluationRow:
|
18 |
+
algo: str
|
19 |
+
env: str
|
20 |
+
seed: Optional[int]
|
21 |
+
reward_mean: float
|
22 |
+
reward_std: float
|
23 |
+
eval_episodes: int
|
24 |
+
best: str
|
25 |
+
wandb_url: str
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
|
29 |
+
results = defaultdict(list)
|
30 |
+
for r in rows:
|
31 |
+
for k, v in asdict(r).items():
|
32 |
+
results[k].append(v)
|
33 |
+
return pd.DataFrame(results)
|
34 |
+
|
35 |
+
|
36 |
+
class EvalTableData(NamedTuple):
|
37 |
+
run: wandb.apis.public.Run
|
38 |
+
evaluation: Evaluation
|
39 |
+
|
40 |
+
|
41 |
+
def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
|
42 |
+
best_stats = sorted(
|
43 |
+
[d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
|
44 |
+
)[0]
|
45 |
+
table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
|
46 |
+
rows = [
|
47 |
+
EvaluationRow(
|
48 |
+
config.algo,
|
49 |
+
config.env_id,
|
50 |
+
config.seed(),
|
51 |
+
stats.score.mean,
|
52 |
+
stats.score.std,
|
53 |
+
len(stats),
|
54 |
+
"*" if stats == best_stats else "",
|
55 |
+
f"[wandb]({r.url})",
|
56 |
+
)
|
57 |
+
for (r, (_, stats, config)) in table_data
|
58 |
+
]
|
59 |
+
df = EvaluationRow.data_frame(rows)
|
60 |
+
return df.to_markdown(index=False)
|
61 |
+
|
62 |
+
|
63 |
+
def github_project_link(github_url: str) -> str:
|
64 |
+
return f"[{urlparse(github_url).path}]({github_url})"
|
65 |
+
|
66 |
+
|
67 |
+
def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
|
68 |
+
algo_caps = algo.upper()
|
69 |
+
lines = [
|
70 |
+
f"# **{algo_caps}** Agent playing **{env}**",
|
71 |
+
f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
|
72 |
+
f"the {github_project_link(github_url)} repo.",
|
73 |
+
f"All models trained at this commit can be found at {wandb_report_url}.",
|
74 |
+
]
|
75 |
+
return "\n\n".join(lines)
|
76 |
+
|
77 |
+
|
78 |
+
def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
|
79 |
+
if not commit_hash:
|
80 |
+
return github_project_link(github_url)
|
81 |
+
return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
|
82 |
+
|
83 |
+
|
84 |
+
def results_section(
|
85 |
+
table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
|
86 |
+
) -> str:
|
87 |
+
# type: ignore
|
88 |
+
lines = [
|
89 |
+
"## Training Results",
|
90 |
+
f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
|
91 |
+
+ "agents using different initial seeds. "
|
92 |
+
+ f"These agents were trained by checking out "
|
93 |
+
+ f"{github_tree_link(github_url, commit_hash)}. "
|
94 |
+
+ "The best and last models were kept from each training. "
|
95 |
+
+ "This submission has loaded the best models from each training, reevaluates "
|
96 |
+
+ "them, and selects the best model from these latest evaluations (mean - std).",
|
97 |
+
]
|
98 |
+
lines.append(evaluation_table(table_data))
|
99 |
+
return "\n\n".join(lines)
|
100 |
+
|
101 |
+
|
102 |
+
def prerequisites_section() -> str:
|
103 |
+
return """
|
104 |
+
### Prerequisites: Weights & Biases (WandB)
|
105 |
+
Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
|
106 |
+
By default training goes to a rl-algo-impls project while benchmarks go to
|
107 |
+
rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
|
108 |
+
models and the model weights are uploaded to WandB.
|
109 |
+
|
110 |
+
Before doing anything below, you'll need to create a wandb account and run `wandb
|
111 |
+
login`.
|
112 |
+
"""
|
113 |
+
|
114 |
+
|
115 |
+
def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
|
116 |
+
return f"""
|
117 |
+
## Usage
|
118 |
+
{urlparse(github_url).path}: {github_url}
|
119 |
+
|
120 |
+
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
121 |
+
implementation could be sufficiently different to not be able to reproduce similar
|
122 |
+
results. You might need to checkout the commit the agent was trained on:
|
123 |
+
{github_tree_link(github_url, commit_hash)}.
|
124 |
+
```
|
125 |
+
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
126 |
+
python enjoy.py --wandb-run-path={run_path}
|
127 |
+
```
|
128 |
+
|
129 |
+
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
130 |
+
Colab starting from the
|
131 |
+
[colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
|
132 |
+
notebook.
|
133 |
+
"""
|
134 |
+
|
135 |
+
|
136 |
+
def training_setion(
|
137 |
+
github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
|
138 |
+
) -> str:
|
139 |
+
return f"""
|
140 |
+
## Training
|
141 |
+
If you want the highest chance to reproduce these results, you'll want to checkout the
|
142 |
+
commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
|
143 |
+
training is deterministic, different hardware will give different results.
|
144 |
+
|
145 |
+
```
|
146 |
+
python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
|
147 |
+
```
|
148 |
+
|
149 |
+
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
150 |
+
Colab starting from the
|
151 |
+
[colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
|
152 |
+
notebook.
|
153 |
+
"""
|
154 |
+
|
155 |
+
|
156 |
+
def benchmarking_section(report_url: str) -> str:
|
157 |
+
return f"""
|
158 |
+
## Benchmarking (with Lambda Labs instance)
|
159 |
+
This and other models from {report_url} were generated by running a script on a Lambda
|
160 |
+
Labs instance. In a Lambda Labs instance terminal:
|
161 |
+
```
|
162 |
+
git clone [email protected]:sgoodfriend/rl-algo-impls.git
|
163 |
+
cd rl-algo-impls
|
164 |
+
bash ./lambda_labs/setup.sh
|
165 |
+
wandb login
|
166 |
+
bash ./lambda_labs/benchmark.sh
|
167 |
+
```
|
168 |
+
|
169 |
+
### Alternative: Google Colab Pro+
|
170 |
+
As an alternative,
|
171 |
+
[colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
|
172 |
+
can be used. However, this requires a Google Colab Pro+ subscription and running across
|
173 |
+
4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
|
174 |
+
"""
|
175 |
+
|
176 |
+
|
177 |
+
def hyperparams_section(run_config: Dict[str, Any]) -> str:
|
178 |
+
return f"""
|
179 |
+
## Hyperparameters
|
180 |
+
This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
|
181 |
+
run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
|
182 |
+
close and has some additional data:
|
183 |
+
```
|
184 |
+
{yaml.dump(run_config)}
|
185 |
+
```
|
186 |
+
"""
|
187 |
+
|
188 |
+
|
189 |
+
def model_card_text(
|
190 |
+
algo: str,
|
191 |
+
env: str,
|
192 |
+
github_url: str,
|
193 |
+
commit_hash: str,
|
194 |
+
wandb_report_url: str,
|
195 |
+
table_data: List[EvalTableData],
|
196 |
+
best_eval: EvalTableData,
|
197 |
+
) -> str:
|
198 |
+
run, (_, _, config) = best_eval
|
199 |
+
run_path = "/".join(run.path)
|
200 |
+
return "\n\n".join(
|
201 |
+
[
|
202 |
+
header_section(algo, env, github_url, wandb_report_url),
|
203 |
+
results_section(table_data, algo, github_url, commit_hash),
|
204 |
+
prerequisites_section(),
|
205 |
+
usage_section(github_url, run_path, commit_hash),
|
206 |
+
training_setion(github_url, commit_hash, algo, env, config.seed()),
|
207 |
+
benchmarking_section(wandb_report_url),
|
208 |
+
hyperparams_section(run.config),
|
209 |
+
]
|
210 |
+
)
|
pyproject.toml
CHANGED
@@ -21,6 +21,8 @@ wandb = "^0.13.9"
|
|
21 |
conda-lock = "^1.3.0"
|
22 |
torch-tb-profiler = "^0.4.1"
|
23 |
jupyter = "^1.0.0"
|
|
|
|
|
24 |
|
25 |
[build-system]
|
26 |
requires = ["poetry-core"]
|
|
|
21 |
conda-lock = "^1.3.0"
|
22 |
torch-tb-profiler = "^0.4.1"
|
23 |
jupyter = "^1.0.0"
|
24 |
+
tabulate = "^0.9.0"
|
25 |
+
huggingface-hub = "^0.12.0"
|
26 |
|
27 |
[build-system]
|
28 |
requires = ["poetry-core"]
|
replay.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpz2flad47/ppo-Acrobot-v1/replay.mp4"]}, "episode": {"r": -73.0, "l": 74, "t": 1.297341}}
|
runner/config.py
CHANGED
@@ -59,14 +59,17 @@ class Config:
|
|
59 |
def eval_params(self) -> Dict[str, Any]:
|
60 |
return self.hyperparams.get("eval_params", {})
|
61 |
|
|
|
|
|
|
|
|
|
62 |
@property
|
63 |
def env_id(self) -> str:
|
64 |
return self.args.env
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
if self.args.seed is not None:
|
70 |
parts.append(f"S{self.args.seed}")
|
71 |
make_kwargs = self.env_hyperparams.get("make_kwargs", {})
|
72 |
if make_kwargs:
|
@@ -81,7 +84,7 @@ class Config:
|
|
81 |
|
82 |
@property
|
83 |
def run_name(self) -> str:
|
84 |
-
parts = [self.model_name, self.run_id]
|
85 |
return "-".join(parts)
|
86 |
|
87 |
@property
|
@@ -97,7 +100,7 @@ class Config:
|
|
97 |
best: bool = False,
|
98 |
extension: str = "",
|
99 |
) -> str:
|
100 |
-
return self.model_name + ("-best" if best else "") + extension
|
101 |
|
102 |
def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
|
103 |
return os.path.join(
|
@@ -123,8 +126,8 @@ class Config:
|
|
123 |
|
124 |
@property
|
125 |
def video_prefix(self) -> str:
|
126 |
-
return os.path.join(self.videos_dir, self.model_name)
|
127 |
|
128 |
@property
|
129 |
def best_videos_dir(self) -> str:
|
130 |
-
return os.path.join(self.videos_dir, f"{self.model_name}-best")
|
|
|
59 |
def eval_params(self) -> Dict[str, Any]:
|
60 |
return self.hyperparams.get("eval_params", {})
|
61 |
|
62 |
+
@property
|
63 |
+
def algo(self) -> str:
|
64 |
+
return self.args.algo
|
65 |
+
|
66 |
@property
|
67 |
def env_id(self) -> str:
|
68 |
return self.args.env
|
69 |
|
70 |
+
def model_name(self, include_seed: bool = True) -> str:
|
71 |
+
parts = [self.algo, self.env_id]
|
72 |
+
if include_seed and self.args.seed is not None:
|
|
|
73 |
parts.append(f"S{self.args.seed}")
|
74 |
make_kwargs = self.env_hyperparams.get("make_kwargs", {})
|
75 |
if make_kwargs:
|
|
|
84 |
|
85 |
@property
|
86 |
def run_name(self) -> str:
|
87 |
+
parts = [self.model_name(), self.run_id]
|
88 |
return "-".join(parts)
|
89 |
|
90 |
@property
|
|
|
100 |
best: bool = False,
|
101 |
extension: str = "",
|
102 |
) -> str:
|
103 |
+
return self.model_name() + ("-best" if best else "") + extension
|
104 |
|
105 |
def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
|
106 |
return os.path.join(
|
|
|
126 |
|
127 |
@property
|
128 |
def video_prefix(self) -> str:
|
129 |
+
return os.path.join(self.videos_dir, self.model_name())
|
130 |
|
131 |
@property
|
132 |
def best_videos_dir(self) -> str:
|
133 |
+
return os.path.join(self.videos_dir, f"{self.model_name()}-best")
|
runner/evaluate.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import NamedTuple, Optional
|
6 |
+
|
7 |
+
from runner.env import make_eval_env
|
8 |
+
from runner.config import Config, RunArgs
|
9 |
+
from runner.running_utils import (
|
10 |
+
load_hyperparams,
|
11 |
+
set_seeds,
|
12 |
+
get_device,
|
13 |
+
make_policy,
|
14 |
+
)
|
15 |
+
from shared.callbacks.eval_callback import evaluate
|
16 |
+
from shared.policy.policy import Policy
|
17 |
+
from shared.stats import EpisodesStats
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class EvalArgs(RunArgs):
|
22 |
+
render: bool = True
|
23 |
+
best: bool = True
|
24 |
+
n_envs: Optional[int] = 1
|
25 |
+
n_episodes: int = 3
|
26 |
+
deterministic_eval: Optional[bool] = None
|
27 |
+
no_print_returns: bool = False
|
28 |
+
wandb_run_path: Optional[str] = None
|
29 |
+
|
30 |
+
|
31 |
+
class Evaluation(NamedTuple):
|
32 |
+
policy: Policy
|
33 |
+
stats: EpisodesStats
|
34 |
+
config: Config
|
35 |
+
|
36 |
+
|
37 |
+
def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
38 |
+
if args.wandb_run_path:
|
39 |
+
import wandb
|
40 |
+
|
41 |
+
api = wandb.Api()
|
42 |
+
run = api.run(args.wandb_run_path)
|
43 |
+
hyperparams = run.config
|
44 |
+
|
45 |
+
args.algo = hyperparams["algo"]
|
46 |
+
args.env = hyperparams["env"]
|
47 |
+
args.seed = hyperparams.get("seed", None)
|
48 |
+
args.use_deterministic_algorithms = hyperparams.get(
|
49 |
+
"use_deterministic_algorithms", True
|
50 |
+
)
|
51 |
+
|
52 |
+
config = Config(args, hyperparams, root_dir)
|
53 |
+
model_path = config.model_dir_path(best=args.best, downloaded=True)
|
54 |
+
|
55 |
+
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
|
56 |
+
run.file(model_archive_name).download()
|
57 |
+
if os.path.isdir(model_path):
|
58 |
+
shutil.rmtree(model_path)
|
59 |
+
shutil.unpack_archive(model_archive_name, model_path)
|
60 |
+
os.remove(model_archive_name)
|
61 |
+
else:
|
62 |
+
hyperparams = load_hyperparams(args.algo, args.env, root_dir)
|
63 |
+
|
64 |
+
config = Config(args, hyperparams, root_dir)
|
65 |
+
model_path = config.model_dir_path(best=args.best)
|
66 |
+
|
67 |
+
print(args)
|
68 |
+
|
69 |
+
set_seeds(args.seed, args.use_deterministic_algorithms)
|
70 |
+
|
71 |
+
env = make_eval_env(
|
72 |
+
config,
|
73 |
+
override_n_envs=args.n_envs,
|
74 |
+
render=args.render,
|
75 |
+
normalize_load_path=model_path,
|
76 |
+
**config.env_hyperparams,
|
77 |
+
)
|
78 |
+
device = get_device(config.device, env)
|
79 |
+
policy = make_policy(
|
80 |
+
args.algo,
|
81 |
+
env,
|
82 |
+
device,
|
83 |
+
load_path=model_path,
|
84 |
+
**config.policy_hyperparams,
|
85 |
+
).eval()
|
86 |
+
|
87 |
+
deterministic = (
|
88 |
+
args.deterministic_eval
|
89 |
+
if args.deterministic_eval is not None
|
90 |
+
else config.eval_params.get("deterministic", True)
|
91 |
+
)
|
92 |
+
return Evaluation(
|
93 |
+
policy,
|
94 |
+
evaluate(
|
95 |
+
env,
|
96 |
+
policy,
|
97 |
+
args.n_episodes,
|
98 |
+
render=args.render,
|
99 |
+
deterministic=deterministic,
|
100 |
+
print_returns=not args.no_print_returns,
|
101 |
+
),
|
102 |
+
config,
|
103 |
+
)
|
runner/running_utils.py
CHANGED
@@ -40,28 +40,28 @@ POLICIES: Dict[str, Type[Policy]] = {
|
|
40 |
HYPERPARAMS_PATH = "hyperparams"
|
41 |
|
42 |
|
43 |
-
def base_parser() -> argparse.ArgumentParser:
|
44 |
parser = argparse.ArgumentParser()
|
45 |
parser.add_argument(
|
46 |
"--algo",
|
47 |
default="dqn",
|
48 |
type=str,
|
49 |
choices=list(ALGOS.keys()),
|
50 |
-
nargs="+",
|
51 |
help="Abbreviation(s) of algorithm(s)",
|
52 |
)
|
53 |
parser.add_argument(
|
54 |
"--env",
|
55 |
default="CartPole-v1",
|
56 |
type=str,
|
57 |
-
nargs="+",
|
58 |
help="Name of environment(s) in gym",
|
59 |
)
|
60 |
parser.add_argument(
|
61 |
"--seed",
|
62 |
default=1,
|
63 |
type=int,
|
64 |
-
nargs="*",
|
65 |
help="Seeds to run experiment. Unset will do one run with no set seed",
|
66 |
)
|
67 |
parser.add_argument(
|
|
|
40 |
HYPERPARAMS_PATH = "hyperparams"
|
41 |
|
42 |
|
43 |
+
def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
|
44 |
parser = argparse.ArgumentParser()
|
45 |
parser.add_argument(
|
46 |
"--algo",
|
47 |
default="dqn",
|
48 |
type=str,
|
49 |
choices=list(ALGOS.keys()),
|
50 |
+
nargs="+" if multiple else 1,
|
51 |
help="Abbreviation(s) of algorithm(s)",
|
52 |
)
|
53 |
parser.add_argument(
|
54 |
"--env",
|
55 |
default="CartPole-v1",
|
56 |
type=str,
|
57 |
+
nargs="+" if multiple else 1,
|
58 |
help="Name of environment(s) in gym",
|
59 |
)
|
60 |
parser.add_argument(
|
61 |
"--seed",
|
62 |
default=1,
|
63 |
type=int,
|
64 |
+
nargs="*" if multiple else "?",
|
65 |
help="Seeds to run experiment. Unset will do one run with no set seed",
|
66 |
)
|
67 |
parser.add_argument(
|
shared/callbacks/eval_callback.py
CHANGED
@@ -22,7 +22,10 @@ class EvaluateAccumulator(EpisodeAccumulator):
|
|
22 |
self.print_returns = print_returns
|
23 |
|
24 |
def on_done(self, ep_idx: int, episode: Episode) -> None:
|
25 |
-
if
|
|
|
|
|
|
|
26 |
return
|
27 |
self.completed_episodes_by_env_idx[ep_idx].append(episode)
|
28 |
if self.print_returns:
|
@@ -36,11 +39,14 @@ class EvaluateAccumulator(EpisodeAccumulator):
|
|
36 |
return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
|
37 |
|
38 |
@property
|
39 |
-
def episodes(self) ->
|
40 |
-
return list(itertools.chain(*self.completed_episodes_by_env_idx))
|
41 |
|
42 |
def is_done(self) -> bool:
|
43 |
-
return all(
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
def evaluate(
|
@@ -108,7 +114,7 @@ class EvalCallback(Callback):
|
|
108 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
109 |
super().on_step(timesteps_elapsed)
|
110 |
if self.timesteps_elapsed // self.step_freq >= len(self.stats):
|
111 |
-
|
112 |
self.evaluate()
|
113 |
return True
|
114 |
|
@@ -134,10 +140,12 @@ class EvalCallback(Callback):
|
|
134 |
assert self.best_model_path
|
135 |
self.policy.save(self.best_model_path)
|
136 |
print("Saved best model")
|
137 |
-
self.best.write_to_tensorboard(
|
|
|
|
|
138 |
if strictly_better and self.record_best_videos:
|
139 |
assert self.video_env and self.best_video_dir
|
140 |
-
|
141 |
self.best_video_base_path = os.path.join(
|
142 |
self.best_video_dir, str(self.timesteps_elapsed)
|
143 |
)
|
@@ -159,16 +167,15 @@ class EvalCallback(Callback):
|
|
159 |
|
160 |
return eval_stat
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
eval_env_wrapper.
|
172 |
-
|
173 |
-
|
174 |
-
eval_env_wrapper = eval_env_wrapper.venv
|
|
|
22 |
self.print_returns = print_returns
|
23 |
|
24 |
def on_done(self, ep_idx: int, episode: Episode) -> None:
|
25 |
+
if (
|
26 |
+
len(self.completed_episodes_by_env_idx[ep_idx])
|
27 |
+
>= self.goal_episodes_per_env
|
28 |
+
):
|
29 |
return
|
30 |
self.completed_episodes_by_env_idx[ep_idx].append(episode)
|
31 |
if self.print_returns:
|
|
|
39 |
return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
|
40 |
|
41 |
@property
|
42 |
+
def episodes(self) -> List[Episode]:
|
43 |
+
return list(itertools.chain(*self.completed_episodes_by_env_idx))
|
44 |
|
45 |
def is_done(self) -> bool:
|
46 |
+
return all(
|
47 |
+
len(ce) == self.goal_episodes_per_env
|
48 |
+
for ce in self.completed_episodes_by_env_idx
|
49 |
+
)
|
50 |
|
51 |
|
52 |
def evaluate(
|
|
|
114 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
115 |
super().on_step(timesteps_elapsed)
|
116 |
if self.timesteps_elapsed // self.step_freq >= len(self.stats):
|
117 |
+
sync_vec_normalize(self.policy.vec_normalize, self.env)
|
118 |
self.evaluate()
|
119 |
return True
|
120 |
|
|
|
140 |
assert self.best_model_path
|
141 |
self.policy.save(self.best_model_path)
|
142 |
print("Saved best model")
|
143 |
+
self.best.write_to_tensorboard(
|
144 |
+
self.tb_writer, "best_eval", self.timesteps_elapsed
|
145 |
+
)
|
146 |
if strictly_better and self.record_best_videos:
|
147 |
assert self.video_env and self.best_video_dir
|
148 |
+
sync_vec_normalize(self.policy.vec_normalize, self.video_env)
|
149 |
self.best_video_base_path = os.path.join(
|
150 |
self.best_video_dir, str(self.timesteps_elapsed)
|
151 |
)
|
|
|
167 |
|
168 |
return eval_stat
|
169 |
|
170 |
+
|
171 |
+
def sync_vec_normalize(
|
172 |
+
origin_vec_normalize: Optional[VecNormalize], destination_env: VecEnv
|
173 |
+
) -> None:
|
174 |
+
if origin_vec_normalize is not None:
|
175 |
+
eval_env_wrapper = destination_env
|
176 |
+
while isinstance(eval_env_wrapper, VecEnvWrapper):
|
177 |
+
if isinstance(eval_env_wrapper, VecNormalize):
|
178 |
+
if hasattr(origin_vec_normalize, "obs_rms"):
|
179 |
+
eval_env_wrapper.obs_rms = deepcopy(origin_vec_normalize.obs_rms)
|
180 |
+
eval_env_wrapper.ret_rms = deepcopy(origin_vec_normalize.ret_rms)
|
181 |
+
eval_env_wrapper = eval_env_wrapper.venv
|
|
shared/policy/policy.py
CHANGED
@@ -54,7 +54,9 @@ class Policy(nn.Module, ABC):
|
|
54 |
@abstractmethod
|
55 |
def load(self, path: str) -> None:
|
56 |
# VecNormalize load occurs in env.py
|
57 |
-
self.load_state_dict(
|
|
|
|
|
58 |
|
59 |
def reset_noise(self) -> None:
|
60 |
pass
|
|
|
54 |
@abstractmethod
|
55 |
def load(self, path: str) -> None:
|
56 |
# VecNormalize load occurs in env.py
|
57 |
+
self.load_state_dict(
|
58 |
+
torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
|
59 |
+
)
|
60 |
|
61 |
def reset_noise(self) -> None:
|
62 |
pass
|
shared/stats.py
CHANGED
@@ -94,6 +94,9 @@ class EpisodesStats:
|
|
94 |
f"Length: {self.length}"
|
95 |
)
|
96 |
|
|
|
|
|
|
|
97 |
def _asdict(self) -> dict:
|
98 |
return {
|
99 |
"n_episodes": len(self.episodes),
|
@@ -147,27 +150,3 @@ class EpisodeAccumulator:
|
|
147 |
|
148 |
def stats(self) -> EpisodesStats:
|
149 |
return EpisodesStats(self.episodes)
|
150 |
-
|
151 |
-
|
152 |
-
class RolloutStats(EpisodeAccumulator):
|
153 |
-
def __init__(self, num_envs: int, print_n_episodes: int, tb_writer: SummaryWriter):
|
154 |
-
super().__init__(num_envs)
|
155 |
-
self.print_n_episodes = print_n_episodes
|
156 |
-
self.epochs: List[EpisodesStats] = []
|
157 |
-
self.tb_writer = tb_writer
|
158 |
-
|
159 |
-
def on_done(self, ep_idx: int, episode: Episode) -> None:
|
160 |
-
if (
|
161 |
-
self.print_n_episodes >= 0
|
162 |
-
and len(self.episodes) % self.print_n_episodes == 0
|
163 |
-
):
|
164 |
-
sample = self.episodes[-self.print_n_episodes :]
|
165 |
-
epoch = EpisodesStats(sample)
|
166 |
-
self.epochs.append(epoch)
|
167 |
-
total_steps = np.sum([e.length for e in self.episodes])
|
168 |
-
print(
|
169 |
-
f"Episode: {len(self.episodes)} | "
|
170 |
-
f"{epoch} | "
|
171 |
-
f"Total Steps: {total_steps}"
|
172 |
-
)
|
173 |
-
epoch.write_to_tensorboard(self.tb_writer, "train", global_step=total_steps)
|
|
|
94 |
f"Length: {self.length}"
|
95 |
)
|
96 |
|
97 |
+
def __len__(self) -> int:
|
98 |
+
return len(self.episodes)
|
99 |
+
|
100 |
def _asdict(self) -> dict:
|
101 |
return {
|
102 |
"n_episodes": len(self.episodes),
|
|
|
150 |
|
151 |
def stats(self) -> EpisodesStats:
|
152 |
return EpisodesStats(self.episodes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|