Trong học máy, chưng cất tri thức (knowledge distillation) hay chưng cất mô hình (model distillation) là quá trình chuyển giao tri thức từ mô hình lớn sang mô hình nhỏ hơn. Dù các mô hình lớn (như mạng nơ-ron sâu hoặc tập hợp nhiều mô hình) có khả năng lưu trữ tri thức cao hơn, khả năng này có thể chưa được sử dụng hết. Chi phí tính toán để chạy mô hình vẫn cao ngay cả khi nó chỉ sử dụng một phần nhỏ tri thức. Chưng cất tri thức cho phép chuyển giao kiến thức từ mô hình lớn sang nhỏ mà không làm giảm hiệu quả. Các mô hình nhỏ tốn ít chi phí tính toán hơn, có thể triển khai trên thiết bị phần cứng yếu hơn (như điện thoại di động).[1]
Khái niệm này khác với nén mô hình (model compression) - phương pháp giảm kích thước mô hình lớn mà không cần huấn luyện mô hình mới. Nén mô hình thường giữ nguyên kiến trúc và số lượng tham số, chỉ giảm số bit trên mỗi tham số.
Chưng cất tri thức đã được áp dụng thành công trong nhiều lĩnh vực như phát hiện vật thể,[2] mô hình âm học,[3] và xử lý ngôn ngữ tự nhiên.[4] Gần đây, phương pháp này cũng được áp dụng cho mạng nơ-ron đồ thị xử lý dữ liệu phi lưới.[5]
Việc chuyển giao tri thức từ mô hình lớn sang nhỏ cần đảm bảo tính hiệu quả. Nếu cả hai mô hình cùng được huấn luyện trên một tập dữ liệu, mô hình nhỏ có thể thiếu khả năng học biểu diễn tri thức cô đọng so với mô hình lớn. Tuy nhiên, thông tin về biểu diễn tri thức được mã hóa trong phân phối giả xác suất (pseudolikelihood) đầu ra. Khi mô hình dự đoán đúng một lớp, nó gán giá trị lớn cho biến đầu ra tương ứng và giá trị nhỏ hơn cho các biến khác. Phân phối giá trị đầu ra cung cấp thông tin về cách mô hình lớn biểu diễn tri thức. Do đó, mục tiêu triển khai mô hình hiệu quả có thể đạt được bằng cách: huấn luyện mô hình lớn trên dữ liệu để khai thác khả năng học biểu diễn tri thức, sau đó chưng cất tri thức này vào mô hình nhỏ thông qua việc huấn luyện nó học đầu ra "mềm" của mô hình lớn.[1]
Với mô hình lớn là hàm số của biến véc-tơ được huấn luyện cho bài toán phân loại, lớp cuối cùng thường là softmax dạng:
Trong đó là nhiệt độ (temperature) - tham số thường đặt bằng 1 cho softmax chuẩn. Toán tử softmax chuyển giá trị logit thành giả xác suất. Giá trị nhiệt độ cao tạo phân phối "mềm" hơn giữa các lớp đầu ra. Quá trình chưng lọc tri thức bao gồm việc huấn luyện mô hình nhỏ (gọi là mô hình chưng cất) trên tập chuyển giao (khác với tập dữ liệu huấn luyện mô hình lớn) sử dụng hàm mất mát entropy chéo giữa đầu ra của mô hình chưng cất và đầu ra của mô hình lớn , với giá trị nhiệt độ cao cho cả hai mô hình[1]:
Trong bối cảnh này, nhiệt độ cao làm tăng entropy của đầu ra. Điều này cung cấp nhiều thông tin hơn để mô hình chưng cất học so với mục tiêu cứng. Đồng thời, nó làm giảm phương sai của gradient giữa các mẫu khác nhau. Nhờ đó, tốc độ học có thể được tăng lên.[1]
Nếu nhãn thực tế (ground truth) có sẵn cho tập chuyển giao, quá trình có thể được củng cố bằng cách thêm entropy chéo vào hàm mất mát. Thành phần này tính toán giữa đầu ra của mô hình chưng cất (với ) và nhãn đã biết :
Trong đó, thành phần mất mát liên quan đến mô hình lớn được nhân với hệ số . Lý do là khi nhiệt độ tăng, gradient của mất mát so với trọng số mô hình tỷ lệ với .[1]
Giả sử các giá trị logit có trung bình bằng 0. Khi đó, nén mô hình là trường hợp đặc biệt của chưng cất tri thức. Gradient của hàm mất mát chưng cất so với logit của mô hình chưng cất được tính bởi:
Ở đây, là logit của mô hình lớn. Với giá trị lớn, công thức này có thể xấp xỉ thành:
Giả thiết trung bình bằng 0 dẫn đến . Đây là đạo hàm của , nghĩa là mất mát tương đương với việc khớp logit của hai mô hình như trong nén mô hình.[1]
Lặp lại cho đến khi đạt được mức độ thưa hoặc hiệu suất mong muốn:
Huấn luyện mạng (bằng các phương pháp như lan truyền ngược) cho đến khi thu được nghiệm hợp lý
Tính toán độ quan trọng (saliency) cho từng tham số
Xóa các tham số có độ quan trọng thấp nhất
Việc xóa tham số nghĩa là cố định giá trị của nó về 0. "Độ quan trọng" của tham số được định nghĩa là , với là hàm mất mát. Đạo hàm bậc hai có thể tính bằng lan truyền ngược bậc hai.
Ý tưởng chính của OBD là xấp xỉ hàm mất mát trong vùng lân cận nghiệm tối ưu bằng khai triển Taylor:Do (vì là tối ưu) và bỏ qua đạo hàm chéo để tiết kiệm tính toán. Như vậy, độ quan trọng của tham số ước lượng mức tăng mất mát nếu tham số đó bị xóa.
Một phương pháp liên quan là nén mô hình hoặc tỉa mô hình, trong đó một mạng đã được huấn luyện sẽ được giảm kích thước. Điều này lần đầu tiên được thực hiện vào năm 1965 bởi Alexey Ivakhnenko và Valentin Lapa tại Ukraina (1965).[7][8][9] Các mạng sâu của họ được huấn luyện từng lớp thông qua phân tích hồi quy. Các đơn vị ẩn không cần thiết được tỉa bỏ bằng cách sử dụng một tập kiểm định riêng.[10] Các phương pháp nén mạng nơ-ron khác bao gồm Biased Weight Decay[11] và Optimal Brain Damage.[6]
Một ví dụ sớm về chưng cất mạng nơ-ron đã được Jürgen Schmidhuber công bố vào năm 1991, trong lĩnh vực mạng nơ-ron hồi quy (RNN). Vấn đề là dự đoán chuỗi cho các chuỗi dài, tức là học sâu. Nó được giải quyết bởi hai RNN. Một trong số chúng (automatizer) dự đoán chuỗi, và một cái khác (chunker) dự đoán lỗi của automatizer. Đồng thời, automatizer dự đoán trạng thái bên trong của chunker. Sau khi automatizer dự đoán tốt trạng thái bên trong của chunker, nó sẽ bắt đầu sửa lỗi, và cuối cùng chunker trở nên lỗi thời, chỉ còn lại một RNN.[12][13]
Ý tưởng sử dụng đầu ra của một mạng nơ-ron để huấn luyện một mạng nơ-ron khác cũng được nghiên cứu dưới dạng cấu hình mạng giáo viên-học sinh.[14] Năm 1992, một số bài báo đã nghiên cứu cơ học thống kê của cấu hình giáo viên-học sinh với các máy chuyên gia (committee machine)[15][16] hoặc cả hai đều là máy riêng biệt.[17]
Việc nén kiến thức của nhiều mô hình vào một mạng nơ-ron duy nhất được gọi là nén mô hình vào năm 2006: nén được thực hiện bằng cách huấn luyện một mô hình nhỏ hơn trên một lượng lớn dữ liệu giả được gán nhãn bởi một tập hợp hiệu suất cao, tối ưu hóa để khớp logit của mô hình nén với logit của tập hợp.[18] Bản in trước về chưng cất tri thức của Geoffrey Hinton và cộng sự (2015)[1] đã định nghĩa khái niệm và cho thấy một số kết quả đạt được trong nhiệm vụ phân loại hình ảnh.
Việc chưng cất kiến thức cũng liên quan đến khái niệm nhân bản hành vi được Faraz Torabi và cộng sự thảo luận.[19]
^ abcdefgHinton, Geoffrey; Vinyals, Oriol; Dean, Jeff (2015). "Distilling the knowledge in a neural network". arΧiv:1503.02531 [stat.ML].
^Chen, Guobin; Choi, Wongun; Yu, Xiang; Han, Tony; Chandraker, Manmohan (2017). “Learning efficient object detection models with knowledge distillation”. Advances in Neural Information Processing Systems: 742–751.
^Asami, Taichi; Masumura, Ryo; Yamaguchi, Yoshikazu; Masataki, Hirokazu; Aono, Yushi (2017). Domain adaptation of DNN acoustic models using knowledge distillation. IEEE International Conference on Acoustics, Speech and Signal Processing. tr. 5185–5189.
^Cui, Jia; Kingsbury, Brian; Ramabhadran, Bhuvana; Saon, George; Sercu, Tom; Audhkhasi, Kartik; Sethy, Abhinav; Nussbaum-Thom, Markus; Rosenberg, Andrew (2017). Knowledge distillation across ensembles of multilingual models for low-resource languages. IEEE International Conference on Acoustics, Speech and Signal Processing. tr. 4825–4829.
^Buciluǎ, Cristian; Caruana, Rich; Niculescu-Mizil, Alexandru (2006). “Model compression”. Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining.
^Torabi, Faraz; Warnell, Garrett; Stone, Peter (2018). "Behavioral Cloning from Observation". arΧiv:1805.01954 [cs.AI].