Thuật toán lan truyền ngược (backpropagation) là một phương pháp cốt lõi trong học sâu (deep learning) và mạng nơ-ron nhân tạo (artificial neural networks), được sử dụng để huấn luyện mô hình bằng cách tối ưu hóa các tham số (trọng số và độ lệch) dựa trên việc giảm thiểu hàm mất mát (loss function). Dưới đây là tổng quan, cơ sở toán học và ví dụ minh họa:
Tổng quan
Lan truyền ngược là một kỹ thuật tính toán đạo hàm của hàm mất mát đối với từng tham số trong mạng nơ-ron, từ đó cập nhật các tham số này theo hướng giảm dần sai số. Quá trình này bao gồm hai giai đoạn chính:
- Lan truyền xuôi (Forward Propagation): Tín hiệu đầu vào được truyền qua các tầng của mạng nơ-ron để tạo ra đầu ra dự đoán.
- Lan truyền ngược (Backward Propagation): Sai số giữa đầu ra dự đoán và giá trị thực tế được truyền ngược từ tầng cuối về tầng đầu, sử dụng quy tắc chuỗi (chain rule) để tính gradient cho từng tham số.
Mục tiêu là sử dụng gradient descent (hoặc các biến thể của nó) để điều chỉnh trọng số sao cho hàm mất mát giảm dần qua các vòng lặp huấn luyện.
Cơ sở toán học
1. Mô hình mạng nơ-ron cơ bản
Giả sử ta có một mạng nơ-ron với:
- Đầu vào: $x$
- Trọng số: $w$ (kết nối giữa các nơ-ron)
- Độ lệch: $b$
- Hàm kích hoạt: $\sigma$ (ví dụ: sigmoid, ReLU)
- Đầu ra dự đoán: $\hat{y}$
- Giá trị thực tế: $y$
- Hàm mất mát: $L$
ví dụ: Mean Squared Error - MSE: $$L = \frac{1}{2} (\hat{y} - y)^2$$
Công thức cơ bản tại một nơ-ron ở tầng ( l ):
- $z^l = w^l a^{l-1} + b^l$ (tổng trọng số đầu vào và độ lệch)
- $a^l = \sigma(z^l)$ (đầu ra sau hàm kích hoạt)
2. Quy tắc chuỗi trong lan truyền ngược
Lan truyền ngược dựa trên quy tắc chuỗi để tính đạo hàm của $L$ đối với từng tham số $( w ), ( b )$ qua các tầng:
- Gradient của hàm mất mát đối với $w^l$:
$$\frac{\partial L}{\partial w^l} = \frac{\partial L}{\partial a^l} \cdot \frac{\partial a^l}{\partial z^l} \cdot \frac{\partial z^l}{\partial w^l}$$ - Trong đó:
- $( \frac{\partial L}{\partial a^l} )$: Sai số tại tầng $l$ (thường gọi là $ \delta^l$)
- $\frac{\partial a^l}{\partial z^l} = \sigma'(z^l)$: Đạo hàm của hàm kích hoạt
- $\frac{\partial z^l}{\partial w^l} = a^{l-1}$: Đầu ra của tầng trước
- Sai số $\delta^l$ tại tầng $ l $ được tính dựa trên tầng tiếp theo $( l+1 )$:
$$\delta^l = (w^{l+1})^T \delta^{l+1} \cdot \sigma'(z^l)$$
3. Cập nhật trọng số
Sau khi tính gradient, trọng số được cập nhật bằng gradient descent:
$$ w^l = w^l - \eta \cdot \frac{\partial L}{\partial w^l}$$
Trong đó $\eta $ là tốc độ học (learning rate).
Ví dụ minh họa
Tình huống đơn giản
Giả sử ta có một mạng nơ-ron với:
- 1 đầu vào: $ x = 2$
- 1 tầng ẩn: $ w_1 = 0.5 , b_1 = 0 $, hàm kích hoạt $sigmoid ( \sigma(z) = \frac{1}{1 + e^{-z}} )$
- 1 đầu ra: $w_2 = 0.8 , b_2 = 0$, không có hàm kích hoạt ở tầng cuối
- Giá trị mục tiêu: $ y = 1 $
- Hàm mất mát: $L = \frac{1}{2} (\hat{y} - y)^2$
- Tốc độ học: $\eta = 0.1 $
Bước 1: Lan truyền xuôi
- Tầng ẩn: $z_1 = w_1 x + b_1 = 0.5 \cdot 2 + 0 = 1$
$ a_1 = \sigma(1) = \frac{1}{1 + e^{-1}} \approx 0.731$ - Tầng đầu ra: $z_2 = w_2 a_1 + b_2 = 0.8 \cdot 0.731 + 0 \approx 0.585$
- Mất mát: $L = \frac{1}{2} (0.585 - 1)^2 = \frac{1}{2} (0.415)^2 \approx 0.086$
Bước 2: Lan truyền ngược
- Tính sai số tại tầng đầu ra:
$$ \delta_2 = \frac{\partial L}{\partial \hat{y}} = \hat{y} - y = 0.585 - 1 = -0.415$$
$$ \frac{\partial L}{\partial w_2} = \delta_2 \cdot a_1 = -0.415 \cdot 0.731 \approx -0.303$$ - Tính sai số tại tầng ẩn:
$$\delta_1 = \delta_2 \cdot w_2 \cdot \sigma'(z_1) $$
Với $\sigma'(z) = \sigma(z) (1 - \sigma(z)) $, ta có $ \sigma'(1) = 0.731 \cdot (1 - 0.731) \approx 0.197 $
$$ \delta_1 = -0.415 \cdot 0.8 \cdot 0.197 \approx -0.065 $$
$$\frac{\partial L}{\partial w_1} = \delta_1 \cdot x = -0.065 \cdot 2 \approx -0.13 $$
Bước 3: Cập nhật trọng số
- $w_2 = 0.8 - 0.1 \cdot (-0.303) = 0.8 + 0.0303 \approx 0.83$
- $w_1 = 0.5 - 0.1 \cdot (-0.13) = 0.5 + 0.013 \approx 0.513$
Quá trình này lặp lại cho đến khi ( L ) đủ nhỏ.
Kết luận
Thuật toán lan truyền ngược là nền tảng để huấn luyện mạng nơ-ron hiệu quả. Nó tận dụng tính toán gradient một cách tuần tự và có thể mở rộng cho các mạng sâu với nhiều tầng và hàng triệu tham số. Trong thực tế, các thư viện như TensorFlow hoặc PyTorch tự động hóa quá trình này, nhưng hiểu cơ chế toán học giúp nắm rõ cách mạng nơ-ron học từ dữ liệu.