Memory usage is a common issue for large ML models. Especially in academia, we have to use resources wisely and make the most out of resources available. While working on my mixture model’s KL-objective, I have to make some less common optimization to reduce memory usage.
Decoder outputs a large matrix \(O\) with dimensionality \((M \times B \times L \times D)\) where \(M\) is the number of clusters, \(B\) is a batch size, \(L\) is a sequence lengths and \(D\) is model output dimension. Lets assume from now on that \(T = (L \times D)\) so we an reduce dim by 1.
Now, let’s talk code. We use implementation of PowerSpherical distribution provided by Nicola De Cao (https://github.com/nicola-decao/power_spherical) [1]. In particular we are interested in one function log_prob
#[1] De Cao, N., Aziz, W. (2020).
#The Power Spherical distrbution.
#In Proceedings of the 37th International
#Conference on Machine Learning, INNF+.
#https://github.com/nicola-decao/power_spherical/blob/master/power_spherical/distributions.py#L176-L189
def log_prob(self, value):
return self.log_normalizer() + self.scale * torch.log1p(
(self.loc * value).sum(-1)
)
def log_normalizer(self):
alpha = self.base_dist.marginal_t.base_dist.concentration1
beta = self.base_dist.marginal_t.base_dist.concentration0
return -(
(alpha + beta) * math.log(2)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + beta)
+ beta * math.log(math.pi)
)
Let’s look closely at one line
(self.loc * value).sum(-1)
To get probability of each mixture for each value, we need to involve broadcasting, so we expand sample z_m at dimension 1, value is \(M \times 1 \times T \times D\) and loc is \(M \times T \times D\)
Then self.loc * value would be a \(M \times M \times T \times D\) matrix, and then we do sum over last dimension. Remember, in NLP we typically work with high-dimensional vectors. Scary.
Lets assume our typical numbers for M,T and D. M=128, T=512, D=128 and see how much memory it takes to get log probs
import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np
def call_log_prob(q, z_smaples):
return q.log_prob(z_smaples.unsqueeze(1))
if __name__=="__main__":
out = torch.rand(128,512,128)
scale = torch.norm(out, dim=-1)
loc = F.normalize(out, dim=-1)
q = power_spherical.PowerSpherical(loc, scale)
z_samples = q.rsample()
with profiler.profile(
use_cuda=True, with_stack=True, profile_memory=True
) as prof:
with profiler.record_function("call_log_prob"):
val1 = call_log_prob(q, z_samples)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))
This operation requires 4.1GB of cuda memory! And we cannot run it with more than T=512!
As we mentioned above, (self.loc*value).sum(-1) requires multiplication with broadcasting and then summation. Very convenient, we can use einsum [3,4] instead!
However, since we call external function of a distribution, we cannot use TorchScript [2] on this function! Let’s calculate log_normalizer on the fly.
import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np
def call_log_prob(q, z_smaples):
return q.log_prob(z_smaples.unsqueeze(1))
def calc_log_prob_partial(q, value):
log_prob = q.log_normalizer()+q.scale * torch.log1p(
torch.einsum("ijk, ljk -> lij", q.loc, value)
)
return log_prob
@torch.jit.script
def calc_log_prob(loc, scale, value):
beta = torch.tensor([(loc.shape[-1] - 1) / 2], device='cuda')
alpha = beta + scale
log_prob = -(
(alpha + beta) * math.log(2)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + beta)
+ beta * math.log(math.pi)
) + scale * torch.log1p(
torch.einsum("ijk, ljk -> lij", loc, value)
)
return log_prob
if __name__=="__main__":
out = torch.rand(128,512,128)
scale = torch.norm(out, dim=-1)
loc = F.normalize(out, dim=-1)
q = power_spherical.PowerSpherical(loc, scale)
z_samples = q.rsample()
with profiler.profile(
use_cuda=True, with_stack=True, profile_memory=True
) as prof:
with profiler.record_function("call_log_prob"):
val1 = call_log_prob(q, z_samples)
with profiler.record_function("calc_log_prob_partial"):
val2 = calc_log_prob_partial(q, z_samples)
with profiler.record_function("calc_log_prob"):
val3 = calc_log_prob(q.loc, q.scale, z_samples)
#ensure correctness of computations
assert torch.allclose(val1, val2)
assert torch.allclose(val1, val3)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))
Lets combine everything in the table! We are mostly interested in Self CUDA Mem and CUDA time avg
| Name | Self CUDA Mem | CUDA time avg | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % |
|---|---|---|---|---|---|---|---|
| call_log_prob | -4.10Gb | 71.429ms | 1.84 | 11.105ms | 11.105ms | 419.000us | 0.07 |
| calc_log_prob_partial | -98.50Mb | 466.243ms | 87.03 | 524.771ms | 524.771ms | 38.000us | 0.01 |
| calc_log_prob torchscript | 0b | 63.637ms | 10.36 | 63.455ms | 63.455ms | 60.112ms | 9.81 |
Memory usage on GPU 4.1Gb vs 98.5 Mb vs 0 (?!) with torch.jit.script, ~40 times less memory usage (!!!!). Impressive, right? And with torch.jit.script we do not suffer from computational time increase.
einsumTorchScript, use it. You might need to get rid of some external functions and your code will be less readable. However it saves both memory and computational time.[1] De Cao, N., Aziz, W. (2020). The Power Spherical distrbution. In Proceedings of the 37th International Conference on Machine Learning, INNF+.
[2] https://pytorch.org/docs/stable/jit.html
[3] https://pytorch.org/docs/stable/generated/torch.einsum.html