Snake-V1
This model is trained on the Jumanji snake environment
Developed by: InstaDeep
Model Sources
- Repository: Jumanji
- Paper: TBD
How to use
Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.
pip install --quiet -U pip -r ../requirements/requirements-train.txt ../.
Below is an example script for loading and running the Jumanji model
import pickle
import joblib
import jax
from hydra import compose, initialize
from huggingface_hub import hf_hub_download
from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device
# initialise the config
with initialize(version_base=None, config_path="jumanji/training/configs"):
cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])
# get model state from HF
REPO_ID = "d-byrne/snake-v1_training_state"
FILENAME = "Snake-v1_training_state"
model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
with open(model_weights,"rb") as f:
training_state = pickle.load(f)
params = first_from_device(training_state.params_state.params)
env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
# rollout a few episodes
NUM_EPISODES = 10
states = []
key = jax.random.PRNGKey(cfg.seed)
for episode in range(NUM_EPISODES):
key, reset_key = jax.random.split(key)
state, timestep = jax.jit(env.reset)(reset_key)
while not timestep.last():
key, action_key = jax.random.split(key)
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
action, _ = policy(observation, action_key)
state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
states.append(state)
# Freeze the terminal frame to pause the GIF.
for _ in range(10):
states.append(state)
# animate a GIF
env.animate(states, interval=150).save("./snake.gif")
# save PNG
import matplotlib.pyplot as plt
%matplotlib inline
env.render(states[117])
plt.savefig("connector.png", dpi=300)