How to plot weights animation of a LLM for debugging

How to plot weights animation of a LLM for debugging

So, I want to quickly plot a layer of weights changes overtime during training.

If you don’t want anything fancy like Tensorboard, you can just use matplotLib.

class PlotMultiMatrix:
  def __init__(self, figure_name: str, data: OrderedDict[str, torch.Tensor]):
    self.data = data
    self.names = list(data.keys())

    self.fig = plt.figure(figsize=(5,5))

    gs = gridspec.GridSpec(3, 1, height_ratios=[1, 0.01, 0.01])
    self.ax = plt.subplot(gs[0])
    ax_sl = plt.subplot(gs[1])
    self.ax_cb = plt.subplot(gs[2])
    gs.tight_layout(self.fig, rect=[None, None, None, None])

    
    self.fig.suptitle(figure_name)

    self.spos = Slider(ax_sl, 'Pos', 0, len(data), valinit=0, valstep=1)

    weights = self.data[self.names[0]]
    self.im = self.ax.imshow(weights, cmap='tab20b')# cmap='YlOrRd')


  def animate(self, i):
    self.spos.set_val(i) 
    return self.fig,


  def update(self, val):
    pos = int(self.spos.val)
    weights = self.data[self.names[pos]]

    self.im.set_data(weights)
    self.im.set_clim(vmin=weights.min(), vmax=weights.max())
    self.ax.set_title(f'{self.names[pos]}')

    self.fig.canvas.draw_idle()


  def plot_multi_matrix(self):

    self.spos.on_changed(self.update)
    self.fig.colorbar(self.im, orientation='horizontal', pad=0.05, cax=self.ax_cb)
    ani = FuncAnimation(self.fig, self.animate, frames=len(self.data), interval=1000, blit=True)
    
    return ani

https://github.com/jljacoblo/jacAI/blob/master/util/plot_graphs.py