狠人用500行SQL实现GPT大模型


本文介绍了如何在SQL中实现一个大型语言模型。作者解释了生成式大型语言模型的技术原理,并详细讲解了GPT2的实现过程。其中包括分词器的实现和字典的构建。通过使用递归CTE,将文本分割为标记,并合并最佳相邻对,以生成编码后的表示。作者还提到了生成文本的过程,并介绍了GPT2中的生成算法。

原文摘要:
今年,全城的热门话题是人工智能以及它如何为你做一切事情。

我喜欢某人或某事为我做一切事情。为此,我决定请ChatGPT来写我的新年帖:

“嘿ChatGPT。你能用SQL实现大型语言模型吗?”
“不,SQL不适合实现大型语言模型。
SQL 是一种用于管理和查询关系数据库中的数据的语言。
实现语言模型需要复杂的算法、大型数据集,通常还需要深度学习框架,而这超出了 SQL 的能力。”

如果你想把一件事做好,你就必须自己去做。
受到这一乐观预测的鼓舞,今天我们将用 SQL实现一个大型语言模型。

在撰写本文时,我使用了 Jay Mody 撰写的精彩文章《GPT[url=https://jaykmody.com/blog/gpt-from-scratch/] in 60 Lines of [/url]NumPy》。本文对GPT模型的内部运作原理的解释比我希望的要好得多。

从技术角度来看什么是生成式大语言模型?
生成式 LLM 是一个函数。它接受一个文本字符串作为输入(在人工智能术语中称为“提示”),并返回一个字符串和数字数组。该函数的签名如下所示:

llm(prompt: str) -> list[tuple[str, float]]

这个函数是确定性的。它在底层做了很多数学运算,但所有这些数学运算都是硬连线的。如果您使用相同的输入重复调用它,它将始终返回相同的输出。

对于任何使用ChatGPT和类似产品的人来说,这可能会感到惊讶,因为它们可以对同一问题给出不同的答案。
然而,这是真的。我们很快就会看到它是如何工作的。

函数返回什么值?
像这样的东西:

llm("I wish you a happy New")
 
0       (' Year', 0.967553)
1       (' Years', 0.018199688)
2       (' year', 0.003573329)
3       (' York', 0.003114716)
4       (' New', 0.0009022804)

50252   (' carbohyd', 2.3950911e-15)
50253   (' volunte', 2.2590102e-15)
50254   ('pmwiki', 1.369229e-15)
50255   (' proport', 1.1198108e-15)
50256   (' cumbers', 7.568147e-17)

它返回一个元组数组。每个元组由一个单词(或者更确切地说,一个字符串)和一个数字组成。数字是这个词将继续提示的概率。该模型“认为短语“I wish you a happy New”后面会跟有字符序列“Year”,概率为 96.7%,“Years”为 1.8%,依此类推。

上面引用“思考”这个词是因为,当然,模型并没有真正思考。它根据一些硬连线的内部逻辑机械地返回单词和数字的数组。

如果它是那么笨,那么确定,它怎么能生成不同的文本呢?
大型语言模型用于文本应用程序(聊天机器人、内容生成器、代码助手等)。
这些应用程序会反复调用模型,并选择模型建议的单词(具有一定的随机性)。下一个建议的单词被添加到提示中,然后再次调用模型。如此循环往复,直到生成足够多的单词。

累积的单词序列看起来就像人类语言中的文本,包含语法、句法,甚至看起来像是智能和推理。
在这一点上,它与马尔科夫链 的工作原理并无二致。

大型语言模型的内部结构被连接起来,以便下一个建议的单词将是提示语的自然延续,并具有完整的语法、语义和情感。通过一系列科学上的突破(以及编程上的艰辛),GPT(即生成式预训练转换器)算法家族得以发展,使函数具备了这样的逻辑。

GPT生成式预训练转换器(Generative Pre-trained Transformer)是什么意思?
"Generative 生成 "意味着它可以生成文本(正如我们之前看到的那样,通过向提示符递归添加连续字符)。

"Transformer 转换器"是指它使用了一种特殊的神经网络,这种网络由 Google 首先开发,并在本文中进行了描述。

"Pre-trained 预训练 "有点历史。最初,人们认为模型延续文本的能力只是更专业任务的先决条件:

  1. 推理(找到短语之间的逻辑联系)
  2. 分类(例如,从评论文本中猜测酒店评分中的星数)、机器翻译等等。

有人认为,这两部分应该分开训练,语言部分只是为后面第二步的 "真正 "任务进行的预训练。

正如最初的GPT论文所说:

我们证明,通过在各种未标记文本的语料库上对语言模型进行生成式预训练,然后对每个特定任务进行区分性微调, 可以实现这些任务的巨大收益。

直到后来人们才意识到,当模型足够大时,第二步往往是不必要的。Transformer 模型经过训练除了生成文本之外什么也不做,结果证明能够遵循这些文本中包含的人类语言指令,无需额外的训练(人工智能术语中的“微调”)。

代码
当我们尝试使用 GPT2 根据提示生成文本时,会发生以下情况:

def generate(prompt: str) -> str:
  # 将字符串转换为标记列表。
  tokens = tokenize(prompt) # tokenize(prompt: str) -> list[int]
 
  while True:
 
    # 运行算法。
    返回标记的概率:50257 个浮点数的列表,加起来等于 1。
    candidates = gpt2(tokens) # gpt2(tokens: list[int]) -> list[float]
 
    从候选标记列表中选择下一个标记
    next_token = select_next_token(candidates)
    # select_next_token(candidates: list[float]) -> int
 
    # 将其添加到标记列表中
    tokens.append(next_token)
 
    # 决定是否要停止生成。
    # 它可以是标记计数器、超时、停止符或其他。
    if should_stop_generating():
      break
 
  # 将标记列表转换为字符串
  completion = detokenize(tokens) # detokenize(tokens: list[int]) -> str
  return completion

让我们用 SQL 来一一实现所有这些部分。

分词器/标记器
在将文本输入神经网络之前,需要将其转换为数字列表。当然,这并不是什么新闻:Unicode 等文本编码就是这么做的。然而,普通 Unicode 并不能很好地与神经网络配合使用。

神经网络的核心是进行大量的矩阵乘法运算,并在这些矩阵的系数中捕捉它们所具有的预测能力。这些矩阵中,有些是 "字母表 "中每个可能的值都有一行,有些则是每个 "字符 "都有一行。

在这里,"字母表 "和 "字符 "并不具有通常的含义。在 Unicode 中,"字母表 "的长度为 149186 个字符(这是本文撰写时的 Unicode 点数),而一个 "字符 "可以是这样的:﷽(是的,这是一个 Unicode 点,编号为 65021,编码阿拉伯语中对穆斯林特别重要的一整句话)。请注意,同样的短语也可以用普通的阿拉伯字母书写。这意味着同一文本可以有多种编码。

以 "PostgreSQL "为例。如果我们使用 Unicode 对其进行编码(转换为数字数组),我们将得到 10 个数字,可能从 1 到 149186。这意味着我们的神经网络需要存储一个包含 149186 行的矩阵,并对矩阵中的 10 行进行多次计算。其中一些行(对应英文字母)会被大量使用,并包含大量信息;而另一些行,如大便表情符号和来自死语言的晦涩符号,则几乎不会被使用,但仍会占用空间。

当然,我们希望尽可能减少 "字母表 "长度和 "字符 "数量。理想情况下,字母表中的所有 "字符 "都应均匀分布,而且我们仍然希望我们的编码能像 Unicode 一样强大。

直观地说,我们能做到这一点的方法就是为我们所处理的文本中经常出现的单词序列分配唯一的数字。在 Unicode 中,阿拉伯语中的同一宗教短语可以使用单个码位或逐个字母进行编码。既然我们正在进行自己的编码,那么对于模型中重要的单词和短语(即在文本中经常出现的单词和短语),我们也可以进行同样的编码。

例如,我们可以为 "Post"、"greSQL "和 "ing "设置不同的数字。这样,"PostgreSQL "和 "Posting "在我们的表示法中长度都是 2。当然,我们仍然会为较短的序列和单个字节保留单独的代码点。即使我们遇到的是胡言乱语或外语文本,尽管长度较长,但仍然可以编码。

GPT2 使用一种名为 "字节对编码 "的算法变体来实现这一目的。它的标记器使用一个包含 50257 个码位(用人工智能术语来说,就是 "标记")的字典,这些码位对应 UTF-8 中的不同字节序列(加上作为单独标记的 "文本结束")。

这个字典是通过统计分析建立的:

  • 从 256 个标记的简单编码开始:每个字节一个标记。
  • 选取一个大型文本语料库(最好是模型将要训练的语料库)。
  • 对其进行编码。
  • 计算哪对标记出现频率最高。假设是 0x20 0x74(空格后是小写字母 "t")。
  • 为这对字节分配下一个可用值(257)。
  • 重复步骤 3-5,注意字节序列。如果字节序列可以用复合标记编码,则使用复合标记。如果存在歧义(例如,"abc "在某些情况下可以编码为 "a "+"bc "或 "ab "+"c"),则使用数字最小的那个(因为它是较早添加的,因此频率较高)。如此递归进行,直到所有能折叠成一个标记的序列都能折叠成一个标记。
  • 如此反复进行 50000 次。

50000 这个数字或多或少是开发者随意选择的。其他模式的代币数量保持在类似的范围内(从 30k 到 100k)。

该算法每迭代一次,就会向字典中添加一个由之前的两个标记连接而成的新标记。最终,我们将得到 50256 个标记。为 "文本结束 "添加一个固定数字的标记,就大功告成了。

GPT2 版本的 BTE 还有另一层编码:令牌字典将令牌映射为字符串,而不是字节数组。该函数定义了从字节到字符串的映射。我们将把它生成的字典保存在表格编码器中。

让我们看看如何在 SQL 中实现标记符。

令牌生成器是 GPT2 不可分割的一部分,令牌字典可与模型的其他部分一起从 OpenAI 网站下载。我们需要将其导入表令牌器。
在本篇文章的底部,您可以找到代码库的链接。
其代码将自动填充模型所需的数据库表。

在递归 CTE 中,我们将把单词分割成标记(从单字节开始),然后合并最佳的相邻对,直到没有可合并的内容为止。合并本身发生在嵌套递归 CTE 中。

在演示中,我将使用单词 "Mississippilessly"。结果集中的每条记录都显示了迄今为止找到的最佳词对,以及查询的进度。

WITH    RECURSIVE
        bpe AS
        (
        SELECT  (n + 1)::BIGINT AS position, character, TRUE AS continue, 1 AS step,
                NULL::INT AS token, NULL::TEXT AS combined
        FROM    CONVERT_TO('Mississippilessly', 'UTF-8') AS bytes
        CROSS JOIN LATERAL
                GENERATE_SERIES(0, LENGTH(bytes) - 1) AS n
        JOIN    encoder
        ON      byte = GET_BYTE(bytes, n)
        UNION ALL
        (
        WITH    RECURSIVE
                base AS
                (
                SELECT  *
                FROM    bpe
                WHERE   continue
                ),
                bn AS
                (
                SELECT  ROW_NUMBER() OVER (ORDER BY position) AS position,
                        continue,
                        character,
                        character || LEAD(character) OVER (ORDER BY position) AS cluster
                FROM    base
                ),
                top_rank AS
                (
                SELECT  tokenizer.*
                FROM    bn
                CROSS JOIN LATERAL
                        (
                        SELECT  *
                        FROM    tokenizer
                        WHERE   tokenizer.cluster = bn.cluster
                        LIMIT   1
                        ) tokenizer
                ORDER BY
                        token
                LIMIT   1
                ),
                breaks AS
                (
                SELECT  0::BIGINT AS position, 1 AS length
                UNION ALL
                SELECT  bn.position,
                        CASE WHEN token IS NULL THEN 1 ELSE 2 END
                FROM    breaks
                JOIN    bn
                ON      bn.position = breaks.position + length
                LEFT JOIN
                        top_rank
                USING   (cluster)
                )
        SELECT  position, character, token IS NOT NULL,
                (SELECT step + 1 FROM base LIMIT 1), token, top_rank.cluster
        FROM    breaks
        LEFT JOIN
                top_rank
        ON      1 = 1
        CROSS JOIN LATERAL
                (
                SELECT  STRING_AGG(character, '' ORDER BY position) AS character
                FROM    bn
                WHERE   bn.position >= breaks.position
                        AND bn.position < breaks.position + length
                ) bn
        WHERE   position > 0
        )
        )
SELECT  step, MAX(token) AS token, MAX(combined) AS combined, ARRAY_AGG(character ORDER BY position)
FROM    bpe
WHERE   continue
GROUP BY
        step
ORDER BY
        step

step    token    combined    array_agg
1    None    None    ['M', 'i', 's', 's', 'i', 's', 's', 'i', 'p', 'p', 'i', 'l', 'e', 's', 's', 'l', 'y']
2    271    is    ['M', 'is', 's', 'is', 's', 'i', 'p', 'p', 'i', 'l', 'e', 's', 's', 'l', 'y']
3    274    es    ['M', 'is', 's', 'is', 's', 'i', 'p', 'p', 'i', 'l', 'es', 's', 'l', 'y']
4    306    ly    ['M', 'is', 's', 'is', 's', 'i', 'p', 'p', 'i', 'l', 'es', 's', 'ly']
5    346    il    ['M', 'is', 's', 'is', 's', 'i', 'p', 'p', 'il', 'es', 's', 'ly']
6    381    pp    ['M', 'is', 's', 'is', 's', 'i', 'pp', 'il', 'es', 's', 'ly']
7    408    ess    ['M', 'is', 's', 'is', 's', 'i', 'pp', 'il', 'ess', 'ly']
8    747    iss    ['M', 'iss', 'iss', 'i', 'pp', 'il', 'ess', 'ly']
9    3974    ipp    ['M', 'iss', 'iss', 'ipp', 'il', 'ess', 'ly']
10    17140    Miss    ['Miss', 'iss', 'ipp', 'il', 'ess', 'ly']
11    30608    iless    ['Miss', 'iss', 'ipp', 'iless', 'ly']

在每一步中,BPE 算法都会找到要合并的最佳标记对并将其合并(您可以在输出中看到合并后的标记对及其等级)。这一过程将标记空间的大小从 Unicode 的 150k 降至 50k,标记数量(在这一特定单词中)从 17 降至 5。

在处理多个单词时,标记符号生成器会首先使用这个 regexp 将文本分割成不同的单词,然后分别合并每个单词内的标记符号。
不幸的是,PostgreSQL 不支持 regexp 中的 Unicode 字符属性,所以我不得不稍作调整(很可能在这个过程中扼杀了对 Unicode 的正确支持)。下面是它在 SQL 中的样子:

WITH    input AS
        (
        SELECT  'PostgreSQL is great' AS prompt
        ),
        clusters AS
        (
        SELECT  part_position, bpe.*
        FROM    input
        CROSS JOIN LATERAL
                REGEXP_MATCHES(prompt, '''s|''t|''re|''ve|''m|''ll|''d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+', 'g') WITH ORDINALITY AS rm (part, part_position)
        CROSS JOIN LATERAL
                (
                WITH    RECURSIVE
                        bpe AS
                        (
                        SELECT  (n + 1)::BIGINT AS position, character, TRUE AS continue
                        FROM    CONVERT_TO(part[1], 'UTF-8') AS bytes
                        CROSS JOIN LATERAL
                                GENERATE_SERIES(0, LENGTH(bytes) - 1) AS n
                        JOIN    encoder
                        ON      byte = GET_BYTE(bytes, n)
                        UNION ALL
                        (
                        WITH    RECURSIVE
                                base AS
                                (
                                SELECT  *
                                FROM    bpe
                                WHERE   continue
                                ),
                                bn AS
                                (
                                SELECT  ROW_NUMBER() OVER (ORDER BY position) AS position,
                                        continue,
                                        character,
                                        character || LEAD(character) OVER (ORDER BY position) AS cluster
                                FROM    base
                                ),
                                top_rank AS
                                (
                                SELECT  tokenizer.*
                                FROM    bn
                                CROSS JOIN LATERAL
                                        (
                                        SELECT  *
                                        FROM    tokenizer
                                        WHERE   tokenizer.cluster = bn.cluster
                                        LIMIT   1
                                        ) tokenizer
                                ORDER BY
                                        token
                                LIMIT   1
                                ),
                                breaks AS
                                (
                                SELECT  0::BIGINT AS position, 1 AS length
                                UNION ALL
                                SELECT  bn.position,
                                        CASE WHEN token IS NULL THEN 1 ELSE 2 END
                                FROM    breaks
                                JOIN    bn
                                ON      bn.position = breaks.position + length
                                LEFT JOIN
                                        top_rank
                                USING   (cluster)
                                )
                        SELECT  position, character, token IS NOT NULL
                        FROM    breaks
                        LEFT JOIN
                                top_rank
                        ON      1 = 1
                        CROSS JOIN LATERAL
                                (
                                SELECT  STRING_AGG(character, '' ORDER BY position) AS character
                                FROM    bn
                                WHERE   bn.position >= breaks.position
                                        AND bn.position < breaks.position + length
                                ) bn
                        WHERE   position > 0
                        )
                        )
                SELECT  position, character AS cluster
                FROM    bpe
                WHERE   NOT continue
                ) bpe
        ),
        tokens AS
        (
        SELECT  token, cluster
        FROM    clusters
        JOIN    tokenizer
        USING   (cluster)
        )
SELECT  *
FROM    tokens

输出
token    cluster
6307    Post
47701    greSQL
318    Ġis
1049    Ġgreat

这个奇怪的字符就是空格。

该查询将提示符标记化,并将其转换为数字数组。这样,提示符就可以进入模型的各个层级了。


嵌入
标记代表了人类语言的各个部分(一般来说,每个标记代表约 0.75 个单词),因此任何试图成功完成文本补全的模型都应该以某种方式对这些部分之间的关系进行编码。即使孤立来看,语音的各个部分也具有正交属性。

让我们来看看 "传票subpoena "这个词(在 GPT2 标记符号化器中,它本身就有一个完整的标记符号)。它是一个名词吗?是的,非常像。它是动词吗?算是吧。是形容词吗?不是,但如果你眯起眼睛仔细看的话,它也可以是。是法律术语吗?当然是。等等。

所有这些属性都是正交的,即相互独立。一个词可以是法律名词,但不能是形容词或动词。在英语中,它们的任何组合都可能发生。(banq注:在不同上下文中词语的意思不同,人工建模中使用BC这样概念逐个上下文建模,而在人工智能中则将多个上下文作为向量矩阵建模)

具有正交属性的事物最好使用矢量进行编码。我们可以有多种属性,而不是单一的属性(如代币编号)。而且,如果我们能随心所欲地摆动它们,也会有所帮助。例如,要想用一个词来延续 "律师引用的法院判决中提到了...... "这个短语,我们可能需要一个侧重于法律层面的词,同时又侧重于作为一个名词。我们并不关心它是否兼做形容词、动词或花朵。

在数学中,将较窄的值映射到较宽的空间(如将标记 ID 映射到向量)被称为嵌入。这正是我们要做的。

我们如何确定这些向量所代表的属性?我们不需要。我们只是为每个标记提供足够的向量空间,并希望模型在训练阶段能在这些维度上填充一些有意义的东西。GPT2 的向量使用了 768 个维度。我们无法事先知道(实际上,甚至在回溯时也无法知道),比如说,维数 247 会编码单词的什么属性。它肯定会编码一些东西,但要知道是什么并不容易。

我们想在向量空间中嵌入每个标记的什么属性?
任何与下一个标记有关的属性。

令牌 ID?当然。不同的令牌有不同的含义。

令牌在文本中的位置?请回答。"蓝紫色 "和 "蓝紫色 "不是一回事。

标记之间的关系?当然可以!这可能是工作中最重要的部分,而变形金刚架构的 "注意 "区块是第一个把它处理好的区块。

标记和位置很容易嵌入。比方说,我们有一句短语 "PostgreSQL 太棒了",正如我们已经知道的,它映射到四个标记:[6307, 47701, 318, 1049]。

在 GPT2 的其他参数中,有两个矩阵被称为 WTE(词标记嵌入)和 WPE(词位置嵌入)。顾名思义,前者存储词标记的嵌入,后者存储词位置的嵌入。这些嵌入的实际值是在 GPT2 的训练过程中填充("学习")的。就我们而言,它们是数据库表 wte 和 wpe 中的常量。

WTE 为 50257×768,WPE 为 1024×768。后者意味着我们可以在 GPT2 提示中使用的最大标记数是 1024。如果我们在提示符中提供更多的标记,就无法为它们提取位置嵌入。这是模型的一个架构方面(用人工智能术语来说就是 "超参数"),在设计时就已设定,不能通过训练来改变。当人们谈论 LLM 的 "上下文窗口 "时,他们指的就是这个数字。

我们在第 0 位有 6307 个标记,第 1 位有 47701 个标记,第 2 位有 318 个标记,第 3 位有 1049 个标记。对于每个标记和位置,我们都有两个向量:一个来自 WTE,另一个来自 WPE。我们需要将它们相加。得到的四个向量将作为算法下一部分的输入:带有注意力机制的前馈神经网络。

对于 SQL 部分,我们将使用 pgvector,这是 PostgreSQL 的一个扩展。

这里有一点免责声明:通常情况下,我都是用普通 SQL 编写新年文章的代码,有时会使用纯 SQL 函数作为助手。这篇文章也完全可以这样做,在数组上定义矢量操作,但代价是性能会有所下降(在第一版中已经这样做了,尽管速度很慢)。随着人工智能的出现和矢量数据库重要性的增加,pgvector 或其等价物肯定会在两三个版本中成为 PostgreSQL 的核心。我只是决定顺势而为。

WITH    embeddings AS
        (
        SELECT  place, values
        FROM    UNNEST(ARRAY[6307, 47701, 318, 1049]) WITH ORDINALITY AS tokens (token, ordinality)
        CROSS JOIN LATERAL
                (
                SELECT  ordinality - 1 AS place
                ) o
        CROSS JOIN LATERAL
                (
                SELECT  wte.values + wpe.values AS values
                FROM    wte
                CROSS JOIN
                        wpe
                WHERE   wte.token = tokens.token
                        AND wpe.place = o.place
                ) embedding
        )
SELECT  place, (values::REAL)[0:5]
FROM    embeddings

place    values
0    [0.1035146, -0.22879261, 0.18413992, -0.29924694, 0.18642524]
1    [0.10757777, -0.0011023134, -0.0077463835, 0.03656415, -0.14654925]
2    [-0.005507436, -0.07471258, 0.11009377, -0.11708109, -0.14026159]
3    [-0.04785268, -0.0792546, 0.1628486, -0.3598496, 0.11462127]

注意
使 Transformer 架构真正发挥作用的是自我注意机制。
瓦斯马尼等人在 2017 年发表的论文《Attention is all you need》中首次描述了这一机制,该论文可能是最著名的人工智能论文,其名字后来也成为了雪泥鸿爪(为其他论文命名的陈词滥调)。

到目前为止,我们已经有了几个向量,希望它们能编码提示语中单词的一些句法和语义属性。我们需要将这些属性以某种方式转移到最后一个向量上。剧透一下:最后,最后一个向量将存储续词的嵌入。

在 "我看着紫罗兰,发现它不是寻常的...... "这样的短语中,省略号必须是你看到的东西(这个概念必须从 "看到 "跳转),是紫罗兰的属性(从 "紫罗兰 "跳转到 "它",再跳转到省略号),是 "不寻常 "的东西,这是从 "不是 "和 "寻常 "跳转,在负责通常性的维度上(上下文中)翻转符号。

现实世界中的类比是,一个人在阅读一本外语书,他对这本书有基本的了解,但还不是很熟悉。他们需要有意识地从一个单词追溯到另一个单词,如果不注意短语的关键部分,他们的理解就会出错。

为了实现从一个标记到另一个标记的意义转换,我们需要让所有标记的向量相互影响。如果我们想在 "它 "这个词中填充一些具体的语义,那么有多少语义应该来自提示语中的前几个向量,又有多少应该来自 "它 "这个词本身呢?

为了解决这个问题,该模型使用了 12 组矩阵,分别称为 Q(查询)、K(关键)和 V(值)。每个矩阵有 64 列。它们通过 768×2304 的线性变换 c_attn 从向量嵌入中获得,其权重和偏差存储在表 c_attn_w 和 c_attn_b 中。

c_attn 的结果是一个有 n_token 行和 2304 列(3×12×64)的矩阵。它由 12 个 Q 矩阵、12 个 K 矩阵和 12 个 V 矩阵按此顺序水平堆叠而成。

每组 Q、K 和 V 称为一个 "头"。它们通过计算注意力函数来执行称为 "多头因果自注意力 "的步骤。

下面是注意力函数的计算公式:

点击标题见原文

其中 softmax 是权重标准化函数。其定义如下

点击标题见原文

 是一个常数矩阵,称为 "因果掩码"。其定义如下
点击标题见原文
Softmax 将负无穷变为零。

为什么需要屏蔽Mask?
前面例子中的提示有 4 个标记,模型首先要做的就是计算这 4 个标记的 4 个嵌入向量。随着模型的发展,这些向量将进行大量的计算,但在大多数情况下,它们是独立和并行的。一个向量的变化不会影响其他向量,就好像它们不存在一样。自我关注区块是整个模型中唯一一个矢量会相互影响的地方。

一旦模型完成数学运算,下一个标记的候选者将完全由上一次嵌入决定。所有的信息流都应流向最后一个向量,而不是从它流出。在模型的前向传递过程中,最后一个嵌入向量的瞬态值不应影响之前嵌入向量的瞬态值。

这就是为什么我们要 "屏蔽 "后一个嵌入向量,使其不会通过这一特定通道影响前一个嵌入向量。这就是 "多头因果自关注 "中的 "因果 "一词的由来。

为什么矩阵被称为 "查询"、"键 "和 "值"?
老实说,我不确定这是不是一个好的比喻。不过,我还是会从直觉上谈谈我的看法。

在机器学习中,一般来说,计算不应涉及变长循环或语句分支。一切都应通过简单分析函数(加法、乘法、幂、对数和三角函数)的组合来完成。这样,依赖于自动微分等技术的反向传播技术才能高效工作。

详细点击标题

前馈
这就是深度神经网络的作用。这一步实际上使用了模型参数的大部分。

这一步是一个具有三层(768、3072、768)的多层感知器,使用高斯误差线性单元( GELU)作为激活函数:

下面是我们在 SQL 中执行此操作的方法:


我们在前面的步骤中看到的内容会在各层(称为 "块")中重复。区块以流水线的方式设置,因此前一个区块的输出会直接进入下一个区块。每个区块都有自己的学习参数集。

在 SQL 中,我们需要使用递归 CTE 来连接各个区块。

当最后一个数据块产生数值后,我们需要使用已学参数对其进行规范化处理


标记
我们已经有了一个嵌入(一个 768 向量),根据模型,它捕捉到了提示语最有可能延续的语义和语法。现在,我们需要将其映射回标记。

模型的第一步就是将标记映射到它们的嵌入。这是通过 50257×768 矩阵 wpe 完成的。我们需要使用相同的矩阵将嵌入映射回标记。

问题是,精确的反向映射是不可能的:嵌入不(可能)等于矩阵中的任何一行。因此,我们需要找到与嵌入 "最接近 "的标记。

由于嵌入的维度(正如我们所希望的那样)捕捉了标记的某些语义和语法方面,因此我们需要它们尽可能地匹配。合并每个维度的接近程度的一种方法是计算两个嵌入式的点积。点积越大,则标记与预测越接近。

为此,我们将把嵌入值乘以矩阵 wte。结果将是一个 50257 行高的单列矩阵。结果中的每个值都是预测嵌入和标记嵌入的点积。这个数字越大,标记继续提示的可能性就越大。

要选择下一个标记,我们需要将相似度转换为概率。为此,我们将使用我们的好朋友 softmax(与我们用来规范注意力权重的函数相同)。

为什么使用softmax计算概率?
Softmax具有满足Luce 选择公理的良好特性。这意味着两个选项的相对概率不依赖于其他选项的存在或概率。如果 A 的概率是 B 的两倍,那么其他选项的存在或不存在都不会改变这个比率(尽管它当然可以改变绝对值)。

点积向量(人工智能术语中的“ logit ”)包含没有内在尺度的任意分数。如果 A 的分数比 B 的分数大,我们就知道它的可能性更大,但仅此而已。我们可以随意调整softmax的输入,只要它们保持顺序(即分数越大)。

一种常见的方法是通过减去集合中的最大值来标准化分数(使最大分数变为 0,其余变为负数)。然后我们取一些固定数量(比如说五个或十个)最高分。最后,我们将每个分数乘以一个常数,然后将其输入到softmax中。

我们获得的最高分的数量通常被称为顶部\_n乘法常数(或者更确切地说,它的倒数)被称为“温度”(时间)。温度越高,概率越平滑,下一个选择的令牌不仅仅是第一个令牌的机会就越大。

为什么叫“温度”?
softmax函数还有一个名字:玻尔兹曼分布。它广泛应用于物理学。除其他外,它还作为气压公式的基础,该公式说明密度或空气如何随高度变化。

直觉上,热空气上升。它传播到距地球更远的地方。当空气很热时,空气分子更有可能从邻近的空气分子反弹并跳跃到原本不可能的高度。与较冷的温度相比,空气密度在高海拔地区增加,在海平面地区下降。

以此类推,大的“温度”会增加第二选择令牌被选择的概率(当然,以牺牲第一选择令牌为代价)。推论变得更不可预测并且更具“创造性”。

让我们将这一切放入 SQL 中。提示是“ PostgreSQL很棒”。根据模型,以下是最有可能延续该短语的前 5 个标记,以及它们在不同温度下的概率:


推理
最后,我们可以进行真正的推理了:运行模型,根据概率选择一个标记,将其添加到提示中,然后重复,直到生成足够多的标记。

正如我们之前看到的,LLM 本身是确定性的:它只是对预定义常量进行一系列矩阵乘法和其他数学运算。只要提示符和温度、top_n 等超参数相同,输出结果也会相同。

唯一的非确定过程是令牌选择。其中存在随机性(程度不一)。这就是为什么基于 GPT 的聊天机器人可以对相同的提示给出不同的答案。

我们将使用 "新年快乐!我希望 "作为提示语,并让模型为这个提示语生成 10 个新标记。温度设置为 2,top_n 设置为 5。

您可以在GitHub存储库中找到查询和安装代码:quassnoi/explain-extended-2024