Tensorization of neural network layers

Introduction

First proposed by Novikov et al. 2015, tensorization of neural network aims to represent large weight matrices in neural networks with product of smaller, high dimensional tensors, which largely reduces the computational cost. For more details and applications please refer to my talk. In this post, we see how easy it has become to implement such tensorized neural networks.

1. Einstein summation notation.

Einstein summation notation can express tensor contraction in an extremely elegant fashion. There are only two conventions to follow:

  • The mode (“dimensionality”) of a tensor is denoted using indices, which is in contrast to the usual matrix notation, where one denotes an entry in a tensor with the indices.
  • The contraction is carried out over all modes whose indices appear more than once. Note: the order in which the contraction of multiple modes take place is not denoted explicitly.

One could express a large amount of tensor operations using Einstein summation notation in numpy, torch, tensorflow and other libraries.

Here are some basic usage:

import numpy as np
import torch

np.random.seed(123)
torch.manual_seed(123)

1.1. Summation over all elements

a = torch.randn(5)

torch.einsum('n->', a)

# equivalent to a.sum()

1.2. Summation over rows or columns

A = torch.randn(5, 3)

torch.einsum('nm->m', A)

# equivalent to A.sum(axis=0)

1.3. Inner product of vectors

a = torch.randn(5)
b = torch.randn(5)

torch.einsum('n,n->', a, b)

# equivalent to torch.inner(a, b)

1.4. Outer product of vectors

a = torch.randn(5)
b = torch.randn(3)

torch.einsum('n,m->nm', a, b) 

# equivalent to torch.outer(a, b)

1.5. Matrix vector product

A = torch.randn(3, 5)
b = torch.randn(5)

torch.einsum('mn,n->m', A, b) 

# equivalent to torch.matmul(A, b)

1.6. Matrix matrix product

A = torch.randn(3, 5)
B = torch.randn(5, 4)

torch.einsum('mn,nk->mk', A, B) 

# equivalent to torch.matmul(A, B)

1.7. Tensor contraction


A = torch.randn(5, 10, 7)
B = torch.randn(5, 7, 2)

torch.einsum('mnk,mkr->nr', A, B)

which is equivalent to

C = torch.zeros(10, 2)

for m in range(5):
    for n in range(10):
        for k in range(7):
            for r in range(2):
                C[n, r] += A[m, n, k] * B[m, k, r]

or

C = torch.zeros(10, 2)

for n in range(10):
    for r in range(2):
        C[n, r] = (A[:, n, :] * B[:, :, r]).sum()

2. Tensorizing neural network layers

We create a simple regression task as an isolated layer.

import numpy as np
import torch
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

np.random.seed(123)
torch.manual_seed(123)

2.1. Differentiability

Obviously, tensor contraction are differentiable linear operations. The einsum API is thus also differentiable.

W = torch.randn(3, 5, requires_grad=True)
x = torch.randn(5)
y = torch.randn(3)

y_hat = torch.einsum('NM,M->N', W, x)
loss = ((y - y_hat)**2).sum()
loss.backward()

W.grad

# tensor([[ 2.3307,  2.4076,  4.2191, -3.5718, -4.3023],
#         [ 3.2671,  3.3749,  5.9143, -5.0069, -6.0309],
#         [-4.2287, -4.3682, -7.6550,  6.4805,  7.8060]])

As we know, the 1st order derivative of such a layer is $-2 \cdot (y - W x )x^T$. One could check with

-torch.outer(y-(torch.matmul(W, x)), x)*2

# tensor([[ 2.3307,  2.4076,  4.2191, -3.5718, -4.3023],
#         [ 3.2671,  3.3749,  5.9143, -5.0069, -6.0309],
#         [-4.2287, -4.3682, -7.6550,  6.4805,  7.8060]], grad_fn=<MulBackward0>)

2.2. A tensor-train layer

2.2.1 We create a toy dataset and verify the existence of a solution:


n_samples = 256
n_informative = 128

n_input = 4096
n_output = 16

X, y, coef = make_regression(n_samples=n_samples, n_features=n_input, n_informative=n_informative, n_targets=n_output, coef=True)
X = (X - X.mean(0)) / X.std(0)
y = (y - y.mean(0)) / y.std(0)

lr = LinearRegression()
lr.fit(X, y)

lr_y_hat = lr.predict(X)
mean_squared_error(y.reshape(-1), lr_y_hat.reshape(-1)), r2_score(y.reshape(-1), lr_y_hat.reshape(-1))

2.2.2. We decompose the input and output dimensions as

$$ 4096 = 8 \times 8 \times 8 \times 8, \ 16 = 2 \times 2 \times 2 \times 2 $$ and implement


m = 8
n = 2

X = X.reshape(n_samples, m, m, m, m)
y = y.reshape(n_samples, n, n, n, n)

X = torch.tensor(X, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)

2.2.3. Define the rank of the decomposition:

r = 8
W1_init = np.random.randn(m, n, r)
W1_init /= (W1_init**2).sum()**0.5

W2_init = np.random.randn(m, n, r, r)
W2_init /= (W2_init**2).sum()**0.5

W3_init = np.random.randn(m, n, r, r)
W3_init /= (W3_init**2).sum()**0.5

W4_init = np.random.randn(m, n, r)
W4_init /= (W4_init**2).sum()**0.5

W1 = torch.tensor(W1_init, requires_grad=True, dtype=torch.float)
W2 = torch.tensor(W2_init, requires_grad=True, dtype=torch.float)
W3 = torch.tensor(W3_init, requires_grad=True, dtype=torch.float)
W4 = torch.tensor(W4_init, requires_grad=True, dtype=torch.float)

2.2.4. Perform the gradient descent:

opt = torch.optim.Adam([W1, W2, W3, W4], lr=5e-2)
loss = torch.nn.MSELoss()

n_epochs = 10000

for i in range(n_epochs):
    y_hat = torch.einsum('aei, bfij, cgjk, dhk, nabcd -> nefgh', W1, W2, W3, W4, X)
    error = loss(y_hat.reshape(n_samples, -1), y.reshape(n_samples, -1))
    error.backward()
    opt.step()
    opt.zero_grad()
    

Note that we only need a single line to define a tensor-train layer and the backward() and step() are boilerplate from torch.

Of course we could have implemented the update from scratch, such as

lr = 5e-2
for i in range(10000): 
    y_hat = torch.einsum('aei, bfij, cgjk, dhk, nabcd -> nefgh', W1, W2, W3, W4, X)
    error = ((y-y_hat)**2).mean()

    error.backward()
    with torch.no_grad():
        W1 -= lr*W1.grad 
        W2 -= lr*W2.grad
        W3 -= lr*W3.grad
        W4 -= lr*W4.grad

        W1.grad.zero_()
        W2.grad.zero_()
        W3.grad.zero_()
        W4.grad.zero_()

But at least for this example, adam performs way better than plain SGD.

2.2.5. Evaluation

y_pred = torch.einsum('aei, bfij, cgjk, dhk, nabcd -> nefgh', W1, W2, W3, W4, X).detach().numpy()

print(mean_squared_error(y.numpy().reshape(-1), y_pred.reshape(-1)))

print(r2_score(y.numpy().reshape(-1), y_pred.reshape(-1)))

# 0.055088233
# 0.9449117677629324

It cannot fit the data as perfectly as the full model did. But keep in mind that we only use 3.5% of the parameters:


tt_size = np.prod(W1_init.shape) + np.prod(W2_init.shape) + np.prod(W3_init.shape) + np.prod(W4_init.shape)
full_size = n_input * n_output

tt_size / full_size
# 0.03515625
Dr. Yinchong Yang
Dr. Yinchong Yang
Senior Key Expert of Robust Artificial Intelligence

My research interests include trustworthy AI and tensor decomposition in machine learning.

Related