import matplotlib
import matplotlib.pyplot as plt
def plot_digits(instances, images_per_row=5, **options):
size = 28
images_per_row = min(len(instances), images_per_row)
images = [instance.reshape(size,size) for instance in instances]
n_rows = (len(instances) - 1) // images_per_row + 1
row_images = []
n_empty = n_rows * images_per_row - len(instances)
images.append(np.zeros((size, size * n_empty)))
for row in range(n_rows):
rimages = images[row * images_per_row : (row + 1) * images_per_row]
row_images.append(np.concatenate(rimages, axis=1))
image = np.concatenate(row_images, axis=0)
plt.imshow(image, cmap = matplotlib.cm.binary, **options)
plt.axis(“off”)
plt.figure(figsize=(7, 4))
plt.subplot(121)
Plotting ‘original’ image
plot_digits(X_train[::2100])
plt.title(“Original”, fontsize=16)
plt.subplot(122)
Plotting the corresponding ‘recovered’ image
plot_digits(X_train_recovered[::2100])
plt.title(“Compressed”, fontsize=16)
plt.show()