Mar 30 2023 · 4 min read

Memory usage is a common issue for large ML models. Especially in academia, we have to use resources wisely and make the most out of resources available. While working on my mixture model’s KL-objective, I have to make some less common optimization to reduce memory usage.

Decoder outputs a large matrix \(O\) with dimensionality \((M \times B \times L \times D)\) where \(M\) is the number of clusters, \(B\) is a batch size, \(L\) is a sequence lengths and \(D\) is model output dimension. Lets assume from now on that \(T = (L \times D)\) so we an reduce dim by 1.

- Each raw of matrix \(O\) represents mixture component \(Q_m = PowerSpherical(\mu_m, \kappa_m)\) where each row \(m\) is a location value \(\mu_m\) and a spread parameter \(\kappa_m = ||\mu_m||\) and \(Q(x) = \sum_{1}^{M} p(m) Q_m(x)\)
- We sample \(z_m\) from each mixture component \(Q_m\) and get a matrix of samples, where each raw is a sample of component \(m\).
- Now, our goal is to calculate \(\log q_j(z_m)\), i.e., for each sample \(z_m\) calculate log probabilities given all mixture components \(q_j\) where \(j\in[1,M]\) So we end up with the matrix of \(M \times M \times T\)

Now, let’s talk code. We use implementation of PowerSpherical distribution provided by Nicola De Cao (https://github.com/nicola-decao/power_spherical) [1]. In particular we are interested in one function `log_prob`

```
#[1] De Cao, N., Aziz, W. (2020).
#The Power Spherical distrbution.
#In Proceedings of the 37th International
#Conference on Machine Learning, INNF+.
#https://github.com/nicola-decao/power_spherical/blob/master/power_spherical/distributions.py#L176-L189
def log_prob(self, value):
return self.log_normalizer() + self.scale * torch.log1p(
(self.loc * value).sum(-1)
)
def log_normalizer(self):
alpha = self.base_dist.marginal_t.base_dist.concentration1
beta = self.base_dist.marginal_t.base_dist.concentration0
return -(
(alpha + beta) * math.log(2)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + beta)
+ beta * math.log(math.pi)
)
```

Let’s look closely at one line

```
(self.loc * value).sum(-1)
```

To get probability of each mixture for each value, we need to involve `broadcasting`

, so we expand sample `z_m`

at dimension 1, `value`

is \(M \times 1 \times T \times D\) and `loc`

is \(M \times T \times D\)

Then `self.loc * value`

would be a \(M \times M \times T \times D\) matrix, and then we do `sum`

over last dimension. Remember, in NLP we typically work with high-dimensional vectors. Scary.

Lets assume our typical numbers for M,T and D. M=128, T=512, D=128 and see how much memory it takes to get log probs

```
import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np
def call_log_prob(q, z_smaples):
return q.log_prob(z_smaples.unsqueeze(1))
if __name__=="__main__":
out = torch.rand(128,512,128)
scale = torch.norm(out, dim=-1)
loc = F.normalize(out, dim=-1)
q = power_spherical.PowerSpherical(loc, scale)
z_samples = q.rsample()
with profiler.profile(
use_cuda=True, with_stack=True, profile_memory=True
) as prof:
with profiler.record_function("call_log_prob"):
val1 = call_log_prob(q, z_samples)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))
```

This operation requires **4.1GB** of cuda memory! And we cannot run it with more than T=512!

As we mentioned above, `(self.loc*value).sum(-1)`

requires multiplication with broadcasting and then summation. Very convenient, we can use `einsum`

[3,4] instead!

However, since we call external function of a distribution, we cannot use `TorchScript`

[2] on this function! Let’s calculate `log_normalizer`

on the fly.

```
import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np
def call_log_prob(q, z_smaples):
return q.log_prob(z_smaples.unsqueeze(1))
def calc_log_prob_partial(q, value):
log_prob = q.log_normalizer()+q.scale * torch.log1p(
torch.einsum("ijk, ljk -> lij", q.loc, value)
)
return log_prob
@torch.jit.script
def calc_log_prob(loc, scale, value):
beta = torch.tensor([(loc.shape[-1] - 1) / 2], device='cuda')
alpha = beta + scale
log_prob = -(
(alpha + beta) * math.log(2)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + beta)
+ beta * math.log(math.pi)
) + scale * torch.log1p(
torch.einsum("ijk, ljk -> lij", loc, value)
)
return log_prob
if __name__=="__main__":
out = torch.rand(128,512,128)
scale = torch.norm(out, dim=-1)
loc = F.normalize(out, dim=-1)
q = power_spherical.PowerSpherical(loc, scale)
z_samples = q.rsample()
with profiler.profile(
use_cuda=True, with_stack=True, profile_memory=True
) as prof:
with profiler.record_function("call_log_prob"):
val1 = call_log_prob(q, z_samples)
with profiler.record_function("calc_log_prob_partial"):
val2 = calc_log_prob_partial(q, z_samples)
with profiler.record_function("calc_log_prob"):
val3 = calc_log_prob(q.loc, q.scale, z_samples)
#ensure correctness of computations
assert torch.allclose(val1, val2)
assert torch.allclose(val1, val3)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))
```

Lets combine everything in the table! We are mostly interested in Self CUDA Mem and CUDA time avg

Name | Self CUDA Mem | CUDA time avg | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % |
---|---|---|---|---|---|---|---|

call_log_prob | -4.10Gb | 71.429ms | 1.84 | 11.105ms | 11.105ms | 419.000us | 0.07 |

calc_log_prob_partial | -98.50Mb | 466.243ms | 87.03 | 524.771ms | 524.771ms | 38.000us | 0.01 |

calc_log_prob torchscript | 0b | 63.637ms | 10.36 | 63.455ms | 63.455ms | 60.112ms | 9.81 |

Memory usage on GPU 4.1Gb vs 98.5 Mb vs 0 (?!) with `torch.jit.script`

, ~40 times less memory usage (!!!!). Impressive, right? And with `torch.jit.script`

we do not suffer from computational time increase.

- If you have huge matrix multiplication followed by sum → use
`einsum`

- If you can use
`TorchScript`

, use it. You might need to get rid of some external functions and your code will be less readable. However it saves both memory and computational time.

[1] De Cao, N., Aziz, W. (2020). The Power Spherical distrbution. In Proceedings of the 37th International Conference on Machine Learning, INNF+.

[2] https://pytorch.org/docs/stable/jit.html

[3] https://pytorch.org/docs/stable/generated/torch.einsum.html