武林争霸
588.47MB · 2025-09-11
大家好,我是IT周瑜。
Transformer大家都知道吧,就是各种大模型背后的那个”变形金刚“,Transformer翻译过来就是”变形金刚“。
业界都是用Python来实现Transformer,前几天我突发奇想,我能不能用Java也来实现一下Transformer呢?
于是乎,就有了这篇文章和后续的几篇文章,挑战一下,看看能不能用纯Java代码,也能实现Transformer。
今天先实现QKV机制。
在进入代码实战之前,我们先用一个通俗的比喻来理解什么是QKV。
想象一下你在图书馆里查资料。
你的大脑会拿着你的查询(Q),去和所有书的标签(K)进行匹配,计算一个“相关性分数”。分数越高的,说明这本书你越感兴趣。然后,根据这个分数,你按比例吸收所有书的内容(V),最终得到的信息就高度集中于你最关心的“机器学习”领域。
这个过程可以用一个著名的公式来概括:
接下来,我们就来看看Java代码是如何实现这个流程的。
我们的目标就是用Java实现上述公式的计算过程。下面是完整的代码,后面我们会解析它的主流程,如果你想要完整Transformer代码以及对应讲解视频,可以加我微信:it_zhouyu
package com.zhouyu;
import java.util.Arrays;
/**
* 作者:IT周瑜
* 公众号:IT周瑜
* 微信号:it_zhouyu
*/
public class QKV {
public static void main(String[] args) {
// 我喜欢编程 -> 我、喜欢、编程
// (1,4) 一个词的词向量
double[][] query = new double[][]{
{0.5, 0.2, 0.8, 0.1}
};
// (3,4) 三个词的词向量
double[][] key = new double[][]{
{0.3, 0.6, 0.1, 0.9},
{0.9, 0.2, 0.5, 0.4},
{0.1, 0.8, 0.7, 0.3}
};
// (3,4) 三个词的词向量
double[][] value = new double[][]{
{1.0, 2.0, 3.0, 4.0},
{2.5, 3.5, 4.5, 5.5},
{5.1, 4.1, 3.1, 2.1}
};
int dk = key[0].length;
// (1,4) * (4,3)
// printMatrix("Query:", query);
// printMatrix("Key:", key);
// printMatrix("Key Transpose:", transpose(key));
// 注意力分数
double[][] scores = dotProduct(query, transpose(key));
printMatrix("Scores:", scores);
// 缩放
double scaleFactor = Math.sqrt(dk);
for (int i = 0; i < scores.length; i++) {
for (int j = 0; j < scores[i].length; j++) {
scores[i][j] /= scaleFactor;
}
}
// 注意力权重
double[][] attentionWeights = softmax(scores);
printMatrix("Attention Weights:", attentionWeights);
// 注意力输出
double[][] attentionOutput = dotProduct(attentionWeights, value);
printMatrix("Attention Output", attentionOutput);
}
public static double[][] dotProduct(double[][] matrixA, double[][] matrixB) {
int a_rows = matrixA.length;
int a_columns = matrixA[0].length;
int b_rows = matrixB.length;
int b_columns = matrixB[0].length;
// (1,4) * (4,3) 第一个矩阵的列数必须等于第二个矩阵的行数
if (a_columns != b_rows) {
throw new IllegalArgumentException("矩阵维度不匹配,无法进行乘法运算。");
}
double[][] result = new double[a_rows][b_columns];
for (int i = 0; i < a_rows; i++) {
for (int j = 0; j < b_columns; j++) {
for (int k = 0; k < a_columns; k++) {
result[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
return result;
}
public static double[][] transpose(double[][] matrix) {
int m = matrix.length;
int n = matrix[0].length;
double[][] transposedMatrix = new double[n][m];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
transposedMatrix[j][i] = matrix[i][j];
}
}
return transposedMatrix;
}
public static double[][] softmax(double[][] matrix) {
double[][] result = new double[matrix.length][matrix[0].length];
for (int i = 0; i < matrix.length; i++) {
double maxVal = Double.NEGATIVE_INFINITY;
// 找到行中的最大值
for (double val : matrix[i]) {
if (val > maxVal) {
maxVal = val;
}
}
double sumExp = 0.0;
// 计算 exp(x - max) 的总和
for (int j = 0; j < matrix[i].length; j++) {
result[i][j] = Math.exp(matrix[i][j] - maxVal);
sumExp += result[i][j];
}
// 除以总和,得到概率分布
for (int j = 0; j < matrix[i].length; j++) {
result[i][j] /= sumExp;
}
}
return result;
}
public static void printMatrix(String name, double[][] matrix) {
System.out.println(name + ":");
for (double[] row : matrix) {
System.out.println(Arrays.toString(row));
}
System.out.println();
}
}
代码的main
方法完美地对应了我们前面提到的注意力公式和图书馆的比喻。整个过程分为四步:
第一步:计算相关性分数 (Q * K^T
)
我们用 query
矩阵乘以 key
矩阵的转置。这一步的目的,就是计算出我们的“查询”词和句子中其他所有“标签”词之间的相关性分数。分数越高,代表关系越近。
第二步:缩放分数
将上一步得到的所有分数除以一个固定的值(向量维度的平方根)。这主要是为了在模型训练时保持数据的稳定性,我们可以暂时理解为一个标准化的操作。
第三步:计算最终权重 (Softmax)
Softmax
函数会将上一步得到的分数转换成一组总和为1的概率值,也就是最终的“注意力权重”。这个权重决定了我们应该对每个词的“内容(V)”投入多少关注。
第四步:加权求和 (Weights * V
)
最后,我们用上一步得到的权重,去乘以每个词对应的 value
矩阵(词的“内容”)。这一步相当于,根据权重,按比例提取所有词的信息,然后融合在一起,得到一个包含了全局上下文信息的全新向量。
运行代码后,最后的Attention Output
就是一个融合了其他词信息的新向量。
Attention Output:
[2.964126624094556, 3.275048669898762, 3.5859707157029685, 3.8968927615071745]
通过这个简单的Java程序,我们一步步实现了Transformer模型中最核心的QKV机制,点赞+关注,下次继续,必须用Java把完整的Transformer实现出来。