about « all posts

One example of memory optimization for KL-loss calculation in pytorch

Mar 30 2023 · 4 min read
#pytorch #memory #fairseq #nlp #machine-learning

Memory usage is a common issue for large ML models. Especially in academia, we have to use resources wisely and make the most out of resources available. While working on my mixture model’s KL-objective, I have to make some less common optimization to reduce memory usage.

Setup

Decoder outputs a large matrix \(O\) with dimensionality \((M \times B \times L \times D)\) where \(M\) is the number of clusters, \(B\) is a batch size, \(L\) is a sequence lengths and \(D\) is model output dimension. Lets assume from now on that \(T = (L \times D)\) so we an reduce dim by 1.

  1. Each raw of matrix \(O\) represents mixture component \(Q_m = PowerSpherical(\mu_m, \kappa_m)\) where each row \(m\) is a location value \(\mu_m\) and a spread parameter \(\kappa_m = ||\mu_m||\) and \(Q(x) = \sum_{1}^{M} p(m) Q_m(x)\)
  2. We sample \(z_m\) from each mixture component \(Q_m\) and get a matrix of samples, where each raw is a sample of component \(m\).
  3. Now, our goal is to calculate \(\log q_j(z_m)\), i.e., for each sample \(z_m\) calculate log probabilities given all mixture components \(q_j\) where \(j\in[1,M]\) So we end up with the matrix of \(M \times M \times T\)

Now, let’s talk code. We use implementation of PowerSpherical distribution provided by Nicola De Cao (https://github.com/nicola-decao/power_spherical) [1]. In particular we are interested in one function log_prob

#[1] De Cao, N., Aziz, W. (2020). 
#The Power Spherical distrbution.
#In Proceedings of the 37th International 
#Conference on Machine Learning, INNF+.
#https://github.com/nicola-decao/power_spherical/blob/master/power_spherical/distributions.py#L176-L189

def log_prob(self, value):
        return self.log_normalizer() + self.scale * torch.log1p(
            (self.loc * value).sum(-1)
        )
def log_normalizer(self):
        alpha = self.base_dist.marginal_t.base_dist.concentration1
        beta = self.base_dist.marginal_t.base_dist.concentration0
        return -(
            (alpha + beta) * math.log(2)
            + torch.lgamma(alpha)
            - torch.lgamma(alpha + beta)
            + beta * math.log(math.pi)
        )

Let’s look closely at one line

(self.loc * value).sum(-1)

To get probability of each mixture for each value, we need to involve broadcasting, so we expand sample z_m at dimension 1, value is \(M \times 1 \times T \times D\) and loc is \(M \times T \times D\)

Then self.loc * value would be a \(M \times M \times T \times D\) matrix, and then we do sum over last dimension. Remember, in NLP we typically work with high-dimensional vectors. Scary.

Dive into optimization

Lets assume our typical numbers for M,T and D. M=128, T=512, D=128 and see how much memory it takes to get log probs

import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np

def call_log_prob(q, z_smaples):
    return q.log_prob(z_smaples.unsqueeze(1))

if __name__=="__main__":
    out = torch.rand(128,512,128)
    scale = torch.norm(out, dim=-1)
    loc = F.normalize(out, dim=-1)

    q = power_spherical.PowerSpherical(loc, scale)

    z_samples = q.rsample()

    with profiler.profile(
        use_cuda=True, with_stack=True, profile_memory=True
        ) as prof:
        with profiler.record_function("call_log_prob"):
            val1 = call_log_prob(q, z_samples)
    print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))

This operation requires 4.1GB of cuda memory! And we cannot run it with more than T=512!

As we mentioned above, (self.loc*value).sum(-1) requires multiplication with broadcasting and then summation. Very convenient, we can use einsum [3,4] instead!

However, since we call external function of a distribution, we cannot use TorchScript [2] on this function! Let’s calculate log_normalizer on the fly.

import math
import torch.autograd.profiler as profiler
import power_spherical
import torch
import torch.nn.functional as F
import numpy as np

def call_log_prob(q, z_smaples):
    return q.log_prob(z_smaples.unsqueeze(1))

def calc_log_prob_partial(q, value):
    log_prob = q.log_normalizer()+q.scale * torch.log1p(
                torch.einsum("ijk, ljk -> lij", q.loc, value)
            )
    return log_prob

@torch.jit.script
def calc_log_prob(loc, scale, value):
    beta = torch.tensor([(loc.shape[-1] - 1) / 2], device='cuda')
    alpha = beta + scale
    log_prob = -(
                    (alpha + beta) * math.log(2)
                    + torch.lgamma(alpha)
                    - torch.lgamma(alpha + beta)
                    + beta * math.log(math.pi)
            ) + scale * torch.log1p(
                torch.einsum("ijk, ljk -> lij", loc, value)
            )
    return log_prob

if __name__=="__main__":
    out = torch.rand(128,512,128)
    scale = torch.norm(out, dim=-1)
    loc = F.normalize(out, dim=-1)

    q = power_spherical.PowerSpherical(loc, scale)

    z_samples = q.rsample()

    with profiler.profile(
        use_cuda=True, with_stack=True, profile_memory=True
        ) as prof:
        with profiler.record_function("call_log_prob"):
            val1 = call_log_prob(q, z_samples)
        with profiler.record_function("calc_log_prob_partial"):
            val2 = calc_log_prob_partial(q, z_samples)
        with profiler.record_function("calc_log_prob"):
            val3 = calc_log_prob(q.loc, q.scale, z_samples)
    #ensure correctness of computations
    assert torch.allclose(val1, val2)
    assert torch.allclose(val1, val3)
    print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=20))

Lets combine everything in the table! We are mostly interested in Self CUDA Mem and CUDA time avg

Name Self CUDA Mem CUDA time avg CPU total % CPU total CPU time avg Self CUDA Self CUDA %
call_log_prob -4.10Gb 71.429ms 1.84 11.105ms 11.105ms 419.000us 0.07
calc_log_prob_partial -98.50Mb 466.243ms 87.03 524.771ms 524.771ms 38.000us 0.01
calc_log_prob torchscript 0b 63.637ms 10.36 63.455ms 63.455ms 60.112ms 9.81

Memory usage on GPU 4.1Gb vs 98.5 Mb vs 0 (?!) with torch.jit.script, ~40 times less memory usage (!!!!). Impressive, right? And with torch.jit.script we do not suffer from computational time increase.

Lessons learned:

  1. If you have huge matrix multiplication followed by sum → use einsum
  2. If you can use TorchScript, use it. You might need to get rid of some external functions and your code will be less readable. However it saves both memory and computational time.

[1] De Cao, N., Aziz, W. (2020). The Power Spherical distrbution. In Proceedings of the 37th International Conference on Machine Learning, INNF+.

[2] https://pytorch.org/docs/stable/jit.html

[3] https://pytorch.org/docs/stable/generated/torch.einsum.html

[4] https://rockt.github.io/2018/04/30/einsum