Matplotlib is an incredibly powerful tool for visualization that no data engineer should overlook. Before you work on data, you got to know how to visualize the data.
Below are some of the feature that I found important.
Feature 1: Animate Language Models’ weights using matplotlib.animation()
Sometimes it is very helpful to see what the model’s layers are doing during training. Visualizing changes in weights and biases can greatly aid in understanding the learning process.
You can plot all the layers after LayerNorm, or visualize the adjustments in the attention heads as the model converges to a solution.
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
fig, ax = plt.subplots()
line, = ax.plot([], [], lw=2)
def init():
ax.set_ylim(-1.1, 1.1)
ax.set_xlim(0, 2)
del xdata[:]
del ydata[:]
line.set_data(xdata, ydata)
return line,
def update(frame):
xdata.append(frame)
ydata.append(np.sin(frame))
line.set_data(xdata, ydata)
return line,
ani = animation.FuncAnimation(fig, update, frames=np.linspace(0, 2, 128),
init_func=init, blit=True)
HTML(ani.to_jshtml())
Feature 2: 3D plots
When visualizing high-dimensional data, you often need to resort to dimensionality reduction techniques such as PCA. Sometimes you want to plot a 3D plot, especially for PCA while exploring which features capture the most variance in the data.
For example, after reducing dimensions, let’s say you’re curious about how your dataset’s samples spread across the three principal components.
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Assuming 'data' is your PCA reduced data with shape (n_samples, 3)
x, y, z = data[:,0], data[:,1], data[:,2]
ax.scatter(x, y, z, c='r', marker='o')
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
plt.show()
Feature 3: Plot and continue
In short, when you plot, use this plt.show(block=False).
This will continue to run the code, very useful if you plot multiple plots consecutively or when you’re trying to interact with other functionalities of your script while keeping the plots active.
import matplotlib.pyplot as plt
plt.plot([1, 2, 3], [4, 5, 6])
plt.show(block=False)
plt.pause(1) # Pause for 1 second before continuing execution
Feature 4: Histograms with Gaussian fit
It is very important for the weights to be normalized when training deep neural networks. This is because if you have weights that are too small or too big, they can hinder the model’s ability to learn effectively, causing either vanishing or exploding gradients.
Visualizing the distribution of your weights can be insightful, especially with a Gaussian overlay to check for normality.
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
data = np.random.normal(0, 1, 1000)
# Fit a normal distribution to the data
mu, std = norm.fit(data)
# Plot the histogram
plt.hist(data, bins=25, density=True, alpha=0.6, color='g')
# Plot the Gaussian fit
xmin, xmax = plt.xlim()
x = np.linspace(xmin, xmax, 100)
p = norm.pdf(x, mu, std)
plt.plot(x, p, 'k', linewidth=2)
title = "Fit results: mu = %.2f, std = %.2f" % (mu, std)
plt.title(title)
plt.show()
Feature 5: Faster rendering with Data Blitting
Sometimes your data can be too big, and plotting will become very slow. This slowness can be due to the massive overhead from redrawing static parts of the plot. By using blitting, only the parts of the plot which have changed are redrawn, which speeds up the rendering significantly.
This technique is particularly useful for fast-updating plots, like those used in animations or interactive sessions.
import matplotlib.pyplot as plt
import numpy as np
plt.ion() # Turn on interactive mode
fig, ax = plt.subplots()
xdata, ydata = [], []
for x in np.arange(0,10,0.5):
xdata.append(x)
ydata.append(np.exp(-x**2))
line, = ax.plot(xdata, ydata, 'r')
plt.draw()
plt.pause(0.1)
line.remove()
plt.ioff() # Turn off interactive mode
plt.show()