import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.io import imread
from sklearn.model_selection import train_test_split
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg')
img_size = 28
downsize_factor = 4
original_size = img_size * downsize_factor
ground_truth_folder = './output/ground_truth/%dd%d' % (original_size, downsize_factor)
from src.cnn import get_data
X, Y = get_data(ground_truth_folder, IMG_SIZE=img_size)
from src.vis_utils import plot_array
fig = plt.figure(figsize=(8, 6))
plot_array(fig, X, Y, num_classes=9)
plt.show()
Xtr, Xte, Ytr, Yte = train_test_split(X, Y, test_size=400, random_state=42)
print 'Training data shape: ', Xtr.shape
print 'Training labels shape: ', Ytr.shape
print 'Test data shape: ', Xte.shape
print 'Test labels shape: ', Yte.shape
mean_image = np.mean(Xtr, axis=0) # take mean image over training set only
Xtr -= mean_image
Xte -= mean_image
%matplotlib inline
fig, ax = plt.subplots(figsize=(4, 4))
ax.axis('off')
ax.imshow(mean_image[:,:,0])
Xtr, Xval, Ytr, Yval = train_test_split(Xtr, Ytr, test_size=130, random_state=42)
print 'Training data shape: ', Xtr.shape
print 'Training labels shape: ', Ytr.shape
print 'Training data shape: ', Xval.shape
print 'Training labels shape: ', Yval.shape
from src.cnn import ConvolutionalNeuralNetwork
model = ConvolutionalNeuralNetwork(IMG_SIZE=img_size, NUM_CHANNELS=1, NUM_LABELS=9,
BATCH_SIZE=64, NUM_VALIDATION=Xval.shape[0], NUM_TEST=Xte.shape[0])
model.train(Xtr, Ytr, Xval, Yval, max_iters=1500)
predictions = model.test_model(Xte)
correct = np.sum(predictions == Yte)
total = predictions.shape[0]
print 'Test error: %.02f%%' % (100 * (1 - float(correct) / float(total)))
from src.vis_utils import plot_confusion_matrix
classes = ['interphase', 'large', 'prometaphase', 'metaphase', 'bright',
'anaphase', 'early anaphase', 'polylobed', 'apoptosis']
confusion_matrix = np.zeros((9, 9), np.int32)
for i in range(len(predictions)):
confusion_matrix[Ytr[i]][predictions[i]] += 1
fig, ax = plt.subplots(figsize=(8, 6))
plot_confusion_matrix(ax, confusion_matrix, classes, fontsize=15)