Data gathers, and now tensor watch begins

Quite recently Microsoft Research team has published a debugging and visualization tool for deep learning. The tool is called TensorWatch and it greatly simplifies the creation of arbitrary dashboards that will grant control over the training models. The beauty of the tool is that it displays real-time visualisations right in your Jupyter Notebook cell. Intrigued? Let's see how it works.

Installation

TensorWatch supports Python 3 and you can get it through pip package manager. The following command will install tensorwatch and all its dependencies. The ! redirects the command to the shell.

! pip install tensorwatch

Excellent! Now you have installed tensorwatch in your environment. It works with PyTorch 0.4-1.x and with TensorFlow eager tensors.

Extra dependencies

This blogpost requires several additional packages to be installed. We need the graphviz library to scrutinise and visialise our neural network's architecture and regim for data exploration. The %%capture magic captures the output in the following cell.

%%capture
! pip install graphviz
! apt install graphviz
! git clone https://github.com/sytelus/regim.git
! pip install -e regim/

Quick Start

Here is a basic example provided by the authors of tensorwatch. First we create a Watcher object that is in charge of streams. Next we create a stream to log the data and a Jupyter notebook to listen to it. And finally we write tuples of integers and respective squares to the created stream that is logged in the 'test.log' file.

import tensorwatch as tw
import time

# streams will be stored in test.log file
w = tw.Watcher(filename='test.log')

# create a stream for logging
s = w.create_stream(name='metric1')

# generate Jupyter Notebook to view real-time streams
w.make_notebook()

for i in range(1000):
    # write x,y pair we want to log
    s.write((i, i*i)) 

    time.sleep(1)

Now, if you open the newly created notebook and run all cells, you will get a plot that displays the $i, i^2$ pairs in real time. Just as follows:

Quick start

This may not seem impressive, but TensorWatch can do much more than that. Let's dive into somewhat real-world example.

We are going to download the CIFAR10 dataset and apply a t-sne to a part of it. This technique is well suited for the visualization of high-dimensional datasets. If t-sne is a new concept for you, please watch the video that clearly explains it and walks you through the basics.

Next we will build a model to classify images with PyTorch. We will inspect the network and it components.

And finally we will train the model, visualising the process online during the training.

Imports

First comes the magic. Jupyter's magic is a set of command line like instructions modifying the Jupyter cell's behaviour.

In our case the %matplotlib magic is designed to set the matplotlib's backend. You have several options:

  • inline
  • ipympl
  • notebook
  • widget

You can read more about it in the documentation. Now usually we start our Jupyter notebook with %matplotlib inline command on the top. It enables the plots to be displayed below the cell with respective code. However the inline backend does not allow for interactive plots and therefore we need another backend to reveal the full potential of TensorWatch.

If you run the Jupyter Notebook, you can use notebook and widget backends. JupyterLab is not compatible with them and you need to choose the ipympl backend.

Unfortunately, I could not make the JupyterLab work with the ipympl backend and had to fall back to Jupyter Notebook. This is frustrating! The JupyterLab was released quite recently and has many issues, for example with Find and Replace.

Here are several links related to the problem with interactive backend in JupyterLab. Maybe you will figure it out how to make things work.

github1
github2
stackoverflow1
stackoverflow2

%matplotlib notebook

Otherwise imports are as usual. We import newly installed tensorwatch and regim along with data science classics, namely os, pandas, numpy, torch and matplotlib. We got familiar with them in the previous tutorial.

import tensorwatch as tw
import regim
from regim import DataUtils

import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

Loading data

We will download one of the datasets available in torchvision.datasets. You can read it up in the documentation. Keep in mind that torchvision.datasets.CIFAR10 will automatically download the files in the provided directory, in our case ./data. It may take some time to download.

Downloaded files are then collected in the datasets with DataLoader.

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ['plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Files already downloaded and verified
Files already downloaded and verified

We will borrow the code to build a Neural Network and some utility functions from PyTorch tutorial on CIFAR10.

Typically you display an image or two to understand what are you working with before you construct and train the neural network.

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
png
  dog  frog   dog  bird

Exploratory analysis with TensorWatch

TensorWatch leverages the data exploration step with interactive visualisations. For this purpose we reload the data with a new set of transforms that now includes linearisation.

tsf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     regim.ReshapeTransform((-1,))])

data = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=tsf)
dataset = torch.utils.data.DataLoader(data, shuffle=True, num_workers=0)
Files already downloaded and verified

T-sne algorithm is slow on big datasets such as CIFAR10. Here is where the regim library turns out to be useful. The DataUtils contains function sample_by_class that allows us to pick an exact number of instances for each class.

We collect k=50 images for each of our 10 classes as numpy.array, since we chose as_np=True, and we do not split in train and test. Next we flatten the images, preparing them for t-sne transformation.

inputs, labels = DataUtils.sample_by_class(dataset, k=50, as_np=True, no_test=True)
inputs = inputs.reshape((500, -1))
labels = labels.reshape(-1)

TensorWatch has a wrapper function around sklearn.manifold.TSNE. We can run it and get the components for following visualisation.

components = tw.get_tsne_components((inputs, labels))

In the Quick Start section we had a following sequence of events:

  1. A Watcher object logged the data in test.log
  2. We created a stream where we were writing the data during computation
  3. In a new notebook WatcherClient got the data from logs
  4. Finally, Visualiser made visualisations

In the next cell we simply create an ArrayStream and instantly visualise it with Visualiser. That's what the authors of TensorWatch call lazy logging.

Visualiser will make a 3D plot with t-sne visualisation. We can add hover_images and hover_image_reshape options to include images, so that you can see the respective image when you draw mouse over a point.

In the author's tutorial there were black and white images from the MNIST dataset and it was not clear, how to make coloured images appear. The code wraps around the matplotlib library, therefore I suggest looking into matplotlib's documentation for insights.

In case of images in colour, you need to reshape arrays in the next order: image, height and width, then colour.

comp_stream = tw.ArrayStream(components)
vis = tw.Visualizer(comp_stream, vis_type='tsne', 
                    hover_images=np.transpose(inputs.reshape((500, 3, 32, 32)), (0, 2, 3, 1)),
                    hover_image_reshape=(32,32,3))
vis.show()

The result will appear like this:

t sne1

Network creation and debugging with TensorWatch

We create a neural network of the following architecture:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

Now we can draw a graph with model architecture and dataflow. TensorWatch has draw_model function that visualises the model. Note that it may not work with PyTorch older than 0.4.1. In this case upgrade PyTorch.

tw.draw_model(net, [1, 3, 32, 32])
svg

TensorWatch's draw_model graphs all layers of our network along with data shapes on every step. This may be useful when you inspect your model's architecture. The model_stats function calculates your model's statistics including memory usage, shapes, parameters, et cetera. It will make easier for you to plan necessary resources for model training.

tw.model_stats(net, [1, 3, 32, 32])
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv1 3 32 32 6 28 28 456.0 0.02 705,600.0 357,504.0 14112.0 18816.0 27.82% 32928.0
1 pool 16 10 10 16 5 5 0.0 0.00 1,200.0 1,600.0 6400.0 1600.0 11.17% 8000.0
2 conv2 6 14 14 16 10 10 2416.0 0.01 480,000.0 241,600.0 14368.0 6400.0 26.68% 20768.0
3 fc1 400 120 48120.0 0.00 95,880.0 48,000.0 194080.0 480.0 12.93% 194560.0
4 fc2 120 84 10164.0 0.00 20,076.0 10,080.0 41136.0 336.0 10.96% 41472.0
5 fc3 84 10 850.0 0.00 1,670.0 840.0 3736.0 40.0 10.43% 3776.0
total 62006.0 0.03 1,304,426.0 659,624.0 3736.0 40.0 100.00% 301504.0

Now we are ready to start training. We need to run the model separately, e.g., run python3 model.py in the command line, where model.py is the file with the code for our model. In the code we insert snippet to log the loss:

watcher = tw.Watcher() # <- Create Watcher
loss_stream = watcher.create_stream(name="loss") # <- Open stream

...

for epoch in range(2000): 

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
        
            ...
            
            running_loss += loss.item()
            
            ...
            
            loss_stream.write(running_loss/2000) # <- Write loss to the stream
            

In the notebook we create a WatcherClient and open stream with logs, which we will pass to Visualizer.

client = tw.WatcherClient()
model_loss = client.open_stream('loss')

visualiser = tw.Visualizer(stream=model_loss)
visualiser.show()

Conclusions

TensorWatch appears to be a valuable utility tool that may facilitate everyday work with neural networks in data analysis. It is unfortunate that it's application is limited by JupyterLab's issues with matplotlib's backends, so fall back to Jupyter Notebooks in order to try TensorWatch.

TensorWatch's authors could also opt for bokeh or plotly. These libraries were initially designed for interactive and appealing visualisations. It would be nice if users could choose library in the following releases of TensorWatch.