Skip to content

Commit 2ed6465

Browse files
authored
Merge pull request #129 from egesko/master
Update notebook to work with new version of capsa
2 parents 3d7b511 + 03d53b1 commit 2ed6465

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

lab3/Part2_BiasAndUncertainty.ipynb

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
]
4141
},
4242
{
43+
"attachments": {},
4344
"cell_type": "markdown",
4445
"metadata": {
4546
"id": "IgYKebt871EK"
@@ -59,6 +60,7 @@
5960
]
6061
},
6162
{
63+
"attachments": {},
6264
"cell_type": "markdown",
6365
"metadata": {
6466
"id": "6JTRoM7E71EU"
@@ -95,6 +97,7 @@
9597
]
9698
},
9799
{
100+
"attachments": {},
98101
"cell_type": "markdown",
99102
"metadata": {
100103
"id": "6VKVqLb371EV"
@@ -130,6 +133,7 @@
130133
]
131134
},
132135
{
136+
"attachments": {},
133137
"cell_type": "markdown",
134138
"metadata": {
135139
"id": "cREmhMWJ71EX"
@@ -143,6 +147,7 @@
143147
]
144148
},
145149
{
150+
"attachments": {},
146151
"cell_type": "markdown",
147152
"metadata": {
148153
"id": "1NhotGiT71EY"
@@ -199,6 +204,7 @@
199204
]
200205
},
201206
{
207+
"attachments": {},
202208
"cell_type": "markdown",
203209
"metadata": {
204210
"id": "LgTG6buf71Ea"
@@ -256,6 +262,7 @@
256262
]
257263
},
258264
{
265+
"attachments": {},
259266
"cell_type": "markdown",
260267
"metadata": {
261268
"id": "SzFGcrhv71Ed"
@@ -303,11 +310,12 @@
303310
"# Get all faces from the testing dataset\n",
304311
"test_imgs = test_loader.get_all_faces()\n",
305312
"\n",
306-
"# Call the Capsa-wrapped classifier to generate outputs: predictions, uncertainty, and bias!\n",
307-
"predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)"
313+
"# Call the Capsa-wrapped classifier to generate outputs: a RiskTensor dictionary consisting of predictions, uncertainty, and bias!\n",
314+
"out = wrapped_model.predict(test_imgs, batch_size=512)\n"
308315
]
309316
},
310317
{
318+
"attachments": {},
311319
"cell_type": "markdown",
312320
"metadata": {
313321
"id": "629ng-_H6WOk"
@@ -329,10 +337,10 @@
329337
"### Analyzing representation bias scores ###\n",
330338
"\n",
331339
"# Sort according to lowest to highest representation scores\n",
332-
"indices = np.argsort(bias, axis=None) # sort the score values themselves\n",
340+
"indices = np.argsort(out.bias, axis=None) # sort the score values themselves\n",
333341
"sorted_images = test_imgs[indices] # sort images from lowest to highest representations\n",
334-
"sorted_biases = bias[indices] # order the representation bias scores\n",
335-
"sorted_preds = predictions[indices] # order the prediction values\n",
342+
"sorted_biases = out.bias.numpy()[indices] # order the representation bias scores\n",
343+
"sorted_preds = out.y_hat.numpy()[indices] # order the prediction values\n",
336344
"\n",
337345
"\n",
338346
"# Visualize the 20 images with the lowest and highest representation in the test dataset\n",
@@ -345,6 +353,7 @@
345353
]
346354
},
347355
{
356+
"attachments": {},
348357
"cell_type": "markdown",
349358
"metadata": {
350359
"id": "-JYmGMJF71Ef"
@@ -368,6 +377,7 @@
368377
]
369378
},
370379
{
380+
"attachments": {},
371381
"cell_type": "markdown",
372382
"metadata": {
373383
"id": "i8ERzg2-71Ef"
@@ -389,6 +399,7 @@
389399
]
390400
},
391401
{
402+
"attachments": {},
392403
"cell_type": "markdown",
393404
"metadata": {
394405
"id": "cRNV-3SU71Eg"
@@ -404,6 +415,7 @@
404415
]
405416
},
406417
{
418+
"attachments": {},
407419
"cell_type": "markdown",
408420
"metadata": {
409421
"id": "ww5lx7ue71Eg"
@@ -420,6 +432,7 @@
420432
]
421433
},
422434
{
435+
"attachments": {},
423436
"cell_type": "markdown",
424437
"metadata": {
425438
"id": "NEfeWo2p7wKm"
@@ -442,10 +455,10 @@
442455
"### Analyzing epistemic uncertainty estimates ###\n",
443456
"\n",
444457
"# Sort according to epistemic uncertainty estimates\n",
445-
"epistemic_indices = np.argsort(uncertainty, axis=None) # sort the uncertainty values\n",
458+
"epistemic_indices = np.argsort(out.epistemic, axis=None) # sort the uncertainty values\n",
446459
"epistemic_images = test_imgs[epistemic_indices] # sort images from lowest to highest uncertainty\n",
447-
"sorted_epistemic = uncertainty[epistemic_indices] # order the uncertainty scores\n",
448-
"sorted_epistemic_preds = predictions[epistemic_indices] # order the prediction values\n",
460+
"sorted_epistemic = out.epistemic.numpy()[epistemic_indices] # order the uncertainty scores\n",
461+
"sorted_epistemic_preds = out.y_hat.numpy()[epistemic_indices] # order the prediction values\n",
449462
"\n",
450463
"\n",
451464
"# Visualize the 20 images with the LEAST and MOST epistemic uncertainty\n",
@@ -458,6 +471,7 @@
458471
]
459472
},
460473
{
474+
"attachments": {},
461475
"cell_type": "markdown",
462476
"metadata": {
463477
"id": "L0dA8EyX71Eh"
@@ -481,6 +495,7 @@
481495
]
482496
},
483497
{
498+
"attachments": {},
484499
"cell_type": "markdown",
485500
"metadata": {
486501
"id": "iyn0IE6x71Eh"
@@ -496,6 +511,7 @@
496511
]
497512
},
498513
{
514+
"attachments": {},
499515
"cell_type": "markdown",
500516
"metadata": {
501517
"id": "XbwRbesM71Eh"
@@ -561,11 +577,11 @@
561577
"\n",
562578
" # After the epoch is done, recompute data sampling proabilities \n",
563579
" # according to the inverse of the bias\n",
564-
" pred, unc, bias = wrapper(train_imgs)\n",
580+
" out = wrapper(train_imgs)\n",
565581
"\n",
566582
" # Increase the probability of sampling under-represented datapoints by setting \n",
567583
" # the probability to the **inverse** of the biases\n",
568-
" inverse_bias = 1.0 / (bias.numpy() + 1e-7)\n",
584+
" inverse_bias = 1.0 / (np.mean(out.bias.numpy(),axis=-1) + 1e-7)\n",
569585
"\n",
570586
" # Normalize the inverse biases in order to convert them to probabilities\n",
571587
" p_faces = inverse_bias / np.sum(inverse_bias)\n",
@@ -575,6 +591,7 @@
575591
]
576592
},
577593
{
594+
"attachments": {},
578595
"cell_type": "markdown",
579596
"metadata": {
580597
"id": "SwXrAeBo71Ej"
@@ -598,13 +615,13 @@
598615
"### Evaluation of debiased model ###\n",
599616
"\n",
600617
"# Get classification predictions, uncertainties, and representation bias scores\n",
601-
"pred, unc, bias = wrapper.predict(test_imgs)\n",
618+
"out = wrapper.predict(test_imgs)\n",
602619
"\n",
603620
"# Sort according to lowest to highest representation scores\n",
604-
"indices = np.argsort(bias, axis=None)\n",
621+
"indices = np.argsort(out.bias, axis=None)\n",
605622
"bias_images = test_imgs[indices] # sort the images\n",
606-
"sorted_bias = bias[indices] # sort the representation bias scores\n",
607-
"sorted_bias_preds = pred[indices] # sort the predictions\n",
623+
"sorted_bias = out.bias.numpy()[indices] # sort the representation bias scores\n",
624+
"sorted_bias_preds = out.y_hat.numpy()[indices] # sort the predictions\n",
608625
"\n",
609626
"# Plot the representation bias vs. the accuracy\n",
610627
"plt.xlabel(\"Density (Representation)\")\n",
@@ -613,6 +630,7 @@
613630
]
614631
},
615632
{
633+
"attachments": {},
616634
"cell_type": "markdown",
617635
"metadata": {
618636
"id": "d1cEEnII71Ej"
@@ -681,7 +699,7 @@
681699
"name": "python",
682700
"nbconvert_exporter": "python",
683701
"pygments_lexer": "ipython3",
684-
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
702+
"version": "3.9.16"
685703
},
686704
"vscode": {
687705
"interpreter": {

0 commit comments

Comments
 (0)