Top 5 features on MatPlotlib for Deep Learning

Top 5 features on MatPlotlib for Deep Learning

Matplotlib is an incredibly powerful tool for visualization that no data engineer should overlook. Before you work on data, you got to know how to visualize the data.

Below are some of the feature that I found important.

Feature 1: Animate Language Models’ weights using matplotlib.animation()

Sometimes it is very helpful to see what the model’s layers are doing during training. Visualizing changes in weights and biases can greatly aid in understanding the learning process.

You can plot all the layers after LayerNorm, or visualize the adjustments in the attention heads as the model converges to a solution.

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots()
line, = ax.plot([], [], lw=2)

def init():
    ax.set_ylim(-1.1, 1.1)
    ax.set_xlim(0, 2)
    del xdata[:]
    del ydata[:]
    line.set_data(xdata, ydata)
    return line,

def update(frame):
    xdata.append(frame)
    ydata.append(np.sin(frame))
    line.set_data(xdata, ydata)
    return line,

ani = animation.FuncAnimation(fig, update, frames=np.linspace(0, 2, 128),
                              init_func=init, blit=True)
HTML(ani.to_jshtml())

Feature 2: 3D plots

When visualizing high-dimensional data, you often need to resort to dimensionality reduction techniques such as PCA. Sometimes you want to plot a 3D plot, especially for PCA while exploring which features capture the most variance in the data.

For example, after reducing dimensions, let’s say you’re curious about how your dataset’s samples spread across the three principal components.

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Assuming 'data' is your PCA reduced data with shape (n_samples, 3)
x, y, z = data[:,0], data[:,1], data[:,2]

ax.scatter(x, y, z, c='r', marker='o')

ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')

plt.show()

Feature 3: Plot and continue

In short, when you plot, use this plt.show(block=False).

This will continue to run the code, very useful if you plot multiple plots consecutively or when you’re trying to interact with other functionalities of your script while keeping the plots active.

import matplotlib.pyplot as plt

plt.plot([1, 2, 3], [4, 5, 6])
plt.show(block=False)
plt.pause(1)  # Pause for 1 second before continuing execution

Feature 4: Histograms with Gaussian fit

It is very important for the weights to be normalized when training deep neural networks. This is because if you have weights that are too small or too big, they can hinder the model’s ability to learn effectively, causing either vanishing or exploding gradients.

Visualizing the distribution of your weights can be insightful, especially with a Gaussian overlay to check for normality.

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

data = np.random.normal(0, 1, 1000)

# Fit a normal distribution to the data
mu, std = norm.fit(data)

# Plot the histogram
plt.hist(data, bins=25, density=True, alpha=0.6, color='g')

# Plot the Gaussian fit
xmin, xmax = plt.xlim()
x = np.linspace(xmin, xmax, 100)
p = norm.pdf(x, mu, std)
plt.plot(x, p, 'k', linewidth=2)
title = "Fit results: mu = %.2f,  std = %.2f" % (mu, std)
plt.title(title)

plt.show()

Feature 5: Faster rendering with Data Blitting

Sometimes your data can be too big, and plotting will become very slow. This slowness can be due to the massive overhead from redrawing static parts of the plot. By using blitting, only the parts of the plot which have changed are redrawn, which speeds up the rendering significantly.

This technique is particularly useful for fast-updating plots, like those used in animations or interactive sessions.

import matplotlib.pyplot as plt
import numpy as np

plt.ion()  # Turn on interactive mode
fig, ax = plt.subplots()
xdata, ydata = [], []

for x in np.arange(0,10,0.5):
    xdata.append(x)
    ydata.append(np.exp(-x**2))
    line, = ax.plot(xdata, ydata, 'r')
    plt.draw()
    plt.pause(0.1)
    line.remove()

plt.ioff()  # Turn off interactive mode
plt.show()