Scaling LLMs with PyTorch 2.0 FSDP on Amazon EKS – Part 2 | Artificial Intelligence

Scaling LLMs with PyTorch 2.0 FSDP on Amazon EKS – Part 2 | Artificial IntelligenceMore Info

This article is a collaborative effort with Meta’s PyTorch team, serving as the second installment in our series showcasing the effectiveness and simplicity of utilizing PyTorch 2.0 on AWS.

Recent advancements in machine learning (ML) research have demonstrated that large language models (LLMs) trained on vast datasets yield superior model quality. The evolution of contemporary models has led to significant growth in their size, necessitating advanced tools and infrastructure for efficient large-scale training. While PyTorch Distributed Data Parallelism (DDP) enables data processing at scale, it requires the entire model to fit on a single GPU. In contrast, the PyTorch Fully Sharded Data Parallel (FSDP) library overcomes this limitation by allowing model sharding, facilitating the training of large models across data parallel workers.

To effectively train distributed models, a scalable cluster of worker nodes is essential. Amazon Elastic Kubernetes Service (Amazon EKS) simplifies the execution of AI/ML workloads, making it more efficient and less time-consuming to manage. In this blog post, we explore how to leverage the PyTorch FSDP library for seamless linear scaling of deep learning models on AWS, using Amazon EKS and AWS Deep Learning Containers (DLCs). We illustrate this with a step-by-step implementation of training 7B, 13B, and 70B Llama2 models on 16 Amazon Elastic Compute Cloud (Amazon EC2) p4de.24xlarge instances (each equipped with 8 NVIDIA A100 Tensor Core GPUs and 80 GB HBM2e memory) or 16 p5.48xlarge instances (each with 8 NVIDIA H100 Tensor Core GPUs and 80 GB HBM3 memory), achieving nearly linear throughput scaling and thus facilitating faster training times.

The scaling chart below illustrates that the p5.48xlarge instances achieve 87% scaling efficiency with FSDP Llama2 fine-tuning in a 16-node cluster configuration.

Challenges in Training LLMs

The adoption of LLMs is growing among businesses for various applications such as virtual assistants, translation, content generation, and computer vision, enhancing efficiency and accuracy. However, the process of training or fine-tuning these extensive models for specific use cases demands substantial data and computational power, complicating the ML stack. Limited memory on individual GPUs restricts the size of trainable models and the per-GPU batch sizes.

To tackle this issue, model parallelism techniques like DeepSpeed ZeRO and PyTorch FSDP have emerged, addressing GPU memory limitations. These methods use sharded data parallel techniques, where each accelerator retains only a portion (or shard) of a model replica, significantly reducing the memory burden of training jobs.

In this post, we demonstrate the fine-tuning of the Llama2 model using PyTorch FSDP on Amazon EKS, scaling up compute and GPU resources to meet the model’s requirements.

Overview of FSDP

In traditional PyTorch DDP training, each GPU (referred to as a worker) maintains a complete model copy, including weights, gradients, and optimizer states. Each worker processes a data batch and synchronizes gradients via an all-reduce operation after the backward pass.

The replication of the entire model on each GPU limits the size of models in a DDP workflow. FSDP addresses this constraint by distributing model parameters, optimizer states, and gradients across data parallel workers while maintaining the simplicity of data parallelism.

As illustrated in the diagram below, in the case of DDP, every GPU contains a full model state (M(OS + G + P)). Conversely, in FSDP, each GPU retains only a portion of the model state (M(OS + G + P)). This approach drastically lowers the GPU memory footprint compared to DDP, enabling the training of very large models or permitting larger batch sizes.

However, this reduction comes with increased communication overhead, mitigated through FSDP optimizations like overlapping communication with computation, utilizing techniques like pre-fetching. For further details, consult Getting Started with Fully Sharded Data Parallel (FSDP).

FSDP provides various parameters to optimize training job performance and efficiency. Key features include:

  • Transformer wrapping policy
  • Flexible mixed precision
  • Activation checkpointing
  • Diverse sharding strategies to accommodate different network speeds and cluster configurations:
    • FULL_SHARD – Shard model parameters, gradients, and optimizer states
    • HYBRID_SHARD – Full shard within a node DDP across nodes, supporting flexible sharding for a full model replica (HSDP)
    • SHARD_GRAD_OP – Shard only gradients and optimizer states
    • NO_SHARD – Similar to DDP

For more information about FSDP, refer to Efficient Large-Scale Training with Pytorch FSDP and AWS.

The figure below demonstrates how FSDP operates across two data parallel processes.

Solution Overview

In this post, we establish a compute cluster utilizing Amazon EKS, a managed service for running Kubernetes in both the AWS Cloud and on-premises data centers. Numerous customers are turning to Amazon EKS for Kubernetes-based AI/ML workloads, taking advantage of its performance, scalability, reliability, and seamless integration with AWS networking, security, and other services.

For our FSDP application, we employ the Kubeflow Training Operator on Amazon EKS, a Kubernetes-native project that simplifies scalable distributed training and fine-tuning of ML models. It supports several ML frameworks, including PyTorch, allowing for the deployment and management of PyTorch training jobs at scale.

By leveraging the PyTorchJob custom resource from the Kubeflow Training Operator, we execute training jobs on Kubernetes with a configurable number of worker replicas, optimizing resource utilization.

Key components of the training operator relevant to our Llama2 fine-tuning endeavor include:

  • A centralized Kubernetes controller orchestrating distributed PyTorch training jobs.
  • PyTorchJob, a Kubernetes custom resource provided by the Kubeflow Training Operator, to define and deploy Llama2 training tasks on Kubernetes.
  • etcd, which plays a crucial role in implementing the rendezvous mechanism to coordinate distributed training of PyTorch models. The etcd server facilitates coordination and synchronization among participating workers during the training process.

The diagram below illustrates the architecture of the solution.

Most details will be managed through automation scripts used to execute the Llama2 example. We reference the following code examples in this case:

  • End-to-end fsdp example
  • Llama-recipes example

What is Llama2?

Llama2 represents a significant advancement in the realm of LLMs, providing enhanced capabilities and performance. Its architecture is designed to tackle complex tasks effectively, making it a vital tool in the AI landscape. For further insights into the intricacies of Llama2, you can explore this excellent resource.


Comments

Leave a Reply

Your email address will not be published. Required fields are marked *