Why is it hard to run inference for large transformer models and how to overcome?

发布于: 雪球转发:0回复:0喜欢:0

Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge (网页链接{Pope et al. 2022}):

Large memory footprint. Both model parameters and intermediate states are needed in memory at inference time. For example,The KV cache should be stored in memory during decoding time; E.g. For a batch size of 512 and context length of 2048, the KV cache totals 3TB, that is 3x the model size (!).Inference cost from the attention mechanism scales quadratically with input sequence length.

Low parallelizability. Inference generation is executed in an autoregressive fashion, making the decoding process hard to parallel.

In this post, we will look into several approaches for making transformer inference more efficient. Some are general network compression methods, while others are specific to transformer architecture.

Methods Overview

We in general consider the following as goals for model inference optimization:

Reduce the memory footprint of the model by using fewer GPU devices and less GPU memory;

Reduce the desired computation complexity by lowering the number of FLOPs needed;

Reduce the inference latency and make things run faster.

Several methods can be used to make inference cheaper in memory or/and faster in time.

Apply various parallelism to scale up the model across a large number of GPUs. Smart parallelism of model components and data makes it possible to run a model of trillions of parameters.

Memory offloading to offload temporarily unused data to the CPU and read them back when needed later. This helps with memory usage but causes higher latency.

Smart batching strategy; E.g. EffectiveTransformer packs consecutive sequences together to remove padding within one batch.

Network compression techniques, such as pruning, quantization, distillation. A model of smaller size, in terms of parameter count or bitwidth, should demand less memory and run faster.

Improvement specific to a target model architecture. Many architectural changes, especially those for attention layers, help with transformer decoding speed.