Skip to content

Commit

Permalink
Merge pull request #1 from zanarmstrong/zanarmstrong-graphs
Browse files Browse the repository at this point in the history
Add graphs of percents for the prediction array
  • Loading branch information
zanarmstrong committed Jul 31, 2018
2 parents c2f05af + abe47ae commit 51cf530
Showing 1 changed file with 116 additions and 55 deletions.
171 changes: 116 additions & 55 deletions samples/core/tutorials/keras/basic_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,101 @@
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ygh2yYC972ne",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"We can graph this to look at the full set of 10 channels"
]
},
{
"metadata": {
"id": "DvYmmrpIy6Y1",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def plot_image(predictions_array, true_label, img):\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.grid('off')\n",
" plt.imshow(img, cmap=plt.cm.binary)\n",
"\n",
" predicted_label = np.argmax(predictions_array)\n",
" if predicted_label == true_label:\n",
" color = 'blue'\n",
" else:\n",
" color = 'red'\n",
" plt.xlabel(\"{} {} ({})\".format(class_names[predicted_label],\n",
" str(round(predictions_array[np.argmax(predictions_array)] * 100)) + \"%\",\n",
" class_names[true_label]),\n",
" color=color)\n",
"\n",
"def plot_value_array(predictions_array, true_label):\n",
" plt.grid('off')\n",
" thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n",
" plt.ylim([0, 1]) \n",
" predicted_label = np.argmax(predictions_array)\n",
" \n",
" thisplot[predicted_label].set_color('red')\n",
" thisplot[true_label].set_color('blue')\n",
"\n",
" \n",
"# define plot to look at the image, predicted label, actual label, predicted percent for top prediction, and graph of all prediction values\n",
"def plot_fig_and_predarray(iter):\n",
" plt.figure(figsize=(6,3))\n",
" \n",
" # plot the image first\n",
" plt.subplot(1,2,1)\n",
" plot_image(predictions[iter], test_labels[iter], test_images[int(iter)])\n",
" \n",
" # then the graph of 10 values\n",
" plt.subplot(1,2,2)\n",
" plot_value_array(predictions[iter], test_labels[iter])"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "d4Ov9OFDMmOD",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Let's look at the 0th image, predictions, and prediction array. "
]
},
{
"metadata": {
"id": "HV5jw-5HwSmO",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"plot_fig_and_predarray(0)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Ko-uzOufSCSe",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"plot_fig_and_predarray(12)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
Expand All @@ -853,41 +948,29 @@
},
"cell_type": "markdown",
"source": [
"Let's plot several images with their predictions. Correct prediction labels are green and incorrect prediction labels are red."
"Let's plot several images with their predictions. Correct prediction labels are blue and incorrect prediction labels are red. The number gives the percent (out of 100) for the predicted label. Note that it can be wrong even when very confident. "
]
},
{
"metadata": {
"id": "YGBDAiziCaXR",
"id": "hQlnbqaw2Qu_",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
"colab": {}
},
"cell_type": "code",
"source": [
"# Plot the first 25 test images, their predicted label, and the true label\n",
"# Color correct predictions in green, incorrect predictions in red\n",
"plt.figure(figsize=(10,10))\n",
"for i in range(25):\n",
" plt.subplot(5,5,i+1)\n",
"# Plot the first X test images, their predicted label, and the true label\n",
"# Color correct predictions in blue, incorrect predictions in red\n",
"num_images = int(50)\n",
"plt.figure(figsize=(24,20))\n",
"for i in range(num_images):\n",
" plt.subplot(10,num_images / 5,2*i+1)\n",
" plot_image(predictions[i], test_labels[i], test_images[i])\n",
" \n",
" plt.subplot(10,num_images / 5,2*i+2)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.grid('off')\n",
" plt.imshow(test_images[i], cmap=plt.cm.binary)\n",
" predicted_label = np.argmax(predictions[i])\n",
" true_label = test_labels[i]\n",
" if predicted_label == true_label:\n",
" color = 'green'\n",
" else:\n",
" color = 'red'\n",
" plt.xlabel(\"{} ({})\".format(class_names[predicted_label], \n",
" class_names[true_label]),\n",
" color=color)\n",
" "
" plot_value_array(predictions[i], test_labels[i])"
],
"execution_count": 0,
"outputs": []
Expand All @@ -906,12 +989,7 @@
"metadata": {
"id": "yRJ7JU7JCaXT",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
"colab": {}
},
"cell_type": "code",
"source": [
Expand All @@ -937,12 +1015,7 @@
"metadata": {
"id": "lDFh5yF_CaXW",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
"colab": {}
},
"cell_type": "code",
"source": [
Expand All @@ -968,18 +1041,13 @@
"metadata": {
"id": "o_rzNSdrCaXY",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
"colab": {}
},
"cell_type": "code",
"source": [
"predictions = model.predict(img)\n",
"predictions_single = model.predict(img)\n",
"\n",
"print(predictions)"
"print(predictions_single)"
],
"execution_count": 0,
"outputs": []
Expand All @@ -998,18 +1066,11 @@
"metadata": {
"id": "2tRmdq_8CaXb",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
"colab": {}
},
"cell_type": "code",
"source": [
"prediction = predictions[0]\n",
"\n",
"np.argmax(prediction)"
"np.argmax(predictions_single[0])"
],
"execution_count": 0,
"outputs": []
Expand All @@ -1025,4 +1086,4 @@
]
}
]
}
}

0 comments on commit 51cf530

Please sign in to comment.