You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1209 lines
676 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# cosa c'è in questo notebook \n",
"\n",
"Training di 4 reti neurali in grado di riconoscere dei set di gesti associati alla lingua americana dei segni (AmericanSignLanguage), tramite l'analisi di immagini statiche e metodi di image recognition.\n",
"\n",
"Ho considerato la ASL anzichè la LIS (Lingua Italiana dei Segni) dato che ho trovato molti piu esempi da scopiazzare per questa fase iniziale ma sopratutto perchè la quantita di dati disponibile è (forse?) ovviamente maggiore.\n",
"\n",
"non credere all'hype, più che intelligenti le AI sono ingorde: + DATI + RISULTATO, qualsiasi la struttura del dato\n",
"detto questo se +dato comporta +risultato, allora perchè non cambiare il dato ;-)\n",
"\n",
"Ciò nonostante il meccanismo è del tutto analogo, tanto dal punto di vista piu astratto di relazione di insiemi quanto da quello piu pratico per cui tra l'italiano e l'inglese-americano _comuni_ c'è un buon grado di traducibilità; ovvero dal punto di vista semiotico il residuo comunicativo non è eccessivamente significativo in quanto appartanenti alla stessa cultura occidentale ed aventi una simile origine a livello geografico e linguistico.\n",
"\n",
"Il linguaggio utilizzato è __python3__ ed il framework AI/ML di paciocco è __tensorflow__ (di google) con __keras__ come API di alto livello.\n",
"\n",
"I 4 modelli sotto allenati e pacioccati sono ordinati per crescente profondità di architettura (sono tutti modelli ConvolutionalNeuralNetwork) e quantita/complessità di dato, quindi si parte da scemo-piccolo e si arriva a saccente-grande.\n",
"\n",
"PS per sapere cosa sono CNN framework LIS, ASL ecc ecc dai un occhiata al wiki:\n",
"https://git.tropici.net/agropunx/geografia/wiki/_pages\n",
"\n",
"PPS2\n",
"IO HO GIA SCARICATO PROCESSATO I DATI QUINDI IN CASO PROVA A CONTATTARMI DIRETTAMENTE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__General Imports__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# general\n",
"import os, random,pickle\n",
"from time import time\n",
"from tqdm import tqdm\n",
"\n",
"#numeric and df\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"#plot\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# image and computer vision\n",
"import cv2\n",
"from skimage.transform import resize\n",
"from IPython.display import Image\n",
"\n",
"# data shuffling and metrics\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"\n",
"# ML/AI stuff\n",
"import keras\n",
"from keras.layers import MaxPooling2D,MaxPool2D, Dense, Flatten, Dropout\n",
"from keras.callbacks import TensorBoard, ReduceLROnPlateau\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.applications.inception_v3 import InceptionV3\n",
"\n",
"from tensorflow.keras.layers import Conv2D\n",
"from tensorflow.keras.optimizers import RMSprop, Adam, SGD\n",
"from tensorflow.keras.utils import to_categorical"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# modello ASL_static_A: 3-layer-CNN con dataset MNIST-AmericanSignLanguage"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"asl_static_A\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n",
" \n",
" conv2d_2 (Conv2D) (None, 26, 26, 128) 1280 \n",
" \n",
" max_pooling2d_2 (MaxPooling (None, 13, 13, 128) 0 \n",
" 2D) \n",
" \n",
" dropout_2 (Dropout) (None, 13, 13, 128) 0 \n",
" \n",
" flatten (Flatten) (None, 21632) 0 \n",
" \n",
" dense (Dense) (None, 512) 11076096 \n",
" \n",
" dropout_3 (Dropout) (None, 512) 0 \n",
" \n",
" dense_1 (Dense) (None, 25) 12825 \n",
" \n",
"=================================================================\n",
"Total params: 11,090,201\n",
"Trainable params: 11,090,201\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"#Defining the Convolutional Neural Network\n",
"\n",
"img_input = keras.Input(shape=(28, 28, 1))\n",
"x = Conv2D(32, (3, 3), input_shape = (28,28,1), activation='relu')(img_input)\n",
"x = MaxPooling2D(pool_size = (2, 2))(x)\n",
"x = Dropout(0.25)(x)\n",
"\n",
"x = Conv2D(64, (3, 3), input_shape = (28,28,1), activation='relu')(img_input)\n",
"x = MaxPooling2D(pool_size = (2, 2))(x)\n",
"x = Dropout(0.25)(x)\n",
"\n",
"x = Conv2D(128, (3, 3), input_shape = (28,28,1), activation='relu')(img_input)\n",
"x = MaxPooling2D(pool_size = (2, 2))(x)\n",
"x = Dropout(0.25)(x)\n",
"\n",
"x = Flatten()(x)\n",
"x = units = Dense(512, activation = 'relu')(x)\n",
"x = Dropout(0.25)(x)\n",
"out = Dense(units = 25, activation = 'softmax')(x)\n",
"\n",
"asl_static_A = keras.Model(inputs=img_input,outputs=out, name='asl_static_A')\n",
"\n",
"asl_static_A.summary()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def train(model, X_in, y_in, optimizer = Adam(learning_rate=.001), loss='categorical_crossentropy', metrics=['accuracy'],\n",
" batch_size = 512, epochs = 50, verbose = 1):\n",
"\n",
" model.compile(optimizer=optimizer, loss=loss, metrics=metrics)\n",
"\n",
" start = time()\n",
" history = model.fit(X_in['train'], y_in['train'], batch_size=batch_size, epochs=epochs, validation_split=0.1, shuffle = True, verbose=1)\n",
" train_time = time() - start\n",
"\n",
" #model.summary()\n",
" plt.figure(figsize=(12, 12))\n",
" plt.subplot(3, 2, 1)\n",
" plt.plot(history.history['accuracy'], label = 'train_accuracy')\n",
" plt.plot(history.history['val_accuracy'], label = 'val_accuracy')\n",
" plt.xlabel('epoch')\n",
" plt.ylabel('accuracy')\n",
" plt.legend()\n",
" plt.subplot(3, 2, 2)\n",
" plt.plot(history.history['loss'], label = 'train_loss')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Dataset source\n",
"\n",
"MNIST dataset for the American Sign Language (ASL)\n",
"\n",
"warnnn download tocca loggarsi\n",
"\n",
"url https://www.kaggle.com/datamunge/sign-language-mnist\n",
"\n",
"messo in ./data/mnist"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAApkAAAHXCAYAAAD3FUHTAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsMAAA7DAcdvqGQAAP+lSURBVHhe7P33kxzXlucJru0fsD+srdnY7K7t2Kxtb3evzUzvtO6uqq6qV6/6ST5qLUACJEiCAgQIAgQVCJAACK01MpFaax0ZWmuZIVLrhEhomTq/+z3XwzMDeHz1yNnKGmszmOXX3MPDwzPCr597Pudc9X+Yn5/HYz3WYz3WYz3WYz3WYz3WP6YeQ+ZjPdZjPdZjPdZjPdZj/aPrMWQ+1mM91mM91mM91mM91j+6HkPmYz3WYz3WYz3WYz3WY/2j6zFkPtZjPdZjPdZjPdZjPdY/uh5D5mM91mM91mM91mM91mP9o+vPQubc3Bymp6dx79493LlzB7dv336sFZDcW7nHcq/lnv9YWfxcPS67fxo9Lrv/dvW47P7b1eOy+29Xj8vuv1393LL7s5ApF7J1WpTsnfasHHAYnbAb3UoOk0fJbuFrymFzw2nnMYdLye3yK3ndAfg8wSUF3CElv0vbD7nDCHsiiHjiSmF37I8U9SYQDcQRCyaU1H6kK6s4ouEYQqGIUiQSQzyaQDwUQzKSQCqSUkqEkzzG88NpxCMZREMptZ+IptEVy6Arzv0Yj2WvF+W1I6G4uk5XOI4Mt6lABKlgCOlwBOkYX8f5XjzJz3UhyeukImme14tMuBspqov/VyT/e/n7PqyJiQlViHLPf6wsfq7kOhXlRaisKEZVZQmqq0qXVFNdhtqactTVVig11FcvqamxVqmxoWZJ8rq5qQ4tzfVqq5+j7+vvtVIdzQ1oa+Tr+hq08L02vtfWUo+21ga0ctvaXIP2pjKeV8bPlaOhuQKNbTWoa6tEo4HfobUMLW3laG8sgq2uELb6ApgaCmFoKuFny9DQVIp6fq6+qQp1DdWoratUv0VUL6quUKqrrkRNTVVW3Oex6qoy3o+SJZWXFfEeFautqKy0EKUlBSgpvqBUXJSvVFSYh8KC87iQf1YpL/8Mzuedxrnzp5RWouzcbvdPksfjUVufz6fk9XrVMV3yWo6r8zwueHz8DOX2uuDyONXWQ3m9bnjlfbcTLpcDLtnytdNNO3bSph1ezY49fvj8AXj9Pji8rAecNridHtox/4fDCY/dRlnhspq4tcBrM8NtNsNlscJt47VYP7hdAThYR7i47+Exj5X7Vqk7fLA5/LC6fbC43LA6bDxm4XlmOK0WWC0mGC1mdPJ6ZosNFspkssBoNHGryWg0LqmzsxMGg2FJHR0dPGZaksFgXJGyK7lwHqUFeUrlRRdQWVKIKj5b1XzGaiuKUF9VwuezhM8pRXus53PZUE1b4DPayOe1sbqaot3V1KG5lrZWI/u0Rb7XVFeNZtqWvm2mDTTTbluyaqW9iv21NdbzNe1TzlN2XI36uirU11ehoZH/o7YILSWn4a06i+6a8xioOY2uwv3w5O2DpfwsWmv4PRvKUc3za/jZirpyVDXRthpZR9RW8nvJdy1TaqotR1N1qSb+nobKMtRXlqKOv6uW2xrWQdVlxagt4++makuLuOVrbquLC1DD+1NTIvu8T7S/CtpeeVEeygrPoaTgLEoKz6K44Ax1DoW0vQLanmxXouzsfHYdfO6cTrEDPpt8DkVeN22JCvj8CPL5l63f60MgEIDfz31K9n0B2gcl9iHSXy8f43Uof1De8/IzPoTD9IHc+nweXjukJNcKBnk8xP8X5r4uHhOFAvSXQfq6AF/z+4SCfoRD8n5YKRCSzwbV/9GuwetSgQB9Lf1jOEifSkVCUeXnwmF+Ts6hf/Pw2m5e08NrB+RalJ+fUwou/8Yfk9/vfUjym35MK1F2ba2NaG9rop03o8PQAkNnKzqNbWor6jC2oL2T73HbaW6D0dwBk7ldE88zmdpZr3TyuCaThfUIzzHyXJOlHQZu2ymDtQPtfN3Oz3VaWa9wa7CwbqHMtg5YKYtVrsXPmVupNu7zf/Fzpuz5nbxuJ/+fkbLxf1oNbeQqfpbvmfmeubMNFtZZFpOR/5f/z8zvbGri92yF0cDPK7VR/I2UoYO/l+po5+/j72+nWtsa0dbOe9LRpLai1jb64axaWllHUM0t9Nfim1t4jD5c+femWtYT9MdUY5NWfzSwLhBe+Dll92chU4hVAabRugSZAphKWbh0mr1KCjDpLHIh00nnJI7JQ8D8U5Cp61HI1CVgqW9jvi7EgzmgKLCZA26asUSVBDIFFHU4TPK8rohIoJKfJ2Dm6lHIVIBKxaJJXncZVjO8ZjoYRToU5n4UqWgUyVhcAaZ8TiAzGenmOT08J6P2fwpkxgirUnhyz3+sLH6u5DoCmT8GmLmQKY5HB0kdGH9M8vDp+ofebxXAbKhFK4/JQ7sEl7Llg93cQshsqSQw0jFx29hKx9fKB5lqJWy2EyAN9SXwtlUh1FAKX00B7FV5MFdf4PFitMnneE5dfTlqBJL54NfTAGSr9sWR0gnWUQKX1YTNajrvGjpxgcwqOj1dApi6dNAUyHwUNHXI1HWBzi7/wtklrUTZPQqTf0oCkfpW16OAqUOmR0CSDk6cm2x14PSJ0xOn4OP5XvmcdtxDZ+D2yPV9FJ2Oh3bsZdDIayn49DgIqU7aM52viN8j4LQjQLgMEgyjDiuidNohgqHPYiFM2gmLBFiXB2Zu7Xbu2+nIHfwthFinwwc7YVVBJq9lJcCK03cSNJ021kFWKytvkQ02m0PJymtaCLCaLKygzQ8BZy5kih6FzJUou7LCfAWXFQQoDTAJUISqGj5ndZXFhEmCmILKh9XI57WplhCnoJK2WEtbq6V9ETRFLXU1hEjaFm2slU5gSY0M6AiYbXQIHc0M7CS4I2S2NdJpNNAelV3zmrQTUTPPa2+iDZafx8WQBXeDRlwzVSFTfAj2I9+i/fhuGCvz+b8Ii40M4qgqAmdVI4M47jeJ6hmc1hIyCZ9qn7Cp/S6RBs0S9NVxv471Ti3robpyQnVWApYCmSLZF8CszsJ4pQJNHTbzUUrgLKHdFTPAK8o7g8LzhMzzZ1ak7DTA5HOahUy3UwNNAUwfbUGBZRYwlbKAKTaWC5lL0EUoy5XYnoCfAKOCP8JhJEK441ZsUANMea3B5EOASUUikkCJKF8ngCgKhUIERoFHbgmLIfo/HSqXP8v3IxowKsjkOfpWfKVcU/tsHF7Cq89PQJXvwXN00FSw+RMhU6BZ3xfpcKm/XomyayNgtrUTMnMAUyBOJPtLwCmgRpjTIJNARwC0yNbE1yYDlYVMM+sLk4AiIdREWKQEHEUWAqOZwGix81xbp4LNTpv22iqSfR6Ta5sJpHK+heep8wU2ebyjs4V1Vgfs/D92fi+HoZks1Qob/6fN2E7u4vkCmTyn3cLfY27Wvgdh1GQ0wNhJUM1CppFQKjIIcApQ8x7ocJkLmY+CpgJRXYR08dniz3XIFMBsYp3SxDpGJKD5c8ruz0KmpEYVYFIOo2QwHXCaCI+iLFyKXBafAkynja8lQ0GH4aLjEAlc5gKm38toiAp6dDGCohRgevnAU1FfTFMWLGO+FKJ+gh8BsytMEMxKwSbhLU4Q1GFNMo8KDGVfweGyEmEt0ygwmYh28zxu4z1Zyb6IkEkJbIrisbS6boLGl4p2oZvgmaaBSxazOxrjMQJsFjLjAqcClOr79aitQKYArBzPBWL5ziL1PSm73Y7BwUF1z3+sLH6u5Dp6FlMHSx0udcD8uZApD6AuHSqX4FIBJR9Ueb9ZO0c9tPIgM1pS23Y+wK10kATK5pYqNBMqm9prCZh0lFR7YxVMDRXw8aEecZkw1FGHdG0xvAUnYco/BmPZWXTWFdFxlvA7V9AQ+NDzwa8XCWjmwGZdnUBmBQGTcEnVcv9R0JRsZi5oiiSbqWc0c2GziM5OALPgwjkUFD4MmitRdjoo/jnpUKlnMbWtJg0wWfln3/MQID0EScmoyFbb1+TjMX/2sx45l/tKfO32+OFyi2Q/+76AKAHT7XUQTJ2ETAdCHjvCDjPidhN6nFb0OSxIsoIPM7oOs/L
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Hand gestures of ASL\n",
"Image('data/mnist/amer_sign2.png')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>pixel1</th>\n",
" <th>pixel2</th>\n",
" <th>pixel3</th>\n",
" <th>pixel4</th>\n",
" <th>pixel5</th>\n",
" <th>pixel6</th>\n",
" <th>pixel7</th>\n",
" <th>pixel8</th>\n",
" <th>pixel9</th>\n",
" <th>...</th>\n",
" <th>pixel775</th>\n",
" <th>pixel776</th>\n",
" <th>pixel777</th>\n",
" <th>pixel778</th>\n",
" <th>pixel779</th>\n",
" <th>pixel780</th>\n",
" <th>pixel781</th>\n",
" <th>pixel782</th>\n",
" <th>pixel783</th>\n",
" <th>pixel784</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3</td>\n",
" <td>107</td>\n",
" <td>118</td>\n",
" <td>127</td>\n",
" <td>134</td>\n",
" <td>139</td>\n",
" <td>143</td>\n",
" <td>146</td>\n",
" <td>150</td>\n",
" <td>153</td>\n",
" <td>...</td>\n",
" <td>207</td>\n",
" <td>207</td>\n",
" <td>207</td>\n",
" <td>207</td>\n",
" <td>206</td>\n",
" <td>206</td>\n",
" <td>206</td>\n",
" <td>204</td>\n",
" <td>203</td>\n",
" <td>202</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>6</td>\n",
" <td>155</td>\n",
" <td>157</td>\n",
" <td>156</td>\n",
" <td>156</td>\n",
" <td>156</td>\n",
" <td>157</td>\n",
" <td>156</td>\n",
" <td>158</td>\n",
" <td>158</td>\n",
" <td>...</td>\n",
" <td>69</td>\n",
" <td>149</td>\n",
" <td>128</td>\n",
" <td>87</td>\n",
" <td>94</td>\n",
" <td>163</td>\n",
" <td>175</td>\n",
" <td>103</td>\n",
" <td>135</td>\n",
" <td>149</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>187</td>\n",
" <td>188</td>\n",
" <td>188</td>\n",
" <td>187</td>\n",
" <td>187</td>\n",
" <td>186</td>\n",
" <td>187</td>\n",
" <td>188</td>\n",
" <td>187</td>\n",
" <td>...</td>\n",
" <td>202</td>\n",
" <td>201</td>\n",
" <td>200</td>\n",
" <td>199</td>\n",
" <td>198</td>\n",
" <td>199</td>\n",
" <td>198</td>\n",
" <td>195</td>\n",
" <td>194</td>\n",
" <td>195</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2</td>\n",
" <td>211</td>\n",
" <td>211</td>\n",
" <td>212</td>\n",
" <td>212</td>\n",
" <td>211</td>\n",
" <td>210</td>\n",
" <td>211</td>\n",
" <td>210</td>\n",
" <td>210</td>\n",
" <td>...</td>\n",
" <td>235</td>\n",
" <td>234</td>\n",
" <td>233</td>\n",
" <td>231</td>\n",
" <td>230</td>\n",
" <td>226</td>\n",
" <td>225</td>\n",
" <td>222</td>\n",
" <td>229</td>\n",
" <td>163</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13</td>\n",
" <td>164</td>\n",
" <td>167</td>\n",
" <td>170</td>\n",
" <td>172</td>\n",
" <td>176</td>\n",
" <td>179</td>\n",
" <td>180</td>\n",
" <td>184</td>\n",
" <td>185</td>\n",
" <td>...</td>\n",
" <td>92</td>\n",
" <td>105</td>\n",
" <td>105</td>\n",
" <td>108</td>\n",
" <td>133</td>\n",
" <td>163</td>\n",
" <td>157</td>\n",
" <td>163</td>\n",
" <td>164</td>\n",
" <td>179</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 785 columns</p>\n",
"</div>"
],
"text/plain": [
" label pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 \\\n",
"0 3 107 118 127 134 139 143 146 150 \n",
"1 6 155 157 156 156 156 157 156 158 \n",
"2 2 187 188 188 187 187 186 187 188 \n",
"3 2 211 211 212 212 211 210 211 210 \n",
"4 13 164 167 170 172 176 179 180 184 \n",
"\n",
" pixel9 ... pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 \\\n",
"0 153 ... 207 207 207 207 206 206 \n",
"1 158 ... 69 149 128 87 94 163 \n",
"2 187 ... 202 201 200 199 198 199 \n",
"3 210 ... 235 234 233 231 230 226 \n",
"4 185 ... 92 105 105 108 133 163 \n",
"\n",
" pixel781 pixel782 pixel783 pixel784 \n",
"0 206 204 203 202 \n",
"1 175 103 135 149 \n",
"2 198 195 194 195 \n",
"3 225 222 229 163 \n",
"4 157 163 164 179 \n",
"\n",
"[5 rows x 785 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df = pd.read_csv('data/mnist/sign_mnist_train.csv')\n",
"test_df = pd.read_csv('data/mnist/sign_mnist_test.csv')\n",
"\n",
"train_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__prepara il dataset da allenare__"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"#labels\n",
"class_names = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y' ]\n",
"X = {\n",
" 'train': train_df.values[:, 1:] / 255,\n",
" 'test': test_df.values[:, 1:] / 255\n",
"}\n",
"\n",
"y = {\n",
" 'train': train_df.values[:, 0],\n",
" 'test' : test_df.values[:,0]\n",
"}\n",
"\n",
"\n",
"X['train'] = X['train'].reshape(X['train'].shape[0], *(28, 28, 1))\n",
"X['test'] = X['test'].reshape(X['test'].shape[0], *(28, 28, 1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__nota__: 28x28x1 ----> immagini 28x28 con 1 canale cromatico (b/n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Train and Evaluate__"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"49/49 [==============================] - 24s 480ms/step - loss: 2.3351 - accuracy: 0.3831 - val_loss: 1.1610 - val_accuracy: 0.6934\n",
"Epoch 2/30\n",
" 7/49 [===>..........................] - ETA: 19s - loss: 1.2259 - accuracy: 0.6562"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:root:Internal Python error in the inspect module.\n",
"Below is the traceback from this internal error.\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\", line 3331, in run_code\n",
" exec(code_obj, self.user_global_ns, self.user_ns)\n",
" File \"<ipython-input-9-a7a727486f0a>\", line 11, in <module>\n",
" train(asl_static_A, X,y, **params)\n",
" File \"<ipython-input-5-dfa5e24106e0>\", line 7, in train\n",
" history = model.fit(X_in['train'], y_in['train'], batch_size=batch_size, epochs=epochs, validation_split=0.1, shuffle = True, verbose=1)\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/keras/utils/traceback_utils.py\", line 64, in error_handler\n",
" return fn(*args, **kwargs)\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/keras/engine/training.py\", line 1216, in fit\n",
" tmp_logs = self.train_function(iterator)\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n",
" return fn(*args, **kwargs)\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py\", line 910, in __call__\n",
" result = self._call(*args, **kwds)\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py\", line 942, in _call\n",
" return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py\", line 3130, in __call__\n",
" return graph_function._call_flat(\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py\", line 1959, in _call_flat\n",
" return self._build_call_outputs(self._inference_function.call(\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py\", line 598, in call\n",
" outputs = execute.execute(\n",
" File \"/home/agropunx/.local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py\", line 58, in quick_execute\n",
" tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n",
"KeyboardInterrupt\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\", line 2044, in showtraceback\n",
" stb = value._render_traceback_()\n",
"AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3/dist-packages/IPython/core/ultratb.py\", line 1148, in get_records\n",
" return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
" File \"/usr/lib/python3/dist-packages/IPython/core/ultratb.py\", line 316, in wrapped\n",
" return f(*args, **kwargs)\n",
" File \"/usr/lib/python3/dist-packages/IPython/core/ultratb.py\", line 350, in _fixed_getinnerframes\n",
" records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
" File \"/usr/lib/python3.8/inspect.py\", line 1515, in getinnerframes\n",
" frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
" File \"/usr/lib/python3.8/inspect.py\", line 1473, in getframeinfo\n",
" filename = getsourcefile(frame) or getfile(frame)\n",
" File \"/usr/lib/python3.8/inspect.py\", line 708, in getsourcefile\n",
" if getattr(getmodule(object, filename), '__loader__', None) is not None:\n",
" File \"/usr/lib/python3.8/inspect.py\", line 754, in getmodule\n",
" os.path.realpath(f)] = module.__name__\n",
" File \"/usr/lib/python3.8/posixpath.py\", line 391, in realpath\n",
" path, ok = _joinrealpath(filename[:0], filename, {})\n",
" File \"/usr/lib/python3.8/posixpath.py\", line 425, in _joinrealpath\n",
" if not islink(newpath):\n",
" File \"/usr/lib/python3.8/posixpath.py\", line 167, in islink\n",
" st = os.lstat(path)\n",
"KeyboardInterrupt\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m"
]
}
],
"source": [
"params = {\n",
" 'optimizer': Adam(learning_rate=.001),\n",
" 'loss' : 'sparse_categorical_crossentropy',\n",
" 'metrics' : ['accuracy'],\n",
" 'batch_size' : 512,\n",
" 'epochs' : 30,\n",
" 'verbose': 1, \n",
"}\n",
"\n",
"train(asl_static_A, X,y, **params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"asl_static_A.save(\"models/asl_static_A\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# modello ASL_static_B: 3-layer-CNN con dataset kaggle-AmericanSignLanguage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"il modello è analogo a quello di prima (ASL_static_B) a parte che in questo caso aggiungo altri due canali in ingresso al modello (fondamentalmente prima b/n adesso RGB).\n",
"\n",
"Il dataset invece è molto piu grande rispetto prima, quindi ci aspettiamo buoni risultati con poche epoche (ogni esempio viene visto dal modello tot epoche per training) prima dell'overfitting (momento in cui ri allenare sullo stesso dato diventa ridondante).\n",
"\n",
"riduco quindi a 4 epoche, prima erano 30 per comun regola a casaccio\n",
"\n",
"al dataset vengono aggiunti anche i segni relativi a _spazio_ , _cancella_ e _niente_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Dataset source:\n",
"\n",
"American Sign Language (ASL)\n",
"\n",
"url https://www.kaggle.com/grassknoted/asl-alphabet\n",
" \n",
"warnnnnn download ---> size~1GB\n",
"\n",
"messo in ./data/asl\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"train_dir = 'data/asl/asl_alphabet_train/asl_alphabet_train'\n",
"test_dir = 'data/asl/asl_alphabet_test/asl_alphabet_test'\n",
"classes = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', \n",
" 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', \n",
" 'W', 'X', 'Y', 'Z', 'nothing', 'space', 'del']\n",
"plt.figure(figsize=(11, 11))\n",
"for i in range (0,29):\n",
" plt.subplot(7,7,i+1)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" path = train_dir + \"/{0}/{0}1.jpg\".format(classes[i])\n",
" img = plt.imread(path)\n",
" plt.imshow(img)\n",
" plt.xlabel(classes[i])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"size=(32,32)\n",
"\n",
"images = []\n",
"labels = []\n",
"index = -1\n",
"for folder in os.listdir(train_dir):\n",
" index +=1\n",
" for image in os.listdir(train_dir + \"/\" + folder):\n",
" temp_img = cv2.imread(train_dir + '/' + folder + '/' + image)\n",
" temp_img = cv2.resize(temp_img, size)\n",
" images.append(temp_img)\n",
" labels.append(index)\n",
" \n",
"images = np.array(images)\n",
"images = images.astype('float32')/255.0\n",
"labels = tensorflow.keras.utils.to_categorical(labels)\n",
"x_train, x_test, y_train, y_test = train_test_split(images, labels, test_size = 0.1)\n",
" \n",
"print('Loaded', len(x_train),'images for training,','Train data shape =', x_train.shape)\n",
"print('Loaded', len(x_test),'images for testing','Test data shape =', x_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_img = keras.Input(shape=(32,32,3))\n",
"\n",
"x = Conv2D(64, (3, 3), padding='same', input_shape=(32, 32, 3), activation='relu')(input_img)\n",
"x = MaxPooling2D(pool_size=(2, 2))(x)\n",
"x = Conv2D(128, (3, 3), padding='same', input_shape=(32, 32, 3), activation='relu')(x)\n",
"x = MaxPooling2D(pool_size=(2, 2))(x)\n",
"\n",
"x = Conv2D(256, (3, 3), padding='same', input_shape=(32, 32, 3), activation='relu')(x)\n",
"x = MaxPooling2D(pool_size=(2, 2))(x)\n",
"x = keras.layers.BatchNormalization()(x)\n",
"\n",
"x = Flatten()(x)\n",
"x = Dropout(0.5)(x)\n",
"x = Dense(1024, activation='sigmoid')(x)\n",
"out = Dense(len(classes), activation='softmax')(x)\n",
"\n",
"asl_static_B = keras.Model(inputs=input_img,outputs=out)\n",
"\n",
"X = {\n",
" 'train':x_train,\n",
" 'test':x_test\n",
"}\n",
"y={\n",
" 'train':y_train,\n",
" 'test':y_test\n",
"}\n",
"params = {\n",
" 'optimizer': Adam(learning_rate=.001),\n",
" 'loss' : 'categorical_crossentropy',\n",
" 'metrics' : ['accuracy'],\n",
" 'batch_size' : 64,\n",
" 'epochs' : 4,\n",
" 'verbose': 1,\n",
" \n",
"}\n",
"\n",
"train(asl_static_B, X,y, **params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"asl_static_B.save('models/asl_static_B')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# modello ASL_static_C: 3-layer-CNN con dataset kaggle-AmericanSignLanguage\n",
"\n",
"dataset source analog al precedente, ma con qualche parola intera in piu rispetto al solo vocabolario composto da alfabeto e numeri, in totale ci sono 51 labels qua\n",
"\n",
"il dataset è un po piu cicciotto quindi lancio il training differentemente, facendo batching degli esempi dalla cartella di origine (non ce la faccio a caricarlo tuttassieme nella ram)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"train_dir='data/asl/sign_digit_fewWords'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'ImageDataGenerator' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-7653e1af06e1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m data_generator = ImageDataGenerator(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0msamplewise_center\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0msamplewise_std_normalization\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mbrightness_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mzoom_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'ImageDataGenerator' is not defined"
]
}
],
"source": [
"data_generator = ImageDataGenerator(\n",
" samplewise_center=True, \n",
" samplewise_std_normalization=True,\n",
" brightness_range=[0.8, 1.0],\n",
" zoom_range=[1.0, 1.2],\n",
" validation_split=0.1\n",
")\n",
"\n",
"train_generator = data_generator.flow_from_directory(train_dir, target_size=(200,200), shuffle=True, seed=13,\n",
" class_mode='categorical', batch_size=64, subset=\"training\")\n",
"\n",
"validation_generator = data_generator.flow_from_directory(train_dir, target_size=(200, 200), shuffle=True, seed=13,\n",
" class_mode='categorical', batch_size=64, subset=\"validation\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from keras.layers import BatchNormalization\n",
"from tensorflow.keras import regularizers\n",
"input_img = keras.Input(shape=(200,200,3))\n",
" \n",
"x = Conv2D(64, kernel_size = [3,3], padding = 'same', activation = 'relu', input_shape = (200,200,3))(input_img)\n",
"x = MaxPool2D(pool_size = [3,3])(x)\n",
" \n",
"x = Conv2D(128, kernel_size = [5,5], padding = 'same', activation = 'relu')(x)\n",
"x = MaxPool2D(pool_size = [3,3])(x)\n",
" \n",
"x = Conv2D(256, kernel_size = [3,3], padding = 'same', activation = 'relu')(x)\n",
"x = MaxPool2D(pool_size = [3,3])(x)\n",
" \n",
"x = BatchNormalization()(x)\n",
"x = Flatten()(x)\n",
"x = Dropout(0.5)(x)\n",
" \n",
"x = Dense(1024, activation = 'relu', kernel_regularizer = regularizers.l2(0.001))(x)\n",
"x = Dense(512, activation = 'relu', kernel_regularizer = regularizers.l2(0.001))(x)\n",
"out = Dense(51, activation = 'softmax')(x)\n",
"asl_static_C = keras.Model(inputs=input_img,outputs=out)\n",
"\n",
"asl_static_C.compile(optimizer = 'adam', loss = keras.losses.categorical_crossentropy, metrics = [\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_hist = asl_static_C.fit_generator(train_generator,\n",
" validation_data=validation_generator,\n",
" steps_per_epoch=200,\n",
" validation_steps=50,\n",
" epochs=15,\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"asl_static_C.save('models/asl_static_C.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ASL_static_D - inception transfer learning\n",
"\n",
"Qua si prende un modello noto di google chiamato __inceptionv3__ https://keras.io/api/applications/inceptionv3/ come punto di partenza per il successivo training del task di classificazione immagini.\n",
"in questo caso si parla di transfer-learning.\n",
"\n",
"il dataset è identico a quello di prima (asl_static_C)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_dir=train_dir\n",
"content=sorted(os.listdir(training_dir))\n",
"print(content)\n",
"len(content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"WEIGHTS_FILE = './inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'\n",
"\n",
"inception_v3_model = keras.applications.inception_v3.InceptionV3(\n",
" input_shape = (200, 200, 3), \n",
" include_top = False, \n",
" weights = 'imagenet'\n",
")\n",
"\n",
"inception_v3_model.summary()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inception_output_layer = inception_v3_model.get_layer('mixed7')\n",
"print('Inception model output shape:', inception_output_layer.output_shape)\n",
"\n",
"inception_output = inception_v3_model.output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import layers\n",
"x = layers.GlobalAveragePooling2D()(inception_output)\n",
"x = layers.Dense(1024, activation='relu')(x) \n",
"x = layers.Dense(51, activation='softmax')(x) \n",
"\n",
"asl_static_D = Model(inception_v3_model.input, x) \n",
"\n",
"asl_static_D.compile(\n",
" optimizer=SGD(lr=0.0001, momentum=0.9),\n",
" loss='categorical_crossentropy',\n",
" metrics=['acc']\n",
")\n",
"for layer in model.layers[:249]:\n",
" layer.trainable = False\n",
"for layer in model.layers[249:]:\n",
" layer.trainable = True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"LOSS_THRESHOLD = 0.2\n",
"ACCURACY_THRESHOLD = 0.95\n",
"\n",
"class ModelCallback(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs={}):\n",
" if logs.get('val_loss') <= LOSS_THRESHOLD and logs.get('val_acc') >= ACCURACY_THRESHOLD:\n",
" print(\"\\nReached\", ACCURACY_THRESHOLD * 100, \"accuracy, Stopping!\")\n",
" self.model.stop_training = True\n",
"\n",
"callback = ModelCallback()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"history = asl_static_D.fit_generator(\n",
" train_generator,\n",
" validation_data=validation_generator,\n",
" steps_per_epoch=200,\n",
" validation_steps=50,\n",
" epochs=20,\n",
" callbacks=[callback]\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"asl_static_D.save('models/asl_static_D.h5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"acc = history.history['acc']\n",
"val_acc = history.history['val_acc']\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"\n",
"epochs = range(len(acc))\n",
"\n",
"plt.plot(epochs, acc, 'r', label='Training accuracy')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation accuracy')\n",
"plt.title('Training and validation accuracy')\n",
"plt.legend(loc=0)\n",
"plt.figure()\n",
"\n",
"\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# asl_static_C VS asl_static_D\n",
"\n",
"Considerando lo stesso dataset e labelset, confronto tra il modello C (3Layer-3Channel shallow CNN) e D (transfer learning da modello con milamilioni di paramentri super sblindato di google)\n",
"\n",
"Qual'è meglio?\n",
"\n",
"A che prezzo?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"asl_static_C = keras.models.load_model('models/asl_static_C.h5')\n",
"asl_static_D = keras.models.load_model('models/asl_static_D.h5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_generator = ImageDataGenerator(\n",
" samplewise_center=True, \n",
" samplewise_std_normalization=True,\n",
" brightness_range=[0.8, 1.0],\n",
" zoom_range=[1.0, 1.2],\n",
" validation_split=0.1\n",
")\n",
"\n",
"test_generator = data_generator.flow_from_directory(train_dir, target_size=(200, 200), shuffle=False, seed=13,\n",
" class_mode='categorical', batch_size=64, subset=\"validation\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"predictions_C = asl_static_C.predict(test_generator, workers=-1, use_multiprocessing=True, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"predictions_D = asl_static_D.predict(test_generator, workers=-1, use_multiprocessing=True, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert predictions_D.shape==predictions_C.shape\n",
"predictions_D.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!du models/asl_static_C.h5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!du models/asl_static_D.h5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"yp_C = np.argmax(predictions_C,axis=1)\n",
"yp_D = np.argmax(predictions_D,axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"confusion_matrix_C = confusion_matrix(test_generator.classes, yp_C)\n",
"confusion_matrix_D = confusion_matrix(test_generator.classes, yp_D)\n",
"\n",
"fig,axs = plt.subplots(1,2,figsize=(20,10))\n",
"sns.heatmap(confusion_matrix_C,ax=axs[0])\n",
"axs[0].set_title('asl_static_C')\n",
"sns.heatmap(confusion_matrix_D,ax=axs[1])\n",
"axs[1].set_title('asl_static_D - tf inception')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"piu stai sulla diagonale meglio è (considerate le stesse 51 classi, l'asse delle y corrisponde ai valori predetti, l'asse delle x invece ai valori attuali per i sample di test (ground_truth)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mapx={v:k for k,v in test_generator.class_indices.items()}\n",
"upto=5\n",
"print(\"asl_static_C prediction for 5 samples of '1' gesture samples:\")\n",
"print([mapx[xx] for xx in yp_C[:5]])\n",
"print('\\nsame for model asl_static_D:')\n",
"print([mapx[xx] for xx in yp_D[:5]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('these were the original photos')\n",
"fig,axs = plt.subplots(1,5,figsize=(50,10))\n",
"for filename,ax in zip(test_generator.filenames[:5],axs.flatten()):\n",
" img = cv2.imread(f'{train_dir}/{filename}')\n",
" ax.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__conclusione__\n",
"\n",
"asl_static_D è il vincitore, ha 4 volte piu parametri (che in scala sono comunque molti anche per asl_static_C ~5'000'000) ma è molto piu sicuro."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}