How to plot weights animation of a LLM for debugging

So, I want to quickly plot a layer of weights changes overtime during training.

If you don’t want anything fancy like Tensorboard, you can just use matplotLib.

class PlotMultiMatrix:
  def __init__(self, figure_name: str, data: OrderedDict[str, torch.Tensor]): = data
    self.names = list(data.keys())

    self.fig = plt.figure(figsize=(5,5))

    gs = gridspec.GridSpec(3, 1, height_ratios=[1, 0.01, 0.01]) = plt.subplot(gs[0])
    ax_sl = plt.subplot(gs[1])
    self.ax_cb = plt.subplot(gs[2])
    gs.tight_layout(self.fig, rect=[None, None, None, None])


    self.spos = Slider(ax_sl, 'Pos', 0, len(data), valinit=0, valstep=1)

    weights =[self.names[0]] =, cmap='tab20b')# cmap='YlOrRd')

  def animate(self, i):
    return self.fig,

  def update(self, val):
    pos = int(self.spos.val)
    weights =[self.names[pos]], vmax=weights.max())'{self.names[pos]}')


  def plot_multi_matrix(self):

    self.fig.colorbar(, orientation='horizontal', pad=0.05, cax=self.ax_cb)
    ani = FuncAnimation(self.fig, self.animate, frames=len(, interval=1000, blit=True)
    return ani