diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..483b46a7e5340b1014cbaf98cf7f2e56e7d7a654 --- /dev/null +++ b/app.py @@ -0,0 +1,42 @@ +import streamlit as st +from utils import inject_custom_css +from PIL import Image +from matplotlib import rcParams + +color_title = "#000000" +color_text = "#000000" + +text = f""" +

馃嵅 Rewarded soups 馃嵅

+

Welcome to our interactive streamlit app showcasing the key concepts and experiments presented in our paper
Rewarded soups: towards Pareto-optimal alignment by interpolating weights fine-tuned on diverse rewards

+

Asbtract

+

Foundation models are first pre-trained on vast unsupervised datasets and then fine-tuned on labeled data. Reinforcement learning, notably from human feedback (RLHF), can further align the network with the intended usage. Yet the imperfections in the proxy reward may hinder the training and lead to suboptimal results; the diversity of objectives in real-world tasks and human opinions exacerbate the issue. This paper proposes embracing the heterogeneity of diverse rewards by following a multi-policy strategy. Rather than focusing on a single a priori reward, we aim for Pareto-optimal generalization across the entire space of preferences. To this end, we propose rewarded soup, first specializing multiple networks independently (one for each proxy reward) and then interpolating their weights linearly. This succeeds empirically because we show that the weights remain linearly connected when fine-tuned on diverse rewards from a shared pre-trained initialization. We demonstrate the effectiveness of our approach for text-to-text (summarization, Q&A, helpful assistant, review), text-image (image captioning, text-to-image generation, visual grounding, VQA), and control (locomotion) tasks. We hope to enhance the alignment of deep models, and how they interact with the world in all its diversity.

+ +

What will I find here ?

+ +

In this app, you will find interactive figures and qualitative examples demonstratating the effectiveness of our approach. Specifically, we detail the following tasks: RLHF of LLaMA for news summarization, RLHF of a diffusion model for text-to-image generation, and the locomotion task. To help the reproduction of these results, we also provide our code here.

+""" + + +def run_UI(): + rcParams['font.family'] = 'sans-serif' + rcParams['font.sans-serif'] = ['Tahoma'] + inject_custom_css("streamlit_app/assets/styles.css") + + st.markdown( + f""" + + {text} + """, + unsafe_allow_html=True, + ) + +if __name__ == "__main__": + img = Image.open("streamlit_app/assets/images/icon.png") + st.set_page_config( + page_title="Rewarded soups", + page_icon=img, + layout="wide", + ) + st.set_option("deprecation.showPyplotGlobalUse", False) + run_UI() diff --git a/assets/assets/images/soup.jpg b/assets/assets/images/soup.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0623e90a2749fb669ed57915bf43b7c82986f64f Binary files /dev/null and b/assets/assets/images/soup.jpg differ diff --git a/assets/images/icon.png b/assets/images/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..1ba7ff247ec3746275840832cf706e923880081d Binary files /dev/null and b/assets/images/icon.png differ diff --git a/assets/styles.css b/assets/styles.css new file mode 100644 index 0000000000000000000000000000000000000000..2b4811b3347da3cebbdc6d1ec843c195bcdd09d3 --- /dev/null +++ b/assets/styles.css @@ -0,0 +1,186 @@ + +.settings{ + height:1rem; +} + +div{ + text-align: justify; +} + +div[role="button"]{ + font-size: 16px; +} + +h6{ + font-size: 14px; +} + +h1{ + text-align: center !important; +} +h4{ + text-align: center !important; +} + +table{ + font-size: 0.8vw; + text-align: right !important; +} + +.promptTextbox { + border: 2px solid #f3f3f3; + border-radius: 5px; + background-color: #f9f9f9; + box-shadow: 2px 2px 2px grey; + font-size: 11px; + margin: 0; + padding: 0; +} + +.promptHeader { + padding: 10px !important; + background-color: #ececec; + border-bottom: 1px solid #ddd; + font-weight: bold; + font-size: 16px; + margin: 0; + padding: 0; +} + +.promptContent { + padding: 20px; +} + +.modelOutputBox { + border: 2px solid #c3e6cb; + border-radius: 5px; + background-color: #d4edda; + box-shadow: 2px 2px 2px #73c476; + font-size: 14px; + +} + +.modelOutputHeader { + padding: 10px; + background-color: #ececec; + border-bottom: 1px solid #73c476; + font-weight: bold; + font-size: 16px; +} + +.modelOutputContent { + padding: 20px; +} + +.st-bx { + padding: 10px; + background-color: #ececec; + border-bottom: 1px solid #ddd; + font-weight: bold; + font-size: 16px; + margin: 0; + padding: 0; +} + +.lambda-header { + padding: 2px; + background-color: #d4edda; + font-weight: bold; /* make the text bold */ + font-size: 16px; + width: 50px; /* reduce the width */ + height: 30px; /* reduce the width */ + box-shadow: 2px 2px 2px grey; /* same shadow as .lambdas */ + text-align: center; +} + +.lambdas { + padding: 10px; + background-color: #d4edda; /* same color as .lambda-header */ + font-size: 16px; + box-shadow: 2px 2px 2px grey; + margin-bottom: 10px; /* creates space between this section and the next */ + overflow: hidden; /* this will prevent text exceeding the height from being visible */ + height: auto; /* adjust this to fit your text, accounting for padding and font size */ + width: auto; +} + +.imgContent { + display: flex; + justify-content: center; + padding: 20px; +} + +.imgContainer { + display: flex; + flex-direction: column; + align-items: center; + width: 30%; + max-width: 30%; + margin: 10px; +} + +.imglambda-header { + padding: 2px; + font-weight: bold; /* make the text bold */ + font-size: 16px; + width: 50px; /* reduce the width */ + height: 30px; /* reduce the width */ + text-align: center; +} + +.imglambdas { + padding: 10px; + background-color: #d4edda; /* same color as .lambda-header */ + font-size: 16px; + box-shadow: 2px 2px 2px grey; + margin-bottom: 10px; /* creates space between this section and the next */ + overflow: hidden; /* this will prevent text exceeding the height from being visible */ + height: auto; /* adjust this to fit your text, accounting for padding and font size */ + width: 100%; + max-width: 100%; + display: flex; + justify-content: center; /* Add this line to center the content horizontally */ + align-items: center; /* Add this line to center the content vertically */ +} +.imglambdas img { + max-height: 99%; + max-width: 99%; /* adjust to your preference */ + +} + +div[data-testid='stHorizontalBlock'] { + display: flex; + align-items: center; /* this will vertically center the items */ + padding: 0; /* removes padding */ + margin: 0; /* removes margin */ +} + +div[data-testid='column'] { + justify-content: center; /* this will horizontally center the plot */ + padding: 0; /* removes padding */ + margin: 0; /* removes margin */ +} + +@keyframes typing { + from { width: 0; } + to { width: 100%; } +} + +@keyframes blink-caret { + 0%, 100% { border-color: transparent; } + 50% { border-color: black; } +} + +.typing-effect { + border-right: .15em solid black; + white-space: nowrap; + overflow: hidden; + margin: 0 auto; + animation: typing 2s steps(40, end), blink-caret 2.8s step-end; +} + +.flex-container { + display: flex; + flex-wrap: wrap; + justify-content: center; +} diff --git a/data/imgen/data.pkl b/data/imgen/data.pkl new file mode 100644 index 0000000000000000000000000000000000000000..163e20551340027fcb4470d41f14373f2af572a9 Binary files /dev/null and b/data/imgen/data.pkl differ diff --git a/data/imgen/data_images.pkl b/data/imgen/data_images.pkl new file mode 100644 index 0000000000000000000000000000000000000000..334ed9815a565fe53b005e65bedd1791cfd8e81d Binary files /dev/null and b/data/imgen/data_images.pkl differ diff --git a/data/imgen/viz3/003594a888_0.png b/data/imgen/viz3/003594a888_0.png new file mode 100644 index 0000000000000000000000000000000000000000..0a8239765d86c2f30a40b413749688382299f03b Binary files /dev/null and b/data/imgen/viz3/003594a888_0.png differ diff --git a/data/imgen/viz3/003594a888_1.png b/data/imgen/viz3/003594a888_1.png new file mode 100644 index 0000000000000000000000000000000000000000..155405fc831fece49e93a45722365c753c641bbb Binary files /dev/null and b/data/imgen/viz3/003594a888_1.png differ diff --git a/data/imgen/viz3/003594a888_10.png b/data/imgen/viz3/003594a888_10.png new file mode 100644 index 0000000000000000000000000000000000000000..6b6b10227f48a815c7cd4bc8bbcc2af42b436295 Binary files /dev/null and b/data/imgen/viz3/003594a888_10.png differ diff --git a/data/imgen/viz3/003594a888_2.png b/data/imgen/viz3/003594a888_2.png new file mode 100644 index 0000000000000000000000000000000000000000..8cf2c229e8fd2f3148327026fc971341692cec8d Binary files /dev/null and b/data/imgen/viz3/003594a888_2.png differ diff --git a/data/imgen/viz3/003594a888_3.png b/data/imgen/viz3/003594a888_3.png new file mode 100644 index 0000000000000000000000000000000000000000..2c57bceb5cd24ae3688bb7d57c153a29ba7f50cf Binary files /dev/null and b/data/imgen/viz3/003594a888_3.png differ diff --git a/data/imgen/viz3/003594a888_4.png b/data/imgen/viz3/003594a888_4.png new file mode 100644 index 0000000000000000000000000000000000000000..8480efea521e6d8d58f6baaddbeaa037c9fb7161 Binary files /dev/null and b/data/imgen/viz3/003594a888_4.png differ diff --git a/data/imgen/viz3/003594a888_5.png b/data/imgen/viz3/003594a888_5.png new file mode 100644 index 0000000000000000000000000000000000000000..c953ead8317882b8ddbd3a78b9d1397ab7d3325a Binary files /dev/null and b/data/imgen/viz3/003594a888_5.png differ diff --git a/data/imgen/viz3/003594a888_6.png b/data/imgen/viz3/003594a888_6.png new file mode 100644 index 0000000000000000000000000000000000000000..8368e843cda8444236945543ce68aa95de689058 Binary files /dev/null and b/data/imgen/viz3/003594a888_6.png differ diff --git a/data/imgen/viz3/003594a888_7.png b/data/imgen/viz3/003594a888_7.png new file mode 100644 index 0000000000000000000000000000000000000000..2bcc835e0c70c01ba6ea8e0805d2ad7a1d7073f3 Binary files /dev/null and b/data/imgen/viz3/003594a888_7.png differ diff --git a/data/imgen/viz3/003594a888_8.png b/data/imgen/viz3/003594a888_8.png new file mode 100644 index 0000000000000000000000000000000000000000..1f6fc69c124eb2e9b2a457a650bfb4c0a6b873fb Binary files /dev/null and b/data/imgen/viz3/003594a888_8.png differ diff --git a/data/imgen/viz3/003594a888_9.png b/data/imgen/viz3/003594a888_9.png new file mode 100644 index 0000000000000000000000000000000000000000..7b5c7729493e3a8c2edfb3a0635176c2f30f2347 Binary files /dev/null and b/data/imgen/viz3/003594a888_9.png differ diff --git a/data/imgen/viz3/02d9f671c1_0.png b/data/imgen/viz3/02d9f671c1_0.png new file mode 100644 index 0000000000000000000000000000000000000000..18595db376c5a72be5b76c9b49166ce7df50e393 Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_0.png differ diff --git a/data/imgen/viz3/02d9f671c1_1.png b/data/imgen/viz3/02d9f671c1_1.png new file mode 100644 index 0000000000000000000000000000000000000000..4f1aca183f6bf51ae36bd887df0e7431c2497186 Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_1.png differ diff --git a/data/imgen/viz3/02d9f671c1_10.png b/data/imgen/viz3/02d9f671c1_10.png new file mode 100644 index 0000000000000000000000000000000000000000..bc6a35684dead2f7f626cd845a777b6971ff02b7 Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_10.png differ diff --git a/data/imgen/viz3/02d9f671c1_2.png b/data/imgen/viz3/02d9f671c1_2.png new file mode 100644 index 0000000000000000000000000000000000000000..5e6deb76446045b50e2700e4cedc76e9d7c09d2d Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_2.png differ diff --git a/data/imgen/viz3/02d9f671c1_3.png b/data/imgen/viz3/02d9f671c1_3.png new file mode 100644 index 0000000000000000000000000000000000000000..8862927d70b22988726e670d490d2d0a45b56341 Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_3.png differ diff --git a/data/imgen/viz3/02d9f671c1_4.png b/data/imgen/viz3/02d9f671c1_4.png new file mode 100644 index 0000000000000000000000000000000000000000..a751af0ca581e99b1bea136b74dc848d14b45dd9 Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_4.png differ diff --git a/data/imgen/viz3/02d9f671c1_5.png b/data/imgen/viz3/02d9f671c1_5.png new file mode 100644 index 0000000000000000000000000000000000000000..b049168e2ba9c36850ac89f89bb4f1ca2f9773bd Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_5.png differ diff --git a/data/imgen/viz3/02d9f671c1_6.png b/data/imgen/viz3/02d9f671c1_6.png new file mode 100644 index 0000000000000000000000000000000000000000..243d0f2baa9b7c808d2e97360fb1bf94e720872a Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_6.png differ diff --git a/data/imgen/viz3/02d9f671c1_7.png b/data/imgen/viz3/02d9f671c1_7.png new file mode 100644 index 0000000000000000000000000000000000000000..ffdedc7c141499413d613e124d9743b9e3f2204a Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_7.png differ diff --git a/data/imgen/viz3/02d9f671c1_8.png b/data/imgen/viz3/02d9f671c1_8.png new file mode 100644 index 0000000000000000000000000000000000000000..6111ac5058662a74ba405c310a301c4c2766050c Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_8.png differ diff --git a/data/imgen/viz3/02d9f671c1_9.png b/data/imgen/viz3/02d9f671c1_9.png new file mode 100644 index 0000000000000000000000000000000000000000..1a09821cdd2357d79d9996aaf9e5d15f6570659d Binary files /dev/null and b/data/imgen/viz3/02d9f671c1_9.png differ diff --git a/data/imgen/viz3/22550e7610_0.png b/data/imgen/viz3/22550e7610_0.png new file mode 100644 index 0000000000000000000000000000000000000000..ceca2709e985b9e607b16068bf7b35e7058fb300 Binary files /dev/null and b/data/imgen/viz3/22550e7610_0.png differ diff --git a/data/imgen/viz3/22550e7610_1.png b/data/imgen/viz3/22550e7610_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3a6315ca06380f4ea15790c7496325429d8fc843 Binary files /dev/null and b/data/imgen/viz3/22550e7610_1.png differ diff --git a/data/imgen/viz3/22550e7610_10.png b/data/imgen/viz3/22550e7610_10.png new file mode 100644 index 0000000000000000000000000000000000000000..d0b811ea3b8398cd3df758abd8e684291101f8bf Binary files /dev/null and b/data/imgen/viz3/22550e7610_10.png differ diff --git a/data/imgen/viz3/22550e7610_2.png b/data/imgen/viz3/22550e7610_2.png new file mode 100644 index 0000000000000000000000000000000000000000..ece873c9dc0748d2038d3b07408d2c1865aa512f Binary files /dev/null and b/data/imgen/viz3/22550e7610_2.png differ diff --git a/data/imgen/viz3/22550e7610_3.png b/data/imgen/viz3/22550e7610_3.png new file mode 100644 index 0000000000000000000000000000000000000000..0076378e732b27c8fb47bf6f6245b84b1c85bcf2 Binary files /dev/null and b/data/imgen/viz3/22550e7610_3.png differ diff --git a/data/imgen/viz3/22550e7610_4.png b/data/imgen/viz3/22550e7610_4.png new file mode 100644 index 0000000000000000000000000000000000000000..a440507959da300933fe3d22dfe0d666bfee1c0f Binary files /dev/null and b/data/imgen/viz3/22550e7610_4.png differ diff --git a/data/imgen/viz3/22550e7610_5.png b/data/imgen/viz3/22550e7610_5.png new file mode 100644 index 0000000000000000000000000000000000000000..ac45723d1fdd6086471ac90cbdc20f45d593a4ee Binary files /dev/null and b/data/imgen/viz3/22550e7610_5.png differ diff --git a/data/imgen/viz3/22550e7610_6.png b/data/imgen/viz3/22550e7610_6.png new file mode 100644 index 0000000000000000000000000000000000000000..d42ad0f693dc0586ea1df923eb22cd6fc4f94858 Binary files /dev/null and b/data/imgen/viz3/22550e7610_6.png differ diff --git a/data/imgen/viz3/22550e7610_7.png b/data/imgen/viz3/22550e7610_7.png new file mode 100644 index 0000000000000000000000000000000000000000..f5e16e94c59adbfca4cfc3854f132b922dbe8015 Binary files /dev/null and b/data/imgen/viz3/22550e7610_7.png differ diff --git a/data/imgen/viz3/22550e7610_8.png b/data/imgen/viz3/22550e7610_8.png new file mode 100644 index 0000000000000000000000000000000000000000..e48d156688551dc63329b74ca8e2437c4acf1094 Binary files /dev/null and b/data/imgen/viz3/22550e7610_8.png differ diff --git a/data/imgen/viz3/22550e7610_9.png b/data/imgen/viz3/22550e7610_9.png new file mode 100644 index 0000000000000000000000000000000000000000..b018511d78cbf9504cf81231e756ebc6a159721e Binary files /dev/null and b/data/imgen/viz3/22550e7610_9.png differ diff --git a/data/imgen/viz3/576a8b2407_0.png b/data/imgen/viz3/576a8b2407_0.png new file mode 100644 index 0000000000000000000000000000000000000000..fce21182b89289b465ceb99450801055cae80a7c Binary files /dev/null and b/data/imgen/viz3/576a8b2407_0.png differ diff --git a/data/imgen/viz3/576a8b2407_1.png b/data/imgen/viz3/576a8b2407_1.png new file mode 100644 index 0000000000000000000000000000000000000000..5d914d54f90bb9ba525381cf6c70e7b184f673db Binary files /dev/null and b/data/imgen/viz3/576a8b2407_1.png differ diff --git a/data/imgen/viz3/576a8b2407_10.png b/data/imgen/viz3/576a8b2407_10.png new file mode 100644 index 0000000000000000000000000000000000000000..a7575b7a64b179056960c13e24ec302e81c6ed98 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_10.png differ diff --git a/data/imgen/viz3/576a8b2407_2.png b/data/imgen/viz3/576a8b2407_2.png new file mode 100644 index 0000000000000000000000000000000000000000..dcaafb93234b6a4d4b2cea2a983b4f1a6dbe65c5 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_2.png differ diff --git a/data/imgen/viz3/576a8b2407_3.png b/data/imgen/viz3/576a8b2407_3.png new file mode 100644 index 0000000000000000000000000000000000000000..f7ae98afa0207d1dc1f8f5283bd81f0852c99c86 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_3.png differ diff --git a/data/imgen/viz3/576a8b2407_4.png b/data/imgen/viz3/576a8b2407_4.png new file mode 100644 index 0000000000000000000000000000000000000000..6311ec555ec44e787461537f750dc57c16c125e0 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_4.png differ diff --git a/data/imgen/viz3/576a8b2407_5.png b/data/imgen/viz3/576a8b2407_5.png new file mode 100644 index 0000000000000000000000000000000000000000..f9fc1b2a52ecb7349dc9cf04e1450893027ccaee Binary files /dev/null and b/data/imgen/viz3/576a8b2407_5.png differ diff --git a/data/imgen/viz3/576a8b2407_6.png b/data/imgen/viz3/576a8b2407_6.png new file mode 100644 index 0000000000000000000000000000000000000000..d82b86f1e6571c6638915760023d1a4f3640523f Binary files /dev/null and b/data/imgen/viz3/576a8b2407_6.png differ diff --git a/data/imgen/viz3/576a8b2407_7.png b/data/imgen/viz3/576a8b2407_7.png new file mode 100644 index 0000000000000000000000000000000000000000..c8a1c0039fee284b5ef6e662117cedcea1eff526 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_7.png differ diff --git a/data/imgen/viz3/576a8b2407_8.png b/data/imgen/viz3/576a8b2407_8.png new file mode 100644 index 0000000000000000000000000000000000000000..32ad6af7c56737a5a259434048b3536158d78f29 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_8.png differ diff --git a/data/imgen/viz3/576a8b2407_9.png b/data/imgen/viz3/576a8b2407_9.png new file mode 100644 index 0000000000000000000000000000000000000000..4f6e2c8df76c6ee635f928a61485b9858b7d5175 Binary files /dev/null and b/data/imgen/viz3/576a8b2407_9.png differ diff --git a/data/imgen/viz3/7d071c8065_0.png b/data/imgen/viz3/7d071c8065_0.png new file mode 100644 index 0000000000000000000000000000000000000000..327b082e57e167975ea7f58024d3419367ff1b6a Binary files /dev/null and b/data/imgen/viz3/7d071c8065_0.png differ diff --git a/data/imgen/viz3/7d071c8065_1.png b/data/imgen/viz3/7d071c8065_1.png new file mode 100644 index 0000000000000000000000000000000000000000..523eae9d62e05a158726cb4245e985ba225cfb1b Binary files /dev/null and b/data/imgen/viz3/7d071c8065_1.png differ diff --git a/data/imgen/viz3/7d071c8065_10.png b/data/imgen/viz3/7d071c8065_10.png new file mode 100644 index 0000000000000000000000000000000000000000..65c9e2f4e83a47f3d714e9d25838212020454fe6 Binary files /dev/null and b/data/imgen/viz3/7d071c8065_10.png differ diff --git a/data/imgen/viz3/7d071c8065_2.png b/data/imgen/viz3/7d071c8065_2.png new file mode 100644 index 0000000000000000000000000000000000000000..44aa9a4f837e895ef8f19337861adda5e19c2de1 Binary files /dev/null and b/data/imgen/viz3/7d071c8065_2.png differ diff --git a/data/imgen/viz3/7d071c8065_3.png b/data/imgen/viz3/7d071c8065_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6e26db0cb6b99edd11ddb80a446b828d339724db Binary files /dev/null and b/data/imgen/viz3/7d071c8065_3.png differ diff --git a/data/imgen/viz3/7d071c8065_4.png b/data/imgen/viz3/7d071c8065_4.png new file mode 100644 index 0000000000000000000000000000000000000000..959b95377806a5e0d5991ce35d094fb2101020eb Binary files /dev/null and b/data/imgen/viz3/7d071c8065_4.png differ diff --git a/data/imgen/viz3/7d071c8065_5.png b/data/imgen/viz3/7d071c8065_5.png new file mode 100644 index 0000000000000000000000000000000000000000..bf0af65119e7de378252f0761a8459f2813d1b9c Binary files /dev/null and b/data/imgen/viz3/7d071c8065_5.png differ diff --git a/data/imgen/viz3/7d071c8065_6.png b/data/imgen/viz3/7d071c8065_6.png new file mode 100644 index 0000000000000000000000000000000000000000..174dd4285858038165851906fc57eeeedd73da98 Binary files /dev/null and b/data/imgen/viz3/7d071c8065_6.png differ diff --git a/data/imgen/viz3/7d071c8065_7.png b/data/imgen/viz3/7d071c8065_7.png new file mode 100644 index 0000000000000000000000000000000000000000..dd6219da4f93305b317c03d97cdc7777b4e24770 Binary files /dev/null and b/data/imgen/viz3/7d071c8065_7.png differ diff --git a/data/imgen/viz3/7d071c8065_8.png b/data/imgen/viz3/7d071c8065_8.png new file mode 100644 index 0000000000000000000000000000000000000000..b6d5ca3204992af0deccbe184e576f1c869e5aa1 Binary files /dev/null and b/data/imgen/viz3/7d071c8065_8.png differ diff --git a/data/imgen/viz3/7d071c8065_9.png b/data/imgen/viz3/7d071c8065_9.png new file mode 100644 index 0000000000000000000000000000000000000000..87af6d9d2dd73b49921bdfcadfd2fa4005abf3a2 Binary files /dev/null and b/data/imgen/viz3/7d071c8065_9.png differ diff --git a/data/imgen/viz3/9bbdc711f7_0.png b/data/imgen/viz3/9bbdc711f7_0.png new file mode 100644 index 0000000000000000000000000000000000000000..be2f34026be8fc40a8e45cdf54819159c24cccc5 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_0.png differ diff --git a/data/imgen/viz3/9bbdc711f7_1.png b/data/imgen/viz3/9bbdc711f7_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ea2168b98ba9d29d8b5e72be22111fe15c5f69c0 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_1.png differ diff --git a/data/imgen/viz3/9bbdc711f7_10.png b/data/imgen/viz3/9bbdc711f7_10.png new file mode 100644 index 0000000000000000000000000000000000000000..aeda9fcec0ec63ea079db1ebc914c0e683bb154c Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_10.png differ diff --git a/data/imgen/viz3/9bbdc711f7_2.png b/data/imgen/viz3/9bbdc711f7_2.png new file mode 100644 index 0000000000000000000000000000000000000000..b6357f2f1213486db9276ea5d2fc5eecbe91d3b3 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_2.png differ diff --git a/data/imgen/viz3/9bbdc711f7_3.png b/data/imgen/viz3/9bbdc711f7_3.png new file mode 100644 index 0000000000000000000000000000000000000000..faa2afe15802406bffa76f086ef29dc86c74ae77 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_3.png differ diff --git a/data/imgen/viz3/9bbdc711f7_4.png b/data/imgen/viz3/9bbdc711f7_4.png new file mode 100644 index 0000000000000000000000000000000000000000..86b04d5b2e5154e4ca0c809b43106e050908c36c Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_4.png differ diff --git a/data/imgen/viz3/9bbdc711f7_5.png b/data/imgen/viz3/9bbdc711f7_5.png new file mode 100644 index 0000000000000000000000000000000000000000..15dbfd80aef3304700af439874d7147649defd64 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_5.png differ diff --git a/data/imgen/viz3/9bbdc711f7_6.png b/data/imgen/viz3/9bbdc711f7_6.png new file mode 100644 index 0000000000000000000000000000000000000000..e3a684469a6ca1dfc8c4a02b235860d55d1ab762 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_6.png differ diff --git a/data/imgen/viz3/9bbdc711f7_7.png b/data/imgen/viz3/9bbdc711f7_7.png new file mode 100644 index 0000000000000000000000000000000000000000..56230d0bdfdbedd841b5460b84f8322789bd6ae2 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_7.png differ diff --git a/data/imgen/viz3/9bbdc711f7_8.png b/data/imgen/viz3/9bbdc711f7_8.png new file mode 100644 index 0000000000000000000000000000000000000000..bd7a94b0333c492cd1280b3b6fe4632f0f7129b6 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_8.png differ diff --git a/data/imgen/viz3/9bbdc711f7_9.png b/data/imgen/viz3/9bbdc711f7_9.png new file mode 100644 index 0000000000000000000000000000000000000000..4a51892c207dd87f91ec49ead35c0d04a340bf57 Binary files /dev/null and b/data/imgen/viz3/9bbdc711f7_9.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_0.png b/data/imgen/viz3/9f9a9c83a9_0.png new file mode 100644 index 0000000000000000000000000000000000000000..86b45d88eca90d8879e70da87cc99c7f7c1f4275 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_0.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_1.png b/data/imgen/viz3/9f9a9c83a9_1.png new file mode 100644 index 0000000000000000000000000000000000000000..79d6759c055c4ca729edada8e2abd9f6b350a0e3 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_1.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_10.png b/data/imgen/viz3/9f9a9c83a9_10.png new file mode 100644 index 0000000000000000000000000000000000000000..e42221e9f077ed61c30059b2d873472c9c16bb21 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_10.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_2.png b/data/imgen/viz3/9f9a9c83a9_2.png new file mode 100644 index 0000000000000000000000000000000000000000..bceff560aa18786e23b972394b04667636864236 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_2.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_3.png b/data/imgen/viz3/9f9a9c83a9_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6d0bca83a38fa3df24bb4670576e6923264fd37f Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_3.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_4.png b/data/imgen/viz3/9f9a9c83a9_4.png new file mode 100644 index 0000000000000000000000000000000000000000..909805478af3404134c682727e1f6ba140fd7173 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_4.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_5.png b/data/imgen/viz3/9f9a9c83a9_5.png new file mode 100644 index 0000000000000000000000000000000000000000..53987a84a0ecec56e7ced3e53871f29fddbd63b0 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_5.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_6.png b/data/imgen/viz3/9f9a9c83a9_6.png new file mode 100644 index 0000000000000000000000000000000000000000..8f1440e2383b12116fc40e767acb8f785b4f9235 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_6.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_7.png b/data/imgen/viz3/9f9a9c83a9_7.png new file mode 100644 index 0000000000000000000000000000000000000000..7c8f73f826daf81e6fdd30898f65392afe023d1e Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_7.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_8.png b/data/imgen/viz3/9f9a9c83a9_8.png new file mode 100644 index 0000000000000000000000000000000000000000..eb2692397bef4663e8c3c619b80a2735f93f3a65 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_8.png differ diff --git a/data/imgen/viz3/9f9a9c83a9_9.png b/data/imgen/viz3/9f9a9c83a9_9.png new file mode 100644 index 0000000000000000000000000000000000000000..69c7af8731020fc9c724927346987633f00c9fd8 Binary files /dev/null and b/data/imgen/viz3/9f9a9c83a9_9.png differ diff --git a/data/imgen/viz3/a21c215fb1_0.png b/data/imgen/viz3/a21c215fb1_0.png new file mode 100644 index 0000000000000000000000000000000000000000..521a9f325177f6944b386f1276be03a0e4eb7b0b Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_0.png differ diff --git a/data/imgen/viz3/a21c215fb1_1.png b/data/imgen/viz3/a21c215fb1_1.png new file mode 100644 index 0000000000000000000000000000000000000000..a9205096fc64bc9512133023c333bfb1a64d404a Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_1.png differ diff --git a/data/imgen/viz3/a21c215fb1_10.png b/data/imgen/viz3/a21c215fb1_10.png new file mode 100644 index 0000000000000000000000000000000000000000..507970382fe0c7cf462fb8f6901e80fd3bc4e4d3 Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_10.png differ diff --git a/data/imgen/viz3/a21c215fb1_2.png b/data/imgen/viz3/a21c215fb1_2.png new file mode 100644 index 0000000000000000000000000000000000000000..eb2b4e65c9b61cd47f438c643ac4a833ff63e689 Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_2.png differ diff --git a/data/imgen/viz3/a21c215fb1_3.png b/data/imgen/viz3/a21c215fb1_3.png new file mode 100644 index 0000000000000000000000000000000000000000..154a99aa2030762bf619b8f54d964587071fa236 Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_3.png differ diff --git a/data/imgen/viz3/a21c215fb1_4.png b/data/imgen/viz3/a21c215fb1_4.png new file mode 100644 index 0000000000000000000000000000000000000000..4df13d26565cdbe19a542a026d8d8869e02797b7 Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_4.png differ diff --git a/data/imgen/viz3/a21c215fb1_5.png b/data/imgen/viz3/a21c215fb1_5.png new file mode 100644 index 0000000000000000000000000000000000000000..83301d6d3c906eb13f5208b197dd0118beb2f57b Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_5.png differ diff --git a/data/imgen/viz3/a21c215fb1_6.png b/data/imgen/viz3/a21c215fb1_6.png new file mode 100644 index 0000000000000000000000000000000000000000..b2fbee46c077d1c82eeb427a831c2c947e297162 Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_6.png differ diff --git a/data/imgen/viz3/a21c215fb1_7.png b/data/imgen/viz3/a21c215fb1_7.png new file mode 100644 index 0000000000000000000000000000000000000000..079b65427c59d2302d907baec0f9c6954208c61a Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_7.png differ diff --git a/data/imgen/viz3/a21c215fb1_8.png b/data/imgen/viz3/a21c215fb1_8.png new file mode 100644 index 0000000000000000000000000000000000000000..d06b1cd159b0f09e82fdcc57b7002b31594cc963 Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_8.png differ diff --git a/data/imgen/viz3/a21c215fb1_9.png b/data/imgen/viz3/a21c215fb1_9.png new file mode 100644 index 0000000000000000000000000000000000000000..4b92c4cb3bcf6175b48f2648ba489e529ba1041a Binary files /dev/null and b/data/imgen/viz3/a21c215fb1_9.png differ diff --git a/data/imgen/viz3/c2367c5e8c_0.png b/data/imgen/viz3/c2367c5e8c_0.png new file mode 100644 index 0000000000000000000000000000000000000000..5ab03b18f15bb925152eef42b26e821be9952942 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_0.png differ diff --git a/data/imgen/viz3/c2367c5e8c_1.png b/data/imgen/viz3/c2367c5e8c_1.png new file mode 100644 index 0000000000000000000000000000000000000000..9eacd9413c92ed3b7d3f4c3a4419998f68e676f4 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_1.png differ diff --git a/data/imgen/viz3/c2367c5e8c_10.png b/data/imgen/viz3/c2367c5e8c_10.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a93461945ca93c6937ffe8c1df7029ff1430e7 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_10.png differ diff --git a/data/imgen/viz3/c2367c5e8c_2.png b/data/imgen/viz3/c2367c5e8c_2.png new file mode 100644 index 0000000000000000000000000000000000000000..0b1a81381fced351567e0d874f983e6b37ca43b5 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_2.png differ diff --git a/data/imgen/viz3/c2367c5e8c_3.png b/data/imgen/viz3/c2367c5e8c_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6334fe2eba6bd81403b402cebadccb01bb3278b3 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_3.png differ diff --git a/data/imgen/viz3/c2367c5e8c_4.png b/data/imgen/viz3/c2367c5e8c_4.png new file mode 100644 index 0000000000000000000000000000000000000000..b9409a909b1356a20316f73c3a4f1ba04928ff30 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_4.png differ diff --git a/data/imgen/viz3/c2367c5e8c_5.png b/data/imgen/viz3/c2367c5e8c_5.png new file mode 100644 index 0000000000000000000000000000000000000000..84adaf788ab58aaf99bff546897a9d07af7b222f Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_5.png differ diff --git a/data/imgen/viz3/c2367c5e8c_6.png b/data/imgen/viz3/c2367c5e8c_6.png new file mode 100644 index 0000000000000000000000000000000000000000..7a501bc17739519c9aa487faa8a826149964504e Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_6.png differ diff --git a/data/imgen/viz3/c2367c5e8c_7.png b/data/imgen/viz3/c2367c5e8c_7.png new file mode 100644 index 0000000000000000000000000000000000000000..690c0e434a964ce8936f5b645b3733e9be39a6fb Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_7.png differ diff --git a/data/imgen/viz3/c2367c5e8c_8.png b/data/imgen/viz3/c2367c5e8c_8.png new file mode 100644 index 0000000000000000000000000000000000000000..cf024434a1b27b115c10a1f9de2e391ac0e959f2 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_8.png differ diff --git a/data/imgen/viz3/c2367c5e8c_9.png b/data/imgen/viz3/c2367c5e8c_9.png new file mode 100644 index 0000000000000000000000000000000000000000..950975819725e11fac02db17d70121371bb66ed6 Binary files /dev/null and b/data/imgen/viz3/c2367c5e8c_9.png differ diff --git a/data/imgen/viz3/cf7615bf4c_0.png b/data/imgen/viz3/cf7615bf4c_0.png new file mode 100644 index 0000000000000000000000000000000000000000..9d09189066ff38614557169be32d7f2b4dc637e6 Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_0.png differ diff --git a/data/imgen/viz3/cf7615bf4c_1.png b/data/imgen/viz3/cf7615bf4c_1.png new file mode 100644 index 0000000000000000000000000000000000000000..506c3b482e925e1131b5cb71c1ebf487d415a397 Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_1.png differ diff --git a/data/imgen/viz3/cf7615bf4c_10.png b/data/imgen/viz3/cf7615bf4c_10.png new file mode 100644 index 0000000000000000000000000000000000000000..ab476864cc867234d28a94024158593f1dcd6a73 Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_10.png differ diff --git a/data/imgen/viz3/cf7615bf4c_2.png b/data/imgen/viz3/cf7615bf4c_2.png new file mode 100644 index 0000000000000000000000000000000000000000..00ae2e077d09a06fd3c9e36fac828b5440971dfd Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_2.png differ diff --git a/data/imgen/viz3/cf7615bf4c_3.png b/data/imgen/viz3/cf7615bf4c_3.png new file mode 100644 index 0000000000000000000000000000000000000000..7ac860c45317e4f39ac2ecd49c3f359bdfb052b1 Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_3.png differ diff --git a/data/imgen/viz3/cf7615bf4c_4.png b/data/imgen/viz3/cf7615bf4c_4.png new file mode 100644 index 0000000000000000000000000000000000000000..02cf531209cf1c5dd4bc81cd36c1f5875b04d0fb Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_4.png differ diff --git a/data/imgen/viz3/cf7615bf4c_5.png b/data/imgen/viz3/cf7615bf4c_5.png new file mode 100644 index 0000000000000000000000000000000000000000..361509a295919db5f94a6397b581aea5516e4e66 Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_5.png differ diff --git a/data/imgen/viz3/cf7615bf4c_6.png b/data/imgen/viz3/cf7615bf4c_6.png new file mode 100644 index 0000000000000000000000000000000000000000..a2a2619bde96f94db610e8fa882fdf1969c624eb Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_6.png differ diff --git a/data/imgen/viz3/cf7615bf4c_7.png b/data/imgen/viz3/cf7615bf4c_7.png new file mode 100644 index 0000000000000000000000000000000000000000..7d6a9a350da444c5fb4cc54720a86a2a03a664ba Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_7.png differ diff --git a/data/imgen/viz3/cf7615bf4c_8.png b/data/imgen/viz3/cf7615bf4c_8.png new file mode 100644 index 0000000000000000000000000000000000000000..931bf757b11d46fa4dff62579e10b880fd5a4bdf Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_8.png differ diff --git a/data/imgen/viz3/cf7615bf4c_9.png b/data/imgen/viz3/cf7615bf4c_9.png new file mode 100644 index 0000000000000000000000000000000000000000..e3e4c444ac310c2574e15cdedfea2979044d464c Binary files /dev/null and b/data/imgen/viz3/cf7615bf4c_9.png differ diff --git a/data/imgen/viz3/dd1fe83e32_0.png b/data/imgen/viz3/dd1fe83e32_0.png new file mode 100644 index 0000000000000000000000000000000000000000..024bdb64ed13d85467ecbd61ccb60773f8393510 Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_0.png differ diff --git a/data/imgen/viz3/dd1fe83e32_1.png b/data/imgen/viz3/dd1fe83e32_1.png new file mode 100644 index 0000000000000000000000000000000000000000..bdfb87672f8b979f8cd402486c58da6c334db6b5 Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_1.png differ diff --git a/data/imgen/viz3/dd1fe83e32_10.png b/data/imgen/viz3/dd1fe83e32_10.png new file mode 100644 index 0000000000000000000000000000000000000000..de3c61911543b19e02d48f6ce055caa274f9d7cb Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_10.png differ diff --git a/data/imgen/viz3/dd1fe83e32_2.png b/data/imgen/viz3/dd1fe83e32_2.png new file mode 100644 index 0000000000000000000000000000000000000000..2c03ae2597f1e1955138b8f620477ac397ce2f0e Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_2.png differ diff --git a/data/imgen/viz3/dd1fe83e32_3.png b/data/imgen/viz3/dd1fe83e32_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6e99a8f793128015acb113aab7b083fecd86fe75 Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_3.png differ diff --git a/data/imgen/viz3/dd1fe83e32_4.png b/data/imgen/viz3/dd1fe83e32_4.png new file mode 100644 index 0000000000000000000000000000000000000000..f65969a2f5435f59c87f73a18fbc3205dbbb94d4 Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_4.png differ diff --git a/data/imgen/viz3/dd1fe83e32_5.png b/data/imgen/viz3/dd1fe83e32_5.png new file mode 100644 index 0000000000000000000000000000000000000000..cbedacfa2a052c3ce4fc5184c2a53cd3aea7195d Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_5.png differ diff --git a/data/imgen/viz3/dd1fe83e32_6.png b/data/imgen/viz3/dd1fe83e32_6.png new file mode 100644 index 0000000000000000000000000000000000000000..13dc49e3ef3102b8277f42640c34339dca5bc01d Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_6.png differ diff --git a/data/imgen/viz3/dd1fe83e32_7.png b/data/imgen/viz3/dd1fe83e32_7.png new file mode 100644 index 0000000000000000000000000000000000000000..44bdafdea2176d56a87f60eafa4cfeb263738b94 Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_7.png differ diff --git a/data/imgen/viz3/dd1fe83e32_8.png b/data/imgen/viz3/dd1fe83e32_8.png new file mode 100644 index 0000000000000000000000000000000000000000..113c75808df41707a60692541ae0f2484f360cea Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_8.png differ diff --git a/data/imgen/viz3/dd1fe83e32_9.png b/data/imgen/viz3/dd1fe83e32_9.png new file mode 100644 index 0000000000000000000000000000000000000000..686c7b676de96d5c067568492316b022d8b35338 Binary files /dev/null and b/data/imgen/viz3/dd1fe83e32_9.png differ diff --git a/data/locomotion/pareto/humanoid_averse_taker_with_morl.pkl b/data/locomotion/pareto/humanoid_averse_taker_with_morl.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d85ccba2c91e975c2d93145529ccfa97638b46dc Binary files /dev/null and b/data/locomotion/pareto/humanoid_averse_taker_with_morl.pkl differ diff --git a/data/locomotion/trajectories/0.html b/data/locomotion/trajectories/0.html new file mode 100644 index 0000000000000000000000000000000000000000..d57fc73a56f75bfccf59adf1325cb36d67a7a017 --- /dev/null +++ b/data/locomotion/trajectories/0.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/1.html b/data/locomotion/trajectories/1.html new file mode 100644 index 0000000000000000000000000000000000000000..021d7af124814b79e971bad67e5035c5b10c6666 --- /dev/null +++ b/data/locomotion/trajectories/1.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/10.html b/data/locomotion/trajectories/10.html new file mode 100644 index 0000000000000000000000000000000000000000..499c9dcda7f6fa040269f3cc99241eea18c7cc36 --- /dev/null +++ b/data/locomotion/trajectories/10.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/11.html b/data/locomotion/trajectories/11.html new file mode 100644 index 0000000000000000000000000000000000000000..0d9c4f4b5336a8254c400e9c600e060720f264ba --- /dev/null +++ b/data/locomotion/trajectories/11.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/12.html b/data/locomotion/trajectories/12.html new file mode 100644 index 0000000000000000000000000000000000000000..34cf193c79db3cb39884d031305f11ad7b5b9ef2 --- /dev/null +++ b/data/locomotion/trajectories/12.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/13.html b/data/locomotion/trajectories/13.html new file mode 100644 index 0000000000000000000000000000000000000000..f98cb5d09b1198b201aaee18746f2e8589945f8a --- /dev/null +++ b/data/locomotion/trajectories/13.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/14.html b/data/locomotion/trajectories/14.html new file mode 100644 index 0000000000000000000000000000000000000000..210b2d992d0bedb4b80ba1cc14242d7fc747547c --- /dev/null +++ b/data/locomotion/trajectories/14.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/15.html b/data/locomotion/trajectories/15.html new file mode 100644 index 0000000000000000000000000000000000000000..7dc34d47b58c547cfc51c2fade45a5e4871ce182 --- /dev/null +++ b/data/locomotion/trajectories/15.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/16.html b/data/locomotion/trajectories/16.html new file mode 100644 index 0000000000000000000000000000000000000000..eacabf867b3a1c40ea931dc8659d5cd00bd175f0 --- /dev/null +++ b/data/locomotion/trajectories/16.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/17.html b/data/locomotion/trajectories/17.html new file mode 100644 index 0000000000000000000000000000000000000000..2988216ba8c52c0230b7feb1e52ae07f3d431230 --- /dev/null +++ b/data/locomotion/trajectories/17.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/18.html b/data/locomotion/trajectories/18.html new file mode 100644 index 0000000000000000000000000000000000000000..56554550fd6de29c85e10eab62068f94148a4b9f --- /dev/null +++ b/data/locomotion/trajectories/18.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/19.html b/data/locomotion/trajectories/19.html new file mode 100644 index 0000000000000000000000000000000000000000..1083a2d57d16d3dd5558368160b5dd13fd8f1021 --- /dev/null +++ b/data/locomotion/trajectories/19.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/2.html b/data/locomotion/trajectories/2.html new file mode 100644 index 0000000000000000000000000000000000000000..d433dbea8eb2d1fb5aeff5f48182a040f81b6851 --- /dev/null +++ b/data/locomotion/trajectories/2.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/20.html b/data/locomotion/trajectories/20.html new file mode 100644 index 0000000000000000000000000000000000000000..48e288188baab4ea89f2a58a030da11783128317 --- /dev/null +++ b/data/locomotion/trajectories/20.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/3.html b/data/locomotion/trajectories/3.html new file mode 100644 index 0000000000000000000000000000000000000000..d06410e3a1e34432da3d40ef256127f0f30c1662 --- /dev/null +++ b/data/locomotion/trajectories/3.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/4.html b/data/locomotion/trajectories/4.html new file mode 100644 index 0000000000000000000000000000000000000000..edf9a3cb539c122ba04a53f6248346832b75442e --- /dev/null +++ b/data/locomotion/trajectories/4.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/5.html b/data/locomotion/trajectories/5.html new file mode 100644 index 0000000000000000000000000000000000000000..6e8c341a243a4ca66006b9e7b36d8e991657c21a --- /dev/null +++ b/data/locomotion/trajectories/5.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/6.html b/data/locomotion/trajectories/6.html new file mode 100644 index 0000000000000000000000000000000000000000..85aa284e72c18f6f263d2ae1008a29ee02d4db1b --- /dev/null +++ b/data/locomotion/trajectories/6.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/7.html b/data/locomotion/trajectories/7.html new file mode 100644 index 0000000000000000000000000000000000000000..3658ddf63dfdfcd4348cff7ba9d601d6a28f2b77 --- /dev/null +++ b/data/locomotion/trajectories/7.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/8.html b/data/locomotion/trajectories/8.html new file mode 100644 index 0000000000000000000000000000000000000000..d5d382cd05237168258b9c2874d287b212256b09 --- /dev/null +++ b/data/locomotion/trajectories/8.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/locomotion/trajectories/9.html b/data/locomotion/trajectories/9.html new file mode 100644 index 0000000000000000000000000000000000000000..b94815543e5d58f1bde222fbfc003e95b3e222f6 --- /dev/null +++ b/data/locomotion/trajectories/9.html @@ -0,0 +1,48 @@ + + + + brax visualizer + + + + +
+ + + diff --git a/data/textgen/data.pkl b/data/textgen/data.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d9fed94055a30a5733e323dc73f3fbfc00e526da Binary files /dev/null and b/data/textgen/data.pkl differ diff --git a/data/textgen/data_prompt.pkl b/data/textgen/data_prompt.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f3dcafe177b01cf48839142a3fcbc024a5715d96 Binary files /dev/null and b/data/textgen/data_prompt.pkl differ diff --git a/data/textgen/data_title.pkl b/data/textgen/data_title.pkl new file mode 100644 index 0000000000000000000000000000000000000000..abdd81d9b4b4018c1bcb2dd4669c65c7d9b47507 Binary files /dev/null and b/data/textgen/data_title.pkl differ diff --git "a/pages/01_\342\234\215\357\270\217_News_summarization.py" "b/pages/01_\342\234\215\357\270\217_News_summarization.py" new file mode 100644 index 0000000000000000000000000000000000000000..e7260b907afd3b62352dad3fd17ebf98c86259f7 --- /dev/null +++ "b/pages/01_\342\234\215\357\270\217_News_summarization.py" @@ -0,0 +1,296 @@ +import streamlit as st +from PIL import Image +import codecs +import streamlit.components.v1 as components +from utils import inject_custom_css +import streamlit as st +from streamlit_plotly_events import plotly_events +import pickle +import matplotlib.pyplot as plt +import plotly.graph_objects as go +import typing as tp +import colorsys + +plt.style.use('default') +plt.rcParams['text.usetex'] = True +plt.rcParams['font.family'] = 'serif' + + +def interpolate_color(color1, color2, factor): + """Interpolates between two RGB colors. Factor is between 0 and 1.""" + color1 = colorsys.rgb_to_hls( + int(color1[1:3], 16) / 255.0, + int(color1[3:5], 16) / 255.0, + int(color1[5:], 16) / 255.0 + ) + color2 = colorsys.rgb_to_hls( + int(color2[1:3], 16) / 255.0, + int(color2[3:5], 16) / 255.0, + int(color2[5:], 16) / 255.0 + ) + new_color = [color1[i] * (1 - factor) + color2[i] * factor for i in range(3)] + new_color = colorsys.hls_to_rgb(*new_color) + return '#{:02x}{:02x}{:02x}'.format( + int(new_color[0] * 255), int(new_color[1] * 255), int(new_color[2] * 255) + ) + + +color1 = "#fa7659" +color2 = "#6dafd7" + +shapes = [ + dict( + type="rect", + xref="paper", + yref="paper", + x0=0, + y0=0, + x1=1, + y1=1, + line=dict( + color="Black", + width=2, + ), + ) +] + +shapes = [ + dict( + type="rect", + xref="paper", + yref="paper", + x0=0, + y0=0, + x1=1, + y1=1, + line=dict( + color="Black", + width=2, + ), + ) +] + + +def plot_pareto(dict_results: tp.Dict): + + reward1_key = "R1" + reward2_key = "R2" + + # Series for "wa" + dict_results["wa_d"] = [x for i, x in enumerate(dict_results["wa_d"]) if i % 2 == 0] + lambda_values_wa = [ + round(i / (len(dict_results["wa_d"]) - 1), 2) for i in range(len(dict_results["wa_d"])) + ][::-1] + reward1_values_wa = [item[reward1_key] for item in dict_results["wa_d"]] + reward2_values_wa = [item[reward2_key] for item in dict_results["wa_d"]] + + # Series for "morl" + # Series for "init" + reward1_values_morl = [dict_results["morl"][reward1_key]] + reward2_values_morl = [dict_results["morl"][reward2_key]] + + # Series for "init" + reward1_values_init = [dict_results["init"][reward1_key]] + reward2_values_init = [dict_results["init"][reward2_key]] + + layout = go.Layout(autosize=False, width=1000, height=1000) + fig = go.Figure(layout=layout) + + for i in range(len(reward1_values_wa) - 1): + fig.add_trace( + go.Scatter( + x=reward1_values_wa[i:i + 2], + y=reward2_values_wa[i:i + 2], + mode='lines', + hoverinfo='skip', + line=dict( + color=interpolate_color(color1, color2, i / (len(reward1_values_wa) - 1)), + width=2 + ), + showlegend=False + ) + ) + + # Plot for "wa" + fig.add_trace( + go.Scatter( + x=reward1_values_wa, + y=reward2_values_wa, + mode='markers', + name='Rewarded soups: 0鈮の烩墹1', + hoverinfo='text', + hovertext=[f'位={lmbda}' for lmbda in lambda_values_wa], + marker=dict( + color=[ + interpolate_color(color1, color2, i / len(lambda_values_wa)) + for i in range(len(lambda_values_wa)) + ], + size=10 + ) + ) + ) + + # Plot for "morl" + fig.add_trace( + go.Scatter( + x=reward1_values_morl, + y=reward2_values_morl, + mode='markers', + name='MORL: 渭=0.5', + hoverinfo='skip', + marker=dict(color='#A45EE9', size=15, symbol="star"), + ) + ) + + # Plot for "init" + fig.add_trace( + go.Scatter( + x=reward1_values_init, + y=reward2_values_init, + mode='markers', + name='Pre-trained init', + hoverinfo='skip', + marker=dict(color='#9f9bc8', size=15, symbol="star"), + ) + ) + + fig.update_layout( + xaxis=dict( + #range = [5.21,5.31], + #nticks=6, + showticklabels=True, + ticks='outside', + tickfont=dict(size=18,), + title=dict(text="R1", font=dict(size=18), standoff=10), + showgrid=False, + zeroline=False, + hoverformat='.2f' + ), + yaxis=dict( + #range = [0.78,0.825], + #nticks=7, + showticklabels=True, + ticks='outside', + tickfont=dict(size=18,), + title=dict(text="R2", font=dict(size=18), standoff=10), + showgrid=False, + zeroline=False, + hoverformat='.2f' + ), + font=dict(family="Roboto", size=12, color="Black"), + hovermode='x unified', + autosize=False, + width=500, + height=500, + margin=dict(l=100, r=50, b=150, t=20, pad=0), + paper_bgcolor="White", + plot_bgcolor="White", + shapes=shapes, + legend=dict( + x=0.5, + y=0.03, + traceorder="normal", + font=dict(family="Roboto", size=12, color="black"), + bgcolor="White", + bordercolor="Black", + borderwidth=1 + ) + ) + + return fig + + +def run(): + + st.write( + f""" + + + +

RLHF of LLaMA for diverse news summarization

""", + unsafe_allow_html=True + ) + + st.markdown( + r""" +Given the importance of RLHF to train LLMs, we begin our experiments with text-to-text generation. +Our pre-trained network is LLaMA-7b, instruction fine-tuned on Alpaca. +For RL training with PPO, we employ the trl package and the setup from with low-rank adapters (LoRA) for efficiency. +Here we consider summarization on Reuter news. +To evaluate the summary in the absence of supervision, we utilized two different reward models, available on HuggingFace: [$R_1$](https://huggingface.co./Tristan/gpt2_reward_summarization) follows the Summarize from Human Feedback paper while [$R_2$](https://huggingface.co./CogComp/bart-faithful-summary-detector) leverages contrast candidate generation. + +Our results below reveal the following insights. The front defined by rewarded soups between the two weights specialized on $R_1$ (i.e., $\lambda=0.0$) and $R_2$ (i.e., $\lambda=1.0$) is above the straight line connecting those two points; this validates what we call in the paper *the linear mode connectivity hypothesis*. Moreover, the front intersects the point obtained by multi-objective RL (MORL) fine-tuning on $(1-\mu) \times R_1 + \mu \times R_2$ for $\mu=0.5$ (i.e., the average of the two rewards). Interestingly, when we compare both full fronts in the paper, they exhibit qualitatively the same shape. The qualitative visual inspections of the generations show that increasing $\lambda$ leads to shorter but more factual summaries; this is because $R_2$ evaluates faithfulness in priority.""", + unsafe_allow_html=True + ) + st.markdown( + """

Click on a rewarded soup point on the left and select a subject on the right!

""", + unsafe_allow_html=True + ) + + files = [] + + with open("streamlit_app/data/textgen/data.pkl", "rb") as f: + data = pickle.load(f) + with open("streamlit_app/data/textgen/data_prompt.pkl", "rb") as f: + data_prompt = pickle.load(f) + with open("streamlit_app/data/textgen/data_title.pkl", "rb") as f: + data_title = pickle.load(f) + + left, right = st.columns((2, 2)) + with left: + fig = plot_pareto(data) + onclick = plotly_events(fig, click_event=True) + with right: + option = st.selectbox('', data_title.keys()) + + subject = data_title[option] + st.markdown( + f""" +
+
+ Text to summarize: +
+
+ {data_prompt[subject]['query']} +
+
+ """, + unsafe_allow_html=True + ) + st.markdown("
", unsafe_allow_html=True) + + summary1 = data_prompt[subject]['outs'][0]["out"] + summary3 = data_prompt[subject]['outs'][-1]["out"] + nb_summaries = len(data_prompt[subject]['outs']) + if len(onclick) > 0: + idx = onclick[0]["pointIndex"] + else: + idx = 5 + lambda2 = round(1 - idx / (len(data["wa_d"]) - 1), 2) + summary2 = data_prompt[subject]['outs'][idx]["out"] + bgcolor = interpolate_color(color2, color1, lambda2) + + st.markdown( + f""" +
+
+ Generated summaries: +
+
+
位=0.0
{summary3}
+
位={lambda2}
{summary2}

+
位=1.0
{summary1}

+
+
+ """, + unsafe_allow_html=True + ) + + +if __name__ == "__main__": + img = Image.open("streamlit_app/assets/images/icon.png") + st.set_page_config(page_title="Rewarded soups", page_icon=img, layout="wide") + inject_custom_css("streamlit_app/assets/styles.css") + st.set_option('deprecation.showPyplotGlobalUse', False) + run() diff --git "a/pages/02_\360\237\216\250_Image_generation.py" "b/pages/02_\360\237\216\250_Image_generation.py" new file mode 100644 index 0000000000000000000000000000000000000000..1145c6e1ba1d21e18c352728edf81c64e7fa07b5 --- /dev/null +++ "b/pages/02_\360\237\216\250_Image_generation.py" @@ -0,0 +1,252 @@ +import streamlit as st +from PIL import Image +import codecs +import streamlit.components.v1 as components +from utils import inject_custom_css +import streamlit as st +from streamlit_plotly_events import plotly_events +import pickle +import matplotlib.pyplot as plt +import plotly.graph_objects as go +import typing as tp +import colorsys + +plt.style.use('default') + +def interpolate_color(color1, color2, factor): + """Interpolates between two RGB colors. Factor is between 0 and 1.""" + color1 = colorsys.rgb_to_hls(int(color1[1:3], 16)/255.0, int(color1[3:5], 16)/255.0, int(color1[5:], 16)/255.0) + color2 = colorsys.rgb_to_hls(int(color2[1:3], 16)/255.0, int(color2[3:5], 16)/255.0, int(color2[5:], 16)/255.0) + new_color = [color1[i] * (1 - factor) + color2[i] * factor for i in range(3)] + new_color = colorsys.hls_to_rgb(*new_color) + return '#{:02x}{:02x}{:02x}'.format(int(new_color[0]*255), int(new_color[1]*255), int(new_color[2]*255)) + + +color1 = "#fa7659" +color2 = "#6dafd7" + +shapes=[ + dict( + type="rect", + xref="paper", + yref="paper", + x0=0, + y0=0, + x1=1, + y1=1, + line=dict( + color="Black", + width=2, + ), + ) +] + +def plot_pareto(dict_results: tp.Dict): + + reward1_key = "ava" + reward2_key = "cafe" + + # Series for "wa" + lambda_values_wa = [round(1 - i/(len(dict_results["wa_d"])-1),2) for i in range(len(dict_results["wa_d"]))] + reward1_values_wa = [item[reward1_key] for item in dict_results["wa_d"]] + reward2_values_wa = [item[reward2_key] for item in dict_results["wa_d"]] + + # Series for "morl" + mu_values_morl = [round(1 - i/(len(dict_results["morl_d"])-1),2) for i in range(len(dict_results["morl_d"]))] + reward1_values_morl = [item[reward1_key] for item in dict_results["morl_d"]][3] + reward2_values_morl = [item[reward2_key] for item in dict_results["morl_d"]][3] + + # Series for "init" + reward1_values_init = [dict_results["init"][reward1_key]] + reward2_values_init = [dict_results["init"][reward2_key]] + + + layout = go.Layout(autosize=False,width=1000,height=1000) + fig = go.Figure(layout=layout) + + + for i in range(len(reward1_values_wa) - 1): + fig.add_trace(go.Scatter( + x=reward1_values_wa[i:i+2], + y=reward2_values_wa[i:i+2], + mode='lines', + hoverinfo='skip', + line=dict( + color=interpolate_color(color1, color2, i/(len(reward1_values_wa)-1)), + width=2 + ), + showlegend=False + )) + + + # Plot for "wa" + fig.add_trace( + go.Scatter( + x=reward1_values_wa, + y=reward2_values_wa, + mode='markers', + name='Rewarded soups: 0鈮の烩墹1', + hoverinfo='text', + hovertext=[f'位={lmbda}' for lmbda in lambda_values_wa], + marker=dict( + color=[ + interpolate_color(color1, color2, i / len(lambda_values_wa)) + for i in range(len(lambda_values_wa)) + ], + size=10 + ) + ) + ) + + # Plot for "morl" + fig.add_trace( + go.Scatter( + x=[reward1_values_morl], + y=[reward2_values_morl], + mode='markers', + name='MORL: 渭=0.5', + hoverinfo='skip', + marker=dict(color='#A45EE9', size=15, symbol="star") + ) + ) + + # Plot for "init" + fig.add_trace( + go.Scatter( + x=reward1_values_init, + y=reward2_values_init, + mode='markers', + name='Pre-trained init', + hoverinfo='skip', + marker=dict(color='#9f9bc8', size=15, symbol="star"), + ) + ) + + fig.update_layout( + xaxis=dict( + showticklabels=True, + ticks='outside', + tickfont=dict(size=18,), + title=dict(text="Ava reward", font=dict(size=18), standoff=10), + showgrid=False, + zeroline=False, + hoverformat='.2f' + ), + yaxis=dict( + showticklabels=True, + ticks='outside', + tickfont=dict(size=18,), + title=dict(text="Cafe reward", font=dict(size=18), standoff=10), + showgrid=False, + zeroline=False, + hoverformat='.2f' + ), + font=dict(family="Roboto", size=12, color="Black"), + hovermode='x unified', + autosize=False, + width=500, + height=500, + margin=dict(l=100, r=50, b=150, t=20, pad=0), + paper_bgcolor="White", + plot_bgcolor="White", + shapes=shapes, + legend=dict( + x=0.5, + y=0.03, + traceorder="normal", + font=dict(family="Roboto", size=12, color="black"), + bgcolor="White", + bordercolor="Black", + borderwidth=1 + ) + ) + + return fig + +def run(): + + st.write( + f""" + + + +

RLHF of diffusion model for diverse human aesthetics

""",unsafe_allow_html=True) + + st.markdown( + r""" +Beyond text generation, we now apply RS to align text-to-image generation with human feedbacks. +Here, we demonstrate how rewarded soups allows to interpolate between models fine-tuned for different aesthetic metrics. +Our network is a diffusion model with 2.2B parameters, pre-trained on an internal dataset of 300M images; it reaches similar quality as Stable Diffusion, which was not used for copyright reasons. +To represent the subjectivity of human aesthetics, we employ $N=2$ open-source reward models: [*ava*](https://github.com/christophschuhmann/improved-aesthetic-predictor/), trained on the AVA dataset, and [*cafe*](https://huggingface.co./cafeai/cafe_aesthetic), trained on a mix of real-life and manga images. +We first generate 10000 images; then, for each reward, we remove half of the images with the lowest reward's score and fine-tune 10\% of the parameters on the reward-weighted negative log-likelihood. + +Our results below show that interpolating between the expert models unveils a Pareto-optimal front, enabling alignment with a variety of aesthetic preferences. +Specifically, all interpolated models produce images of similar quality compared to fine-tuned models, demonstrating linear mode connectivity between the two fine-tuned models. +This ability to adapt at test time paves the way for a new form of user interaction with text-to-image models, beyond prompt engineering. +""", + unsafe_allow_html=True + ) + st.markdown("""

Click on a rewarded soup point on the left and select a prompt on the right!

""",unsafe_allow_html=True) + + files = [] + + with open("streamlit_app/data/imgen/data.pkl","rb") as f: + data = pickle.load(f) + with open("streamlit_app/data/imgen/data_images.pkl","rb") as f: + data_images = pickle.load(f) + + row_0_1,row_0_2 = st.columns([2,3]) + with row_0_1: + fig = plot_pareto(data) + onclick = plotly_events(fig, click_event=True) + with row_0_2: + option = st.selectbox('',data_images.keys()) + for i in range(11): + filename = f'https://github.com/continual-subspace/hidden_soup/blob/main/{data_images[option]["filename"]}_{i}.png?raw=true' + files.append(filename) + row_1_1,row_1_2,row_1_3 = st.columns([1,1,1]) + if len(onclick) > 0: + idx = onclick[-1]['pointIndex'] + else: + idx = 5 + img = files[idx] + bgcolor = interpolate_color(color2,color1,round(1 - idx/(len(files)-1),2)) + lambda2 = round(1 - idx/(len(files)-1),2) + + img1 = files[0] + img0 = files[-1] + + st.markdown( + f""" +
+
+ Generated images: +
+
+
+
位=0.0
+
{img0}
+
+
+
位={lambda2}
+
{img}
+
+
+
位=1.0
+
{img1}
+
+
+
+ """, + unsafe_allow_html=True + ) + + + +if __name__ == "__main__": + img = Image.open("streamlit_app/assets/images/icon.png") + st.set_page_config(page_title="Rewarded soups",page_icon=img,layout="wide") + inject_custom_css("streamlit_app/assets/styles.css") + st.set_option('deprecation.showPyplotGlobalUse', False) + run() diff --git "a/pages/03_\360\237\246\277_Locomotion.py" "b/pages/03_\360\237\246\277_Locomotion.py" new file mode 100644 index 0000000000000000000000000000000000000000..f0ec94bcf4dcaddcbe3abf8acd19f49150c7b030 --- /dev/null +++ "b/pages/03_\360\237\246\277_Locomotion.py" @@ -0,0 +1,216 @@ +import streamlit as st +from PIL import Image +import codecs +import streamlit.components.v1 as components +from utils import inject_custom_css +import streamlit as st +from streamlit_plotly_events import plotly_events +import pickle +import matplotlib.pyplot as plt +import plotly.graph_objects as go +import typing as tp + +plt.style.use('default') + +shapes=[ + dict( + type="rect", + xref="paper", + yref="paper", + x0=0, + y0=0, + x1=1, + y1=1, + line=dict( + color="Black", + width=2, + ), + ) +] + +import colorsys + +def interpolate_color(color1, color2, factor): + """Interpolates between two RGB colors. Factor is between 0 and 1.""" + color1 = colorsys.rgb_to_hls(int(color1[1:3], 16)/255.0, int(color1[3:5], 16)/255.0, int(color1[5:], 16)/255.0) + color2 = colorsys.rgb_to_hls(int(color2[1:3], 16)/255.0, int(color2[3:5], 16)/255.0, int(color2[5:], 16)/255.0) + new_color = [color1[i] * (1 - factor) + color2[i] * factor for i in range(3)] + new_color = colorsys.hls_to_rgb(*new_color) + return '#{:02x}{:02x}{:02x}'.format(int(new_color[0]*255), int(new_color[1]*255), int(new_color[2]*255)) + + +color1 = "#fa7659" +color2 = "#6dafd7" + +def plot_pareto(dict_results: tp.Dict): + keys = list(dict_results["wa"][0].keys()) + lambda_key, reward2_key, reward1_key = keys + + # Series for "wa" + dict_results["wa"] = [x for i,x in enumerate(dict_results["wa"]) if i%2==0] + lambda_values_wa = [item[lambda_key] for item in dict_results["wa"]][::-1] + reward1_values_wa = [item[reward1_key] for item in dict_results["wa"]][::-1] + reward2_values_wa = [item[reward2_key] for item in dict_results["wa"]][::-1] + + # Series for "init" + reward1_values_init = [item[reward1_key] for item in dict_results["init"]] + reward2_values_init = [item[reward2_key] for item in dict_results["init"]] + + layout = go.Layout(autosize=False,width=1000,height=1000) + fig = go.Figure(layout=layout) + + for i in range(len(reward1_values_wa) - 1): + fig.add_trace(go.Scatter( + x=reward1_values_wa[i:i+2], + y=reward2_values_wa[i:i+2], + mode='lines', + hoverinfo='skip', + line=dict( + color=interpolate_color(color1, color2, i/(len(reward1_values_wa)-1)), + width=2 + ), + showlegend=False + )) + + # Plot for "wa" + fig.add_trace( + go.Scatter( + x=reward1_values_wa, + y=reward2_values_wa, + mode='markers', + name='Rewarded soups: 0鈮の烩墹1', + hoverinfo='text', + hovertext=[f'位={lmbda}' for lmbda in lambda_values_wa], + marker=dict( + color=[ + interpolate_color(color1, color2, i / len(lambda_values_wa)) + for i in range(len(lambda_values_wa)) + ], + size=10 + ) + ) + ) + + # Plot for "morl" + fig.add_trace( + go.Scatter( + x=[6400.], + y=[3300.], + mode='markers', + name='MORL: 渭=0.5', + hoverinfo='skip', + marker=dict(color='#A45EE9', size=15, symbol="star"), + ) + ) + # Plot for "init" + fig.add_trace( + go.Scatter( + x=reward1_values_init, + y=reward2_values_init, + mode='markers', + name='Pre-trained init', + hoverinfo='skip', + marker=dict(color='#9f9bc8', size=15, symbol="star"), + ) + ) + + fig.update_layout( + xaxis=dict( + range=[3000, 7000], + nticks=6, + showticklabels=True, + ticks='outside', + tickfont=dict(size=18,), + title=dict(text="Risky reward", font=dict(size=18), standoff=10), + showgrid=False, + zeroline=False, + hoverformat='.2f' + ), + yaxis=dict( + range=[-1000, 4500], + nticks=7, + showticklabels=True, + ticks='outside', + tickfont=dict(size=18,), + title=dict(text="Cautious reward", font=dict(size=18), standoff=10), + showgrid=False, + zeroline=False, + hoverformat='.2f' + ), + font=dict(family="Roboto", size=12, color="Black"), + hovermode='x unified', + autosize=False, + width=500, + height=500, + margin=dict(l=100, r=50, b=150, t=20, pad=0), + paper_bgcolor="White", + plot_bgcolor="White", + shapes=shapes, + legend=dict( + x=0.5, + y=0.03, + traceorder="normal", + font=dict(family="Roboto", size=12, color="black"), + bgcolor="White", + bordercolor="Black", + borderwidth=1 + ) + ) + + return fig + +def run(): + + st.write( + f""" + + + +

Making humanoid run more naturally with diverse engineered rewards

""",unsafe_allow_html=True) + + st.markdown( + r""" +Teaching humanoids to walk in a human-like manner serves as a benchmark to evaluate RL strategies for continuous control. One of the key challenges is shaping a suitable proxy reward, given the intricate coordination and balance involved in human locomotion. It is standard to consider the dense reward at each timestep: ${r(t)=velocity-\alpha \times \sum_t a^{2}_{t}}$, controlling the agent's velocity while penalizing wide actions. Yet, the penalty coefficient $\alpha$ is challenging to set. To tackle this, we devised two rewards in the Brax physics engine: a *risky* one with $\alpha=0$, and a *cautious* one $\alpha=1$. + +Below in the interactive animation, you will see the humanoids trained with these two rewards: the humanoid for $\alpha=0$ is the fastest but the most chaotic, while the one for $\alpha=1$ is more cautious but slower. For intermediate values of $\lambda$, the policy is obtained by linear interpolation of those extreme weights, arguably resulting in smoother motion patterns. +""", unsafe_allow_html=True + ) + st.markdown("""

Click on a rewarded soup point!

""",unsafe_allow_html=True) + + files = [] + for i in range(21): + filename = f'streamlit_app/data/locomotion/trajectories/{i}.html' + files.append(codecs.open(filename, "r", "utf-8").read()) + files = [x for i,x in enumerate(files) if i%2==0] + + row_0_1,row_0_2,row_0_3,row_0_4 = st.columns([3,1,1,1]) + with row_0_1: + with open("streamlit_app/data/locomotion/pareto/humanoid_averse_taker_with_morl.pkl","rb") as f: + dict_results = pickle.load(f) + fig = plot_pareto(dict_results) + onclick = plotly_events(fig, click_event=True) + with row_0_4: + st.markdown(f"""
位=1.0
""",unsafe_allow_html=True) + components.html(files[-1],width=150,height=300) + with row_0_3: + if len(onclick) > 0: + idx = onclick[-1]['pointIndex'] + else: + idx = 5 + st.markdown( + f"""
位={round(1-idx/(len(files)-1),2)}
""", + unsafe_allow_html=True + ) + components.html(files[idx], width=150, height=300) + with row_0_2: + st.markdown(f"""
位=0.0
""",unsafe_allow_html=True) + components.html(files[0],width=150,height=300) + + +if __name__ == "__main__": + img = Image.open("streamlit_app/assets/images/icon.png") + st.set_page_config(page_title="Rewarded soups",page_icon=img,layout="wide") + inject_custom_css("streamlit_app/assets/styles.css") + st.set_option('deprecation.showPyplotGlobalUse', False) + run() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f24451d73f1996afc26d1bb546f1cf486d27b7d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +plotly==5.14.1 +matplotlib +streamlit-plotly-events \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea464f5cdff764979afe96c574b6483740f04d5 --- /dev/null +++ b/utils.py @@ -0,0 +1,5 @@ +import streamlit as st + +def inject_custom_css(file): + with open(file) as f: + st.markdown(f'', unsafe_allow_html=True) \ No newline at end of file