Quay lại
Mở Rộng Huấn Luyện PyTorch Trên Nhiều GPU: Làm Chủ Song Song Dữ Liệu Trong Các Dự Án Thực Tế
28 Th12, 2025
•
08.02 AM
Mở Rộng Huấn Luyện PyTorch Trên Nhiều GPU: Làm Chủ Song Song Dữ Liệu Trong Các Dự Án Thực Tế
Tóm tắt: Từng bị kẹt khi huấn luyện mô hình AI lớn trên một GPU vì hết bộ nhớ hoặc thời gian huấn luyện quá lâu? Song song dữ liệu cho phép bạn chia nhỏ tập dữ liệu trên nhiều GPU, huấn luyện nhanh hơn mà không cần viết lại code mô hình. Trong bài này, tôi sẽ hướng dẫn bạn từng chi tiết của DataParallel và DistributedDataParallel trong PyTorch, chia sẻ các đoạn code từ dự án của tôi, những cạm bẫy tôi đã tránh được, và khi nào nên chọn cái này thay vì cái kia. Lợi ích: cắt thời gian huấn luyện xuống còn 1/3–1/4 trên phần cứng phổ thông, nhưng phải để ý chi phí truyền thông và công việc thiết lập.
Giới thiệu
Nếu bạn đang xây dựng mô hình ML trong hệ sinh thái AI bùng nổ ở Việt Nam—từ chatbot cho thương mại điện tử đến nhận dạng hình ảnh cho nhà máy—giới hạn một GPU sẽ giết chết năng suất của bạn. Tôi đã thấy nhiều team lãng phí hàng tuần chờ mô hình hội tụ trên RTX 3090. Song song dữ liệu quan trọng vì nó biến dàn multi-GPU của bạn thành một "con thú", xử lý batch lớn hơn và mô hình to hơn. Điều này thể hiện rõ khi fine-tuning LLM hoặc huấn luyện CNN trên các tập dữ liệu tùy biến. Dev có 2+ GPU, data scientist đang scale prototype, và kỹ sư startup nên quan tâm—đây là cây cầu từ thử nghiệm đồ chơi đến môi trường production.
Bối cảnh và Thuật ngữ
- Song song dữ liệu (Data Parallelism): Chia batch dữ liệu của bạn trên các GPU; mỗi GPU chạy đầy đủ forward/backward trên phần dữ liệu của nó, rồi trung bình gradient.
- Song song mô hình (Model Parallelism): Chia các tầng của mô hình trên nhiều GPU—hữu ích cho các mô hình khổng lồ như GPT, nhưng khó debug hơn.
- DistributedDataParallel (DDP): Phiên bản mở rộng của PyTorch cho huấn luyện multi-node, multi-GPU; xử lý truyền thông hiệu quả.
- DataParallel (DP): Wrapper đơn giản cho một máy; nhân bản mô hình sang nhiều GPU nhưng có thể nghẽn tại thread chính.
- Đồng bộ gradient (Gradient Synchronization): Tất cả GPU chia sẻ và trung bình gradient sau mỗi backward pass để giữ mô hình đồng bộ.
- Kích thước batch (Batch Size): Tổng số mẫu mỗi iteration; tăng lên với nhiều GPU để giữ ổn định.
- AllReduce: Toán tử tập thể cộng gradient trên các GPU, sau đó broadcast giá trị trung bình.
Phân tích Kỹ thuật
Hoạt động như sau từng bước, giống như tôi vẽ bảng giải thích cho đồng đội. Đầu tiên, wrap mô hình của bạn: PyTorch nhân bản nó trên mỗi GPU. Batch đầu vào được chia đều—ví dụ 128 mẫu trên 4 GPU thành 32 mẫu mỗi GPU. Mỗi GPU tính loss độc lập. Sau đó, gradient được all-reduce thông qua backend NCCL. Bước tối ưu hóa (optimizer step) diễn ra một lần trên mỗi GPU, đã được đồng bộ.
Sơ đồ dòng chảy dạng text:
Batch (128) --> Split to GPU0(32), GPU1(32), GPU2(32), GPU3(32)
|
v
Forward Pass (parallel)
|
v
Backward Pass (parallel)
|
v
AllReduce Gradients <--> Average
|
v
Optimizer Step (synced)Các thành phần: Bộ chia (Splitter) chia dữ liệu; các bản sao (replica) tính toán; ring-allreduce (cây hiệu quả) đồng bộ gradient.
Các chế độ lỗi: Sự cố mạng trong môi trường multi-node có thể làm DDP sập—tôi đã mất vài ngày vì Ethernet chập chờn. Batch không cân đều gây OOM trên một GPU. Mô hình đe dọa (threat model): GPU bị mất đồng bộ do NaN trong gradient; giảm thiểu bằng gradient clipping.
Triển khai Thực tế
Bắt đầu với DP để test nhanh trên một máy. Bước 1: Import torch.nn.parallel.DataParallel. Bước 2: Wrap mô hình sau khi move sang device. Bước 3: Dùng DataLoader với num_workers>0. Lưu ý: DP giữ mô hình trên GPU0, rồi scatter dữ liệu—ổn cho 4 GPU, nhưng nghẽn ở 8+.
Chuyển sang DDP để scale: chạy bằng torchrun, init_process_group. Cạm bẫy: đặt world_size sai sẽ khiến chương trình treo mãi.
# DDP setup - why: scales beyond one node, async comms
import torch.distributed as dist
import torch.multiprocessing as mp
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
model = MyModel().to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
# Why device_ids: pins to local GPU, avoids cross-node messMẹo: Bật pin_memory trong DataLoader tăng tốc truyền dữ liệu khoảng 20%.
Đánh giá và Chỉ số
Theo dõi throughput (mẫu/giây), mức sử dụng GPU (>90% là tốt), sử dụng bộ nhớ, và tốc độ hội tụ (đường cong loss giống với single-GPU). Trong môi trường production, giám sát độ trễ truyền thông—đột biến là dấu hiệu nghẽn.
| Phương pháp | Tăng tốc (4x RTX 4090) | Hiệu quả bộ nhớ | Dùng khi |
|---|---|---|---|
| DataParallel | 3x | Tốt | Máy đơn, <8 GPU |
| DistributedDataParallel | 3.8x | Tốt nhất | Multi-node, quy mô lớn |
Chọn DP cho prototype; DDP cho cluster. Tăng tốc tuyến tính là hiếm—đạt 80% đã là thắng lợi.
Hạn chế và Đánh đổi
DP tuần tự hóa gradient trên rank 0, làm các GPU khác bị đói tài nguyên ở quy mô lớn. DDP cần kết nối tốc độ cao (InfiniBand là lý tưởng, không phải LAN văn phòng). Nó không giải quyết nhu cầu song song mô hình cho mô hình 100B+ tham số. Chi phí: Nhiều GPU = tốn tiền, và debug phân tán rất mệt—log bị phân tán khắp nơi. Không phù hợp cho mô hình nhỏ; overhead sẽ ăn hết lợi ích. Ở Việt Nam, tiền điện cũng khá đau nếu chạy dàn máy 24/7.
Kết luận
- DP để tăng tốc dễ dàng trên một máy; DDP cho scale nghiêm túc—hãy test cả hai sớm.
- Luôn clip gradient và scale batch size tuyến tính theo số GPU.
- Giám sát mức sử dụng và truyền thông; >90% nghĩa là bạn đang tối ưu tốt.
- Bắt đầu đơn giản, sau đó nâng cấp lên torchrun cho môi trường production.
Hiện tại bạn có bao nhiêu GPU trong hệ thống, và điều gì đang giới hạn tốc độ huấn luyện của bạn?
Tài liệu tham khảo
Blog Khác
•
11.46 AM
•
10.03 AM
•
05.35 PM
•
05.44 PM