diff --git a/README.md b/README.md
index 3c29bad..3906389 100644
--- a/README.md
+++ b/README.md
@@ -135,8 +135,8 @@ If you're interviewing for any role touching LLMs or Transformers, expect at lea
| # | Problem | What You'll Implement | Difficulty | Freq | Key Concepts |
|:---:|---------|----------------------|:----------:|:----:|--------------|
-| 23 | Cross-Attention
| `MultiHeadCrossAttention` (nn.Module) |  | โญ | Encoder-decoder, Q from decoder, K/V from encoder |
-| 5 | Scaled Dot-Product Attention
| `scaled_dot_product_attention(Q, K, V)` |  | ๐ฅ | `softmax(QK^T/โd_k)V`, the foundation of everything |
+| 23 | Cross-Attention
| `MultiHeadCrossAttention` (nn.Module) |  | โญ | Encoder-decoder, Q from decoder, K/V from encoder |
+| 5 | Scaled Dot-Product Attention
| `scaled_dot_product_attention(Q, K, V)` |  | ๐ฅ | `softmax(QK^T/โd_k)V`, the foundation of everything |
| 6 | Multi-Head Attention
| `MultiHeadAttention` (nn.Module) |  | ๐ฅ | Parallel heads, split/concat, projection matrices |
| 9 | Causal Self-Attention
| `causal_attention(Q, K, V)` |  | ๐ฅ | Autoregressive masking with `-inf`, GPT-style |
| 10 | Grouped Query Attention
| `GroupQueryAttention` (nn.Module) |  | โญ | GQA (LLaMA 2), KV sharing across heads |
diff --git a/solutions/05_attention_solution.ipynb b/solutions/05_attention_solution.ipynb
index e82f45f..eecef07 100644
--- a/solutions/05_attention_solution.ipynb
+++ b/solutions/05_attention_solution.ipynb
@@ -1,106 +1,107 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "5f63d076",
- "metadata": {},
- "source": [
- "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/05_attention_solution.ipynb)\n\n",
- "# ๐ด Solution: Softmax Attention\n",
- "\n",
- "Reference solution for the core Transformer attention mechanism.\n",
- "\n",
- "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$"
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/05_attention_solution.ipynb)\n",
+ "\n",
+ "# ๐ Solution: Softmax Attention\n",
+ "\n",
+ "Reference solution for the core Transformer attention mechanism.\n",
+ "\n",
+ "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$"
+ ],
+ "id": "5f63d076"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
+ "try:\n",
+ " import google.colab\n",
+ " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
+ "except ImportError:\n",
+ " pass\n"
+ ],
+ "execution_count": null,
+ "outputs": [],
+ "id": "ce663fb0"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "import torch\n",
+ "import math"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# โ
SOLUTION\n",
+ "\n",
+ "def scaled_dot_product_attention(Q, K, V):\n",
+ " d_k = K.size(-1)\n",
+ " scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)\n",
+ " weights = torch.softmax(scores, dim=-1)\n",
+ " return torch.bmm(weights, V)"
+ ],
+ "execution_count": null,
+ "outputs": [],
+ "id": "828be673"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Verify\n",
+ "torch.manual_seed(42)\n",
+ "Q = torch.randn(2, 4, 8)\n",
+ "K = torch.randn(2, 4, 8)\n",
+ "V = torch.randn(2, 4, 8)\n",
+ "\n",
+ "out = scaled_dot_product_attention(Q, K, V)\n",
+ "print(\"Output shape:\", out.shape)\n",
+ "print(\"Attention weights sum to 1?\", True)\n",
+ "\n",
+ "# Cross-attention (seq_q != seq_k)\n",
+ "Q2 = torch.randn(1, 3, 16)\n",
+ "K2 = torch.randn(1, 5, 16)\n",
+ "V2 = torch.randn(1, 5, 32)\n",
+ "out2 = scaled_dot_product_attention(Q2, K2, V2)\n",
+ "print(\"Cross-attention shape:\", out2.shape, \"(expected: 1, 3, 32)\")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Run judge\n",
+ "from torch_judge import check\n",
+ "check(\"attention\")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ce663fb0",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
- "try:\n",
- " import google.colab\n",
- " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
- "except ImportError:\n",
- " pass\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import math"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "828be673",
- "metadata": {},
- "outputs": [],
- "source": [
- "# โ
SOLUTION\n",
- "\n",
- "def scaled_dot_product_attention(Q, K, V):\n",
- " d_k = K.size(-1)\n",
- " scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)\n",
- " weights = torch.softmax(scores, dim=-1)\n",
- " return torch.bmm(weights, V)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Verify\n",
- "torch.manual_seed(42)\n",
- "Q = torch.randn(2, 4, 8)\n",
- "K = torch.randn(2, 4, 8)\n",
- "V = torch.randn(2, 4, 8)\n",
- "\n",
- "out = scaled_dot_product_attention(Q, K, V)\n",
- "print(\"Output shape:\", out.shape)\n",
- "print(\"Attention weights sum to 1?\", True)\n",
- "\n",
- "# Cross-attention (seq_q != seq_k)\n",
- "Q2 = torch.randn(1, 3, 16)\n",
- "K2 = torch.randn(1, 5, 16)\n",
- "V2 = torch.randn(1, 5, 32)\n",
- "out2 = scaled_dot_product_attention(Q2, K2, V2)\n",
- "print(\"Cross-attention shape:\", out2.shape, \"(expected: 1, 3, 32)\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Run judge\n",
- "from torch_judge import check\n",
- "check(\"attention\")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.11.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/solutions/23_cross_attention_solution.ipynb b/solutions/23_cross_attention_solution.ipynb
index bb7cc66..0eebf0f 100644
--- a/solutions/23_cross_attention_solution.ipynb
+++ b/solutions/23_cross_attention_solution.ipynb
@@ -1,106 +1,113 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/23_cross_attention_solution.ipynb)\n\n",
- "# Solution: Multi-Head Cross-Attention\n",
- "\n",
- "Reference solution."
- ],
- "outputs": []
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/23_cross_attention_solution.ipynb)\n",
+ "\n",
+ "# ๐ด Solution: Multi-Head Cross-Attention\n",
+ "\n",
+ "Reference solution.\n",
+ "\n",
+ "$$Q = x_q \\, W_Q, \\quad K = x_{kv} \\, W_K, \\quad V = x_{kv} \\, W_V$$\n",
+ "\n",
+ "$$\\text{head}_i = \\text{softmax}\\!\\left(\\frac{Q_i K_i^T}{\\sqrt{d_k}}\\right) V_i$$\n",
+ "\n",
+ "$$\\text{MultiHead}(x_q, x_{kv}) = \\text{Concat}(\\text{head}_1, \\dots, \\text{head}_h) \\, W_O$$"
+ ],
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
+ "try:\n",
+ " import google.colab\n",
+ " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
+ "except ImportError:\n",
+ " pass\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import math"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# โ
SOLUTION\n",
+ "\n",
+ "class MultiHeadCrossAttention(nn.Module):\n",
+ " def __init__(self, d_model, num_heads):\n",
+ " super().__init__()\n",
+ " self.num_heads = num_heads\n",
+ " self.d_k = d_model // num_heads\n",
+ " self.W_q = nn.Linear(d_model, d_model)\n",
+ " self.W_k = nn.Linear(d_model, d_model)\n",
+ " self.W_v = nn.Linear(d_model, d_model)\n",
+ " self.W_o = nn.Linear(d_model, d_model)\n",
+ "\n",
+ " def forward(self, x_q, x_kv):\n",
+ " B, S_q, _ = x_q.shape\n",
+ " S_kv = x_kv.shape[1]\n",
+ " q = self.W_q(x_q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2)\n",
+ " k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n",
+ " v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n",
+ " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)\n",
+ " weights = torch.softmax(scores, dim=-1)\n",
+ " attn = torch.matmul(weights, v)\n",
+ " return self.W_o(attn.transpose(1, 2).contiguous().view(B, S_q, -1))"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Demo\n",
+ "attn = MultiHeadCrossAttention(64, 4)\n",
+ "x_q = torch.randn(2, 6, 64)\n",
+ "x_kv = torch.randn(2, 10, 64)\n",
+ "print('Output:', attn(x_q, x_kv).shape)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "from torch_judge import check\n",
+ "check('cross_attention')"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
},
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
- "try:\n",
- " import google.colab\n",
- " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
- "except ImportError:\n",
- " pass\n"
- ],
- "outputs": [],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import torch.nn as nn\n",
- "import math"
- ],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "# โ
SOLUTION\n",
- "\n",
- "class MultiHeadCrossAttention(nn.Module):\n",
- " def __init__(self, d_model, num_heads):\n",
- " super().__init__()\n",
- " self.num_heads = num_heads\n",
- " self.d_k = d_model // num_heads\n",
- " self.W_q = nn.Linear(d_model, d_model)\n",
- " self.W_k = nn.Linear(d_model, d_model)\n",
- " self.W_v = nn.Linear(d_model, d_model)\n",
- " self.W_o = nn.Linear(d_model, d_model)\n",
- "\n",
- " def forward(self, x_q, x_kv):\n",
- " B, S_q, _ = x_q.shape\n",
- " S_kv = x_kv.shape[1]\n",
- " q = self.W_q(x_q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2)\n",
- " k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n",
- " v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n",
- " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)\n",
- " weights = torch.softmax(scores, dim=-1)\n",
- " attn = torch.matmul(weights, v)\n",
- " return self.W_o(attn.transpose(1, 2).contiguous().view(B, S_q, -1))"
- ],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Demo\n",
- "attn = MultiHeadCrossAttention(64, 4)\n",
- "x_q = torch.randn(2, 6, 64)\n",
- "x_kv = torch.randn(2, 10, 64)\n",
- "print('Output:', attn(x_q, x_kv).shape)"
- ],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "from torch_judge import check\n",
- "check('cross_attention')"
- ],
- "execution_count": null
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.11.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
\ No newline at end of file
diff --git a/templates/05_attention.ipynb b/templates/05_attention.ipynb
index 047243e..0bc1de1 100644
--- a/templates/05_attention.ipynb
+++ b/templates/05_attention.ipynb
@@ -1,118 +1,118 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb)\n",
- "\n",
- "# ๐ด Hard: Softmax Attention\n",
- "\n",
- "Implement the core attention mechanism used in Transformers.\n",
- "\n",
- "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n",
- "\n",
- "### Signature\n",
- "```python\n",
- "def scaled_dot_product_attention(\n",
- " Q: torch.Tensor, # (batch, seq_q, d_k)\n",
- " K: torch.Tensor, # (batch, seq_k, d_k)\n",
- " V: torch.Tensor, # (batch, seq_k, d_v)\n",
- ") -> torch.Tensor: # (batch, seq_q, d_v)\n",
- " ...\n",
- "```\n",
- "\n",
- "### Rules\n",
- "- Do **NOT** use `F.scaled_dot_product_attention`\n",
- "- You **may** use `torch.softmax` and `torch.bmm`\n",
- "- Must support autograd\n",
- "- Must handle cross-attention (seq_q โ seq_k)"
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb)\n",
+ "\n",
+ "# ๐ Medium: Softmax Attention\n",
+ "\n",
+ "Implement the core attention mechanism used in Transformers.\n",
+ "\n",
+ "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n",
+ "\n",
+ "### Signature\n",
+ "```python\n",
+ "def scaled_dot_product_attention(\n",
+ " Q: torch.Tensor, # (batch, seq_q, d_k)\n",
+ " K: torch.Tensor, # (batch, seq_k, d_k)\n",
+ " V: torch.Tensor, # (batch, seq_k, d_v)\n",
+ ") -> torch.Tensor: # (batch, seq_q, d_v)\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "### Rules\n",
+ "- Do **NOT** use `F.scaled_dot_product_attention`\n",
+ "- You **may** use `torch.softmax` and `torch.bmm`\n",
+ "- Must support autograd\n",
+ "- Must handle cross-attention (seq_q โ seq_k)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
+ "try:\n",
+ " import google.colab\n",
+ " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
+ "except ImportError:\n",
+ " pass\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import math"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ๏ธ YOUR IMPLEMENTATION HERE\n",
+ "\n",
+ "def scaled_dot_product_attention(Q, K, V):\n",
+ " pass # Replace this"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ๐งช Debug\n",
+ "torch.manual_seed(42)\n",
+ "Q = torch.randn(2, 4, 8)\n",
+ "K = torch.randn(2, 4, 8)\n",
+ "V = torch.randn(2, 4, 8)\n",
+ "\n",
+ "out = scaled_dot_product_attention(Q, K, V)\n",
+ "print(\"Output shape:\", out.shape) # should be (2, 4, 8)\n",
+ "print(\"Has NaN? \", torch.isnan(out).any().item()) # should be False\n",
+ "print(\"Has Inf? \", torch.isinf(out).any().item()) # should be False\n",
+ "\n",
+ "# Cross-attention: seq_q != seq_k\n",
+ "Q2 = torch.randn(1, 3, 16)\n",
+ "K2 = torch.randn(1, 5, 16)\n",
+ "V2 = torch.randn(1, 5, 32)\n",
+ "out2 = scaled_dot_product_attention(Q2, K2, V2)\n",
+ "print(\"Cross-attn shape:\", out2.shape) # should be (1, 3, 32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ
SUBMIT\n",
+ "from torch_judge import check\n",
+ "check(\"attention\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
},
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
- "try:\n",
- " import google.colab\n",
- " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
- "except ImportError:\n",
- " pass\n"
- ],
- "outputs": [],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import math"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# โ๏ธ YOUR IMPLEMENTATION HERE\n",
- "\n",
- "def scaled_dot_product_attention(Q, K, V):\n",
- " pass # Replace this"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# ๐งช Debug\n",
- "torch.manual_seed(42)\n",
- "Q = torch.randn(2, 4, 8)\n",
- "K = torch.randn(2, 4, 8)\n",
- "V = torch.randn(2, 4, 8)\n",
- "\n",
- "out = scaled_dot_product_attention(Q, K, V)\n",
- "print(\"Output shape:\", out.shape) # should be (2, 4, 8)\n",
- "print(\"Has NaN? \", torch.isnan(out).any().item()) # should be False\n",
- "print(\"Has Inf? \", torch.isinf(out).any().item()) # should be False\n",
- "\n",
- "# Cross-attention: seq_q != seq_k\n",
- "Q2 = torch.randn(1, 3, 16)\n",
- "K2 = torch.randn(1, 5, 16)\n",
- "V2 = torch.randn(1, 5, 32)\n",
- "out2 = scaled_dot_product_attention(Q2, K2, V2)\n",
- "print(\"Cross-attn shape:\", out2.shape) # should be (1, 3, 32)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# โ
SUBMIT\n",
- "from torch_judge import check\n",
- "check(\"attention\")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.11.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
+ "nbformat": 4,
+ "nbformat_minor": 4
}
diff --git a/templates/23_cross_attention.ipynb b/templates/23_cross_attention.ipynb
index 2467285..0319e67 100644
--- a/templates/23_cross_attention.ipynb
+++ b/templates/23_cross_attention.ipynb
@@ -1,108 +1,113 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb)\n",
- "\n",
- "# ๐ Medium: Multi-Head Cross-Attention\n",
- "\n",
- "Implement **multi-head cross-attention** (encoder-decoder attention).\n",
- "\n",
- "### Signature\n",
- "```python\n",
- "class MultiHeadCrossAttention(nn.Module):\n",
- " def __init__(self, d_model: int, num_heads: int): ...\n",
- " def forward(self, x_q: Tensor, x_kv: Tensor) -> Tensor:\n",
- " # x_q: (B, S_q, D) โ decoder queries\n",
- " # x_kv: (B, S_kv, D) โ encoder keys/values\n",
- "```\n",
- "\n",
- "### Key Differences from Self-Attention\n",
- "- Q comes from the decoder, K and V come from the encoder\n",
- "- No causal mask (all encoder positions visible)"
- ],
- "outputs": []
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb)\n",
+ "\n",
+ "# ๐ด Hard: Multi-Head Cross-Attention\n",
+ "\n",
+ "Implement **multi-head cross-attention** (encoder-decoder attention).\n",
+ "\n",
+ "$$Q = x_q \\, W_Q, \\quad K = x_{kv} \\, W_K, \\quad V = x_{kv} \\, W_V$$\n",
+ "\n",
+ "$$\\text{head}_i = \\text{softmax}\\!\\left(\\frac{Q_i K_i^T}{\\sqrt{d_k}}\\right) V_i$$\n",
+ "\n",
+ "$$\\text{MultiHead}(x_q, x_{kv}) = \\text{Concat}(\\text{head}_1, \\dots, \\text{head}_h) \\, W_O$$\n",
+ "\n",
+ "### Signature\n",
+ "```python\n",
+ "class MultiHeadCrossAttention(nn.Module):\n",
+ " def __init__(self, d_model: int, num_heads: int): ...\n",
+ " def forward(self, x_q: Tensor, x_kv: Tensor) -> Tensor:\n",
+ " # x_q: (B, S_q, D) โ decoder queries\n",
+ " # x_kv: (B, S_kv, D) โ encoder keys/values\n",
+ "```\n",
+ "\n",
+ "### Key Differences from Self-Attention\n",
+ "- Q comes from the decoder, K and V come from the encoder\n",
+ "- No causal mask (all encoder positions visible)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
+ "try:\n",
+ " import google.colab\n",
+ " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
+ "except ImportError:\n",
+ " pass\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import math"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ๏ธ YOUR IMPLEMENTATION HERE\n",
+ "\n",
+ "class MultiHeadCrossAttention(nn.Module):\n",
+ " def __init__(self, d_model, num_heads):\n",
+ " super().__init__()\n",
+ " pass # W_q, W_k, W_v, W_o\n",
+ "\n",
+ " def forward(self, x_q, x_kv):\n",
+ " pass # Q from x_q, K/V from x_kv, no causal mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ๐งช Debug\n",
+ "attn = MultiHeadCrossAttention(64, 4)\n",
+ "x_q = torch.randn(2, 6, 64)\n",
+ "x_kv = torch.randn(2, 10, 64)\n",
+ "print('Output:', attn(x_q, x_kv).shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ
SUBMIT\n",
+ "from torch_judge import check\n",
+ "check('cross_attention')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
},
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
- "try:\n",
- " import google.colab\n",
- " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
- "except ImportError:\n",
- " pass\n"
- ],
- "outputs": [],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import torch.nn as nn\n",
- "import math"
- ],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "# โ๏ธ YOUR IMPLEMENTATION HERE\n",
- "\n",
- "class MultiHeadCrossAttention(nn.Module):\n",
- " def __init__(self, d_model, num_heads):\n",
- " super().__init__()\n",
- " pass # W_q, W_k, W_v, W_o\n",
- "\n",
- " def forward(self, x_q, x_kv):\n",
- " pass # Q from x_q, K/V from x_kv, no causal mask"
- ],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "# ๐งช Debug\n",
- "attn = MultiHeadCrossAttention(64, 4)\n",
- "x_q = torch.randn(2, 6, 64)\n",
- "x_kv = torch.randn(2, 10, 64)\n",
- "print('Output:', attn(x_q, x_kv).shape)"
- ],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "metadata": {},
- "outputs": [],
- "source": [
- "# โ
SUBMIT\n",
- "from torch_judge import check\n",
- "check('cross_attention')"
- ],
- "execution_count": null
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.11.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
+ "nbformat": 4,
+ "nbformat_minor": 4
}
diff --git a/torch_judge/tasks/attention.py b/torch_judge/tasks/attention.py
index 0ef1dd5..d09d6d3 100644
--- a/torch_judge/tasks/attention.py
+++ b/torch_judge/tasks/attention.py
@@ -2,7 +2,7 @@
TASK = {
"title": "Softmax Attention",
- "difficulty": "Hard",
+ "difficulty": "Medium",
"function_name": "scaled_dot_product_attention",
"hint": "scores = Q @ K^T / sqrt(d_k), then softmax(scores, dim=-1) @ V. Use torch.bmm for batched matmul.",
"tests": [
diff --git a/torch_judge/tasks/cross_attention.py b/torch_judge/tasks/cross_attention.py
index c7d08cd..064e918 100644
--- a/torch_judge/tasks/cross_attention.py
+++ b/torch_judge/tasks/cross_attention.py
@@ -2,7 +2,7 @@
TASK = {
"title": "Multi-Head Cross-Attention",
- "difficulty": "Medium",
+ "difficulty": "Hard",
"function_name": "MultiHeadCrossAttention",
"hint": "Q from decoder (x_q), K/V from encoder (x_kv). Project, reshape to multi-head, compute scaled dot-product attention (no causal mask). Concat heads and project output.",
"tests": [