Example: Image Segmentation (Camvid Dataset)¶
The dataset to perform imgage segmentation can be downloaded from here
Import libraries¶
import os
import numpy as np
import medicalai as ai
import tensorflow as tf
Define the hyperparameters¶
Specify the dataset folder which further contains test & train folders each with n class object folders
datasetFolderPath = "../data/camvid/"
(IMG_HEIGHT,IMG_WIDTH) = (256,256)
EXPT_NAME = '1'
AI_NAME = 'unet'
MODEL_SAVE_NAME = '../model/'+AI_NAME+'/Medical_RSNA_'+str(IMG_HEIGHT)+'x'+str(IMG_WIDTH)+'_'+AI_NAME+'_EXPT_'+str(EXPT_NAME)
batch_size = 32
epochs = 10
learning_rate = 0.0001
Define the augmentation for the generator¶
(The following augmentation is for image only)
augment = ai.AUGMENTATION(rescale= 1./255)
flag_multi_class is set to True
dsHandler = ai.segmentaionGenerator(folder=datasetFolderPath,targetDim=(IMG_HEIGHT,IMG_WIDTH),
augmentation=augment, class_mode=None,
batch_size = batch_size,
image_folder_name = "image", mask_folder_name = "masks",
flag_multi_class=True)
trainGen = dsHandler.load_train_generator()
testGen = dsHandler.load_test_generator()
Train model¶
Now our image generator is ready to be trained on our model. But first we need to define a tensorflow callback for the model
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
MODEL_SAVE_NAME+'best.h5',
verbose=0,
mode='auto',
save_freq=5,
save_best_only=True,
)
callbacks = [model_checkpoint]
TRAIN_STEPS = int(np.ceil(dsHandler.imageGen.generator.n/dsHandler.batch_size))
trainer = ai.TRAIN_ENGINE()
trainer.train_and_save_segmentation(AI_NAME=AI_NAME,
MODEL_SAVE_NAME = MODEL_SAVE_NAME,
trainSet=trainGen,inputSize = (IMG_HEIGHT,IMG_WIDTH,3),
TRAIN_STEPS=TRAIN_STEPS,
BATCH_SIZE= BATCH_SIZE, EPOCHS= EPOCHS,
LEARNING_RATE= LEARNING_RATE, SAVE_BEST_MODEL = SAVE_BEST_MODEL,
callbacks = callbacks,
showModel = False)
infEngine = ai.INFERENCE_ENGINE(MODEL_SAVE_NAME)
predsG = infEngine.predict_segmentation(testGen)
infEngine.saveResult(save_path='results',npyfile=predsG,num_class=32,flag_multi_class=True)