yesssssssss commited on
Commit
a531787
·
1 Parent(s): c7943d0
Files changed (1) hide show
  1. Untitled.ipynb +0 -260
Untitled.ipynb DELETED
@@ -1,260 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 4,
6
- "id": "1e0cd6a7",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import sys\n",
11
- "sys.path.insert(0,'..')"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 5,
17
- "id": "ba81c2ba",
18
- "metadata": {},
19
- "outputs": [],
20
- "source": [
21
- "from scripts.transformer_prediction_interface import TabPFNClassifier"
22
- ]
23
- },
24
- {
25
- "cell_type": "code",
26
- "execution_count": 56,
27
- "id": "0fe8a920",
28
- "metadata": {},
29
- "outputs": [
30
- {
31
- "name": "stdout",
32
- "output_type": "stream",
33
- "text": [
34
- "/Users/samuelmueller/TabPFN/TabPFN\r\n"
35
- ]
36
- }
37
- ],
38
- "source": [
39
- "!pwd"
40
- ]
41
- },
42
- {
43
- "cell_type": "code",
44
- "execution_count": 49,
45
- "id": "fd08a53d",
46
- "metadata": {},
47
- "outputs": [
48
- {
49
- "name": "stdout",
50
- "output_type": "stream",
51
- "text": [
52
- "Caching examples at: '/Users/samuelmueller/TabPFN/TabPFN/gradio_cached_examples/670/log.csv'\n"
53
- ]
54
- },
55
- {
56
- "name": "stderr",
57
- "output_type": "stream",
58
- "text": [
59
- "/Users/samuelmueller/opt/anaconda3/envs/TabPFN/lib/python3.7/site-packages/gradio/networking.py:59: ResourceWarning: unclosed <socket.socket fd=280, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('0.0.0.0', 0)>\n",
60
- " s = socket.socket() # create a socket object\n",
61
- "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n",
62
- "/Users/samuelmueller/opt/anaconda3/envs/TabPFN/lib/python3.7/site-packages/gradio/networking.py:59: ResourceWarning: unclosed <socket.socket fd=285, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('0.0.0.0', 0)>\n",
63
- " s = socket.socket() # create a socket object\n",
64
- "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
65
- ]
66
- },
67
- {
68
- "name": "stdout",
69
- "output_type": "stream",
70
- "text": [
71
- "Running on local URL: http://127.0.0.1:7898/\n",
72
- "\n",
73
- "To create a public link, set `share=True` in `launch()`.\n"
74
- ]
75
- },
76
- {
77
- "data": {
78
- "text/html": [
79
- "<div><iframe src=\"http://127.0.0.1:7898/\" width=\"900\" height=\"500\" allow=\"autoplay; camera; microphone;\" frameborder=\"0\" allowfullscreen></iframe></div>"
80
- ],
81
- "text/plain": [
82
- "<IPython.core.display.HTML object>"
83
- ]
84
- },
85
- "metadata": {},
86
- "output_type": "display_data"
87
- },
88
- {
89
- "data": {
90
- "text/plain": [
91
- "(<gradio.routes.App at 0x7fa954c66a90>, 'http://127.0.0.1:7898/', None)"
92
- ]
93
- },
94
- "execution_count": 49,
95
- "metadata": {},
96
- "output_type": "execute_result"
97
- }
98
- ],
99
- "source": [
100
- "import numpy as np\n",
101
- "import pandas as pd\n",
102
- "import torch\n",
103
- "import gradio as gr\n",
104
- "import openml\n",
105
- "\n",
106
- "\n",
107
- "def compute(table: np.array):\n",
108
- " vfunc = np.vectorize(lambda s: len(s))\n",
109
- " non_empty_row_mask = (vfunc(table).sum(1) != 0)\n",
110
- " print(table)\n",
111
- " table = table[non_empty_row_mask]\n",
112
- " empty_mask = table == ''\n",
113
- " empty_inds = np.where(empty_mask)\n",
114
- " assert np.all(empty_inds[1][0] == empty_inds[1])\n",
115
- " y_column = empty_inds[1][0]\n",
116
- " eval_lines = empty_inds[0]\n",
117
- "\n",
118
- " train_table = np.delete(table, eval_lines, axis=0)\n",
119
- " eval_table = table[eval_lines]\n",
120
- "\n",
121
- " try:\n",
122
- " x_train = torch.tensor(np.delete(train_table, y_column, axis=1).astype(np.float32))\n",
123
- " x_eval = torch.tensor(np.delete(eval_table, y_column, axis=1).astype(np.float32))\n",
124
- "\n",
125
- " y_train = train_table[:, y_column]\n",
126
- " except ValueError:\n",
127
- " return \"Please only add numbers (to the inputs) or leave fields empty.\", None\n",
128
- "\n",
129
- " classifier = TabPFNClassifier(base_path='..', device='cpu')\n",
130
- " classifier.fit(x_train, y_train)\n",
131
- " y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)\n",
132
- " print(x_train, y_train, x_eval, y_eval)\n",
133
- "\n",
134
- " # print(file, type(file))\n",
135
- " out_table = table.copy().astype(str)\n",
136
- " out_table[eval_lines, y_column] = [f\"{y_e} (p={p_e:.2f})\" for y_e, p_e in zip(y_eval, p_eval)]\n",
137
- " return None, out_table\n",
138
- "\n",
139
- "\n",
140
- "def upload_file(file):\n",
141
- " if file.name.endswith('.arff'):\n",
142
- " dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)\n",
143
- " X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(\n",
144
- " dataset_format=\"array\"\n",
145
- " )\n",
146
- " return X_\n",
147
- " elif file.name.endswith('.csv') or file.name.endswith('.data'):\n",
148
- " df = pd.read_csv(file.name)\n",
149
- " return df.to_numpy()\n",
150
- "\n",
151
- "\n",
152
- "example = \\\n",
153
- " [\n",
154
- " [1, 2, 1],\n",
155
- " [2, 1, 1],\n",
156
- " [1, 1, 1],\n",
157
- " [2, 2, 2],\n",
158
- " [3, 4, 2],\n",
159
- " [3, 2, 2],\n",
160
- " [2, 3, '']\n",
161
- " ]\n",
162
- "\n",
163
- "with gr.Blocks() as demo:\n",
164
- " gr.Markdown(\"\"\"This demo allows you to play with the **TabPFN**.\n",
165
- " You can either change the table manually (we have filled it with a toy benchmark, sum up to 3 has label 1 and over that label 2).\n",
166
- " The network predicts fields you leave empty. Only one column can have empty entries that are predicted.\n",
167
- " Please, provide everything but the label column as numeric values. It is ok to encode classes as integers.\n",
168
- " \"\"\")\n",
169
- " inp_table = gr.DataFrame(type='numpy', value=example, headers=[''] * 3)\n",
170
- " inp_file = gr.File(\n",
171
- " label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')\n",
172
- " btn = gr.Button(\"Predict Empty Table Cells\")\n",
173
- "\n",
174
- " inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)\n",
175
- "\n",
176
- " out_text = gr.Textbox()\n",
177
- " out_table = gr.DataFrame()\n",
178
- "\n",
179
- " btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table])\n",
180
- " examples = gr.Examples(examples=['./iris.csv'],\n",
181
- " inputs=[inp_file],\n",
182
- " outputs=[inp_table],\n",
183
- " fn=upload_file,\n",
184
- " cache_examples=True)\n",
185
- "\n",
186
- "demo.launch()"
187
- ]
188
- },
189
- {
190
- "cell_type": "code",
191
- "execution_count": 52,
192
- "id": "c4510232",
193
- "metadata": {},
194
- "outputs": [],
195
- "source": [
196
- "df = pd.DataFrame({'hi':[1,2,'j']})"
197
- ]
198
- },
199
- {
200
- "cell_type": "code",
201
- "execution_count": 59,
202
- "id": "2403f193",
203
- "metadata": {},
204
- "outputs": [
205
- {
206
- "data": {
207
- "text/plain": [
208
- "[[1], [2], ['j']]"
209
- ]
210
- },
211
- "execution_count": 59,
212
- "metadata": {},
213
- "output_type": "execute_result"
214
- },
215
- {
216
- "name": "stderr",
217
- "output_type": "stream",
218
- "text": [
219
- "sys:1: ResourceWarning: unclosed socket <zmq.Socket(zmq.PUSH) at 0x7fa9569da910>\n",
220
- "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
221
- ]
222
- }
223
- ],
224
- "source": [
225
- "df.to_numpy().tolist()"
226
- ]
227
- },
228
- {
229
- "cell_type": "code",
230
- "execution_count": null,
231
- "id": "adf1a91c",
232
- "metadata": {},
233
- "outputs": [],
234
- "source": [
235
- "k"
236
- ]
237
- }
238
- ],
239
- "metadata": {
240
- "kernelspec": {
241
- "display_name": "Python 3 (ipykernel)",
242
- "language": "python",
243
- "name": "python3"
244
- },
245
- "language_info": {
246
- "codemirror_mode": {
247
- "name": "ipython",
248
- "version": 3
249
- },
250
- "file_extension": ".py",
251
- "mimetype": "text/x-python",
252
- "name": "python",
253
- "nbconvert_exporter": "python",
254
- "pygments_lexer": "ipython3",
255
- "version": "3.7.13"
256
- }
257
- },
258
- "nbformat": 4,
259
- "nbformat_minor": 5
260
- }