Data gathers, and now tensor watch begins
Quite recently Microsoft Research team has published a debugging and visualization tool for deep learning. The tool is called TensorWatch and it greatly simplifies the creation of arbitrary dashboards that will grant control over the training models. The beauty of the tool is that it displays real-time visualisations right in your Jupyter Notebook cell. Intrigued? Let's see how it works.
TensorWatch supports Python 3 and you can get it through
pip package manager. The following command will install
tensorwatch and all its dependencies. The
! redirects the command to the shell.
! pip install tensorwatch
Excellent! Now you have installed
tensorwatch in your environment. It works with PyTorch 0.4-1.x and with TensorFlow eager tensors.
This blogpost requires several additional packages to be installed. We need the
graphviz library to scrutinise and visialise our neural network's architecture and
regim for data exploration. The
%%capture magic captures the output in the following cell.
%%capture ! pip install graphviz ! apt install graphviz ! git clone https://github.com/sytelus/regim.git ! pip install -e regim/
Here is a basic example provided by the authors of
tensorwatch. First we create a
Watcher object that is in charge of streams. Next we create a stream to log the data and a Jupyter notebook to listen to it. And finally we write tuples of integers and respective squares to the created stream that is logged in the 'test.log' file.
import tensorwatch as tw import time # streams will be stored in test.log file w = tw.Watcher(filename='test.log') # create a stream for logging s = w.create_stream(name='metric1') # generate Jupyter Notebook to view real-time streams w.make_notebook() for i in range(1000): # write x,y pair we want to log s.write((i, i*i)) time.sleep(1)
Now, if you open the newly created notebook and run all cells, you will get a plot that displays the $i, i^2$ pairs in real time. Just as follows:
This may not seem impressive, but TensorWatch can do much more than that. Let's dive into somewhat real-world example.
We are going to download the
CIFAR10 dataset and apply a t-sne to a part of it. This technique is well suited for the visualization of high-dimensional datasets. If t-sne is a new concept for you, please watch the video that clearly explains it and walks you through the basics.
Next we will build a model to classify images with PyTorch. We will inspect the network and it components.
And finally we will train the model, visualising the process online during the training.
First comes the magic. Jupyter's magic is a set of command line like instructions modifying the Jupyter cell's behaviour.
In our case the
%matplotlib magic is designed to set the matplotlib's backend. You have several options:
You can read more about it in the documentation. Now usually we start our Jupyter notebook with
%matplotlib inline command on the top. It enables the plots to be displayed below the cell with respective code. However the
inline backend does not allow for interactive plots and therefore we need another backend to reveal the full potential of TensorWatch.
If you run the Jupyter Notebook, you can use
widget backends. JupyterLab is not compatible with them and you need to choose the
Unfortunately, I could not make the JupyterLab work with the
ipympl backend and had to fall back to Jupyter Notebook. This is frustrating! The JupyterLab was released quite recently and has many issues, for example with Find and Replace.
Here are several links related to the problem with interactive backend in JupyterLab. Maybe you will figure it out how to make things work.
Otherwise imports are as usual. We import newly installed
regim along with data science classics, namely
matplotlib. We got familiar with them in the previous tutorial.
import tensorwatch as tw import regim from regim import DataUtils import os import pandas as pd import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.datasets as datasets import torchvision.transforms as transforms import matplotlib.pyplot as plt
We will download one of the datasets available in
torchvision.datasets. You can read it up in the documentation. Keep in mind that
torchvision.datasets.CIFAR10 will automatically download the files in the provided directory, in our case
./data. It may take some time to download.
Downloaded files are then collected in the datasets with
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Files already downloaded and verified Files already downloaded and verified
We will borrow the code to build a Neural Network and some utility functions from PyTorch tutorial on CIFAR10.
Typically you display an image or two to understand what are you working with before you construct and train the neural network.
# functions to show an image def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # get some random training images dataiter = iter(trainloader) images, labels = dataiter.next() # show images imshow(torchvision.utils.make_grid(images)) # print labels print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
dog frog dog bird
Exploratory analysis with TensorWatch
TensorWatch leverages the data exploration step with interactive visualisations. For this purpose we reload the data with a new set of transforms that now includes linearisation.
tsf = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), regim.ReshapeTransform((-1,))]) data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=tsf) dataset = torch.utils.data.DataLoader(data, shuffle=True, num_workers=0)
Files already downloaded and verified
T-sne algorithm is slow on big datasets such as CIFAR10. Here is where the
regim library turns out to be useful. The
DataUtils contains function
sample_by_class that allows us to pick an exact number of instances for each class.
k=50 images for each of our 10 classes as
numpy.array, since we chose
as_np=True, and we do not split in train and test. Next we flatten the images, preparing them for t-sne transformation.
inputs, labels = DataUtils.sample_by_class(dataset, k=50, as_np=True, no_test=True) inputs = inputs.reshape((500, -1)) labels = labels.reshape(-1)
TensorWatch has a wrapper function around
sklearn.manifold.TSNE. We can run it and get the components for following visualisation.
components = tw.get_tsne_components((inputs, labels))
In the Quick Start section we had a following sequence of events:
Watcherobject logged the data in
- We created a stream where we were writing the data during computation
- In a new notebook
WatcherClientgot the data from logs
In the next cell we simply create an
ArrayStream and instantly visualise it with
Visualiser. That's what the authors of TensorWatch call lazy logging.
Visualiser will make a 3D plot with t-sne visualisation. We can add
hover_image_reshape options to include images, so that you can see the respective image when you draw mouse over a point.
In the author's tutorial there were black and white images from the MNIST dataset and it was not clear, how to make coloured images appear. The code wraps around the
matplotlib library, therefore I suggest looking into matplotlib's documentation for insights.
In case of images in colour, you need to reshape arrays in the next order: image, height and width, then colour.
comp_stream = tw.ArrayStream(components) vis = tw.Visualizer(comp_stream, vis_type='tsne', hover_images=np.transpose(inputs.reshape((500, 3, 32, 32)), (0, 2, 3, 1)), hover_image_reshape=(32,32,3)) vis.show()
The result will appear like this:
Network creation and debugging with TensorWatch
We create a neural network of the following architecture:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net()
Now we can draw a graph with model architecture and dataflow. TensorWatch has
draw_model function that visualises the model. Note that it may not work with PyTorch older than 0.4.1. In this case upgrade PyTorch.
tw.draw_model(net, [1, 3, 32, 32])
draw_model graphs all layers of our network along with data shapes on every step. This may be useful when you inspect your model's architecture. The
model_stats function calculates your model's statistics including memory usage, shapes, parameters, et cetera. It will make easier for you to plan necessary resources for model training.
tw.model_stats(net, [1, 3, 32, 32])
|module name||input shape||output shape||params||memory(MB)||MAdd||Flops||MemRead(B)||MemWrite(B)||duration[%]||MemR+W(B)|
|0||conv1||3 32 32||6 28 28||456.0||0.02||705,600.0||357,504.0||14112.0||18816.0||27.82%||32928.0|
|1||pool||16 10 10||16 5 5||0.0||0.00||1,200.0||1,600.0||6400.0||1600.0||11.17%||8000.0|
|2||conv2||6 14 14||16 10 10||2416.0||0.01||480,000.0||241,600.0||14368.0||6400.0||26.68%||20768.0|
Now we are ready to start training. We need to run the model separately, e.g., run
python3 model.py in the command line, where
model.py is the file with the code for our model. In the code we insert snippet to log the loss:
watcher = tw.Watcher() # <- Create Watcher loss_stream = watcher.create_stream(name="loss") # <- Open stream ... for epoch in range(2000): running_loss = 0.0 for i, data in enumerate(trainloader, 0): ... running_loss += loss.item() ... loss_stream.write(running_loss/2000) # <- Write loss to the stream
In the notebook we create a
WatcherClient and open stream with logs, which we will pass to
client = tw.WatcherClient() model_loss = client.open_stream('loss') visualiser = tw.Visualizer(stream=model_loss) visualiser.show()
TensorWatch appears to be a valuable utility tool that may facilitate everyday work with neural networks in data analysis. It is unfortunate that it's application is limited by JupyterLab's issues with matplotlib's backends, so fall back to Jupyter Notebooks in order to try TensorWatch.
TensorWatch's authors could also opt for
plotly. These libraries were initially designed for interactive and appealing visualisations. It would be nice if users could choose library in the following releases of TensorWatch.