Memory usage is a common issue for large ML models. Especially in academia, we have to use resources wisely and make the most out of resources available. While working on my mixture model’s KL-objective, I have to make some less common optimization to reduce memory usage.
Decoder outputs a large matrix \(O\) with dimensionality \((M \times B \times L \times D)\) where \(M\) is the number of clusters, \(B\) is a batch size, \(L\) is a sequence lengths and \(D\) is model output dimension. Lets assume from now on that \(T = (L \times D)\) so we an reduce dim by 1.
Now, let’s talk code. We use implementation of PowerSpherical distribution provided by Nicola De Cao (https://github.com/nicola-decao/power_spherical) [1]. In particular we are interested in one function log_prob
#[1] De Cao, N., Aziz, W. (2020).
#The Power Spherical distrbution.
#In Proceedings of the 37th International
#Conference on Machine Learning, INNF+.
#https://github.com/nicola-decao/power_spherical/blob/master/power_spherical/distributions.py#L176-L189
def log_prob(self, value):
return self.log_normalizer() + self.scale * torch.log1p(
(self.loc * value).sum(-1)
)
def log_normalizer(self):
alpha = self.base_dist.marginal_t.base_dist.concentration1
beta = self.base_dist.marginal_t.base_dist.concentration0
return -(
(alpha + beta) * math.log(2)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + beta)
+ beta * math.log(math.pi)
)
Let’s look closely at one line
(self.loc * value).sum(-1)
To get probability of each mixture for each value, we need to involve broadcasting
, so we expand sample z_m
at dimension 1, value
is \(M \times 1 \times T \times D\) and loc
is \(M \times T \times D\)
Then self.loc * value
would be a \(M \times M \times T \times D\) matrix, and then we do sum
over last dimension. Remember, in NLP we typically work with high-dimensional vectors. Scary.
Lets assume our typical numbers for M,T and D. M=128, T=512, D=128 and see how much memory it takes to get log probs
import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np
def call_log_prob(q, z_smaples):
return q.log_prob(z_smaples.unsqueeze(1))
if __name__=="__main__":
out = torch.rand(128,512,128)
scale = torch.norm(out, dim=-1)
loc = F.normalize(out, dim=-1)
q = power_spherical.PowerSpherical(loc, scale)
z_samples = q.rsample()
with profiler.profile(
use_cuda=True, with_stack=True, profile_memory=True
) as prof:
with profiler.record_function("call_log_prob"):
val1 = call_log_prob(q, z_samples)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))
This operation requires 4.1GB of cuda memory! And we cannot run it with more than T=512!
As we mentioned above, (self.loc*value).sum(-1)
requires multiplication with broadcasting and then summation. Very convenient, we can use einsum
[3,4] instead!
However, since we call external function of a distribution, we cannot use TorchScript
[2] on this function! Let’s calculate log_normalizer
on the fly.
import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np
def call_log_prob(q, z_smaples):
return q.log_prob(z_smaples.unsqueeze(1))
def calc_log_prob_partial(q, value):
log_prob = q.log_normalizer()+q.scale * torch.log1p(
torch.einsum("ijk, ljk -> lij", q.loc, value)
)
return log_prob
@torch.jit.script
def calc_log_prob(loc, scale, value):
beta = torch.tensor([(loc.shape[-1] - 1) / 2], device='cuda')
alpha = beta + scale
log_prob = -(
(alpha + beta) * math.log(2)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + beta)
+ beta * math.log(math.pi)
) + scale * torch.log1p(
torch.einsum("ijk, ljk -> lij", loc, value)
)
return log_prob
if __name__=="__main__":
out = torch.rand(128,512,128)
scale = torch.norm(out, dim=-1)
loc = F.normalize(out, dim=-1)
q = power_spherical.PowerSpherical(loc, scale)
z_samples = q.rsample()
with profiler.profile(
use_cuda=True, with_stack=True, profile_memory=True
) as prof:
with profiler.record_function("call_log_prob"):
val1 = call_log_prob(q, z_samples)
with profiler.record_function("calc_log_prob_partial"):
val2 = calc_log_prob_partial(q, z_samples)
with profiler.record_function("calc_log_prob"):
val3 = calc_log_prob(q.loc, q.scale, z_samples)
#ensure correctness of computations
assert torch.allclose(val1, val2)
assert torch.allclose(val1, val3)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))
Lets combine everything in the table! We are mostly interested in Self CUDA Mem and CUDA time avg
Name | Self CUDA Mem | CUDA time avg | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % |
---|---|---|---|---|---|---|---|
call_log_prob | -4.10Gb | 71.429ms | 1.84 | 11.105ms | 11.105ms | 419.000us | 0.07 |
calc_log_prob_partial | -98.50Mb | 466.243ms | 87.03 | 524.771ms | 524.771ms | 38.000us | 0.01 |
calc_log_prob torchscript | 0b | 63.637ms | 10.36 | 63.455ms | 63.455ms | 60.112ms | 9.81 |
Memory usage on GPU 4.1Gb vs 98.5 Mb vs 0 (?!) with torch.jit.script
, ~40 times less memory usage (!!!!). Impressive, right? And with torch.jit.script
we do not suffer from computational time increase.
einsum
TorchScript
, use it. You might need to get rid of some external functions and your code will be less readable. However it saves both memory and computational time.[1] De Cao, N., Aziz, W. (2020). The Power Spherical distrbution. In Proceedings of the 37th International Conference on Machine Learning, INNF+.
[2] https://pytorch.org/docs/stable/jit.html
[3] https://pytorch.org/docs/stable/generated/torch.einsum.html