Spaces:
Runtime error
Runtime error
yesssssssss
commited on
Commit
·
1ccdd5a
1
Parent(s):
534982a
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +6 -7
- TabPFN/PrepareDatasets.ipynb +373 -0
- TabPFN/README.md +23 -0
- TabPFN/SyntheticGPAblation.ipynb +392 -0
- TabPFN/TabPFNPredictionOnly.ipynb +253 -0
- TabPFN/TabularEvaluationVisualization.ipynb +0 -0
- TabPFN/TrainingTuningAndPrediction.ipynb +0 -0
- TabPFN/__pycache__/encoders.cpython-37.pyc +0 -0
- TabPFN/__pycache__/layer.cpython-37.pyc +0 -0
- TabPFN/__pycache__/model_builder.cpython-37.pyc +0 -0
- TabPFN/__pycache__/notebook_utils.cpython-37.pyc +0 -0
- TabPFN/__pycache__/positional_encodings.cpython-37.pyc +0 -0
- TabPFN/__pycache__/train.cpython-37.pyc +0 -0
- TabPFN/__pycache__/transformer.cpython-37.pyc +0 -0
- TabPFN/__pycache__/utils.cpython-37.pyc +0 -0
- TabPFN/__pycache__/utils.cpython-38.pyc +0 -0
- TabPFN/datasets/__init__.py +149 -0
- TabPFN/datasets/utils.py +8 -0
- TabPFN/decoders.py +30 -0
- TabPFN/differentiable_pfn_evaluation.py +345 -0
- TabPFN/encoders.py +225 -0
- TabPFN/initializers.py +9 -0
- TabPFN/layer.py +125 -0
- TabPFN/losses.py +41 -0
- TabPFN/model_builder.py +273 -0
- TabPFN/models_diff/gp_ablation_model.cpkt +3 -0
- TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt +3 -0
- TabPFN/notebook_utils.py +32 -0
- TabPFN/positional_encodings.py +70 -0
- TabPFN/prior_tuning_result.pkl +3 -0
- TabPFN/priors/__init__.py +4 -0
- TabPFN/priors/__pycache__/__init__.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/__init__.cpython-38.pyc +0 -0
- TabPFN/priors/__pycache__/differentiable_prior.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/fast_gp.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/fast_gp.cpython-38.pyc +0 -0
- TabPFN/priors/__pycache__/flexible_categorical.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/mlp.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/prior.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/prior_bag.cpython-37.pyc +0 -0
- TabPFN/priors/__pycache__/utils.cpython-37.pyc +0 -0
- TabPFN/priors/differentiable_prior.py +293 -0
- TabPFN/priors/fast_gp.py +144 -0
- TabPFN/priors/flexible_categorical.py +240 -0
- TabPFN/priors/mlp.py +173 -0
- TabPFN/priors/prior.py +12 -0
- TabPFN/priors/prior_bag.py +32 -0
- TabPFN/priors/utils.py +163 -0
- TabPFN/requirements.txt +15 -0
.gitattributes
CHANGED
@@ -29,3 +29,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.cpkt filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.1.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: TabPFN
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.1.1
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
TabPFN/PrepareDatasets.ipynb
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import numpy as np\n",
|
10 |
+
"\n",
|
11 |
+
"import openml\n",
|
12 |
+
"import pandas as pd"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"from tqdm import tqdm\n",
|
22 |
+
"\n",
|
23 |
+
"from datasets import load_openml_list, test_dids_classification, valid_large_classification, open_cc_dids, open_cc_valid_dids\n"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 6,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [
|
31 |
+
{
|
32 |
+
"name": "stdout",
|
33 |
+
"output_type": "stream",
|
34 |
+
"text": [
|
35 |
+
"The autoreload extension is already loaded. To reload it, use:\n",
|
36 |
+
" %reload_ext autoreload\n"
|
37 |
+
]
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"source": [
|
41 |
+
"%load_ext autoreload\n",
|
42 |
+
"\n",
|
43 |
+
"%autoreload 2"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "markdown",
|
48 |
+
"metadata": {
|
49 |
+
"tags": []
|
50 |
+
},
|
51 |
+
"source": [
|
52 |
+
"### Prepare test datasets"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": 7,
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"renamer = {'name': 'Name', 'NumberOfFeatures': '# Features', 'NumberOfSymbolicFeatures': '# Categorical Features', 'NumberOfInstances': '# Instances', 'NumberOfMissingValues': '# NaNs', 'NumberOfClasses': '# Classes', 'MinorityClassSize': 'Minority Class Size'}\n"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 8,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [
|
69 |
+
{
|
70 |
+
"data": {
|
71 |
+
"text/plain": [
|
72 |
+
"OrderedDict([(99,\n",
|
73 |
+
" {'id': 99,\n",
|
74 |
+
" 'alias': 'OpenML-CC18',\n",
|
75 |
+
" 'main_entity_type': 'task',\n",
|
76 |
+
" 'name': 'OpenML-CC18 Curated Classification benchmark',\n",
|
77 |
+
" 'status': 'active',\n",
|
78 |
+
" 'creation_date': '2019-02-21 18:47:13',\n",
|
79 |
+
" 'creator': 1}),\n",
|
80 |
+
" (225,\n",
|
81 |
+
" {'id': 225,\n",
|
82 |
+
" 'alias': 'OpenML-friendly',\n",
|
83 |
+
" 'main_entity_type': 'task',\n",
|
84 |
+
" 'name': 'OpenML100-friendly',\n",
|
85 |
+
" 'status': 'active',\n",
|
86 |
+
" 'creation_date': '2019-09-16 19:41:46',\n",
|
87 |
+
" 'creator': 1})])"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"execution_count": 8,
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "execute_result"
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"openml.study.list_suites()"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 9,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"suite = openml.study.get_suite(suite_id=99)\n",
|
106 |
+
"tasks = openml.tasks.list_tasks(output_format=\"dataframe\")"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": 10,
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"# Using ``@`` in `pd.DataFrame.query <\n",
|
116 |
+
"# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_\n",
|
117 |
+
"# accesses variables outside of the current dataframe.\n",
|
118 |
+
"tasks = tasks.query(\"tid in @suite.tasks\")"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 11,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"tids = list(tasks[np.logical_and(np.logical_and((tasks.NumberOfInstances <= 2000), (tasks.NumberOfFeatures <= 100))\n",
|
128 |
+
" , (tasks.NumberOfClasses <= 10))].tid)"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 12,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [
|
136 |
+
{
|
137 |
+
"data": {
|
138 |
+
"text/plain": [
|
139 |
+
"30"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
"execution_count": 12,
|
143 |
+
"metadata": {},
|
144 |
+
"output_type": "execute_result"
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"len(tids)"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": 13,
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"tids = list(tasks[tasks.NumberOfInstances <= 2000].tid)"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "code",
|
162 |
+
"execution_count": 14,
|
163 |
+
"metadata": {},
|
164 |
+
"outputs": [],
|
165 |
+
"source": [
|
166 |
+
"open_cc_dids = [openml.tasks.get_task(task_id).get_dataset().id for task_id in tids]"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": null,
|
172 |
+
"outputs": [],
|
173 |
+
"source": [
|
174 |
+
"open_ml_datasets, open_ml_datasets_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 100000, num_feats=100, return_capped=True)\n"
|
175 |
+
],
|
176 |
+
"metadata": {
|
177 |
+
"collapsed": false,
|
178 |
+
"pycharm": {
|
179 |
+
"name": "#%%\n"
|
180 |
+
}
|
181 |
+
}
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 16,
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"open_ml_datasets_df = open_ml_datasets_df[open_ml_datasets_df.NumberOfInstances > 10000]"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": 17,
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [
|
197 |
+
{
|
198 |
+
"name": "stdout",
|
199 |
+
"output_type": "stream",
|
200 |
+
"text": [
|
201 |
+
"\\begin{tabular}{lrrrrrrr}\n",
|
202 |
+
"\\toprule\n",
|
203 |
+
" Name & \\# Features & \\# Categorical Features & \\# Instances & \\# Classes & \\# NaNs & Minority Class Size & id \\\\\n",
|
204 |
+
"\\midrule\n",
|
205 |
+
" KDDCup09\\_appetency & 231 & 39 & 50000 & 2 & 8024152 & 890 & 1111 \\\\\n",
|
206 |
+
" airlines & 8 & 5 & 539383 & 2 & 0 & 240264 & 1169 \\\\\n",
|
207 |
+
" bank-marketing & 17 & 10 & 45211 & 2 & 0 & 5289 & 1461 \\\\\n",
|
208 |
+
" nomao & 119 & 30 & 34465 & 2 & 0 & 9844 & 1486 \\\\\n",
|
209 |
+
" adult & 15 & 9 & 48842 & 2 & 6465 & 11687 & 1590 \\\\\n",
|
210 |
+
" covertype & 55 & 45 & 581012 & 7 & 0 & 2747 & 1596 \\\\\n",
|
211 |
+
" numerai28.6 & 22 & 1 & 96320 & 2 & 0 & 47662 & 23517 \\\\\n",
|
212 |
+
" connect-4 & 43 & 43 & 67557 & 3 & 0 & 6449 & 40668 \\\\\n",
|
213 |
+
"jungle\\_chess\\_2pcs\\_raw\\_endgame\\_complete & 7 & 1 & 44819 & 3 & 0 & 4335 & 41027 \\\\\n",
|
214 |
+
" APSFailure & 171 & 1 & 76000 & 2 & 1078695 & 1375 & 41138 \\\\\n",
|
215 |
+
" albert & 79 & 53 & 425240 & 2 & 2734000 & 212620 & 41147 \\\\\n",
|
216 |
+
" MiniBooNE & 51 & 1 & 130064 & 2 & 0 & 36499 & 41150 \\\\\n",
|
217 |
+
" guillermo & 4297 & 1 & 20000 & 2 & 0 & 8003 & 41159 \\\\\n",
|
218 |
+
" riccardo & 4297 & 1 & 20000 & 2 & 0 & 5000 & 41161 \\\\\n",
|
219 |
+
" volkert & 181 & 1 & 58310 & 10 & 0 & 1361 & 41166 \\\\\n",
|
220 |
+
" dionis & 61 & 1 & 416188 & 355 & 0 & 878 & 41167 \\\\\n",
|
221 |
+
" jannis & 55 & 1 & 83733 & 4 & 0 & 1687 & 41168 \\\\\n",
|
222 |
+
" helena & 28 & 1 & 65196 & 100 & 0 & 111 & 41169 \\\\\n",
|
223 |
+
"\\bottomrule\n",
|
224 |
+
"\\end{tabular}\n",
|
225 |
+
"\n"
|
226 |
+
]
|
227 |
+
}
|
228 |
+
],
|
229 |
+
"source": [
|
230 |
+
"print_table = open_ml_datasets_df\n",
|
231 |
+
"print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
|
232 |
+
"print_table['id'] = print_table.index\n",
|
233 |
+
"print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']] = print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].astype(int)\n",
|
234 |
+
"print_table = print_table.rename(columns=renamer)\n",
|
235 |
+
"print(print_table.to_latex(index=False))"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "markdown",
|
240 |
+
"metadata": {
|
241 |
+
"tags": []
|
242 |
+
},
|
243 |
+
"source": [
|
244 |
+
"### Prepare Validation datasets"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": null,
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"open_cc_datasets, open_cc_datasets_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 2000, num_feats=100, return_capped=True)\n",
|
253 |
+
"\n",
|
254 |
+
"def extend_datasets(datasets, filtering = False):\n",
|
255 |
+
" extended_datasets = {}\n",
|
256 |
+
" i = 0\n",
|
257 |
+
" for d in tqdm(datasets):\n",
|
258 |
+
" if ((not 'NumberOfFeatures' in datasets[d])\n",
|
259 |
+
" or (not 'NumberOfClasses' in datasets[d])\n",
|
260 |
+
" or (not 'NumberOfInstances' in datasets[d])\n",
|
261 |
+
" # or datasets[d]['NumberOfFeatures'] >= num_feats\n",
|
262 |
+
" or datasets[d]['NumberOfClasses'] <= 0):\n",
|
263 |
+
" print(datasets[d])\n",
|
264 |
+
" continue\n",
|
265 |
+
" ds = openml.datasets.get_dataset(d, download_data=False)\n",
|
266 |
+
" if filtering and (datasets[d]['NumberOfInstances'] < 150\n",
|
267 |
+
" or datasets[d]['NumberOfInstances'] > 2000\n",
|
268 |
+
" or datasets[d]['NumberOfFeatures'] > 100\n",
|
269 |
+
" or datasets[d]['NumberOfClasses'] > 10):\n",
|
270 |
+
" continue\n",
|
271 |
+
" extended_datasets[d] = datasets[d]\n",
|
272 |
+
" extended_datasets[d].update(ds.qualities)\n",
|
273 |
+
" \n",
|
274 |
+
" return extended_datasets\n",
|
275 |
+
"\n",
|
276 |
+
"# All datasets\n",
|
277 |
+
"openml_list = openml.datasets.list_datasets()\n",
|
278 |
+
"openml_list = pd.DataFrame.from_dict(openml_list, orient=\"index\")\n",
|
279 |
+
"\n",
|
280 |
+
"# Select only classification\n",
|
281 |
+
"openml_list = openml_list[~openml_list['MajorityClassSize'].isna()]\n",
|
282 |
+
"\n",
|
283 |
+
"# Remove duplicated datasets\n",
|
284 |
+
"duplicated = openml_list.duplicated(subset=['MajorityClassSize', 'MaxNominalAttDistinctValues', 'MinorityClassSize',\n",
|
285 |
+
" 'NumberOfClasses', 'NumberOfFeatures', 'NumberOfInstances',\n",
|
286 |
+
" 'NumberOfInstancesWithMissingValues', 'NumberOfMissingValues',\n",
|
287 |
+
" 'NumberOfNumericFeatures', 'NumberOfSymbolicFeatures'], keep='first')\n",
|
288 |
+
"openml_list = openml_list[~duplicated]\n",
|
289 |
+
"\n",
|
290 |
+
"duplicated = openml_list.duplicated(subset=['name'], keep='first')\n",
|
291 |
+
"openml_list = openml_list[~duplicated]\n",
|
292 |
+
"\n",
|
293 |
+
"# Filter out datasets that don't have meta information or Don't fulfill other criteria\n",
|
294 |
+
"openml_list = openml_list.to_dict(orient='index')\n",
|
295 |
+
"openml_list = pd.DataFrame.from_dict(extend_datasets(openml_list, filtering=True), orient=\"index\")\n",
|
296 |
+
"\n",
|
297 |
+
"# Filter out datasets in Open CC\n",
|
298 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: x in test_datasets_multiclass_df.name.values)]\n",
|
299 |
+
"openml_list['CFI'] = openml_list.apply(lambda x: str(x.NumberOfClasses) + '_' + str(x.NumberOfFeatures) + '_' + str(x.NumberOfInstances), axis = 1)\n",
|
300 |
+
"test_datasets_multiclass_df['CFI'] = test_datasets_multiclass_df.apply(lambda x: str(x.NumberOfClasses) + '_' + str(x.NumberOfFeatures) + '_' + str(x.NumberOfInstances), axis = 1)\n",
|
301 |
+
"openml_list = openml_list[~openml_list.CFI.apply(lambda x: x in test_datasets_multiclass_df.CFI.values)]\n",
|
302 |
+
"\n",
|
303 |
+
"# Remove time series and artificial data\n",
|
304 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'autoUniv' in x)]\n",
|
305 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'fri_' in x)]\n",
|
306 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'FOREX' in x)]\n",
|
307 |
+
"\n",
|
308 |
+
"# Remove datasets that overlapped with Open CC closely by name\n",
|
309 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'ilpd' in x)]\n",
|
310 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'car' in x)]\n",
|
311 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'pc1' in x)]\n",
|
312 |
+
"\n",
|
313 |
+
"# Remove datasets that didn't load\n",
|
314 |
+
"openml_list = openml_list[~openml_list.did.apply(lambda x: x in {1065, 40589, 41496, 770, 43097, 43148, 43255, 43595, 43786, 41701})]\n",
|
315 |
+
"\n",
|
316 |
+
"# Remove class skew\n",
|
317 |
+
"openml_list = openml_list[(openml_list.MinorityClassSize / openml_list.MajorityClassSize) > 0.05]\n",
|
318 |
+
"openml_list = openml_list[openml_list.AutoCorrelation != 1]\n",
|
319 |
+
"\n",
|
320 |
+
"# Remove too easy\n",
|
321 |
+
"openml_list = openml_list[openml_list.CfsSubsetEval_DecisionStumpAUC != 1]"
|
322 |
+
],
|
323 |
+
"metadata": {
|
324 |
+
"collapsed": false,
|
325 |
+
"pycharm": {
|
326 |
+
"name": "#%%\n"
|
327 |
+
}
|
328 |
+
}
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"metadata": {},
|
334 |
+
"outputs": [],
|
335 |
+
"source": [
|
336 |
+
"print_table = openml_list\n",
|
337 |
+
"print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
|
338 |
+
"print_table['id'] = print_table.index\n",
|
339 |
+
"print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']] = print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].astype(int)\n",
|
340 |
+
"print_table = print_table.rename(columns=renamer)\n",
|
341 |
+
"print(print_table.to_latex(index=False))"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"cell_type": "code",
|
346 |
+
"execution_count": null,
|
347 |
+
"metadata": {},
|
348 |
+
"outputs": [],
|
349 |
+
"source": []
|
350 |
+
}
|
351 |
+
],
|
352 |
+
"metadata": {
|
353 |
+
"kernelspec": {
|
354 |
+
"display_name": "Python 3 (ipykernel)",
|
355 |
+
"language": "python",
|
356 |
+
"name": "python3"
|
357 |
+
},
|
358 |
+
"language_info": {
|
359 |
+
"codemirror_mode": {
|
360 |
+
"name": "ipython",
|
361 |
+
"version": 3
|
362 |
+
},
|
363 |
+
"file_extension": ".py",
|
364 |
+
"mimetype": "text/x-python",
|
365 |
+
"name": "python",
|
366 |
+
"nbconvert_exporter": "python",
|
367 |
+
"pygments_lexer": "ipython3",
|
368 |
+
"version": "3.7.13"
|
369 |
+
}
|
370 |
+
},
|
371 |
+
"nbformat": 4,
|
372 |
+
"nbformat_minor": 4
|
373 |
+
}
|
TabPFN/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TabPFN
|
2 |
+
|
3 |
+
## Installation
|
4 |
+
```
|
5 |
+
git clone [email protected]:automl/TabPFN.git
|
6 |
+
cd TabPFN
|
7 |
+
conda create -n TabPFN python=3.7
|
8 |
+
conda activate TabPFN
|
9 |
+
pip install -r requirements.txt
|
10 |
+
```
|
11 |
+
|
12 |
+
To run the autogluon baseline please create a separate environment and install autogluon==0.4.0, installation in the same environment as our other baselines is not possible.
|
13 |
+
|
14 |
+
## Usage
|
15 |
+
TrainingTuningAndPrediction: Train a TabPFN, Prior Tune and predict using a pretrained model.
|
16 |
+
|
17 |
+
TabularEvaluationVisualization: Run Baselines and load Baseline and TabPFN Results for comparison and plotting.
|
18 |
+
|
19 |
+
PrepareDatasets: Notebook used to inspect Datasets (Not needed to run baselines / TabPFN).
|
20 |
+
|
21 |
+
SytheticGPAblation: Ablation experiments for Gaussian Process fitting with differentiable Hyper Parameters.
|
22 |
+
|
23 |
+
|
TabPFN/SyntheticGPAblation.ipynb
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"\n",
|
11 |
+
"%autoreload 2"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import os\n",
|
21 |
+
"import time\n",
|
22 |
+
"\n",
|
23 |
+
"import torch\n",
|
24 |
+
"\n",
|
25 |
+
"import numpy as np\n",
|
26 |
+
"\n",
|
27 |
+
"import matplotlib.pyplot as plt\n",
|
28 |
+
"\n",
|
29 |
+
"from model_builder import get_model, get_default_spec, save_model, load_model\n",
|
30 |
+
"\n",
|
31 |
+
"from scripts.model_configs import *"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"tags": []
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Setting params"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 6,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"device = 'cuda'\n",
|
50 |
+
"base_path = os.path.join('.')"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 7,
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"def train_function(config_sample, i, add_name=''):\n",
|
60 |
+
" start_time = time.time()\n",
|
61 |
+
" N_epochs_to_save = 50\n",
|
62 |
+
" \n",
|
63 |
+
" def save_callback(model, epoch):\n",
|
64 |
+
" if not hasattr(model, 'last_saved_epoch'):\n",
|
65 |
+
" model.last_saved_epoch = 0\n",
|
66 |
+
" if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:\n",
|
67 |
+
" print('Saving model..')\n",
|
68 |
+
" config_sample['epoch_in_training'] = epoch\n",
|
69 |
+
" save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',\n",
|
70 |
+
" config_sample)\n",
|
71 |
+
" model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint\n",
|
72 |
+
" \n",
|
73 |
+
" model = get_model(config_sample\n",
|
74 |
+
" , device\n",
|
75 |
+
" , should_train=True\n",
|
76 |
+
" , verbose=1\n",
|
77 |
+
" , epoch_callback = save_callback)\n",
|
78 |
+
" \n",
|
79 |
+
" return"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "markdown",
|
84 |
+
"metadata": {
|
85 |
+
"heading_collapsed": true,
|
86 |
+
"tags": []
|
87 |
+
},
|
88 |
+
"source": [
|
89 |
+
"# Check synthetic data fitting"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "markdown",
|
94 |
+
"metadata": {
|
95 |
+
"tags": []
|
96 |
+
},
|
97 |
+
"source": [
|
98 |
+
"#### Workflow functions"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": 8,
|
104 |
+
"metadata": {
|
105 |
+
"hidden": true,
|
106 |
+
"tags": []
|
107 |
+
},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"def generate_test_data(test_gp_params):\n",
|
111 |
+
" # Generate test data\n",
|
112 |
+
" config = {**test_gp_params}\n",
|
113 |
+
"\n",
|
114 |
+
" config['verbose'] = False\n",
|
115 |
+
" config['differentiable'] = False\n",
|
116 |
+
" #config['bptt'] = config['bptt_in_training']\n",
|
117 |
+
"\n",
|
118 |
+
" model_test_data = get_model(config, device, should_train=False, verbose=True)\n",
|
119 |
+
" (hp_embedding, data, targets_), targets = next(iter(model_test_data[3]))\n",
|
120 |
+
" (hp_embedding, data, targets_), targets = (hp_embedding, data.to(device), targets_.to(device)), targets.to(device)\n",
|
121 |
+
" \n",
|
122 |
+
" return (hp_embedding, data, targets_), targets\n",
|
123 |
+
"\n",
|
124 |
+
"def evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size):\n",
|
125 |
+
" losses, hparams = [], []\n",
|
126 |
+
" for l in np.arange(-1.74, 1.74, plot_step_size):\n",
|
127 |
+
" hparam = [*hparam_true]\n",
|
128 |
+
" hparam[vary_hparam_ind] = l\n",
|
129 |
+
" hp_embedding_used = torch.tensor(hparam).to(device).float()\n",
|
130 |
+
" with torch.inference_mode():\n",
|
131 |
+
" outputs = torch.sigmoid(model[2]((hp_embedding_used.repeat(data.shape[1], 1), data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
|
132 |
+
" \n",
|
133 |
+
" loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten()).detach().cpu()\n",
|
134 |
+
" losses += [loss]\n",
|
135 |
+
" hparam_real = [diff_hparams_f[i][1](hp) for i, hp in enumerate(hparam)]\n",
|
136 |
+
" hparams += [hparam_real]\n",
|
137 |
+
" \n",
|
138 |
+
" print(loss, hparam_real, hparam, outputs.shape)\n",
|
139 |
+
" return np.array(losses), np.array(hparams)"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 9,
|
145 |
+
"metadata": {},
|
146 |
+
"outputs": [],
|
147 |
+
"source": [
|
148 |
+
"def differentiable_hparam_tuning_workflow(config_sample, hparam_label, batch_size=4, N_grad_steps=50, plot_step_size=0.1):\n",
|
149 |
+
" test_gp_params = {\n",
|
150 |
+
" \"lengthscale\": 1.0,\n",
|
151 |
+
" #\"lengthscale_mean\": true_lengthscale,\n",
|
152 |
+
" #\"lengthscale_std\": 0.5,\n",
|
153 |
+
" \"noise\": 0.2,\n",
|
154 |
+
" \"outputscale\": 1.0,\n",
|
155 |
+
" 'batch_size': batch_size\n",
|
156 |
+
" }\n",
|
157 |
+
" config_sample.update(test_gp_params)\n",
|
158 |
+
" (hp_embedding, data, targets_), targets = generate_test_data(config_sample)\n",
|
159 |
+
" hparam_true = [diff_hparams_f[i][0](test_gp_params[hp]) for i, hp in enumerate(diff_hparams_keys)]\n",
|
160 |
+
" #hparam_true = [test_gp_params[hp] for i, hp in enumerate(diff_hparams_keys)]\n",
|
161 |
+
"\n",
|
162 |
+
" for vary_hparam_ind, vary_hparam_name in hparam_label:\n",
|
163 |
+
" print(vary_hparam_name)\n",
|
164 |
+
"\n",
|
165 |
+
" losses, hparams = evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size=plot_step_size)\n",
|
166 |
+
"\n",
|
167 |
+
" # TODO: Make only one parameter diffable\n",
|
168 |
+
" hparam = torch.tensor([*hparam_true]).to(device).float()\n",
|
169 |
+
" hparam[vary_hparam_ind] = hparam[vary_hparam_ind] + 0.1 #random.random() * 2 - 1\n",
|
170 |
+
" hparam = torch.nn.Parameter(hparam, requires_grad=True)\n",
|
171 |
+
" hparam_grad_mask = torch.zeros_like(hparam)\n",
|
172 |
+
" hparam_grad_mask[vary_hparam_ind] = 1\n",
|
173 |
+
"\n",
|
174 |
+
" optimizer = torch.optim.Adam([hparam], lr=0.1)\n",
|
175 |
+
" \n",
|
176 |
+
" for t in range(N_grad_steps):\n",
|
177 |
+
" style = hparam.repeat(data.shape[1], 1)\n",
|
178 |
+
" outputs = torch.sigmoid(model[2]((style, data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
|
179 |
+
" loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten())\n",
|
180 |
+
" optimizer.zero_grad()\n",
|
181 |
+
" loss.backward()\n",
|
182 |
+
" with torch.no_grad():\n",
|
183 |
+
" hparam.grad *= hparam_grad_mask\n",
|
184 |
+
" optimizer.step()\n",
|
185 |
+
" print('loss:', loss, 'hparams', diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind]), 'true', diff_hparams_f[vary_hparam_ind][1](hparam_true[vary_hparam_ind]))\n",
|
186 |
+
" inferred_param = diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind].cpu().detach().numpy())\n",
|
187 |
+
" return hparams, losses, inferred_param, vary_hparam_ind, hparam_true\n",
|
188 |
+
" "
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "markdown",
|
193 |
+
"metadata": {
|
194 |
+
"tags": []
|
195 |
+
},
|
196 |
+
"source": [
|
197 |
+
"#### Fitting a PFN with HP-Diffable GP Prior"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 10,
|
203 |
+
"metadata": {
|
204 |
+
"hidden": true,
|
205 |
+
"tags": []
|
206 |
+
},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"num_features = 5\n",
|
210 |
+
"bptt = 200\n",
|
211 |
+
"eval_positions = [100]\n",
|
212 |
+
"\n",
|
213 |
+
"config_general = get_general_config(num_features, bptt, eval_positions)\n",
|
214 |
+
"config_flexible_categorical = get_flexible_categorical_config(num_features)\n",
|
215 |
+
"\n",
|
216 |
+
"config_gp = {'noise': 0.2, \"lengthscale\": 1.0, \"outputscale\": 1.0}\n",
|
217 |
+
"config_diff_gp = {'differentiable_hyperparameters': {\n",
|
218 |
+
" 'outputscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
|
219 |
+
" 'lengthscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
|
220 |
+
" 'noise': {'distribution': 'uniform', 'min': 0.0000001, 'max': 0.5},\n",
|
221 |
+
" }\n",
|
222 |
+
"}\n",
|
223 |
+
"\n",
|
224 |
+
"config = {**config_general, **config_flexible_categorical, **config_diff_gp, **config_gp}\n",
|
225 |
+
"\n",
|
226 |
+
"config['prior_type'], config['differentiable'], config['flexible'] = 'gp', True, True\n",
|
227 |
+
"config['num_features'], config['num_features_used'] = num_features, num_features\n",
|
228 |
+
"config['epochs'], config['num_steps'], config['verbose'] = 500, 100, False\n",
|
229 |
+
"config[\"lr\"] = 0.00001\n",
|
230 |
+
"config[\"dropout\"] = 0\n",
|
231 |
+
"config[\"emsize\"] = 512\n",
|
232 |
+
"config[\"batch_size\"] = 128\n",
|
233 |
+
"config[\"aggregate_k_gradients\"] = 1\n",
|
234 |
+
"config['set_value_to_nan'] = 0.0\n",
|
235 |
+
"config['output_multiclass_ordered_p'] = 1.0\n",
|
236 |
+
"config['categorical_feature_p'] = 0.0\n",
|
237 |
+
"config['nan_prob_a_reason'] = 0.0\n",
|
238 |
+
"config['nan_prob_no_reason'] = 0.0\n",
|
239 |
+
"config['nan_prob_unknown_reason'] = 0.0\n",
|
240 |
+
"config[\"nlayers\"] = 8\n",
|
241 |
+
"\n",
|
242 |
+
"# TODO: This should not be sampled, but be one config\n",
|
243 |
+
"# TODO: This uses old hyperparam sampler throws error\n",
|
244 |
+
"config_sample = evaluate_hypers(config)"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": 11,
|
250 |
+
"metadata": {
|
251 |
+
"hidden": true,
|
252 |
+
"tags": []
|
253 |
+
},
|
254 |
+
"outputs": [
|
255 |
+
{
|
256 |
+
"name": "stdout",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"Using style prior: True\n",
|
260 |
+
"Using cpu:0 device\n",
|
261 |
+
"Not using distributed\n",
|
262 |
+
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 128, 'seq_len': 200, 'seq_len_maximum': 200, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 128, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 200, 'eval_positions': None, 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': 5, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.2, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'output_multiclass_ordered_p': 1.0, 'recompute_attn': False}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad8dcf80>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
|
263 |
+
"Using a Transformer with 17.35 M parameters\n"
|
264 |
+
]
|
265 |
+
}
|
266 |
+
],
|
267 |
+
"source": [
|
268 |
+
"device = 'cuda'\n",
|
269 |
+
"train_function(config_sample, 0, add_name='gp_experiments_diff_with_noise_no_meta_new')"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "markdown",
|
274 |
+
"metadata": {
|
275 |
+
"tags": []
|
276 |
+
},
|
277 |
+
"source": [
|
278 |
+
"#### Evaluating a PFN (with pretrained model)"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 13,
|
284 |
+
"metadata": {
|
285 |
+
"hidden": true,
|
286 |
+
"tags": []
|
287 |
+
},
|
288 |
+
"outputs": [
|
289 |
+
{
|
290 |
+
"name": "stdout",
|
291 |
+
"output_type": "stream",
|
292 |
+
"text": [
|
293 |
+
"Using style prior: True\n",
|
294 |
+
"Using cpu:0 device\n",
|
295 |
+
"Not using distributed\n",
|
296 |
+
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 1, 'seq_len': 10, 'seq_len_maximum': 10, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 1, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 10, 'eval_positions': [190], 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'output_multiclass_ordered_p': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'multiclass_type': 'rank', 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': <function load_model.<locals>.<lambda> at 0x7f39ad8534d0>, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.03, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'recompute_attn': False, 'bptt_extra_samples': None, 'epoch_in_training': 0.998, 'categorical_features_sampler': <function load_model.<locals>.<lambda> at 0x7f39ad853680>, 'num_features_used_in_training': 5, 'num_classes_in_training': 2, 'batch_size_in_training': 128, 'bptt_in_training': 200, 'bptt_extra_samples_in_training': None}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad81ab90>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
|
297 |
+
"Using a Transformer with 17.35 M parameters\n"
|
298 |
+
]
|
299 |
+
}
|
300 |
+
],
|
301 |
+
"source": [
|
302 |
+
"device = 'cpu'\n",
|
303 |
+
"model, c = load_model(base_path, f'models_diff/gp_ablation_model.cpkt', device, eval_positions, verbose=False)"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": 14,
|
309 |
+
"metadata": {},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"from priors.differentiable_prior import DifferentiableHyperparameterList\n",
|
313 |
+
"diff_list = DifferentiableHyperparameterList(c['differentiable_hyperparameters'], 512, device)\n",
|
314 |
+
"diff_hparams_keys, diff_hparams_f = diff_list.get_hyperparameter_info()"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"metadata": {
|
321 |
+
"tags": []
|
322 |
+
},
|
323 |
+
"outputs": [],
|
324 |
+
"source": [
|
325 |
+
"model[2].eval()\n",
|
326 |
+
"eval_pos = 100\n",
|
327 |
+
"\n",
|
328 |
+
"hparam_label = [(1, 'outputscale')]\n",
|
329 |
+
"hparam_label = [(0, 'lengthscale')]\n",
|
330 |
+
"hparam_label = [(2, 'noise')]\n",
|
331 |
+
"hparam_labels = [[(1, 'outputscale')], [(2, 'noise')], [(0, 'lengthscale')]]\n",
|
332 |
+
"#hparam_labels = [[(2, 'noise')]]\n",
|
333 |
+
"\n",
|
334 |
+
"hparams, losses, inferred_param, vary_hparam_ind, hparam_true = {}, {}, {}, {}, {}\n",
|
335 |
+
"\n",
|
336 |
+
"for hparam_label in hparam_labels:\n",
|
337 |
+
" (hparams[hparam_label[0][1]], losses[hparam_label[0][1]], inferred_param[hparam_label[0][1]], vary_hparam_ind[hparam_label[0][1]], \n",
|
338 |
+
" hparam_true[hparam_label[0][1]]) = differentiable_hparam_tuning_workflow(config_sample, \n",
|
339 |
+
" hparam_label=hparam_label, \n",
|
340 |
+
" batch_size=256, \n",
|
341 |
+
" N_grad_steps=50,\n",
|
342 |
+
" plot_step_size=0.05)\n"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": null,
|
348 |
+
"metadata": {},
|
349 |
+
"outputs": [],
|
350 |
+
"source": [
|
351 |
+
"label = 'lengthscale'\n",
|
352 |
+
"\n",
|
353 |
+
"#import tikzplotlib\n",
|
354 |
+
"\n",
|
355 |
+
"inferred = losses[label]\n",
|
356 |
+
"\n",
|
357 |
+
"plt.plot(hparams[label][:, vary_hparam_ind[label]], losses[label])\n",
|
358 |
+
"true = diff_hparams_f[vary_hparam_ind[label]][1](hparam_true[label][vary_hparam_ind[label]])\n",
|
359 |
+
"plt.axvline(x=inferred_param[label], linestyle='solid', color='red')\n",
|
360 |
+
"plt.axvline(x=true, linestyle='dashed')\n",
|
361 |
+
"\n",
|
362 |
+
"plt.ylabel('Cross entropy Loss')\n",
|
363 |
+
"plt.xlabel(label)\n",
|
364 |
+
"\n",
|
365 |
+
"#tikzplotlib.save(f'diff_inferred_params_{label}.tex', axis_height='5.2cm', axis_width='5.2cm', strict=True)\n",
|
366 |
+
"\n",
|
367 |
+
"plt.show()"
|
368 |
+
]
|
369 |
+
}
|
370 |
+
],
|
371 |
+
"metadata": {
|
372 |
+
"kernelspec": {
|
373 |
+
"display_name": "Python 3 (ipykernel)",
|
374 |
+
"language": "python",
|
375 |
+
"name": "python3"
|
376 |
+
},
|
377 |
+
"language_info": {
|
378 |
+
"codemirror_mode": {
|
379 |
+
"name": "ipython",
|
380 |
+
"version": 3
|
381 |
+
},
|
382 |
+
"file_extension": ".py",
|
383 |
+
"mimetype": "text/x-python",
|
384 |
+
"name": "python",
|
385 |
+
"nbconvert_exporter": "python",
|
386 |
+
"pygments_lexer": "ipython3",
|
387 |
+
"version": "3.7.13"
|
388 |
+
}
|
389 |
+
},
|
390 |
+
"nbformat": 4,
|
391 |
+
"nbformat_minor": 4
|
392 |
+
}
|
TabPFN/TabPFNPredictionOnly.ipynb
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"This notebook shows how to use TabPFN for tabular prediction with a scikit learn wrapper.\n",
|
8 |
+
"\n",
|
9 |
+
"classifier = TabPFNClassifier(device='cpu')\n",
|
10 |
+
"classifier.fit(train_xs, train_ys)\n",
|
11 |
+
"prediction_ = classifier.predict(test_xs)\n",
|
12 |
+
"\n",
|
13 |
+
"The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.\n",
|
14 |
+
"Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error."
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"metadata": {
|
20 |
+
"tags": []
|
21 |
+
},
|
22 |
+
"source": [
|
23 |
+
"### Setup"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"%load_ext autoreload\n",
|
33 |
+
"\n",
|
34 |
+
"%autoreload 2"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"import time\n",
|
44 |
+
"import torch\n",
|
45 |
+
"import numpy as np\n",
|
46 |
+
"import os\n",
|
47 |
+
"import random\n",
|
48 |
+
"\n",
|
49 |
+
"from model_builder import get_model, get_default_spec, save_model, load_model\n",
|
50 |
+
"from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier\n",
|
51 |
+
"\n",
|
52 |
+
"from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids\n",
|
53 |
+
"\n",
|
54 |
+
"from scripts import tabular_metrics"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"base_path = '.'"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "markdown",
|
68 |
+
"metadata": {
|
69 |
+
"tags": []
|
70 |
+
},
|
71 |
+
"source": [
|
72 |
+
"### Load datasets"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {
|
79 |
+
"jupyter": {
|
80 |
+
"outputs_hidden": true
|
81 |
+
},
|
82 |
+
"tags": []
|
83 |
+
},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"max_samples = 10000\n",
|
87 |
+
"bptt = 10000\n",
|
88 |
+
"\n",
|
89 |
+
"cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n",
|
90 |
+
"cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n",
|
91 |
+
"\n",
|
92 |
+
"# Loading longer OpenML Datasets for generalization experiments (optional)\n",
|
93 |
+
"# test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)\n",
|
94 |
+
"\n",
|
95 |
+
"random.seed(0)\n",
|
96 |
+
"random.shuffle(cc_valid_datasets_multiclass)"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"from datasets import get_openml_classification"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [],
|
113 |
+
"source": [
|
114 |
+
"dataset = openml.datasets.get_dataset(31)\n",
|
115 |
+
"X, y, categorical_indicator, attribute_names = dataset.get_data(\n",
|
116 |
+
" dataset_format=\"array\", target=dataset.default_target_attribute\n",
|
117 |
+
" )"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"def get_datasets(selector, task_type, suite='cc'):\n",
|
127 |
+
" if task_type == 'binary':\n",
|
128 |
+
" ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary\n",
|
129 |
+
" else:\n",
|
130 |
+
" if suite == 'openml':\n",
|
131 |
+
" ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass\n",
|
132 |
+
" elif suite == 'cc':\n",
|
133 |
+
" ds = cc_valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass\n",
|
134 |
+
" else:\n",
|
135 |
+
" raise Exception(\"Unknown suite\")\n",
|
136 |
+
" return ds"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": null,
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [],
|
144 |
+
"source": [
|
145 |
+
"model_string, longer, task_type = '', 1, 'multiclass'\n",
|
146 |
+
"eval_positions = [1000]\n",
|
147 |
+
"bptt = 2000\n",
|
148 |
+
" \n",
|
149 |
+
"test_datasets, valid_datasets = get_datasets('test', task_type, suite='cc'), get_datasets('valid', task_type, suite='cc')"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "markdown",
|
154 |
+
"metadata": {
|
155 |
+
"jp-MarkdownHeadingCollapsed": true,
|
156 |
+
"tags": []
|
157 |
+
},
|
158 |
+
"source": [
|
159 |
+
"### Select a dataset for prediction"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"[(i, test_datasets[i][0]) for i in range(len(test_datasets))]"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"evaluation_dataset_index = 4 # Index of the dataset to predict\n",
|
178 |
+
"ds = test_datasets[evaluation_dataset_index]\n",
|
179 |
+
"print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": null,
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [],
|
187 |
+
"source": [
|
188 |
+
"xs, ys = ds[1].clone(), ds[2].clone()\n",
|
189 |
+
"eval_position = xs.shape[0] // 2\n",
|
190 |
+
"train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]\n",
|
191 |
+
"test_xs, test_ys = xs[eval_position:], ys[eval_position:]"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "markdown",
|
196 |
+
"metadata": {
|
197 |
+
"tags": []
|
198 |
+
},
|
199 |
+
"source": [
|
200 |
+
"### Predict using a Fitted and Tuned Model"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"classifier = TabPFNClassifier(device='cpu')\n",
|
210 |
+
"classifier.fit(train_xs, train_ys)\n",
|
211 |
+
"prediction_ = classifier.predict_proba(test_xs)"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": null,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [],
|
219 |
+
"source": [
|
220 |
+
"roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)\n",
|
221 |
+
"'AUC', float(roc), 'Cross Entropy', float(ce)"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": null,
|
227 |
+
"metadata": {},
|
228 |
+
"outputs": [],
|
229 |
+
"source": []
|
230 |
+
}
|
231 |
+
],
|
232 |
+
"metadata": {
|
233 |
+
"kernelspec": {
|
234 |
+
"display_name": "Python 3 (ipykernel)",
|
235 |
+
"language": "python",
|
236 |
+
"name": "python3"
|
237 |
+
},
|
238 |
+
"language_info": {
|
239 |
+
"codemirror_mode": {
|
240 |
+
"name": "ipython",
|
241 |
+
"version": 3
|
242 |
+
},
|
243 |
+
"file_extension": ".py",
|
244 |
+
"mimetype": "text/x-python",
|
245 |
+
"name": "python",
|
246 |
+
"nbconvert_exporter": "python",
|
247 |
+
"pygments_lexer": "ipython3",
|
248 |
+
"version": "3.7.13"
|
249 |
+
}
|
250 |
+
},
|
251 |
+
"nbformat": 4,
|
252 |
+
"nbformat_minor": 4
|
253 |
+
}
|
TabPFN/TabularEvaluationVisualization.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
TabPFN/TrainingTuningAndPrediction.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
TabPFN/__pycache__/encoders.cpython-37.pyc
ADDED
Binary file (9.49 kB). View file
|
|
TabPFN/__pycache__/layer.cpython-37.pyc
ADDED
Binary file (4.49 kB). View file
|
|
TabPFN/__pycache__/model_builder.cpython-37.pyc
ADDED
Binary file (9.77 kB). View file
|
|
TabPFN/__pycache__/notebook_utils.cpython-37.pyc
ADDED
Binary file (1.47 kB). View file
|
|
TabPFN/__pycache__/positional_encodings.cpython-37.pyc
ADDED
Binary file (2.95 kB). View file
|
|
TabPFN/__pycache__/train.cpython-37.pyc
ADDED
Binary file (12 kB). View file
|
|
TabPFN/__pycache__/transformer.cpython-37.pyc
ADDED
Binary file (7.99 kB). View file
|
|
TabPFN/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (10.1 kB). View file
|
|
TabPFN/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (10.2 kB). View file
|
|
TabPFN/datasets/__init__.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import openml
|
5 |
+
|
6 |
+
|
7 |
+
def get_openml_classification(did, max_samples, multiclass=True, shuffled=True):
|
8 |
+
dataset = openml.datasets.get_dataset(did)
|
9 |
+
X, y, categorical_indicator, attribute_names = dataset.get_data(
|
10 |
+
dataset_format="array", target=dataset.default_target_attribute
|
11 |
+
)
|
12 |
+
|
13 |
+
if not multiclass:
|
14 |
+
X = X[y < 2]
|
15 |
+
y = y[y < 2]
|
16 |
+
|
17 |
+
if multiclass and not shuffled:
|
18 |
+
raise NotImplementedError("This combination of multiclass and shuffling isn't implemented")
|
19 |
+
|
20 |
+
if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
|
21 |
+
print('Not a NP Array, skipping')
|
22 |
+
return None, None, None, None
|
23 |
+
|
24 |
+
if not shuffled:
|
25 |
+
sort = np.argsort(y) if y.mean() < 0.5 else np.argsort(-y)
|
26 |
+
pos = int(y.sum()) if y.mean() < 0.5 else int((1 - y).sum())
|
27 |
+
X, y = X[sort][-pos * 2:], y[sort][-pos * 2:]
|
28 |
+
y = torch.tensor(y).reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).float()
|
29 |
+
X = torch.tensor(X).reshape(2, -1, X.shape[1]).transpose(0, 1).reshape(-1, X.shape[1]).flip([0]).float()
|
30 |
+
else:
|
31 |
+
order = np.arange(y.shape[0])
|
32 |
+
np.random.seed(13)
|
33 |
+
np.random.shuffle(order)
|
34 |
+
X, y = torch.tensor(X[order]), torch.tensor(y[order])
|
35 |
+
if max_samples:
|
36 |
+
X, y = X[:max_samples], y[:max_samples]
|
37 |
+
|
38 |
+
return X, y, list(np.where(categorical_indicator)[0]), attribute_names
|
39 |
+
|
40 |
+
def load_openml_list(dids, filter_for_nan=False
|
41 |
+
, num_feats=100
|
42 |
+
, min_samples = 100
|
43 |
+
, max_samples=400
|
44 |
+
, multiclass=True
|
45 |
+
, max_num_classes=10
|
46 |
+
, shuffled=True
|
47 |
+
, return_capped = False):
|
48 |
+
datasets = []
|
49 |
+
openml_list = openml.datasets.list_datasets(dids)
|
50 |
+
print(f'Number of datasets: {len(openml_list)}')
|
51 |
+
|
52 |
+
datalist = pd.DataFrame.from_dict(openml_list, orient="index")
|
53 |
+
if filter_for_nan:
|
54 |
+
datalist = datalist[datalist['NumberOfInstancesWithMissingValues'] == 0]
|
55 |
+
print(f'Number of datasets after Nan and feature number filtering: {len(datalist)}')
|
56 |
+
|
57 |
+
for ds in datalist.index:
|
58 |
+
modifications = {'samples_capped': False, 'classes_capped': False, 'feats_capped': False}
|
59 |
+
entry = datalist.loc[ds]
|
60 |
+
|
61 |
+
print('Loading', entry['name'], entry.did, '..')
|
62 |
+
|
63 |
+
if entry['NumberOfClasses'] == 0.0:
|
64 |
+
raise Exception("Regression not supported")
|
65 |
+
#X, y, categorical_feats, attribute_names = get_openml_regression(int(entry.did), max_samples)
|
66 |
+
else:
|
67 |
+
X, y, categorical_feats, attribute_names = get_openml_classification(int(entry.did), max_samples
|
68 |
+
, multiclass=multiclass, shuffled=shuffled)
|
69 |
+
if X is None:
|
70 |
+
continue
|
71 |
+
|
72 |
+
if X.shape[1] > num_feats:
|
73 |
+
if return_capped:
|
74 |
+
X = X[:, 0:num_feats]
|
75 |
+
categorical_feats = [c for c in categorical_feats if c < num_feats]
|
76 |
+
modifications['feats_capped'] = True
|
77 |
+
else:
|
78 |
+
print('Too many features')
|
79 |
+
continue
|
80 |
+
if X.shape[0] == max_samples:
|
81 |
+
modifications['samples_capped'] = True
|
82 |
+
|
83 |
+
if X.shape[0] < min_samples:
|
84 |
+
print(f'Too few samples left')
|
85 |
+
continue
|
86 |
+
|
87 |
+
if len(np.unique(y)) > max_num_classes:
|
88 |
+
if return_capped:
|
89 |
+
X = X[y < np.unique(y)[10]]
|
90 |
+
y = y[y < np.unique(y)[10]]
|
91 |
+
modifications['classes_capped'] = True
|
92 |
+
else:
|
93 |
+
print(f'Too many classes')
|
94 |
+
continue
|
95 |
+
|
96 |
+
datasets += [[entry['name'], X, y, categorical_feats, attribute_names, modifications]]
|
97 |
+
|
98 |
+
return datasets, datalist
|
99 |
+
|
100 |
+
|
101 |
+
# Classification
|
102 |
+
valid_dids_classification = [13, 59, 4, 15, 40710, 43, 1498]
|
103 |
+
test_dids_classification = [973, 1596, 40981, 1468, 40984, 40975, 41163, 41147, 1111, 41164, 1169, 1486, 41143, 1461, 41167, 40668, 41146, 41169, 41027, 23517, 41165, 41161, 41159, 41138, 1590, 41166, 1464, 41168, 41150, 1489, 41142, 3, 12, 31, 54, 1067]
|
104 |
+
valid_large_classification = [ 943, 23512, 49, 838, 1131, 767, 1142, 748, 1112,
|
105 |
+
1541, 384, 912, 1503, 796, 20, 30, 903, 4541,
|
106 |
+
961, 805, 1000, 4135, 1442, 816, 1130, 906, 1511,
|
107 |
+
184, 181, 137, 1452, 1481, 949, 449, 50, 913,
|
108 |
+
1071, 831, 843, 9, 896, 1532, 311, 39, 451,
|
109 |
+
463, 382, 778, 474, 737, 1162, 1538, 820, 188,
|
110 |
+
452, 1156, 37, 957, 911, 1508, 1054, 745, 1220,
|
111 |
+
763, 900, 25, 387, 38, 757, 1507, 396, 4153,
|
112 |
+
806, 779, 746, 1037, 871, 717, 1480, 1010, 1016,
|
113 |
+
981, 1547, 1002, 1126, 1459, 846, 837, 1042, 273,
|
114 |
+
1524, 375, 1018, 1531, 1458, 6332, 1546, 1129, 679,
|
115 |
+
389]
|
116 |
+
|
117 |
+
open_cc_dids = [11,
|
118 |
+
14,
|
119 |
+
15,
|
120 |
+
16,
|
121 |
+
18,
|
122 |
+
22,
|
123 |
+
23,
|
124 |
+
29,
|
125 |
+
31,
|
126 |
+
37,
|
127 |
+
50,
|
128 |
+
54,
|
129 |
+
188,
|
130 |
+
458,
|
131 |
+
469,
|
132 |
+
1049,
|
133 |
+
1050,
|
134 |
+
1063,
|
135 |
+
1068,
|
136 |
+
1510,
|
137 |
+
1494,
|
138 |
+
1480,
|
139 |
+
1462,
|
140 |
+
1464,
|
141 |
+
6332,
|
142 |
+
23381,
|
143 |
+
40966,
|
144 |
+
40982,
|
145 |
+
40994,
|
146 |
+
40975]
|
147 |
+
# Filtered by N_samples < 2000, N feats < 100, N classes < 10
|
148 |
+
|
149 |
+
open_cc_valid_dids = [13,25,35,40,41,43,48,49,51,53,55,56,59,61,187,285,329,333,334,335,336,337,338,377,446,450,451,452,460,463,464,466,470,475,481,679,694,717,721,724,733,738,745,747,748,750,753,756,757,764,765,767,774,778,786,788,795,796,798,801,802,810,811,814,820,825,826,827,831,839,840,841,844,852,853,854,860,880,886,895,900,906,907,908,909,915,925,930,931,934,939,940,941,949,966,968,984,987,996,1048,1054,1071,1073,1100,1115,1412,1442,1443,1444,1446,1447,1448,1451,1453,1488,1490,1495,1498,1499,1506,1508,1511,1512,1520,1523,4153,23499,40496,40646,40663,40669,40680,40682,40686,40690,40693,40705,40706,40710,40711,40981,41430,41538,41919,41976,42172,42261,42544,42585,42638]
|
TabPFN/datasets/utils.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def normalize_data(eval_xs):
|
2 |
+
mean = eval_xs.mean(0)
|
3 |
+
std = eval_xs.std(0) + .000001
|
4 |
+
eval_xs = (eval_xs - mean) / std
|
5 |
+
|
6 |
+
return eval_xs
|
7 |
+
|
8 |
+
|
TabPFN/decoders.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
class ScaledDecoder(nn.Module):
|
7 |
+
def __init__(self, ninp, nhid, nout):
|
8 |
+
super().__init__()
|
9 |
+
self.linear = nn.Linear(ninp, nhid)
|
10 |
+
self.linear1 = nn.Linear(nhid, nout)
|
11 |
+
self.linear2 = nn.Linear(nhid, 10)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
#return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
15 |
+
x = self.linear(x)
|
16 |
+
x = nn.GELU()(x)
|
17 |
+
temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
|
18 |
+
if random.random() > .99:
|
19 |
+
print(temps.shape,temps[:,:2])
|
20 |
+
return self.linear1(x) / temps.unsqueeze(-1)
|
21 |
+
|
22 |
+
class FixedScaledDecoder(nn.Module):
|
23 |
+
def __init__(self, ninp, nhid, nout):
|
24 |
+
super().__init__()
|
25 |
+
self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
|
26 |
+
self.T = nn.Parameter(torch.ones(10000)/10000)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.mapper(x)/self.T.sum()
|
30 |
+
|
TabPFN/differentiable_pfn_evaluation.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import pickle
|
6 |
+
from scripts import tabular_metrics
|
7 |
+
from scripts.tabular_metrics import calculate_score_per_method
|
8 |
+
from scripts.tabular_evaluation import evaluate
|
9 |
+
from priors.differentiable_prior import draw_random_style
|
10 |
+
from tqdm import tqdm
|
11 |
+
import random
|
12 |
+
from scripts.transformer_prediction_interface import get_params_from_config, load_model_workflow
|
13 |
+
|
14 |
+
"""
|
15 |
+
===============================
|
16 |
+
PUBLIC FUNCTIONS FOR EVALUATION
|
17 |
+
===============================
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
def eval_model_range(i_range, *args, **kwargs):
|
22 |
+
for i in i_range:
|
23 |
+
eval_model(i, *args, **kwargs)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
|
28 |
+
bptt_valid,
|
29 |
+
bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
|
30 |
+
"""
|
31 |
+
Differentiable model evaliation workflow. Evaluates and saves results to disk.
|
32 |
+
|
33 |
+
:param i:
|
34 |
+
:param e:
|
35 |
+
:param valid_datasets:
|
36 |
+
:param test_datasets:
|
37 |
+
:param train_datasets:
|
38 |
+
:param eval_positions_valid:
|
39 |
+
:param eval_positions_test:
|
40 |
+
:param bptt_valid:
|
41 |
+
:param bptt_test:
|
42 |
+
:param add_name:
|
43 |
+
:param base_path:
|
44 |
+
:param device:
|
45 |
+
:param eval_addition:
|
46 |
+
:param extra_tuning_args:
|
47 |
+
:return:
|
48 |
+
"""
|
49 |
+
model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
|
50 |
+
params = {'bptt': bptt_valid
|
51 |
+
, 'bptt_final': bptt_test
|
52 |
+
, 'eval_positions': eval_positions_valid
|
53 |
+
, 'eval_positions_test': eval_positions_test
|
54 |
+
, 'valid_datasets': valid_datasets
|
55 |
+
, 'test_datasets': test_datasets
|
56 |
+
, 'train_datasets': train_datasets
|
57 |
+
, 'verbose': True
|
58 |
+
, 'device': device
|
59 |
+
}
|
60 |
+
|
61 |
+
params.update(get_params_from_config(c))
|
62 |
+
|
63 |
+
start = time.time()
|
64 |
+
metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
|
65 |
+
**extra_tuning_args)
|
66 |
+
print('Evaluation time: ', time.time() - start)
|
67 |
+
|
68 |
+
print(results_file)
|
69 |
+
r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
|
70 |
+
with open(results_file, 'wb') as output:
|
71 |
+
del r[0]['num_features_used']
|
72 |
+
del r[0]['categorical_features_sampler']
|
73 |
+
pickle.dump(r, output)
|
74 |
+
|
75 |
+
_, _, _, style, temperature, _ = r
|
76 |
+
|
77 |
+
return r, model
|
78 |
+
|
79 |
+
"""
|
80 |
+
===============================
|
81 |
+
INTERNAL HELPER FUNCTIONS
|
82 |
+
===============================
|
83 |
+
"""
|
84 |
+
|
85 |
+
def evaluate_differentiable_model(model
|
86 |
+
, valid_datasets
|
87 |
+
, test_datasets
|
88 |
+
, train_datasets
|
89 |
+
, N_draws=100
|
90 |
+
, N_grad_steps=10
|
91 |
+
, eval_positions=None
|
92 |
+
, eval_positions_test=None
|
93 |
+
, bptt=100
|
94 |
+
, bptt_final=200
|
95 |
+
, style=None
|
96 |
+
, n_parallel_configurations=1
|
97 |
+
, device='cpu'
|
98 |
+
, selection_metric='auc'
|
99 |
+
, final_splits=[1, 2, 3, 4, 5]
|
100 |
+
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
|
101 |
+
, **kwargs):
|
102 |
+
"""
|
103 |
+
Evaluation function for diffable model evaluation. Returns a list of results.
|
104 |
+
|
105 |
+
:param model:
|
106 |
+
:param valid_datasets:
|
107 |
+
:param test_datasets:
|
108 |
+
:param train_datasets:
|
109 |
+
:param N_draws:
|
110 |
+
:param N_grad_steps:
|
111 |
+
:param eval_positions:
|
112 |
+
:param eval_positions_test:
|
113 |
+
:param bptt:
|
114 |
+
:param bptt_final:
|
115 |
+
:param style:
|
116 |
+
:param n_parallel_configurations:
|
117 |
+
:param device:
|
118 |
+
:param selection_metric:
|
119 |
+
:param final_splits:
|
120 |
+
:param N_ensemble_configurations_list:
|
121 |
+
:param kwargs:
|
122 |
+
:return:
|
123 |
+
"""
|
124 |
+
torch.manual_seed(0)
|
125 |
+
np.random.seed(0)
|
126 |
+
random.seed(0)
|
127 |
+
|
128 |
+
diffable_metric = tabular_metrics.cross_entropy
|
129 |
+
evaluation_metric = tabular_metrics.auc_metric
|
130 |
+
if selection_metric in ('auc', 'roc'):
|
131 |
+
selection_metric_min_max = 'max'
|
132 |
+
selection_metric = tabular_metrics.auc_metric
|
133 |
+
evaluation_metric = selection_metric
|
134 |
+
elif selection_metric in ('ce', 'selection_metric'):
|
135 |
+
selection_metric_min_max = 'min'
|
136 |
+
selection_metric = tabular_metrics.cross_entropy
|
137 |
+
evaluation_metric = selection_metric
|
138 |
+
|
139 |
+
print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
|
140 |
+
evaluation_metric)
|
141 |
+
print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
|
142 |
+
print('eval_positions', eval_positions)
|
143 |
+
|
144 |
+
def evaluate_valid(style, softmax_temperature, results, results_tracked):
|
145 |
+
result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
|
146 |
+
return_tensor=False, inference_mode=True, selection_metric=selection_metric,
|
147 |
+
evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
|
148 |
+
result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
|
149 |
+
results += [result_valid]
|
150 |
+
results_tracked += [np.nanmean(result_valid)]
|
151 |
+
|
152 |
+
model[2].to(device)
|
153 |
+
model[2].eval()
|
154 |
+
|
155 |
+
results_on_valid, results_on_valid_tracked = [], []
|
156 |
+
best_style, best_softmax_temperature = style, torch.cat(
|
157 |
+
[torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
|
158 |
+
optimization_routes = []
|
159 |
+
|
160 |
+
best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
161 |
+
0)
|
162 |
+
best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
163 |
+
0)
|
164 |
+
|
165 |
+
|
166 |
+
for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
|
167 |
+
style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
168 |
+
0)
|
169 |
+
softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
170 |
+
0)
|
171 |
+
|
172 |
+
evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
|
173 |
+
|
174 |
+
print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
|
175 |
+
|
176 |
+
if N_grad_steps > 0:
|
177 |
+
gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
|
178 |
+
, softmax_temperature=softmax_temperature
|
179 |
+
, model=model[2]
|
180 |
+
, train_datasets=train_datasets
|
181 |
+
, valid_datasets=valid_datasets
|
182 |
+
, selection_metric_min_max=selection_metric_min_max
|
183 |
+
, **kwargs)
|
184 |
+
optimization_routes += [gradient_optimize_result['optimization_route']]
|
185 |
+
|
186 |
+
evaluate_valid(gradient_optimize_result['best_style']
|
187 |
+
, gradient_optimize_result['best_temperature']
|
188 |
+
, results_on_valid, results_on_valid_tracked)
|
189 |
+
|
190 |
+
print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
|
191 |
+
|
192 |
+
if selection_metric_min_max == 'min':
|
193 |
+
is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
|
194 |
+
else:
|
195 |
+
is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
|
196 |
+
|
197 |
+
if is_best or best_style is None:
|
198 |
+
best_style = gradient_optimize_result['best_style'].clone()
|
199 |
+
best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
|
200 |
+
torch.cuda.empty_cache()
|
201 |
+
|
202 |
+
def final_evaluation():
|
203 |
+
print('Running eval dataset with final params (no gradients)..')
|
204 |
+
print(best_style, best_softmax_temperature)
|
205 |
+
result_test = []
|
206 |
+
for N_ensemble_configurations in N_ensemble_configurations_list:
|
207 |
+
print(f'Running with {N_ensemble_configurations} ensemble_configurations')
|
208 |
+
kwargs['N_ensemble_configurations'] = N_ensemble_configurations
|
209 |
+
splits = []
|
210 |
+
for split in final_splits:
|
211 |
+
splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
|
212 |
+
, return_tensor=False, eval_positions=eval_positions_test,
|
213 |
+
bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
|
214 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
|
215 |
+
result_test += [splits]
|
216 |
+
|
217 |
+
print('Running valid dataset with final params (no gradients)..')
|
218 |
+
result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
|
219 |
+
, return_tensor=False, eval_positions=eval_positions_test,
|
220 |
+
bptt=bptt_final, inference_mode=True, model=model[2]
|
221 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)
|
222 |
+
|
223 |
+
return result_test, result_valid
|
224 |
+
|
225 |
+
result_test, result_valid = final_evaluation()
|
226 |
+
|
227 |
+
return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes
|
228 |
+
|
229 |
+
|
230 |
+
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
|
231 |
+
def step():
|
232 |
+
return evaluate(datasets=ds,
|
233 |
+
method='transformer'
|
234 |
+
, overwrite=True
|
235 |
+
, style=used_style
|
236 |
+
, eval_positions=eval_positions
|
237 |
+
, metric_used=selection_metric
|
238 |
+
, save=False
|
239 |
+
, path_interfix=None
|
240 |
+
, base_path=None
|
241 |
+
, verbose=True
|
242 |
+
, **kwargs)
|
243 |
+
|
244 |
+
if return_tensor:
|
245 |
+
r = step()
|
246 |
+
else:
|
247 |
+
with torch.no_grad():
|
248 |
+
r = step()
|
249 |
+
|
250 |
+
calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
|
251 |
+
calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')
|
252 |
+
|
253 |
+
return r
|
254 |
+
|
255 |
+
|
256 |
+
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
|
257 |
+
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
|
258 |
+
"""
|
259 |
+
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
|
260 |
+
|
261 |
+
:param model:
|
262 |
+
:param init_style:
|
263 |
+
:param steps:
|
264 |
+
:param learning_rate:
|
265 |
+
:param softmax_temperature:
|
266 |
+
:param train_datasets:
|
267 |
+
:param valid_datasets:
|
268 |
+
:param optimize_all:
|
269 |
+
:param limit_style:
|
270 |
+
:param N_datasets_sampled:
|
271 |
+
:param optimize_softmax_temperature:
|
272 |
+
:param selection_metric_min_max:
|
273 |
+
:param kwargs:
|
274 |
+
:return:
|
275 |
+
"""
|
276 |
+
grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)
|
277 |
+
|
278 |
+
best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
|
279 |
+
softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
|
280 |
+
variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
|
281 |
+
optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)
|
282 |
+
|
283 |
+
optimization_route_selection, optimization_route_diffable = [], []
|
284 |
+
optimization_route_selection_valid, optimization_route_diffable_valid = [], []
|
285 |
+
|
286 |
+
def eval_opt(ds, return_tensor=True, inference_mode=False):
|
287 |
+
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
|
288 |
+
, inference_mode=inference_mode, model=model[2], **kwargs)
|
289 |
+
|
290 |
+
diffable_metric = result['mean_metric']
|
291 |
+
selection_metric = result['mean_select']
|
292 |
+
|
293 |
+
return diffable_metric, selection_metric
|
294 |
+
|
295 |
+
def eval_all_datasets(datasets, propagate=True):
|
296 |
+
selection_metrics_this_step, diffable_metrics_this_step = [], []
|
297 |
+
for ds in datasets:
|
298 |
+
diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
|
299 |
+
if not torch.isnan(diffable_metric_train).any():
|
300 |
+
if propagate and diffable_metric_train.requires_grad == True:
|
301 |
+
diffable_metric_train.backward()
|
302 |
+
selection_metrics_this_step += [selection_metric_train]
|
303 |
+
diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
|
304 |
+
diffable_metric_train = np.nanmean(diffable_metrics_this_step)
|
305 |
+
selection_metric_train = np.nanmean(selection_metrics_this_step)
|
306 |
+
|
307 |
+
return diffable_metric_train, selection_metric_train
|
308 |
+
|
309 |
+
for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
|
310 |
+
optimizer.zero_grad()
|
311 |
+
|
312 |
+
# Select subset of datasets
|
313 |
+
random.seed(t)
|
314 |
+
train_datasets_ = random.sample(train_datasets, N_datasets_sampled)
|
315 |
+
|
316 |
+
# Get score on train
|
317 |
+
diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
|
318 |
+
optimization_route_selection += [float(selection_metric_train)]
|
319 |
+
optimization_route_diffable += [float(diffable_metric_train)]
|
320 |
+
|
321 |
+
# Get score on valid
|
322 |
+
diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
|
323 |
+
optimization_route_selection_valid += [float(selection_metric_valid)]
|
324 |
+
optimization_route_diffable_valid += [float(diffable_metric_valid)]
|
325 |
+
|
326 |
+
is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
|
327 |
+
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
|
328 |
+
if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
|
329 |
+
print('New best', best_selection_metric, selection_metric_valid)
|
330 |
+
best_style = grad_style.detach().clone()
|
331 |
+
best_temperature = softmax_temperature.detach().clone()
|
332 |
+
best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid
|
333 |
+
|
334 |
+
optimizer.step()
|
335 |
+
|
336 |
+
if limit_style:
|
337 |
+
grad_style = grad_style.detach().clamp(-1.74, 1.74)
|
338 |
+
|
339 |
+
print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
|
340 |
+
f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')
|
341 |
+
|
342 |
+
print(f'Return best:{best_style} {best_selection_metric}')
|
343 |
+
return {'best_style': best_style, 'best_temperature': best_temperature
|
344 |
+
, 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
|
345 |
+
'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}
|
TabPFN/encoders.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from utils import normalize_data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
+
|
9 |
+
|
10 |
+
class StyleEncoder(nn.Module):
|
11 |
+
def __init__(self, em_size, hyperparameter_definitions):
|
12 |
+
super().__init__()
|
13 |
+
# self.embeddings = {}
|
14 |
+
self.em_size = em_size
|
15 |
+
# self.hyperparameter_definitions = {}
|
16 |
+
# for hp in hyperparameter_definitions:
|
17 |
+
# self.embeddings[hp] = nn.Linear(1, self.em_size)
|
18 |
+
# self.embeddings = nn.ModuleDict(self.embeddings)
|
19 |
+
self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
|
20 |
+
|
21 |
+
def forward(self, hyperparameters): # T x B x num_features
|
22 |
+
# Make faster by using matrices
|
23 |
+
# sampled_embeddings = [torch.stack([
|
24 |
+
# self.embeddings[hp](torch.tensor([batch[hp]], device=self.embeddings[hp].weight.device, dtype=torch.float))
|
25 |
+
# for hp in batch
|
26 |
+
# ], -1).sum(-1) for batch in hyperparameters]
|
27 |
+
# return torch.stack(sampled_embeddings, 0)
|
28 |
+
return self.embedding(hyperparameters)
|
29 |
+
|
30 |
+
|
31 |
+
class _PositionalEncoding(nn.Module):
|
32 |
+
def __init__(self, d_model, dropout=0.):
|
33 |
+
super().__init__()
|
34 |
+
self.dropout = nn.Dropout(p=dropout)
|
35 |
+
self.d_model = d_model
|
36 |
+
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
37 |
+
|
38 |
+
def forward(self, x):# T x B x num_features
|
39 |
+
assert self.d_model % x.shape[-1]*2 == 0
|
40 |
+
d_per_feature = self.d_model // x.shape[-1]
|
41 |
+
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
42 |
+
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
43 |
+
interval_size = 10
|
44 |
+
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
|
45 |
+
#print(div_term/2/math.pi)
|
46 |
+
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
47 |
+
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
48 |
+
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
49 |
+
|
50 |
+
|
51 |
+
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
52 |
+
|
53 |
+
class EmbeddingEncoder(nn.Module):
|
54 |
+
def __init__(self, num_features, em_size, num_embs=100):
|
55 |
+
super().__init__()
|
56 |
+
self.num_embs = num_embs
|
57 |
+
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
58 |
+
self.init_weights(.1)
|
59 |
+
self.min_max = (-2,+2)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def width(self):
|
63 |
+
return self.min_max[1] - self.min_max[0]
|
64 |
+
|
65 |
+
def init_weights(self, initrange):
|
66 |
+
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
67 |
+
|
68 |
+
def discretize(self, x):
|
69 |
+
split_size = self.width / self.num_embs
|
70 |
+
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
71 |
+
|
72 |
+
def forward(self, x): # T x B x num_features
|
73 |
+
x_idxs = self.discretize(x)
|
74 |
+
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
75 |
+
# print(x_idxs,self.embeddings.weight.shape)
|
76 |
+
return self.embeddings(x_idxs).mean(-2)
|
77 |
+
|
78 |
+
|
79 |
+
class Normalize(nn.Module):
|
80 |
+
def __init__(self, mean, std):
|
81 |
+
super().__init__()
|
82 |
+
self.mean = mean
|
83 |
+
self.std = std
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return (x-self.mean)/self.std
|
87 |
+
|
88 |
+
|
89 |
+
def get_normalized_uniform_encoder(encoder_creator):
|
90 |
+
"""
|
91 |
+
This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
|
92 |
+
For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
|
93 |
+
be initialized with `encoder_creator(feature_dim, in_dim)`.
|
94 |
+
:param encoder:
|
95 |
+
:return:
|
96 |
+
"""
|
97 |
+
return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
|
98 |
+
|
99 |
+
|
100 |
+
Linear = nn.Linear
|
101 |
+
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Linear(emsize*2,emsize))
|
104 |
+
|
105 |
+
class NanHandlingEncoder(nn.Module):
|
106 |
+
def __init__(self, num_features, emsize, keep_nans=True):
|
107 |
+
super().__init__()
|
108 |
+
self.num_features = 2 * num_features if keep_nans else num_features
|
109 |
+
self.emsize = emsize
|
110 |
+
self.keep_nans = keep_nans
|
111 |
+
self.layer = nn.Linear(self.num_features, self.emsize)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
if self.keep_nans:
|
115 |
+
x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
|
116 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
117 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
118 |
+
)], -1)
|
119 |
+
else:
|
120 |
+
x = torch.nan_to_num(x, nan=0.0)
|
121 |
+
return self.layer(x)
|
122 |
+
|
123 |
+
class Linear(nn.Linear):
|
124 |
+
def __init__(self, num_features, emsize):
|
125 |
+
super().__init__(num_features, emsize)
|
126 |
+
self.num_features = num_features
|
127 |
+
self.emsize = emsize
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
x = torch.nan_to_num(x, nan=0.0)
|
131 |
+
return super().forward(x)
|
132 |
+
|
133 |
+
class SequenceSpanningEncoder(nn.Module):
|
134 |
+
# Regular Encoder transforms Seq_len, B, S -> Seq_len, B, E attending only to last dimension
|
135 |
+
# This Encoder accesses the Seq_Len dimension additionally
|
136 |
+
|
137 |
+
# Why would we want this? We can learn normalization and embedding of features
|
138 |
+
# , this might be more important for e.g. categorical, ordinal feats, nan detection
|
139 |
+
# However maybe this can be easily learned through transformer as well?
|
140 |
+
# A problem is to make this work across any sequence length and be independent of ordering
|
141 |
+
|
142 |
+
# We could use average and maximum pooling and use those with a linear layer
|
143 |
+
|
144 |
+
|
145 |
+
# Another idea !! Similar to this we would like to encode features so that their number is variable
|
146 |
+
# We would like to embed features, also using knowledge of the features in the entire sequence
|
147 |
+
|
148 |
+
# We could use convolution or another transformer
|
149 |
+
# Convolution:
|
150 |
+
|
151 |
+
# Transformer/Conv across sequence dimension that encodes and normalizes features
|
152 |
+
# -> Transformer across feature dimension that encodes features to a constant size
|
153 |
+
|
154 |
+
# Conv with flexible features but no sequence info: S,B,F -(reshape)-> S*B,1,F
|
155 |
+
# -(Conv1d)-> S*B,N,F -(AvgPool,MaxPool)-> S*B,N,1 -> S,B,N
|
156 |
+
# This probably won't work since it's missing a way to recognize which feature is encoded
|
157 |
+
|
158 |
+
# Transformer with flexible features: S,B,F -> F,B*S,1 -> F2,B*S,1 -> S,B,F2
|
159 |
+
|
160 |
+
def __init__(self, num_features, em_size):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
raise NotImplementedError()
|
164 |
+
# Seq_len, B, S -> Seq_len, B, E
|
165 |
+
#
|
166 |
+
self.convs = torch.nn.ModuleList([nn.Conv1d(64 if i else 1, 64, 3) for i in range(5)])
|
167 |
+
# self.linear = nn.Linear(64, emsize)
|
168 |
+
|
169 |
+
class TransformerBasedFeatureEncoder(nn.Module):
|
170 |
+
def __init__(self, num_features, emsize):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
hidden_emsize = emsize
|
174 |
+
encoder = Linear(1, hidden_emsize)
|
175 |
+
n_out = emsize
|
176 |
+
nhid = 2*emsize
|
177 |
+
dropout =0.0
|
178 |
+
nhead=4
|
179 |
+
nlayers=4
|
180 |
+
model = nn.Transformer(nhead=nhead, num_encoder_layers=4, num_decoder_layers=4, d_model=1)
|
181 |
+
|
182 |
+
def forward(self, *input):
|
183 |
+
# S,B,F -> F,S*B,1 -> F2,S*B,1 -> S,B,F2
|
184 |
+
input = input.transpose()
|
185 |
+
self.model(input)
|
186 |
+
|
187 |
+
class Conv(nn.Module):
|
188 |
+
def __init__(self, input_size, emsize):
|
189 |
+
super().__init__()
|
190 |
+
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
|
191 |
+
self.linear = nn.Linear(64,emsize)
|
192 |
+
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
size = math.isqrt(x.shape[-1])
|
196 |
+
assert size*size == x.shape[-1]
|
197 |
+
x = x.reshape(*x.shape[:-1], 1, size, size)
|
198 |
+
for conv in self.convs:
|
199 |
+
if x.shape[-1] < 4:
|
200 |
+
break
|
201 |
+
x = conv(x)
|
202 |
+
x.relu_()
|
203 |
+
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
204 |
+
return self.linear(x)
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
class CanEmb(nn.Embedding):
|
210 |
+
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
|
211 |
+
assert embedding_dim % num_features == 0
|
212 |
+
embedding_dim = embedding_dim // num_features
|
213 |
+
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
lx = x.long()
|
217 |
+
assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
|
218 |
+
x = super().forward(lx)
|
219 |
+
return x.view(*x.shape[:-2], -1)
|
220 |
+
|
221 |
+
def get_Canonical(num_classes):
|
222 |
+
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
223 |
+
|
224 |
+
def get_Embedding(num_embs_per_feature=100):
|
225 |
+
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|
TabPFN/initializers.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
def get_NormalInitializer(std):
|
5 |
+
def initializer(m):
|
6 |
+
if isinstance(m, nn.Linear):
|
7 |
+
nn.init.normal_(m.weight, 0, std)
|
8 |
+
nn.init.normal_(m.bias, 0, std)
|
9 |
+
return initializer
|
TabPFN/layer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn.modules.transformer import *
|
5 |
+
from torch.nn.modules.transformer import _get_activation_fn
|
6 |
+
|
7 |
+
from torch.utils.checkpoint import checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
class TransformerEncoderLayer(Module):
|
11 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
12 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
13 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
14 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
15 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
16 |
+
in a different way during application.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
d_model: the number of expected features in the input (required).
|
20 |
+
nhead: the number of heads in the multiheadattention models (required).
|
21 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
22 |
+
dropout: the dropout value (default=0.1).
|
23 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
24 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
25 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
26 |
+
as (batch, seq, feature). Default: ``False``.
|
27 |
+
|
28 |
+
Examples::
|
29 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
30 |
+
>>> src = torch.rand(10, 32, 512)
|
31 |
+
>>> out = encoder_layer(src)
|
32 |
+
|
33 |
+
Alternatively, when ``batch_first`` is ``True``:
|
34 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
35 |
+
>>> src = torch.rand(32, 10, 512)
|
36 |
+
>>> out = encoder_layer(src)
|
37 |
+
"""
|
38 |
+
__constants__ = ['batch_first']
|
39 |
+
|
40 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
41 |
+
layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
|
42 |
+
device=None, dtype=None, recompute_attn=False) -> None:
|
43 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
44 |
+
super().__init__()
|
45 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
46 |
+
**factory_kwargs)
|
47 |
+
# Implementation of Feedforward model
|
48 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
49 |
+
self.dropout = Dropout(dropout)
|
50 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
51 |
+
|
52 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
53 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
54 |
+
self.dropout1 = Dropout(dropout)
|
55 |
+
self.dropout2 = Dropout(dropout)
|
56 |
+
self.pre_norm = pre_norm
|
57 |
+
self.recompute_attn = recompute_attn
|
58 |
+
|
59 |
+
self.activation = _get_activation_fn(activation)
|
60 |
+
|
61 |
+
def __setstate__(self, state):
|
62 |
+
if 'activation' not in state:
|
63 |
+
state['activation'] = F.relu
|
64 |
+
super().__setstate__(state)
|
65 |
+
|
66 |
+
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
67 |
+
r"""Pass the input through the encoder layer.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
src: the sequence to the encoder layer (required).
|
71 |
+
src_mask: the mask for the src sequence (optional).
|
72 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
73 |
+
|
74 |
+
Shape:
|
75 |
+
see the docs in Transformer class.
|
76 |
+
"""
|
77 |
+
if self.pre_norm:
|
78 |
+
src_ = self.norm1(src)
|
79 |
+
else:
|
80 |
+
src_ = src
|
81 |
+
if isinstance(src_mask, tuple):
|
82 |
+
# global attention setup
|
83 |
+
assert not self.self_attn.batch_first
|
84 |
+
assert src_key_padding_mask is None
|
85 |
+
|
86 |
+
global_src_mask, trainset_src_mask, valset_src_mask = src_mask
|
87 |
+
|
88 |
+
num_global_tokens = global_src_mask.shape[0]
|
89 |
+
num_train_tokens = trainset_src_mask.shape[0]
|
90 |
+
|
91 |
+
global_tokens_src = src_[:num_global_tokens]
|
92 |
+
train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
|
93 |
+
global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
|
94 |
+
eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
|
95 |
+
|
96 |
+
|
97 |
+
attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
|
98 |
+
|
99 |
+
global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
|
100 |
+
train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
|
101 |
+
eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
|
102 |
+
None, True, valset_src_mask)[0]
|
103 |
+
|
104 |
+
src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
|
105 |
+
|
106 |
+
else:
|
107 |
+
if self.recompute_attn:
|
108 |
+
src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
|
109 |
+
else:
|
110 |
+
src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
|
111 |
+
key_padding_mask=src_key_padding_mask)[0]
|
112 |
+
src = src + self.dropout1(src2)
|
113 |
+
if not self.pre_norm:
|
114 |
+
src = self.norm1(src)
|
115 |
+
|
116 |
+
if self.pre_norm:
|
117 |
+
src_ = self.norm2(src)
|
118 |
+
else:
|
119 |
+
src_ = src
|
120 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
|
121 |
+
src = src + self.dropout2(src2)
|
122 |
+
|
123 |
+
if not self.pre_norm:
|
124 |
+
src = self.norm2(src)
|
125 |
+
return src
|
TabPFN/losses.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class CrossEntropyForMulticlassLoss(torch.nn.CrossEntropyLoss):
|
5 |
+
# This loss applies cross entropy after reducing the number of prediction
|
6 |
+
# dimensions to the number of classes in the target
|
7 |
+
|
8 |
+
# TODO: loss.item() doesn't work so the displayed losses are Nans
|
9 |
+
def __init__(self, num_classes, weight=None, size_average=None, ignore_index: int = -100,
|
10 |
+
reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None:
|
11 |
+
super().__init__(size_average=size_average, reduce=reduce, reduction=reduction, ignore_index=ignore_index)
|
12 |
+
self.num_classes = num_classes
|
13 |
+
|
14 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
15 |
+
loss = torch.zeros_like(input[:, :, 0])
|
16 |
+
for b in range(target.shape[1]):
|
17 |
+
l = super().forward(input[:, b, 0:len(torch.unique(target[:, b]))], target[:, b])
|
18 |
+
loss[:, b] += l
|
19 |
+
return loss.flatten()
|
20 |
+
|
21 |
+
def JointBCELossWithLogits(output, target):
|
22 |
+
# output shape: (S, B, NS) with NS = Number of sequences
|
23 |
+
# target shape: (S, B, SL)
|
24 |
+
# Loss = -log(mean_NS(prod_SL(p(target_SL, output_NS))))
|
25 |
+
# Here at the moment NS = SL
|
26 |
+
output = output.unsqueeze(-1).repeat(1, 1, 1, target.shape[-1]) # (S, B, NS, SL)
|
27 |
+
output = output.permute(2, 0, 1, 3) # (NS, S, B, SL)
|
28 |
+
print(target.shape, output.shape)
|
29 |
+
loss = (target * torch.sigmoid(output)) + ((1-target) * (1-torch.sigmoid(output)))
|
30 |
+
loss = loss.prod(-1)
|
31 |
+
loss = loss.mean(0)
|
32 |
+
loss = -torch.log(loss)
|
33 |
+
loss = loss.mean()
|
34 |
+
return loss
|
35 |
+
|
36 |
+
class ScaledSoftmaxCE(nn.Module):
|
37 |
+
def forward(self, x, label):
|
38 |
+
logits = x[..., :-10]
|
39 |
+
temp_scales = x[..., -10:]
|
40 |
+
|
41 |
+
logprobs = logits.softmax(-1)
|
TabPFN/model_builder.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from train import train, Losses
|
2 |
+
import priors
|
3 |
+
import encoders
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from priors.utils import trunc_norm_sampler_f, gamma_sampler_f
|
8 |
+
from utils import get_uniform_single_eval_pos_sampler
|
9 |
+
import torch
|
10 |
+
import math
|
11 |
+
|
12 |
+
def save_model(model, path, filename, config_sample):
|
13 |
+
config_sample = {**config_sample}
|
14 |
+
|
15 |
+
def make_serializable(config_sample):
|
16 |
+
if isinstance(config_sample, dict):
|
17 |
+
config_sample = {k: make_serializable(config_sample[k]) for k in config_sample}
|
18 |
+
if isinstance(config_sample, list):
|
19 |
+
config_sample = [make_serializable(v) for v in config_sample]
|
20 |
+
if callable(config_sample):
|
21 |
+
config_sample = str(config_sample)
|
22 |
+
return config_sample
|
23 |
+
|
24 |
+
#if 'num_features_used' in config_sample:
|
25 |
+
# del config_sample['num_features_used']
|
26 |
+
|
27 |
+
#config_sample['num_classes_as_str'] = str(config_sample['num_classes'])
|
28 |
+
#del config_sample['num_classes']
|
29 |
+
|
30 |
+
config_sample = make_serializable(config_sample)
|
31 |
+
|
32 |
+
torch.save((model.state_dict(), None, config_sample), os.path.join(path, filename))
|
33 |
+
|
34 |
+
|
35 |
+
import subprocess as sp
|
36 |
+
import os
|
37 |
+
|
38 |
+
def get_gpu_memory():
|
39 |
+
command = "nvidia-smi"
|
40 |
+
memory_free_info = sp.check_output(command.split()).decode('ascii')
|
41 |
+
return memory_free_info
|
42 |
+
|
43 |
+
|
44 |
+
def load_model(path, filename, device, eval_positions, verbose):
|
45 |
+
# TODO: This function only restores evaluation functionality but training canät be continued. It is also not flexible.
|
46 |
+
|
47 |
+
model_state, optimizer_state, config_sample = torch.load(
|
48 |
+
os.path.join(path, filename), map_location='cpu')
|
49 |
+
if ('differentiable_hyperparameters' in config_sample
|
50 |
+
and 'prior_mlp_activations' in config_sample['differentiable_hyperparameters']):
|
51 |
+
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values_used'] = config_sample[
|
52 |
+
'differentiable_hyperparameters'][
|
53 |
+
'prior_mlp_activations'][
|
54 |
+
'choice_values']
|
55 |
+
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values'] = [
|
56 |
+
torch.nn.Tanh for k in config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values']]
|
57 |
+
|
58 |
+
config_sample['categorical_features_sampler'] = lambda: lambda x: ([], [], [])
|
59 |
+
config_sample['num_features_used_in_training'] = config_sample['num_features_used']
|
60 |
+
config_sample['num_features_used'] = lambda: config_sample['num_features']
|
61 |
+
config_sample['num_classes_in_training'] = config_sample['num_classes']
|
62 |
+
config_sample['num_classes'] = 2
|
63 |
+
config_sample['batch_size_in_training'] = config_sample['batch_size']
|
64 |
+
config_sample['batch_size'] = 1
|
65 |
+
config_sample['bptt_in_training'] = config_sample['bptt']
|
66 |
+
config_sample['bptt'] = 10
|
67 |
+
config_sample['bptt_extra_samples_in_training'] = config_sample['bptt_extra_samples']
|
68 |
+
config_sample['bptt_extra_samples'] = None
|
69 |
+
|
70 |
+
#print('Memory', str(get_gpu_memory()))
|
71 |
+
|
72 |
+
model = get_model(config_sample, device=device, should_train=False, verbose=verbose)
|
73 |
+
module_prefix = 'module.'
|
74 |
+
model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
|
75 |
+
model[2].load_state_dict(model_state)
|
76 |
+
model[2].to(device)
|
77 |
+
|
78 |
+
return model, config_sample
|
79 |
+
|
80 |
+
def fix_loaded_config_sample(loaded_config_sample, config):
|
81 |
+
def copy_to_sample(*k):
|
82 |
+
t,s = loaded_config_sample, config
|
83 |
+
for k_ in k[:-1]:
|
84 |
+
t = t[k_]
|
85 |
+
s = s[k_]
|
86 |
+
t[k[-1]] = s[k[-1]]
|
87 |
+
copy_to_sample('num_features_used')
|
88 |
+
copy_to_sample('num_classes')
|
89 |
+
copy_to_sample('differentiable_hyperparameters','prior_mlp_activations','choice_values')
|
90 |
+
|
91 |
+
def load_config_sample(path, template_config):
|
92 |
+
model_state, optimizer_state, loaded_config_sample = torch.load(path, map_location='cpu')
|
93 |
+
fix_loaded_config_sample(loaded_config_sample, template_config)
|
94 |
+
return loaded_config_sample
|
95 |
+
|
96 |
+
def get_default_spec(test_datasets, valid_datasets):
|
97 |
+
bptt = 10000
|
98 |
+
eval_positions = [1000, 2000, 3000, 4000, 5000] # list(2 ** np.array([4, 5, 6, 7, 8, 9, 10, 11, 12]))
|
99 |
+
max_features = max([X.shape[1] for (_, X, _, _, _, _) in test_datasets] + [X.shape[1] for (_, X, _, _, _, _) in valid_datasets])
|
100 |
+
max_splits = 5
|
101 |
+
|
102 |
+
return bptt, eval_positions, max_features, max_splits
|
103 |
+
|
104 |
+
def get_mlp_prior_hyperparameters(config):
|
105 |
+
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
106 |
+
|
107 |
+
if "prior_sigma_gamma_k" in config:
|
108 |
+
sigma_sampler = gamma_sampler_f(config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"])
|
109 |
+
config['init_std'] = sigma_sampler
|
110 |
+
if "prior_noise_std_gamma_k" in config:
|
111 |
+
noise_std_sampler = gamma_sampler_f(config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"])
|
112 |
+
config['noise_std'] = noise_std_sampler
|
113 |
+
|
114 |
+
return config
|
115 |
+
|
116 |
+
|
117 |
+
def get_gp_mix_prior_hyperparameters(config):
|
118 |
+
return {'lengthscale_concentration': config["prior_lengthscale_concentration"],
|
119 |
+
'nu': config["prior_nu"],
|
120 |
+
'outputscale_concentration': config["prior_outputscale_concentration"],
|
121 |
+
'categorical_data': config["prior_y_minmax_norm"],
|
122 |
+
'y_minmax_norm': config["prior_lengthscale_concentration"],
|
123 |
+
'noise_concentration': config["prior_noise_concentration"],
|
124 |
+
'noise_rate': config["prior_noise_rate"]}
|
125 |
+
|
126 |
+
def get_gp_prior_hyperparameters(config):
|
127 |
+
return {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
128 |
+
|
129 |
+
|
130 |
+
def get_meta_gp_prior_hyperparameters(config):
|
131 |
+
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
132 |
+
|
133 |
+
if "outputscale_mean" in config:
|
134 |
+
outputscale_sampler = trunc_norm_sampler_f(config["outputscale_mean"]
|
135 |
+
, config["outputscale_mean"] * config["outputscale_std_f"])
|
136 |
+
config['outputscale'] = outputscale_sampler
|
137 |
+
if "lengthscale_mean" in config:
|
138 |
+
lengthscale_sampler = trunc_norm_sampler_f(config["lengthscale_mean"],
|
139 |
+
config["lengthscale_mean"] * config["lengthscale_std_f"])
|
140 |
+
config['lengthscale'] = lengthscale_sampler
|
141 |
+
|
142 |
+
return config
|
143 |
+
|
144 |
+
|
145 |
+
def get_model(config, device, should_train=True, verbose=False, state_dict=None, epoch_callback=None):
|
146 |
+
extra_kwargs = {}
|
147 |
+
verbose_train, verbose_prior = verbose >= 1, verbose >= 2
|
148 |
+
config['verbose'] = verbose_prior
|
149 |
+
|
150 |
+
if 'aggregate_k_gradients' not in config or config['aggregate_k_gradients'] is None:
|
151 |
+
config['aggregate_k_gradients'] = math.ceil(config['batch_size'] * ((config['nlayers'] * config['emsize'] * config['bptt'] * config['bptt']) / 10824640000))
|
152 |
+
|
153 |
+
config['num_steps'] = math.ceil(config['num_steps'] * config['aggregate_k_gradients'])
|
154 |
+
config['batch_size'] = math.ceil(config['batch_size'] / config['aggregate_k_gradients'])
|
155 |
+
config['recompute_attn'] = config['recompute_attn'] if 'recompute_attn' in config else False
|
156 |
+
|
157 |
+
def make_get_batch(model_proto, **extra_kwargs):
|
158 |
+
extra_kwargs = defaultdict(lambda: None, **extra_kwargs)
|
159 |
+
return (lambda batch_size, seq_len, num_features, hyperparameters
|
160 |
+
, device, model_proto=model_proto, get_batch=extra_kwargs['get_batch']
|
161 |
+
, prior_bag_priors=extra_kwargs['prior_bag_priors']: model_proto.get_batch(
|
162 |
+
batch_size=batch_size
|
163 |
+
, seq_len=seq_len
|
164 |
+
, device=device
|
165 |
+
, get_batch=get_batch
|
166 |
+
, hyperparameters=hyperparameters
|
167 |
+
, num_features=num_features))
|
168 |
+
|
169 |
+
if config['prior_type'] == 'prior_bag':
|
170 |
+
# Prior bag combines priors
|
171 |
+
get_batch_gp = make_get_batch(priors.fast_gp)
|
172 |
+
get_batch_mlp = make_get_batch(priors.mlp)
|
173 |
+
if 'flexible' in config and config['flexible']:
|
174 |
+
get_batch_gp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_gp})
|
175 |
+
get_batch_mlp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_mlp})
|
176 |
+
prior_bag_hyperparameters = {'prior_bag_get_batch': (get_batch_gp, get_batch_mlp)
|
177 |
+
, 'prior_bag_exp_weights_1': 2.0}
|
178 |
+
prior_hyperparameters = {**get_mlp_prior_hyperparameters(config), **get_gp_prior_hyperparameters(config)
|
179 |
+
, **prior_bag_hyperparameters}
|
180 |
+
model_proto = priors.prior_bag
|
181 |
+
else:
|
182 |
+
if config['prior_type'] == 'mlp':
|
183 |
+
prior_hyperparameters = get_mlp_prior_hyperparameters(config)
|
184 |
+
model_proto = priors.mlp
|
185 |
+
elif config['prior_type'] == 'gp':
|
186 |
+
prior_hyperparameters = get_gp_prior_hyperparameters(config)
|
187 |
+
model_proto = priors.fast_gp
|
188 |
+
elif config['prior_type'] == 'gp_mix':
|
189 |
+
prior_hyperparameters = get_gp_mix_prior_hyperparameters(config)
|
190 |
+
model_proto = priors.fast_gp_mix
|
191 |
+
else:
|
192 |
+
raise Exception()
|
193 |
+
|
194 |
+
if 'flexible' in config and config['flexible']:
|
195 |
+
get_batch_base = make_get_batch(model_proto)
|
196 |
+
extra_kwargs['get_batch'] = get_batch_base
|
197 |
+
model_proto = priors.flexible_categorical
|
198 |
+
|
199 |
+
use_style = False
|
200 |
+
|
201 |
+
if 'differentiable' in config and config['differentiable']:
|
202 |
+
get_batch_base = make_get_batch(model_proto, **extra_kwargs)
|
203 |
+
extra_kwargs = {'get_batch': get_batch_base, 'differentiable_hyperparameters': config['differentiable_hyperparameters']}
|
204 |
+
model_proto = priors.differentiable_prior
|
205 |
+
use_style = True
|
206 |
+
print(f"Using style prior: {use_style}")
|
207 |
+
|
208 |
+
if (('nan_prob_no_reason' in config and config['nan_prob_no_reason'] > 0.0) or
|
209 |
+
('nan_prob_a_reason' in config and config['nan_prob_a_reason'] > 0.0) or
|
210 |
+
('nan_prob_unknown_reason' in config and config['nan_prob_unknown_reason'] > 0.0)):
|
211 |
+
encoder = encoders.NanHandlingEncoder
|
212 |
+
else:
|
213 |
+
encoder = encoders.Linear
|
214 |
+
|
215 |
+
num_outputs = config['num_outputs'] if 'num_outputs' in config else 1
|
216 |
+
if config['max_num_classes'] == 2:
|
217 |
+
if 'joint_loss' in config and config['joint_loss']:
|
218 |
+
loss = JointBCELossWithLogits
|
219 |
+
else:
|
220 |
+
loss = Losses.bce
|
221 |
+
elif config['max_num_classes'] > 2:
|
222 |
+
loss = Losses.ce(torch.ones((config['max_num_classes'])))
|
223 |
+
else:
|
224 |
+
loss = BarDistribution(borders=get_bucket_limits(500, full_range=(-10, 10)))
|
225 |
+
|
226 |
+
aggregate_k_gradients = 1 if 'aggregate_k_gradients' not in config else config['aggregate_k_gradients']
|
227 |
+
check_is_compatible = False if 'multiclass_loss_type' not in config else (config['multiclass_loss_type'] == 'compatible')
|
228 |
+
config['multiclass_type'] = config['multiclass_type'] if 'multiclass_type' in config else 'rank'
|
229 |
+
config['mix_activations'] = config['mix_activations'] if 'mix_activations' in config else False
|
230 |
+
|
231 |
+
config['bptt_extra_samples'] = config['bptt_extra_samples'] if 'bptt_extra_samples' in config else None
|
232 |
+
config['eval_positions'] = [int(config['bptt'] * 0.95)] if config['bptt_extra_samples'] is None else [int(config['bptt'])]
|
233 |
+
|
234 |
+
epochs = 0 if not should_train else config['epochs']
|
235 |
+
model = train(model_proto.DataLoader
|
236 |
+
, loss
|
237 |
+
, encoder
|
238 |
+
, style_encoder_generator = encoders.StyleEncoder if use_style else None
|
239 |
+
, emsize=config['emsize']
|
240 |
+
, nhead=config['nhead']
|
241 |
+
, y_encoder_generator= encoders.get_Canonical(config['max_num_classes']) if config.get('canonical_y_encoder', False) else encoders.Linear
|
242 |
+
, pos_encoder_generator=None
|
243 |
+
, batch_size=config['batch_size']
|
244 |
+
, nlayers=config['nlayers']
|
245 |
+
, nhid=config['emsize'] * config['nhid_factor']
|
246 |
+
, epochs=epochs
|
247 |
+
, total_available_time_in_s=config.get('total_available_time_in_s', None)
|
248 |
+
, warmup_epochs=20
|
249 |
+
, bptt=config['bptt']
|
250 |
+
, gpu_device=device
|
251 |
+
, dropout=config['dropout']
|
252 |
+
, steps_per_epoch=config['num_steps']
|
253 |
+
, single_eval_pos_gen=get_uniform_single_eval_pos_sampler(config['bptt'])
|
254 |
+
, load_weights_from_this_state_dict=state_dict
|
255 |
+
, aggregate_k_gradients=aggregate_k_gradients
|
256 |
+
, check_is_compatible=check_is_compatible
|
257 |
+
, recompute_attn=config['recompute_attn']
|
258 |
+
, epoch_callback=epoch_callback
|
259 |
+
, bptt_extra_samples = config['bptt_extra_samples']
|
260 |
+
, extra_prior_kwargs_dict={
|
261 |
+
'num_features': config['num_features']
|
262 |
+
, 'fuse_x_y': False
|
263 |
+
, 'hyperparameters': prior_hyperparameters
|
264 |
+
, 'num_outputs':num_outputs
|
265 |
+
, 'dynamic_batch_size': 1 if ('num_global_att_tokens' in config and config['num_global_att_tokens']) else 2
|
266 |
+
, **extra_kwargs
|
267 |
+
}
|
268 |
+
, lr=config['lr']
|
269 |
+
, verbose=verbose_train,
|
270 |
+
weight_decay=config.get('weight_decay', 0.0),
|
271 |
+
normalize_labels=True)
|
272 |
+
|
273 |
+
return model
|
TabPFN/models_diff/gp_ablation_model.cpkt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7b0c8febc553cca3fdee265b5a1cd7567dbf83da855969940be4707a9218ffb
|
3 |
+
size 69460013
|
TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dae97f45bd53d719fc2b23fac4ec55eab16d63892196d939b1bb1c3b408be242
|
3 |
+
size 103616779
|
TabPFN/notebook_utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import io
|
5 |
+
import torch
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
def print_models(base_path, model_string):
|
9 |
+
print(model_string)
|
10 |
+
|
11 |
+
for i in range(80):
|
12 |
+
for e in range(50):
|
13 |
+
exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()
|
14 |
+
if exists:
|
15 |
+
print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))
|
16 |
+
print()
|
17 |
+
|
18 |
+
class CustomUnpickler(pickle.Unpickler):
|
19 |
+
def find_class(self, module, name):
|
20 |
+
if name == 'Manager':
|
21 |
+
from settings import Manager
|
22 |
+
return Manager
|
23 |
+
try:
|
24 |
+
return self.find_class_cpu(module, name)
|
25 |
+
except:
|
26 |
+
return None
|
27 |
+
|
28 |
+
def find_class_cpu(self, module, name):
|
29 |
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
30 |
+
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
|
31 |
+
else:
|
32 |
+
return super().find_class(module, name)
|
TabPFN/positional_encodings.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
# Protocol for positonal encodings.
|
8 |
+
# __init__(d_model, max_len=..[, more optionals])
|
9 |
+
# forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
|
10 |
+
|
11 |
+
|
12 |
+
class NoPositionalEncoding(nn.Module):
|
13 |
+
def __init__(self, d_model, max_len=None):
|
14 |
+
super(NoPositionalEncoding, self).__init__()
|
15 |
+
pass
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x #* math.sqrt(x.shape[-1])
|
19 |
+
|
20 |
+
|
21 |
+
class PositionalEncoding(nn.Module):
|
22 |
+
def __init__(self, d_model, max_len=5000):
|
23 |
+
super(PositionalEncoding, self).__init__()
|
24 |
+
pe = torch.zeros(max_len, d_model)
|
25 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
27 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
28 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
29 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
30 |
+
self.register_buffer('pe', pe)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class LearnedPositionalEncoding(nn.Module):
|
38 |
+
def __init__(self, d_model, max_len=5000):
|
39 |
+
super(LearnedPositionalEncoding, self).__init__()
|
40 |
+
self.max_seq_len = max_len
|
41 |
+
#self.positional_embeddings = nn.Embedding(max_len, d_model)
|
42 |
+
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
43 |
+
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
seq_len, bs, d_model = x.shape
|
47 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
48 |
+
pos_emb = self.positional_embeddings[:seq_len]
|
49 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
50 |
+
|
51 |
+
|
52 |
+
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
53 |
+
# TODO check whether it is a problem to use the same perm. for full batch
|
54 |
+
def forward(self, x):
|
55 |
+
seq_len, bs, d_model = x.shape
|
56 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
57 |
+
assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
|
58 |
+
|
59 |
+
paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
|
60 |
+
pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
|
61 |
+
|
62 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
TabPFN/prior_tuning_result.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24d2189bbc836aeea888cf6c540f2c1b45b5351822931189e8bf10a0bc80a0b6
|
3 |
+
size 18668851
|
TabPFN/priors/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import fast_gp, mlp, flexible_categorical, differentiable_prior, prior_bag
|
2 |
+
|
3 |
+
|
4 |
+
|
TabPFN/priors/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (240 Bytes). View file
|
|
TabPFN/priors/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (244 Bytes). View file
|
|
TabPFN/priors/__pycache__/differentiable_prior.cpython-37.pyc
ADDED
Binary file (15.9 kB). View file
|
|
TabPFN/priors/__pycache__/fast_gp.cpython-37.pyc
ADDED
Binary file (4.35 kB). View file
|
|
TabPFN/priors/__pycache__/fast_gp.cpython-38.pyc
ADDED
Binary file (4.4 kB). View file
|
|
TabPFN/priors/__pycache__/flexible_categorical.cpython-37.pyc
ADDED
Binary file (8.81 kB). View file
|
|
TabPFN/priors/__pycache__/mlp.cpython-37.pyc
ADDED
Binary file (6.72 kB). View file
|
|
TabPFN/priors/__pycache__/prior.cpython-37.pyc
ADDED
Binary file (320 Bytes). View file
|
|
TabPFN/priors/__pycache__/prior_bag.cpython-37.pyc
ADDED
Binary file (1.49 kB). View file
|
|
TabPFN/priors/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (7.79 kB). View file
|
|
TabPFN/priors/differentiable_prior.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
from .utils import get_batch_to_dataloader
|
6 |
+
from utils import default_device
|
7 |
+
from .utils import order_by_y, normalize_by_used_features_f
|
8 |
+
|
9 |
+
from .utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f
|
10 |
+
|
11 |
+
|
12 |
+
def unpack_dict_of_tuples(d):
|
13 |
+
# Returns list of dicts where each dict i contains values of tuple position i
|
14 |
+
# {'a': (1,2), 'b': (3,4)} -> [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}]
|
15 |
+
return [dict(zip(d.keys(), v)) for v in list(zip(*list(d.values())))]
|
16 |
+
|
17 |
+
class DifferentiableHyperparameter(nn.Module):
|
18 |
+
## We can sample this and get a hyperparameter value and a normalized hyperparameter indicator
|
19 |
+
def __init__(self, distribution, embedding_dim, device, **args):
|
20 |
+
super(DifferentiableHyperparameter, self).__init__()
|
21 |
+
|
22 |
+
self.distribution = distribution
|
23 |
+
self.embedding_dim = embedding_dim
|
24 |
+
self.device=device
|
25 |
+
for key in args:
|
26 |
+
setattr(self, key, args[key])
|
27 |
+
|
28 |
+
def get_sampler():
|
29 |
+
#if self.distribution == "beta":
|
30 |
+
# return beta_sampler_f(self.a, self.b), 0, 1
|
31 |
+
#elif self.distribution == "gamma":
|
32 |
+
# return gamma_sampler_f(self.a, self.b), 0, 1
|
33 |
+
#elif self.distribution == "beta_int":
|
34 |
+
# return scaled_beta_sampler_f(self.a, self.b, self.scale, self.min), self.scale + self.min, self.min, self.a / (self.a + self.b)
|
35 |
+
if self.distribution == "uniform":
|
36 |
+
if not hasattr(self, 'sample'):
|
37 |
+
return uniform_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min))
|
38 |
+
else:
|
39 |
+
return lambda: self.sample, self.min, self.max, None, None
|
40 |
+
elif self.distribution == "uniform_int":
|
41 |
+
return uniform_int_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min))
|
42 |
+
|
43 |
+
if self.distribution.startswith("meta"):
|
44 |
+
self.hparams = {}
|
45 |
+
def sample_meta(f):
|
46 |
+
indicators, passed = unpack_dict_of_tuples({hp: self.hparams[hp]() for hp in self.hparams})
|
47 |
+
# sampled_embeddings = list(itertools.chain.from_iterable([sampled_embeddings[k] for k in sampled_embeddings]))
|
48 |
+
meta_passed = f(**passed)
|
49 |
+
return indicators, meta_passed
|
50 |
+
|
51 |
+
args_passed = {'device': device, 'embedding_dim': embedding_dim}
|
52 |
+
if self.distribution == "meta_beta":
|
53 |
+
## Truncated normal where std and mean are drawn randomly logarithmically scaled
|
54 |
+
if hasattr(self, 'b') and hasattr(self, 'k'):
|
55 |
+
self.hparams = {'b': lambda: (None, self.b), 'k': lambda: (None, self.k)}
|
56 |
+
else:
|
57 |
+
self.hparams = {"b": DifferentiableHyperparameter(distribution="uniform", min=self.min
|
58 |
+
, max=self.max, **args_passed)
|
59 |
+
, "k": DifferentiableHyperparameter(distribution="uniform", min=self.min
|
60 |
+
, max=self.max, **args_passed)}
|
61 |
+
def make_beta(b, k):
|
62 |
+
return lambda b=b, k=k: self.scale * beta_sampler_f(b, k)()
|
63 |
+
self.sampler = lambda make_beta=make_beta : sample_meta(make_beta)
|
64 |
+
elif self.distribution == "meta_trunc_norm_log_scaled":
|
65 |
+
# these choices are copied down below, don't change these without changing `replace_differentiable_distributions`
|
66 |
+
self.min_std = self.min_std if hasattr(self, 'min_std') else 0.001
|
67 |
+
self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean
|
68 |
+
## Truncated normal where std and mean are drawn randomly logarithmically scaled
|
69 |
+
if not hasattr(self, 'log_mean'):
|
70 |
+
self.hparams = {"log_mean": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_mean)
|
71 |
+
, max=math.log(self.max_mean), **args_passed)
|
72 |
+
, "log_std": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_std)
|
73 |
+
, max=math.log(self.max_std), **args_passed)}
|
74 |
+
else:
|
75 |
+
self.hparams = {'log_mean': lambda: (None, self.log_mean), 'log_std': lambda: (None, self.log_std)}
|
76 |
+
def make_trunc_norm(log_mean, log_std):
|
77 |
+
return ((lambda : self.lower_bound + round(trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) if self.round
|
78 |
+
else (lambda: self.lower_bound + trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))()))
|
79 |
+
|
80 |
+
self.sampler = lambda make_trunc_norm=make_trunc_norm: sample_meta(make_trunc_norm)
|
81 |
+
elif self.distribution == "meta_trunc_norm":
|
82 |
+
self.min_std = self.min_std if hasattr(self, 'min_std') else 0
|
83 |
+
self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean
|
84 |
+
self.hparams = {"mean": DifferentiableHyperparameter(distribution="uniform", min=self.min_mean
|
85 |
+
, max=self.max_mean, **args_passed)
|
86 |
+
, "std": DifferentiableHyperparameter(distribution="uniform", min=self.min_std
|
87 |
+
, max=self.max_std, **args_passed)}
|
88 |
+
def make_trunc_norm(mean, std):
|
89 |
+
return ((lambda: self.lower_bound + round(
|
90 |
+
trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) if self.round
|
91 |
+
else (
|
92 |
+
lambda make_trunc_norm=make_trunc_norm: self.lower_bound + trunc_norm_sampler_f(math.exp(mean), math.exp(std))()))
|
93 |
+
self.sampler = lambda : sample_meta(make_trunc_norm)
|
94 |
+
elif self.distribution == "meta_choice":
|
95 |
+
if hasattr(self, 'choice_1_weight'):
|
96 |
+
self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))}
|
97 |
+
else:
|
98 |
+
self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0
|
99 |
+
, max=6.0, **args_passed) for i in range(1, len(self.choice_values))}
|
100 |
+
def make_choice(**choices):
|
101 |
+
weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights
|
102 |
+
sample = torch.multinomial(weights, 1, replacement=True).numpy()[0]
|
103 |
+
return self.choice_values[sample]
|
104 |
+
|
105 |
+
self.sampler = lambda make_choice=make_choice: sample_meta(make_choice)
|
106 |
+
elif self.distribution == "meta_choice_mixed":
|
107 |
+
if hasattr(self, 'choice_1_weight'):
|
108 |
+
self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))}
|
109 |
+
else:
|
110 |
+
self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0
|
111 |
+
, max=6.0, **args_passed) for i in range(1, len(self.choice_values))}
|
112 |
+
def make_choice(**choices):
|
113 |
+
weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights
|
114 |
+
def sample():
|
115 |
+
s = torch.multinomial(weights, 1, replacement=True).numpy()[0]
|
116 |
+
return self.choice_values[s]()
|
117 |
+
return lambda: sample
|
118 |
+
|
119 |
+
self.sampler = lambda make_choice=make_choice: sample_meta(make_choice)
|
120 |
+
else:
|
121 |
+
def return_two(x, min, max, mean, std):
|
122 |
+
# Returns (a hyperparameter value, and an indicator value passed to the model)
|
123 |
+
if mean is not None:
|
124 |
+
ind = (x-mean)/std#(2 * (x-min) / (max-min) - 1)
|
125 |
+
else:
|
126 |
+
ind = None
|
127 |
+
return ind, x # normalize indicator to [-1, 1]
|
128 |
+
# def sample_standard(sampler_f, embedding):
|
129 |
+
# s = torch.tensor([sampler_f()], device = self.device)
|
130 |
+
# return s, embedding(s)
|
131 |
+
self.sampler_f, self.sampler_min, self.sampler_max, self.sampler_mean, self.sampler_std = get_sampler()
|
132 |
+
self.sampler = lambda : return_two(self.sampler_f(), min=self.sampler_min, max=self.sampler_max
|
133 |
+
, mean=self.sampler_mean, std=self.sampler_std)
|
134 |
+
# self.embedding_layer = nn.Linear(1, self.embedding_dim, device=self.device)
|
135 |
+
# self.embed = lambda x : self.embedding_layer(
|
136 |
+
# (x - self.sampler_min) / (self.sampler_max - self.sampler_min))
|
137 |
+
#self.sampler = lambda : sample_standard(self.sampler_f, self.embedding)
|
138 |
+
|
139 |
+
|
140 |
+
def forward(self):
|
141 |
+
s, s_passed = self.sampler()
|
142 |
+
return s, s_passed
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
class DifferentiableHyperparameterList(nn.Module):
|
147 |
+
def __init__(self, hyperparameters, embedding_dim, device):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
self.device = device
|
151 |
+
hyperparameters = {k: v for (k, v) in hyperparameters.items() if v}
|
152 |
+
self.hyperparameters = nn.ModuleDict({hp: DifferentiableHyperparameter(embedding_dim = embedding_dim
|
153 |
+
, name = hp
|
154 |
+
, device = device, **hyperparameters[hp]) for hp in hyperparameters})
|
155 |
+
def get_hyperparameter_info(self):
|
156 |
+
sampled_hyperparameters_f, sampled_hyperparameters_keys = [], []
|
157 |
+
def append_hp(hp_key, hp_val):
|
158 |
+
sampled_hyperparameters_keys.append(hp_key)
|
159 |
+
# Function remaps hyperparameters from [-1, 1] range to true value
|
160 |
+
s_min, s_max, s_mean, s_std = hp_val.sampler_min, hp_val.sampler_max, hp_val.sampler_mean, hp_val.sampler_std
|
161 |
+
sampled_hyperparameters_f.append((lambda x: (x-s_mean)/s_std, lambda y : (y * s_std)+s_mean))
|
162 |
+
#sampled_hyperparameters_f.append(((lambda x: ((x - s_min) / (s_max - s_min) * (2) - 1)
|
163 |
+
# , (lambda y: ((y + 1) * (1 / 2) * (s_max - s_min) + s_min))))
|
164 |
+
for hp in self.hyperparameters:
|
165 |
+
hp_val = self.hyperparameters[hp]
|
166 |
+
if hasattr(hp_val, 'hparams'):
|
167 |
+
for hp_ in hp_val.hparams:
|
168 |
+
append_hp(f'{hp}_{hp_}', hp_val.hparams[hp_])
|
169 |
+
else:
|
170 |
+
append_hp(hp, hp_val)
|
171 |
+
|
172 |
+
|
173 |
+
return sampled_hyperparameters_keys, sampled_hyperparameters_f
|
174 |
+
|
175 |
+
def sample_parameter_object(self):
|
176 |
+
sampled_hyperparameters, s_passed = {}, {}
|
177 |
+
for hp in self.hyperparameters:
|
178 |
+
sampled_hyperparameters_, s_passed_ = self.hyperparameters[hp]()
|
179 |
+
s_passed[hp] = s_passed_
|
180 |
+
if isinstance(sampled_hyperparameters_, dict):
|
181 |
+
sampled_hyperparameters_ = {hp + '_' + str(key): val for key, val in sampled_hyperparameters_.items()}
|
182 |
+
sampled_hyperparameters.update(sampled_hyperparameters_)
|
183 |
+
else:
|
184 |
+
sampled_hyperparameters[hp] = sampled_hyperparameters_
|
185 |
+
|
186 |
+
# s_passed contains the values passed to the get_batch function
|
187 |
+
# sampled_hyperparameters contains the indicator of the sampled value, i.e. only number that describe the sampled object
|
188 |
+
return s_passed, sampled_hyperparameters#self.pack_parameter_object(sampled_embeddings)
|
189 |
+
|
190 |
+
class DifferentiablePrior(torch.nn.Module):
|
191 |
+
def __init__(self, get_batch, hyperparameters, differentiable_hyperparameters, args):
|
192 |
+
super(DifferentiablePrior, self).__init__()
|
193 |
+
|
194 |
+
self.h = hyperparameters
|
195 |
+
self.args = args
|
196 |
+
self.get_batch = get_batch
|
197 |
+
self.differentiable_hyperparameters = DifferentiableHyperparameterList(differentiable_hyperparameters
|
198 |
+
, embedding_dim=self.h['emsize']
|
199 |
+
, device=self.args['device'])
|
200 |
+
|
201 |
+
def forward(self):
|
202 |
+
# Sample hyperparameters
|
203 |
+
sampled_hyperparameters_passed, sampled_hyperparameters_indicators = self.differentiable_hyperparameters.sample_parameter_object()
|
204 |
+
|
205 |
+
hyperparameters = {**self.h, **sampled_hyperparameters_passed}
|
206 |
+
x, y, y_ = self.get_batch(hyperparameters=hyperparameters, **self.args)
|
207 |
+
|
208 |
+
return x, y, y_, sampled_hyperparameters_indicators
|
209 |
+
|
210 |
+
|
211 |
+
# TODO: Make this a class that keeps objects
|
212 |
+
@torch.no_grad()
|
213 |
+
def get_batch(batch_size, seq_len, num_features, get_batch
|
214 |
+
, device=default_device, differentiable_hyperparameters={}
|
215 |
+
, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
|
216 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size))
|
217 |
+
num_models = batch_size // batch_size_per_gp_sample
|
218 |
+
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
219 |
+
|
220 |
+
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
|
221 |
+
|
222 |
+
models = [DifferentiablePrior(get_batch, hyperparameters, differentiable_hyperparameters, args) for _ in range(num_models)]
|
223 |
+
sample = sum([[model()] for model in models], [])
|
224 |
+
|
225 |
+
x, y, y_, hyperparameter_dict = zip(*sample)
|
226 |
+
|
227 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
228 |
+
print('Hparams', hyperparameter_dict[0].keys())
|
229 |
+
|
230 |
+
hyperparameter_matrix = []
|
231 |
+
for batch in hyperparameter_dict:
|
232 |
+
hyperparameter_matrix.append([batch[hp] for hp in batch])
|
233 |
+
|
234 |
+
transposed_hyperparameter_matrix = list(zip(*hyperparameter_matrix))
|
235 |
+
assert all([all([hp is None for hp in hp_]) or all([hp is not None for hp in hp_]) for hp_ in transposed_hyperparameter_matrix]), 'it should always be the case that when a hyper-parameter is None, once it is always None'
|
236 |
+
# we remove columns that are only None (i.e. not sampled)
|
237 |
+
hyperparameter_matrix = [[hp for hp in hp_ if hp is not None] for hp_ in hyperparameter_matrix]
|
238 |
+
if len(hyperparameter_matrix[0]) > 0:
|
239 |
+
packed_hyperparameters = torch.tensor(hyperparameter_matrix)
|
240 |
+
packed_hyperparameters = torch.repeat_interleave(packed_hyperparameters, repeats=batch_size_per_gp_sample, dim=0).detach()
|
241 |
+
else:
|
242 |
+
packed_hyperparameters = None
|
243 |
+
|
244 |
+
x, y, y_, packed_hyperparameters = (torch.cat(x, 1).detach()
|
245 |
+
, torch.cat(y, 1).detach()
|
246 |
+
, torch.cat(y_, 1).detach()
|
247 |
+
, packed_hyperparameters)#list(itertools.chain.from_iterable(itertools.repeat(x, batch_size_per_gp_sample) for x in packed_hyperparameters)))#torch.repeat_interleave(torch.stack(packed_hyperparameters, 0).detach(), repeats=batch_size_per_gp_sample, dim=0))
|
248 |
+
|
249 |
+
return x, y, y_, packed_hyperparameters
|
250 |
+
|
251 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
252 |
+
DataLoader.num_outputs = 1
|
253 |
+
#DataLoader.validate = lambda : 0
|
254 |
+
|
255 |
+
def draw_random_style(dl, device):
|
256 |
+
(hp_embedding, data, targets_), targets = next(iter(dl))
|
257 |
+
return hp_embedding.to(device)[0:1, :]
|
258 |
+
|
259 |
+
def merge_style_with_info(diff_hparams_keys, diff_hparams_f, style, transform=True):
|
260 |
+
params = dict(zip(diff_hparams_keys, zip(diff_hparams_f, style.detach().cpu().numpy().tolist()[0])))
|
261 |
+
def t(v):
|
262 |
+
if transform:
|
263 |
+
return v[0][1](v[1])
|
264 |
+
else:
|
265 |
+
return v[1]
|
266 |
+
return {k : t(v) for k, v in params.items()}
|
267 |
+
|
268 |
+
|
269 |
+
import ConfigSpace.hyperparameters as CSH
|
270 |
+
|
271 |
+
def replace_differentiable_distributions(config):
|
272 |
+
diff_config = config['differentiable_hyperparameters']
|
273 |
+
for name, diff_hp_dict in diff_config.items():
|
274 |
+
distribution = diff_hp_dict['distribution']
|
275 |
+
if distribution == 'uniform':
|
276 |
+
diff_hp_dict['sample'] = CSH.UniformFloatHyperparameter(name, diff_hp_dict['min'], diff_hp_dict['max'])
|
277 |
+
elif distribution == 'meta_beta':
|
278 |
+
diff_hp_dict['k'] = CSH.UniformFloatHyperparameter(name+'_k', diff_hp_dict['min'], diff_hp_dict['max'])
|
279 |
+
diff_hp_dict['b'] = CSH.UniformFloatHyperparameter(name+'_b', diff_hp_dict['min'], diff_hp_dict['max'])
|
280 |
+
elif distribution == 'meta_choice':
|
281 |
+
for i in range(1, len(diff_hp_dict['choice_values'])):
|
282 |
+
diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0)
|
283 |
+
elif distribution == 'meta_choice_mixed':
|
284 |
+
for i in range(1, len(diff_hp_dict['choice_values'])):
|
285 |
+
diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0)
|
286 |
+
elif distribution == 'meta_trunc_norm_log_scaled':
|
287 |
+
diff_hp_dict['log_mean'] = CSH.UniformFloatHyperparameter(name+'_log_mean', math.log(diff_hp_dict['min_mean']), math.log(diff_hp_dict['max_mean']))
|
288 |
+
min_std = diff_hp_dict['min_std'] if 'min_std' in diff_hp_dict else 0.001
|
289 |
+
max_std = diff_hp_dict['max_std'] if 'max_std' in diff_hp_dict else diff_hp_dict['max_mean']
|
290 |
+
diff_hp_dict['log_std'] = CSH.UniformFloatHyperparameter(name+'_log_std', math.log(min_std), math.log(max_std))
|
291 |
+
else:
|
292 |
+
raise ValueError(f'Unknown distribution {distribution}')
|
293 |
+
|
TabPFN/priors/fast_gp.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import gpytorch
|
6 |
+
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
from utils import default_device
|
9 |
+
|
10 |
+
|
11 |
+
# We will use the simplest form of GP model, exact inference
|
12 |
+
class ExactGPModel(gpytorch.models.ExactGP):
|
13 |
+
def __init__(self, train_x, train_y, likelihood):
|
14 |
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
15 |
+
self.mean_module = gpytorch.means.ConstantMean()
|
16 |
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
mean_x = self.mean_module(x)
|
20 |
+
covar_x = self.covar_module(x)
|
21 |
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
22 |
+
|
23 |
+
|
24 |
+
def get_model(x, y, hyperparameters):
|
25 |
+
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
|
26 |
+
model = ExactGPModel(x, y, likelihood)
|
27 |
+
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
|
28 |
+
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
|
29 |
+
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
|
30 |
+
hyperparameters["lengthscale"]
|
31 |
+
return model, likelihood
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
36 |
+
equidistant_x=False, fix_x=None, **kwargs):
|
37 |
+
if isinstance(hyperparameters, (tuple, list)):
|
38 |
+
hyperparameters = {"noise": hyperparameters[0]
|
39 |
+
, "outputscale": hyperparameters[1]
|
40 |
+
, "lengthscale": hyperparameters[2]
|
41 |
+
, "is_binary_classification": hyperparameters[3]
|
42 |
+
# , "num_features_used": hyperparameters[4]
|
43 |
+
, "normalize_by_used_features": hyperparameters[5]
|
44 |
+
, "order_y": hyperparameters[6]
|
45 |
+
, "sampling": hyperparameters[7]
|
46 |
+
}
|
47 |
+
elif hyperparameters is None:
|
48 |
+
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
|
49 |
+
|
50 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
51 |
+
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
52 |
+
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
|
53 |
+
|
54 |
+
# hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
|
55 |
+
# hyperparameters.keys()}
|
56 |
+
assert not (equidistant_x and (fix_x is not None))
|
57 |
+
|
58 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
|
59 |
+
if equidistant_x:
|
60 |
+
assert num_features == 1
|
61 |
+
x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
|
62 |
+
elif fix_x is not None:
|
63 |
+
assert fix_x.shape == (seq_len, num_features)
|
64 |
+
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
|
65 |
+
else:
|
66 |
+
if hyperparameters.get('sampling','uniform') == 'uniform':
|
67 |
+
x = torch.rand(batch_size, seq_len, num_features, device=device)
|
68 |
+
else:
|
69 |
+
x = torch.randn(batch_size, seq_len, num_features, device=device)
|
70 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
71 |
+
model.to(device)
|
72 |
+
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
73 |
+
# trained_model.eval()
|
74 |
+
is_fitted = False
|
75 |
+
while not is_fitted:
|
76 |
+
try:
|
77 |
+
with gpytorch.settings.prior_mode(True):
|
78 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
79 |
+
model.to(device)
|
80 |
+
|
81 |
+
d = model(x)
|
82 |
+
d = likelihood(d)
|
83 |
+
sample = d.sample().transpose(0, 1)
|
84 |
+
is_fitted = True
|
85 |
+
except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
|
86 |
+
print('GP Fitting unsuccessful, retrying.. ')
|
87 |
+
print(x)
|
88 |
+
print(hyperparameters)
|
89 |
+
|
90 |
+
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
|
91 |
+
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
92 |
+
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
|
93 |
+
|
94 |
+
# TODO: Multi output
|
95 |
+
return x.transpose(0, 1), sample, sample # x.shape = (T,B,H)
|
96 |
+
|
97 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
98 |
+
DataLoader.num_outputs = 1
|
99 |
+
|
100 |
+
def get_model_on_device(x,y,hyperparameters,device):
|
101 |
+
model, likelihood = get_model(x, y, hyperparameters)
|
102 |
+
model.to(device)
|
103 |
+
return model, likelihood
|
104 |
+
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
|
108 |
+
start_time = time.time()
|
109 |
+
losses_after_t = [.0] if start_pos == 0 else []
|
110 |
+
all_losses_after_t = []
|
111 |
+
|
112 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
113 |
+
for t in range(max(start_pos, 1), len(x), step_size):
|
114 |
+
loss_sum = 0.
|
115 |
+
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
|
116 |
+
|
117 |
+
|
118 |
+
model.eval()
|
119 |
+
# print([t.shape for t in model.train_inputs])
|
120 |
+
# print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
|
121 |
+
f = model(x[t].unsqueeze(1))
|
122 |
+
l = likelihood(f)
|
123 |
+
means = l.mean.squeeze()
|
124 |
+
varis = l.covariance_matrix.squeeze()
|
125 |
+
# print(l.variance.squeeze(), l.mean.squeeze(), y[t])
|
126 |
+
|
127 |
+
assert len(means.shape) == len(varis.shape) == 1
|
128 |
+
assert len(means) == len(varis) == x.shape[1]
|
129 |
+
|
130 |
+
if use_mse:
|
131 |
+
c = nn.MSELoss(reduction='none')
|
132 |
+
ls = c(means, y[t])
|
133 |
+
else:
|
134 |
+
ls = -l.log_prob(y[t].unsqueeze(1))
|
135 |
+
|
136 |
+
losses_after_t.append(ls.mean())
|
137 |
+
all_losses_after_t.append(ls.flatten())
|
138 |
+
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
hps = (.1,.1,.1)
|
142 |
+
for redo_idx in range(1):
|
143 |
+
print(
|
144 |
+
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
|
TabPFN/priors/flexible_categorical.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
from utils import normalize_data, nan_handling_missing_for_unknown_reason_value, nan_handling_missing_for_no_reason_value, nan_handling_missing_for_a_reason_value, to_ranking_low_mem, remove_outliers
|
9 |
+
from .utils import normalize_by_used_features_f, randomize_classes, CategoricalActivation
|
10 |
+
from .utils import uniform_int_sampler_f
|
11 |
+
|
12 |
+
time_it = False
|
13 |
+
|
14 |
+
class BalancedBinarize(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return (x > torch.median(x)).float()
|
20 |
+
|
21 |
+
def class_sampler_f(min_, max_):
|
22 |
+
def s():
|
23 |
+
if random.random() > 0.5:
|
24 |
+
return uniform_int_sampler_f(min_, max_)()
|
25 |
+
return 2
|
26 |
+
return s
|
27 |
+
|
28 |
+
class MulticlassRank(nn.Module):
|
29 |
+
def __init__(self, num_classes, ordered_p=0.5):
|
30 |
+
super().__init__()
|
31 |
+
self.num_classes = class_sampler_f(2, num_classes)()
|
32 |
+
self.ordered_p = ordered_p
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# x has shape (T,B,H)
|
36 |
+
|
37 |
+
# CAUTION: This samples the same idx in sequence for each class boundary in a batch
|
38 |
+
class_boundaries = torch.randint(0, x.shape[0], (self.num_classes - 1,))
|
39 |
+
class_boundaries = x[class_boundaries].unsqueeze(1)
|
40 |
+
|
41 |
+
d = (x > class_boundaries).sum(axis=0)
|
42 |
+
|
43 |
+
randomized_classes = torch.rand((d.shape[1], )) > self.ordered_p
|
44 |
+
d[:, randomized_classes] = randomize_classes(d[:, randomized_classes], self.num_classes)
|
45 |
+
reverse_classes = torch.rand((d.shape[1],)) > 0.5
|
46 |
+
d[:, reverse_classes] = self.num_classes - 1 - d[:, reverse_classes]
|
47 |
+
return d
|
48 |
+
|
49 |
+
class MulticlassValue(nn.Module):
|
50 |
+
def __init__(self, num_classes, ordered_p=0.5):
|
51 |
+
super().__init__()
|
52 |
+
self.num_classes = class_sampler_f(2, num_classes)()
|
53 |
+
self.classes = nn.Parameter(torch.randn(num_classes-1), requires_grad=False)
|
54 |
+
self.ordered_p = ordered_p
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
# x has shape (T,B,H)
|
58 |
+
d = (x > (self.classes.unsqueeze(-1).unsqueeze(-1))).sum(axis=0)
|
59 |
+
|
60 |
+
randomized_classes = torch.rand((d.shape[1],)) > self.ordered_p
|
61 |
+
d[:, randomized_classes] = randomize_classes(d[:, randomized_classes], self.num_classes)
|
62 |
+
reverse_classes = torch.rand((d.shape[1],)) > 0.5
|
63 |
+
d[:, reverse_classes] = self.num_classes - 1 - d[:, reverse_classes]
|
64 |
+
return d
|
65 |
+
|
66 |
+
class MulticlassMultiNode(nn.Module):
|
67 |
+
def __init__(self, num_classes, ordered_p=0.5):
|
68 |
+
super().__init__()
|
69 |
+
self.num_classes = class_sampler_f(2, num_classes)()
|
70 |
+
self.classes = nn.Parameter(torch.randn(num_classes-1), requires_grad=False)
|
71 |
+
self.alt_multi_class = MulticlassValue(num_classes, ordered_p)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
# x has shape T, B, H
|
75 |
+
if len(x.shape) == 2:
|
76 |
+
return self.alt_multi_class(x)
|
77 |
+
T = 3
|
78 |
+
x[torch.isnan(x)] = 0.00001
|
79 |
+
d = torch.multinomial(torch.pow(0.00001+torch.sigmoid(x[:, :, 0:self.num_classes]).reshape(-1, self.num_classes), T), 1, replacement=True).reshape(x.shape[0], x.shape[1]).float()
|
80 |
+
return d
|
81 |
+
|
82 |
+
|
83 |
+
class FlexibleCategorical(torch.nn.Module):
|
84 |
+
def __init__(self, get_batch, hyperparameters, args):
|
85 |
+
super(FlexibleCategorical, self).__init__()
|
86 |
+
|
87 |
+
self.h = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
|
88 |
+
hyperparameters.keys()}
|
89 |
+
self.args = args
|
90 |
+
self.args_passed = {**self.args}
|
91 |
+
self.args_passed.update({'num_features': self.h['num_features_used']})
|
92 |
+
self.get_batch = get_batch
|
93 |
+
|
94 |
+
if self.h['num_classes'] > 1 and not self.h['balanced']:
|
95 |
+
if self.h['multiclass_type'] == 'rank':
|
96 |
+
self.class_assigner = MulticlassRank(self.h['num_classes']
|
97 |
+
, ordered_p=self.h['output_multiclass_ordered_p']
|
98 |
+
)
|
99 |
+
elif self.h['multiclass_type'] == 'value':
|
100 |
+
self.class_assigner = MulticlassValue(self.h['num_classes']
|
101 |
+
, ordered_p=self.h['output_multiclass_ordered_p']
|
102 |
+
)
|
103 |
+
elif self.h['multiclass_type'] == 'multi_node':
|
104 |
+
self.class_assigner = MulticlassMultiNode(self.h['num_classes'])
|
105 |
+
else:
|
106 |
+
raise ValueError("Unknow Multiclass type")
|
107 |
+
elif self.h['num_classes'] == 2 and self.h['balanced']:
|
108 |
+
self.class_assigner = BalancedBinarize()
|
109 |
+
elif self.h['num_classes'] > 2 and self.h['balanced']:
|
110 |
+
raise NotImplementedError("Balanced multiclass training is not possible")
|
111 |
+
else:
|
112 |
+
self.class_assigner = lambda x:x # Regression
|
113 |
+
|
114 |
+
def drop_for_reason(self, x, v):
|
115 |
+
nan_prob_sampler = CategoricalActivation(ordered_p=0.0
|
116 |
+
, categorical_p=1.0
|
117 |
+
, keep_activation_size=False,
|
118 |
+
num_classes_sampler=lambda: 20)
|
119 |
+
d = nan_prob_sampler(x)
|
120 |
+
# TODO: Make a different ordering for each activation
|
121 |
+
x[d < torch.rand((1,), device=x.device) * 20 * self.h['nan_prob_no_reason'] * random.random()] = v
|
122 |
+
return x
|
123 |
+
|
124 |
+
def drop_for_no_reason(self, x, v):
|
125 |
+
x[torch.rand(x.shape, device=self.args['device']) < self.h['nan_prob_no_reason']] = v
|
126 |
+
return x
|
127 |
+
|
128 |
+
def forward(self, batch_size):
|
129 |
+
start = time.time()
|
130 |
+
x, y, y_ = self.get_batch(hyperparameters=self.h, **self.args_passed)
|
131 |
+
if time_it:
|
132 |
+
print('Flex Forward Block 1', round(time.time() - start, 3))
|
133 |
+
|
134 |
+
start = time.time()
|
135 |
+
|
136 |
+
if self.h['nan_prob_no_reason']+self.h['nan_prob_a_reason']+self.h['nan_prob_unknown_reason'] > 0 and random.random() > 0.5: # Only one out of two datasets should have nans
|
137 |
+
if self.h['nan_prob_no_reason'] > 0 and random.random() > 0.5: # Missing for no reason
|
138 |
+
x = self.drop_for_no_reason(x, nan_handling_missing_for_no_reason_value(self.h['set_value_to_nan']))
|
139 |
+
|
140 |
+
if self.h['nan_prob_a_reason'] > 0 and random.random() > 0.5: # Missing for a reason
|
141 |
+
x = self.drop_for_reason(x, nan_handling_missing_for_a_reason_value(self.h['set_value_to_nan']))
|
142 |
+
|
143 |
+
if self.h['nan_prob_unknown_reason'] > 0: # Missing for unknown reason and random.random() > 0.5
|
144 |
+
if random.random() < self.h['nan_prob_unknown_reason_reason_prior']:
|
145 |
+
x = self.drop_for_no_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))
|
146 |
+
else:
|
147 |
+
x = self.drop_for_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))
|
148 |
+
|
149 |
+
# Categorical features
|
150 |
+
if 'categorical_feature_p' in self.h and random.random() > 1 - self.h['categorical_feature_p']:
|
151 |
+
p = random.random()
|
152 |
+
for col in range(x.shape[2]):
|
153 |
+
m = MulticlassRank(10, ordered_p=0.3)
|
154 |
+
if random.random() > p:
|
155 |
+
x[:, :, col] = m(x[:, :, col])
|
156 |
+
|
157 |
+
if time_it:
|
158 |
+
print('Flex Forward Block 2', round(time.time() - start, 3))
|
159 |
+
start = time.time()
|
160 |
+
|
161 |
+
if self.h['normalize_to_ranking']:
|
162 |
+
x = to_ranking_low_mem(x)
|
163 |
+
else:
|
164 |
+
x = remove_outliers(x)
|
165 |
+
x, y = normalize_data(x), normalize_data(y)
|
166 |
+
|
167 |
+
if time_it:
|
168 |
+
print('Flex Forward Block 3', round(time.time() - start, 3))
|
169 |
+
start = time.time()
|
170 |
+
|
171 |
+
# Cast to classification if enabled
|
172 |
+
y = self.class_assigner(y).float()
|
173 |
+
|
174 |
+
if time_it:
|
175 |
+
print('Flex Forward Block 4', round(time.time() - start, 3))
|
176 |
+
start = time.time()
|
177 |
+
if self.h['normalize_by_used_features']:
|
178 |
+
x = normalize_by_used_features_f(x, self.h['num_features_used'], self.args['num_features'], normalize_with_sqrt=self.h.get('normalize_with_sqrt',False))
|
179 |
+
if time_it:
|
180 |
+
print('Flex Forward Block 5', round(time.time() - start, 3))
|
181 |
+
|
182 |
+
start = time.time()
|
183 |
+
# Append empty features if enabled
|
184 |
+
x = torch.cat(
|
185 |
+
[x, torch.zeros((x.shape[0], x.shape[1], self.args['num_features'] - self.h['num_features_used']),
|
186 |
+
device=self.args['device'])], -1)
|
187 |
+
if time_it:
|
188 |
+
print('Flex Forward Block 6', round(time.time() - start, 3))
|
189 |
+
|
190 |
+
return x, y, y # x.shape = (T,B,H)
|
191 |
+
|
192 |
+
import torch.cuda as cutorch
|
193 |
+
|
194 |
+
@torch.no_grad()
|
195 |
+
def get_batch(batch_size, seq_len, num_features, get_batch, device, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
|
196 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(32, batch_size))
|
197 |
+
num_models = batch_size // batch_size_per_gp_sample
|
198 |
+
assert num_models > 0, f'Batch size ({batch_size}) is too small for batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
199 |
+
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
200 |
+
|
201 |
+
# Sample one seq_len for entire batch
|
202 |
+
seq_len = hyperparameters['seq_len_used']() if callable(hyperparameters['seq_len_used']) else seq_len
|
203 |
+
|
204 |
+
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
|
205 |
+
|
206 |
+
models = [FlexibleCategorical(get_batch, hyperparameters, args).to(device) for _ in range(num_models)]
|
207 |
+
|
208 |
+
start = time.time()
|
209 |
+
sample = sum([[model(batch_size=batch_size_per_gp_sample)] for model in models], [])
|
210 |
+
#print('sample', time.time() - start)
|
211 |
+
|
212 |
+
x, y, y_ = zip(*sample)
|
213 |
+
x, y, y_ = torch.cat(x, 1).detach(), torch.cat(y, 1).detach(), torch.cat(y_, 1).detach()
|
214 |
+
|
215 |
+
# # TODO: Reintegrate this code (Doesn't work on batch dim), could be applied to each batch sample individually
|
216 |
+
# if hyperparameters['is_binary_classification'] and hyperparameters['order_y']:
|
217 |
+
# x, y = order_by_y(x, y)
|
218 |
+
|
219 |
+
return x, y, y_
|
220 |
+
|
221 |
+
# num_features_used = num_features_used_sampler()
|
222 |
+
# prior_outputscale = prior_outputscale_sampler()
|
223 |
+
# prior_lengthscale = prior_lengthscale_sampler()
|
224 |
+
#
|
225 |
+
# x, sample = normalize_data(x), normalize_data(sample)
|
226 |
+
#
|
227 |
+
# if is_binary_classification:
|
228 |
+
# sample = (sample > torch.median(sample, dim=0)[0]).float()
|
229 |
+
#
|
230 |
+
# if normalize_by_used_features:
|
231 |
+
# x = normalize_by_used_features_f(x, num_features_used, num_features)
|
232 |
+
#
|
233 |
+
# # # if is_binary_classification and order_y:
|
234 |
+
# # # x, sample = order_by_y(x, sample)
|
235 |
+
# #
|
236 |
+
# # Append empty features if enabled
|
237 |
+
# x = torch.cat([x, torch.zeros((x.shape[0], x.shape[1], num_features - num_features_used), device=device)], -1)
|
238 |
+
|
239 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
240 |
+
DataLoader.num_outputs = 1
|
TabPFN/priors/mlp.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from utils import default_device
|
9 |
+
from .utils import get_batch_to_dataloader
|
10 |
+
|
11 |
+
class GaussianNoise(nn.Module):
|
12 |
+
def __init__(self, std, device):
|
13 |
+
super().__init__()
|
14 |
+
self.std = std
|
15 |
+
self.device=device
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x + torch.normal(torch.zeros_like(x), self.std)
|
19 |
+
|
20 |
+
|
21 |
+
def causes_sampler_f(num_causes):
|
22 |
+
means = np.random.normal(0, 1, (num_causes))
|
23 |
+
std = np.abs(np.random.normal(0, 1, (num_causes)) * means)
|
24 |
+
return means, std
|
25 |
+
|
26 |
+
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, sampling='normal', **kwargs):
|
27 |
+
if ('mix_activations' in hyperparameters) and hyperparameters['mix_activations']:
|
28 |
+
s = hyperparameters['prior_mlp_activations']()
|
29 |
+
hyperparameters['prior_mlp_activations'] = lambda : s
|
30 |
+
|
31 |
+
class MLP(torch.nn.Module):
|
32 |
+
def __init__(self, hyperparameters):
|
33 |
+
super(MLP, self).__init__()
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
|
37 |
+
for key in hyperparameters:
|
38 |
+
setattr(self, key, hyperparameters[key])
|
39 |
+
|
40 |
+
assert (self.num_layers >= 2)
|
41 |
+
|
42 |
+
if 'verbose' in hyperparameters and self.verbose:
|
43 |
+
print({k : hyperparameters[k] for k in ['is_causal', 'num_causes', 'prior_mlp_hidden_dim'
|
44 |
+
, 'num_layers', 'noise_std', 'y_is_effect', 'pre_sample_weights', 'prior_mlp_dropout_prob'
|
45 |
+
, 'pre_sample_causes']})
|
46 |
+
|
47 |
+
if self.is_causal:
|
48 |
+
self.prior_mlp_hidden_dim = max(self.prior_mlp_hidden_dim, num_outputs + 2 * num_features)
|
49 |
+
else:
|
50 |
+
self.num_causes = num_features
|
51 |
+
|
52 |
+
# This means that the mean and standard deviation of each cause is determined in advance
|
53 |
+
if self.pre_sample_causes:
|
54 |
+
self.causes_mean, self.causes_std = causes_sampler_f(self.num_causes)
|
55 |
+
self.causes_mean = torch.tensor(self.causes_mean, device=device).unsqueeze(0).unsqueeze(0).tile(
|
56 |
+
(seq_len, 1, 1))
|
57 |
+
self.causes_std = torch.tensor(self.causes_std, device=device).unsqueeze(0).unsqueeze(0).tile(
|
58 |
+
(seq_len, 1, 1))
|
59 |
+
|
60 |
+
def generate_module(layer_idx, out_dim):
|
61 |
+
# Determine std of each noise term in initialization, so that is shared in runs
|
62 |
+
# torch.abs(torch.normal(torch.zeros((out_dim)), self.noise_std)) - Change std for each dimension?
|
63 |
+
noise = (GaussianNoise(torch.abs(torch.normal(torch.zeros(size=(1, out_dim), device=device), float(self.noise_std))), device=device)
|
64 |
+
if self.pre_sample_weights else GaussianNoise(float(self.noise_std), device=device))
|
65 |
+
return [
|
66 |
+
nn.Sequential(*[self.prior_mlp_activations()
|
67 |
+
, nn.Linear(self.prior_mlp_hidden_dim, out_dim)
|
68 |
+
, noise])
|
69 |
+
]
|
70 |
+
|
71 |
+
self.layers = [nn.Linear(self.num_causes, self.prior_mlp_hidden_dim, device=device)]
|
72 |
+
self.layers += [module for layer_idx in range(self.num_layers-1) for module in generate_module(layer_idx, self.prior_mlp_hidden_dim)]
|
73 |
+
if not self.is_causal:
|
74 |
+
self.layers += generate_module(-1, num_outputs)
|
75 |
+
self.layers = nn.Sequential(*self.layers)
|
76 |
+
|
77 |
+
# Initialize Model parameters
|
78 |
+
for i, (n, p) in enumerate(self.layers.named_parameters()):
|
79 |
+
if self.block_wise_dropout:
|
80 |
+
if len(p.shape) == 2: # Only apply to weight matrices and not bias
|
81 |
+
nn.init.zeros_(p)
|
82 |
+
# TODO: N blocks should be a setting
|
83 |
+
n_blocks = random.randint(1, math.ceil(math.sqrt(min(p.shape[0], p.shape[1]))))
|
84 |
+
w, h = p.shape[0] // n_blocks, p.shape[1] // n_blocks
|
85 |
+
keep_prob = (n_blocks*w*h) / p.numel()
|
86 |
+
for block in range(0, n_blocks):
|
87 |
+
nn.init.normal_(p[w * block: w * (block+1), h * block: h * (block+1)], std=self.init_std / keep_prob**(1/2))
|
88 |
+
else:
|
89 |
+
if len(p.shape) == 2: # Only apply to weight matrices and not bias
|
90 |
+
dropout_prob = self.prior_mlp_dropout_prob if i > 0 else 0.0 # Don't apply dropout in first layer
|
91 |
+
dropout_prob = min(dropout_prob, 0.99)
|
92 |
+
nn.init.normal_(p, std=self.init_std / (1. - dropout_prob)**(1/2))
|
93 |
+
p *= torch.bernoulli(torch.zeros_like(p) + 1. - dropout_prob)
|
94 |
+
|
95 |
+
def forward(self):
|
96 |
+
def sample_normal():
|
97 |
+
if self.pre_sample_causes:
|
98 |
+
causes = torch.normal(self.causes_mean, self.causes_std.abs()).float()
|
99 |
+
else:
|
100 |
+
causes = torch.normal(0., 1., (seq_len, 1, self.num_causes), device=device).float()
|
101 |
+
return causes
|
102 |
+
|
103 |
+
if self.sampling == 'normal':
|
104 |
+
causes = sample_normal()
|
105 |
+
elif self.sampling == 'mixed':
|
106 |
+
zipf_p, multi_p, normal_p = random.random() * 0.66, random.random() * 0.66, random.random() * 0.66
|
107 |
+
def sample_cause(n):
|
108 |
+
if random.random() > normal_p:
|
109 |
+
if self.pre_sample_causes:
|
110 |
+
return torch.normal(self.causes_mean[:, :, n], self.causes_std[:, :, n].abs()).float()
|
111 |
+
else:
|
112 |
+
return torch.normal(0., 1., (seq_len, 1), device=device).float()
|
113 |
+
elif random.random() > multi_p:
|
114 |
+
x = torch.multinomial(torch.rand((random.randint(2, 10))), seq_len, replacement=True).to(device).unsqueeze(-1).float()
|
115 |
+
x = (x - torch.mean(x)) / torch.std(x)
|
116 |
+
return x
|
117 |
+
else:
|
118 |
+
x = torch.minimum(torch.tensor(np.random.zipf(2.0 + random.random() * 2, size=(seq_len)),
|
119 |
+
device=device).unsqueeze(-1).float(), torch.tensor(10.0, device=device))
|
120 |
+
return x - torch.mean(x)
|
121 |
+
causes = torch.cat([sample_cause(n).unsqueeze(-1) for n in range(self.num_causes)], -1)
|
122 |
+
elif self.sampling == 'uniform':
|
123 |
+
causes = torch.rand((seq_len, 1, self.num_causes), device=device)
|
124 |
+
else:
|
125 |
+
raise ValueError(f'Sampling is set to invalid setting: {sampling}.')
|
126 |
+
|
127 |
+
outputs = [causes]
|
128 |
+
for layer in self.layers:
|
129 |
+
outputs.append(layer(outputs[-1]))
|
130 |
+
outputs = outputs[2:]
|
131 |
+
|
132 |
+
if self.is_causal:
|
133 |
+
## Sample nodes from graph if model is causal
|
134 |
+
outputs_flat = torch.cat(outputs, -1)
|
135 |
+
|
136 |
+
if self.in_clique:
|
137 |
+
random_perm = random.randint(0, outputs_flat.shape[-1] - num_outputs - num_features) + torch.randperm(num_outputs + num_features, device=device)
|
138 |
+
else:
|
139 |
+
random_perm = torch.randperm(outputs_flat.shape[-1]-1, device=device)
|
140 |
+
|
141 |
+
random_idx_y = list(range(-num_outputs, -0)) if self.y_is_effect else random_perm[0:num_outputs]
|
142 |
+
random_idx = random_perm[num_outputs:num_outputs + num_features]
|
143 |
+
|
144 |
+
if self.sort_features:
|
145 |
+
random_idx, _ = torch.sort(random_idx)
|
146 |
+
y = outputs_flat[:, :, random_idx_y]
|
147 |
+
|
148 |
+
x = outputs_flat[:, :, random_idx]
|
149 |
+
else:
|
150 |
+
y = outputs[-1][:, :, :]
|
151 |
+
x = causes
|
152 |
+
|
153 |
+
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()) or bool(torch.any(torch.isnan(y)).detach().cpu().numpy()):
|
154 |
+
x[:] = 0.0
|
155 |
+
y[:] = 1.0
|
156 |
+
|
157 |
+
return x, y
|
158 |
+
|
159 |
+
model = MLP(hyperparameters).to(device)
|
160 |
+
|
161 |
+
sample = sum([[model()] for _ in range(0, batch_size)], [])
|
162 |
+
|
163 |
+
x, y = zip(*sample)
|
164 |
+
y = torch.cat(y, 1).detach().squeeze(2)
|
165 |
+
x = torch.cat(x, 1).detach()
|
166 |
+
x = x[..., torch.randperm(x.shape[-1])]
|
167 |
+
|
168 |
+
return x, y, y
|
169 |
+
|
170 |
+
|
171 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
172 |
+
DataLoader.num_outputs = 1
|
173 |
+
|
TabPFN/priors/prior.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
|
3 |
+
|
4 |
+
class PriorDataLoader(DataLoader):
|
5 |
+
pass
|
6 |
+
# init accepts num_steps as first argument
|
7 |
+
|
8 |
+
# has two attributes set on class or object level:
|
9 |
+
# num_features: int and
|
10 |
+
# num_outputs: int
|
11 |
+
# fuse_x_y: bool
|
12 |
+
# Optional: validate function that accepts a transformer model
|
TabPFN/priors/prior_bag.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .utils import get_batch_to_dataloader
|
4 |
+
from utils import default_device
|
5 |
+
|
6 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device
|
7 |
+
, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
|
8 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size))
|
9 |
+
num_models = batch_size // batch_size_per_gp_sample
|
10 |
+
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
11 |
+
|
12 |
+
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
|
13 |
+
|
14 |
+
prior_bag_priors_get_batch = hyperparameters['prior_bag_get_batch']
|
15 |
+
prior_bag_priors_p = [1.0] + [hyperparameters[f'prior_bag_exp_weights_{i}'] for i in range(1, len(prior_bag_priors_get_batch))]
|
16 |
+
|
17 |
+
weights = torch.tensor(prior_bag_priors_p, dtype=torch.float) # create a tensor of weights
|
18 |
+
batch_assignments = torch.multinomial(torch.softmax(weights, 0), num_models, replacement=True).numpy()
|
19 |
+
|
20 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
21 |
+
print('PRIOR_BAG:', weights, batch_assignments)
|
22 |
+
|
23 |
+
sample = sum([[prior_bag_priors_get_batch[int(prior_idx)](hyperparameters=hyperparameters, **args)] for prior_idx in batch_assignments], [])
|
24 |
+
|
25 |
+
x, y, y_ = zip(*sample)
|
26 |
+
x, y, y_ = (torch.cat(x, 1).detach()
|
27 |
+
, torch.cat(y, 1).detach()
|
28 |
+
, torch.cat(y_, 1).detach())
|
29 |
+
return x, y, y_
|
30 |
+
|
31 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
32 |
+
DataLoader.num_outputs = 1
|
TabPFN/priors/utils.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from utils import set_locals_in_self
|
6 |
+
from .prior import PriorDataLoader
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib.gridspec as gridspec
|
11 |
+
import scipy.stats as stats
|
12 |
+
import math
|
13 |
+
|
14 |
+
def get_batch_to_dataloader(get_batch_method_):
|
15 |
+
class DL(PriorDataLoader):
|
16 |
+
get_batch_method = get_batch_method_
|
17 |
+
|
18 |
+
# Caution, you might need to set self.num_features manually if it is not part of the args.
|
19 |
+
def __init__(self, num_steps, fuse_x_y=False, **get_batch_kwargs):
|
20 |
+
set_locals_in_self(locals())
|
21 |
+
# The stuff outside the or is set as class attribute before instantiation.
|
22 |
+
self.num_features = get_batch_kwargs.get('num_features') or self.num_features
|
23 |
+
self.num_outputs = get_batch_kwargs.get('num_outputs') or self.num_outputs
|
24 |
+
print('DataLoader.__dict__', self.__dict__)
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def gbm(*args, fuse_x_y=True, **kwargs):
|
28 |
+
dynamic_seq_len = callable(kwargs['seq_len'])
|
29 |
+
kwargs['seq_len'] = kwargs['seq_len']() if dynamic_seq_len else kwargs['seq_len']
|
30 |
+
# Scales the batch size dynamically with the power of 'dynamic_batch_size'.
|
31 |
+
# A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
|
32 |
+
if dynamic_seq_len and 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
|
33 |
+
kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
|
34 |
+
batch = get_batch_method_(*args, **kwargs)
|
35 |
+
x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
|
36 |
+
if fuse_x_y:
|
37 |
+
return torch.cat([x, torch.cat([torch.zeros_like(y[:1]), y[:-1]], 0).unsqueeze(-1).float()],
|
38 |
+
-1), target_y
|
39 |
+
else:
|
40 |
+
return (style, x, y), target_y
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return self.num_steps
|
44 |
+
|
45 |
+
def __iter__(self):
|
46 |
+
return iter(self.gbm(**self.get_batch_kwargs, fuse_x_y=self.fuse_x_y) for _ in range(self.num_steps))
|
47 |
+
|
48 |
+
|
49 |
+
return DL
|
50 |
+
|
51 |
+
import seaborn as sns
|
52 |
+
def plot_features(data, targets, fig=None):
|
53 |
+
if torch.is_tensor(data):
|
54 |
+
data = data.detach().cpu().numpy()
|
55 |
+
targets = targets.detach().cpu().numpy()
|
56 |
+
#data = np.concatenate([data, data[:, -1:]], -1)
|
57 |
+
#df = pd.DataFrame(data, columns=list(range(0, data.shape[1])))
|
58 |
+
#g = sns.pairplot(df, hue=data.shape[1]-1, palette="Set2", diag_kind="kde", height=2.5)
|
59 |
+
#plt.legend([], [], frameon=False)
|
60 |
+
#g._legend.remove()
|
61 |
+
#g = sns.PairGrid(df, hue=data.shape[1]-1)
|
62 |
+
#g.map_diag(sns.histplot)
|
63 |
+
#g.map_offdiag(sns.scatterplot)
|
64 |
+
#g._legend.remove()
|
65 |
+
|
66 |
+
fig2 = fig if fig else plt.figure(figsize=(8, 8))
|
67 |
+
spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
|
68 |
+
for d in range(0, data.shape[1]):
|
69 |
+
for d2 in range(0, data.shape[1]):
|
70 |
+
sub_ax = fig2.add_subplot(spec2[d, d2])
|
71 |
+
if d == d2:
|
72 |
+
sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
|
73 |
+
sub_ax.set(ylabel=None)
|
74 |
+
else:
|
75 |
+
sns.scatterplot(x=data[:, d], y=data[:, d2],
|
76 |
+
hue=targets[:],legend=False, palette="deep")
|
77 |
+
#plt.scatter(data[:, d], data[:, d2],
|
78 |
+
# c=targets[:])
|
79 |
+
sub_ax.get_xaxis().set_ticks([])
|
80 |
+
sub_ax.get_yaxis().set_ticks([])
|
81 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
82 |
+
fig2.show()
|
83 |
+
|
84 |
+
|
85 |
+
def plot_prior(prior):
|
86 |
+
s = np.array([prior() for _ in range(0, 1000)])
|
87 |
+
count, bins, ignored = plt.hist(s, 50, density=True)
|
88 |
+
print(s.min())
|
89 |
+
plt.show()
|
90 |
+
|
91 |
+
trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
|
92 |
+
beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
|
93 |
+
gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
|
94 |
+
uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
|
95 |
+
uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
|
96 |
+
def zipf_sampler_f(a, b, c):
|
97 |
+
x = np.arange(b, c)
|
98 |
+
weights = x ** (-a)
|
99 |
+
weights /= weights.sum()
|
100 |
+
return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
|
101 |
+
scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
|
102 |
+
|
103 |
+
|
104 |
+
def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
|
105 |
+
if normalize_with_sqrt:
|
106 |
+
return x / (num_features_used / num_features)**(1 / 2)
|
107 |
+
return x / (num_features_used / num_features)
|
108 |
+
|
109 |
+
|
110 |
+
def order_by_y(x, y):
|
111 |
+
order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
|
112 |
+
order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
|
113 |
+
x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
|
114 |
+
y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
|
115 |
+
|
116 |
+
return x, y
|
117 |
+
|
118 |
+
def randomize_classes(x, num_classes):
|
119 |
+
classes = torch.arange(0, num_classes, device=x.device)
|
120 |
+
random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
|
121 |
+
x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class CategoricalActivation(nn.Module):
|
126 |
+
def __init__(self, categorical_p=0.1, ordered_p=0.7
|
127 |
+
, keep_activation_size=False
|
128 |
+
, num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
|
129 |
+
self.categorical_p = categorical_p
|
130 |
+
self.ordered_p = ordered_p
|
131 |
+
self.keep_activation_size = keep_activation_size
|
132 |
+
self.num_classes_sampler = num_classes_sampler
|
133 |
+
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
# x shape: T, B, H
|
138 |
+
|
139 |
+
x = nn.Softsign()(x)
|
140 |
+
|
141 |
+
num_classes = self.num_classes_sampler()
|
142 |
+
hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
|
143 |
+
|
144 |
+
categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
|
145 |
+
class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
|
146 |
+
# Sample a different index for each hidden dimension, but shared for all batches
|
147 |
+
for b in range(x.shape[1]):
|
148 |
+
for h in range(x.shape[2]):
|
149 |
+
ind = torch.randint(0, x.shape[0], (num_classes - 1,))
|
150 |
+
class_boundaries[:, b, h] = x[ind, b, h]
|
151 |
+
|
152 |
+
for b in range(x.shape[1]):
|
153 |
+
x_rel = x[:, b, categorical_classes[b]]
|
154 |
+
boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
|
155 |
+
x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
|
156 |
+
|
157 |
+
ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
|
158 |
+
ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
|
159 |
+
x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
|
160 |
+
|
161 |
+
x = x * hid_strength if self.keep_activation_size else x
|
162 |
+
|
163 |
+
return x
|
TabPFN/requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please use python V 3.7 to be compatible with all packages
|
2 |
+
gpytorch==1.5.0
|
3 |
+
torch==1.9.0
|
4 |
+
scikit-learn==0.24.2
|
5 |
+
pyyaml==5.4.1
|
6 |
+
seaborn==0.11.2
|
7 |
+
xgboost==1.4.0
|
8 |
+
tqdm==4.62.1
|
9 |
+
numpy==1.21.2
|
10 |
+
openml==0.12.2
|
11 |
+
catboost==0.26.1
|
12 |
+
auto-sklearn==0.14.5
|
13 |
+
hyperopt==0.2.5
|
14 |
+
configspace==0.4.21
|
15 |
+
# autogluon==0.4.0
|