diff --git a/README.md b/README.md index 3c29bad..3906389 100644 --- a/README.md +++ b/README.md @@ -135,8 +135,8 @@ If you're interviewing for any role touching LLMs or Transformers, expect at lea | # | Problem | What You'll Implement | Difficulty | Freq | Key Concepts | |:---:|---------|----------------------|:----------:|:----:|--------------| -| 23 | Cross-Attention Open In Colab | `MultiHeadCrossAttention` (nn.Module) | ![Medium](https://img.shields.io/badge/Medium-FF9800?style=flat-square) | โญ | Encoder-decoder, Q from decoder, K/V from encoder | -| 5 | Scaled Dot-Product Attention Open In Colab | `scaled_dot_product_attention(Q, K, V)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ”ฅ | `softmax(QK^T/โˆšd_k)V`, the foundation of everything | +| 23 | Cross-Attention Open In Colab | `MultiHeadCrossAttention` (nn.Module) | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | โญ | Encoder-decoder, Q from decoder, K/V from encoder | +| 5 | Scaled Dot-Product Attention Open In Colab | `scaled_dot_product_attention(Q, K, V)` | ![Medium](https://img.shields.io/badge/Medium-FF9800?style=flat-square) | ๐Ÿ”ฅ | `softmax(QK^T/โˆšd_k)V`, the foundation of everything | | 6 | Multi-Head Attention Open In Colab | `MultiHeadAttention` (nn.Module) | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ”ฅ | Parallel heads, split/concat, projection matrices | | 9 | Causal Self-Attention Open In Colab | `causal_attention(Q, K, V)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ”ฅ | Autoregressive masking with `-inf`, GPT-style | | 10 | Grouped Query Attention Open In Colab | `GroupQueryAttention` (nn.Module) | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | โญ | GQA (LLaMA 2), KV sharing across heads | diff --git a/solutions/05_attention_solution.ipynb b/solutions/05_attention_solution.ipynb index e82f45f..eecef07 100644 --- a/solutions/05_attention_solution.ipynb +++ b/solutions/05_attention_solution.ipynb @@ -1,106 +1,107 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "5f63d076", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/05_attention_solution.ipynb)\n\n", - "# ๐Ÿ”ด Solution: Softmax Attention\n", - "\n", - "Reference solution for the core Transformer attention mechanism.\n", - "\n", - "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/05_attention_solution.ipynb)\n", + "\n", + "# ๐ŸŸ  Solution: Softmax Attention\n", + "\n", + "Reference solution for the core Transformer attention mechanism.\n", + "\n", + "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$" + ], + "id": "5f63d076" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", + "try:\n", + " import google.colab\n", + " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", + "except ImportError:\n", + " pass\n" + ], + "execution_count": null, + "outputs": [], + "id": "ce663fb0" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import torch\n", + "import math" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# โœ… SOLUTION\n", + "\n", + "def scaled_dot_product_attention(Q, K, V):\n", + " d_k = K.size(-1)\n", + " scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)\n", + " weights = torch.softmax(scores, dim=-1)\n", + " return torch.bmm(weights, V)" + ], + "execution_count": null, + "outputs": [], + "id": "828be673" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Verify\n", + "torch.manual_seed(42)\n", + "Q = torch.randn(2, 4, 8)\n", + "K = torch.randn(2, 4, 8)\n", + "V = torch.randn(2, 4, 8)\n", + "\n", + "out = scaled_dot_product_attention(Q, K, V)\n", + "print(\"Output shape:\", out.shape)\n", + "print(\"Attention weights sum to 1?\", True)\n", + "\n", + "# Cross-attention (seq_q != seq_k)\n", + "Q2 = torch.randn(1, 3, 16)\n", + "K2 = torch.randn(1, 5, 16)\n", + "V2 = torch.randn(1, 5, 32)\n", + "out2 = scaled_dot_product_attention(Q2, K2, V2)\n", + "print(\"Cross-attention shape:\", out2.shape, \"(expected: 1, 3, 32)\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Run judge\n", + "from torch_judge import check\n", + "check(\"attention\")" + ], + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce663fb0", - "metadata": {}, - "outputs": [], - "source": [ - "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", - "try:\n", - " import google.colab\n", - " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", - "except ImportError:\n", - " pass\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import math" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "828be673", - "metadata": {}, - "outputs": [], - "source": [ - "# โœ… SOLUTION\n", - "\n", - "def scaled_dot_product_attention(Q, K, V):\n", - " d_k = K.size(-1)\n", - " scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)\n", - " weights = torch.softmax(scores, dim=-1)\n", - " return torch.bmm(weights, V)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Verify\n", - "torch.manual_seed(42)\n", - "Q = torch.randn(2, 4, 8)\n", - "K = torch.randn(2, 4, 8)\n", - "V = torch.randn(2, 4, 8)\n", - "\n", - "out = scaled_dot_product_attention(Q, K, V)\n", - "print(\"Output shape:\", out.shape)\n", - "print(\"Attention weights sum to 1?\", True)\n", - "\n", - "# Cross-attention (seq_q != seq_k)\n", - "Q2 = torch.randn(1, 3, 16)\n", - "K2 = torch.randn(1, 5, 16)\n", - "V2 = torch.randn(1, 5, 32)\n", - "out2 = scaled_dot_product_attention(Q2, K2, V2)\n", - "print(\"Cross-attention shape:\", out2.shape, \"(expected: 1, 3, 32)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run judge\n", - "from torch_judge import check\n", - "check(\"attention\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/solutions/23_cross_attention_solution.ipynb b/solutions/23_cross_attention_solution.ipynb index bb7cc66..0eebf0f 100644 --- a/solutions/23_cross_attention_solution.ipynb +++ b/solutions/23_cross_attention_solution.ipynb @@ -1,106 +1,113 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/23_cross_attention_solution.ipynb)\n\n", - "# Solution: Multi-Head Cross-Attention\n", - "\n", - "Reference solution." - ], - "outputs": [] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/23_cross_attention_solution.ipynb)\n", + "\n", + "# ๐Ÿ”ด Solution: Multi-Head Cross-Attention\n", + "\n", + "Reference solution.\n", + "\n", + "$$Q = x_q \\, W_Q, \\quad K = x_{kv} \\, W_K, \\quad V = x_{kv} \\, W_V$$\n", + "\n", + "$$\\text{head}_i = \\text{softmax}\\!\\left(\\frac{Q_i K_i^T}{\\sqrt{d_k}}\\right) V_i$$\n", + "\n", + "$$\\text{MultiHead}(x_q, x_{kv}) = \\text{Concat}(\\text{head}_1, \\dots, \\text{head}_h) \\, W_O$$" + ], + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", + "try:\n", + " import google.colab\n", + " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", + "except ImportError:\n", + " pass\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import math" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# โœ… SOLUTION\n", + "\n", + "class MultiHeadCrossAttention(nn.Module):\n", + " def __init__(self, d_model, num_heads):\n", + " super().__init__()\n", + " self.num_heads = num_heads\n", + " self.d_k = d_model // num_heads\n", + " self.W_q = nn.Linear(d_model, d_model)\n", + " self.W_k = nn.Linear(d_model, d_model)\n", + " self.W_v = nn.Linear(d_model, d_model)\n", + " self.W_o = nn.Linear(d_model, d_model)\n", + "\n", + " def forward(self, x_q, x_kv):\n", + " B, S_q, _ = x_q.shape\n", + " S_kv = x_kv.shape[1]\n", + " q = self.W_q(x_q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2)\n", + " k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n", + " v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n", + " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)\n", + " weights = torch.softmax(scores, dim=-1)\n", + " attn = torch.matmul(weights, v)\n", + " return self.W_o(attn.transpose(1, 2).contiguous().view(B, S_q, -1))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Demo\n", + "attn = MultiHeadCrossAttention(64, 4)\n", + "x_q = torch.randn(2, 6, 64)\n", + "x_kv = torch.randn(2, 10, 64)\n", + "print('Output:', attn(x_q, x_kv).shape)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "from torch_judge import check\n", + "check('cross_attention')" + ], + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", - "try:\n", - " import google.colab\n", - " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", - "except ImportError:\n", - " pass\n" - ], - "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import math" - ], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "# โœ… SOLUTION\n", - "\n", - "class MultiHeadCrossAttention(nn.Module):\n", - " def __init__(self, d_model, num_heads):\n", - " super().__init__()\n", - " self.num_heads = num_heads\n", - " self.d_k = d_model // num_heads\n", - " self.W_q = nn.Linear(d_model, d_model)\n", - " self.W_k = nn.Linear(d_model, d_model)\n", - " self.W_v = nn.Linear(d_model, d_model)\n", - " self.W_o = nn.Linear(d_model, d_model)\n", - "\n", - " def forward(self, x_q, x_kv):\n", - " B, S_q, _ = x_q.shape\n", - " S_kv = x_kv.shape[1]\n", - " q = self.W_q(x_q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2)\n", - " k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n", - " v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2)\n", - " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)\n", - " weights = torch.softmax(scores, dim=-1)\n", - " attn = torch.matmul(weights, v)\n", - " return self.W_o(attn.transpose(1, 2).contiguous().view(B, S_q, -1))" - ], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "# Demo\n", - "attn = MultiHeadCrossAttention(64, 4)\n", - "x_q = torch.randn(2, 6, 64)\n", - "x_kv = torch.randn(2, 10, 64)\n", - "print('Output:', attn(x_q, x_kv).shape)" - ], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "from torch_judge import check\n", - "check('cross_attention')" - ], - "execution_count": null - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/templates/05_attention.ipynb b/templates/05_attention.ipynb index 047243e..0bc1de1 100644 --- a/templates/05_attention.ipynb +++ b/templates/05_attention.ipynb @@ -1,118 +1,118 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb)\n", - "\n", - "# ๐Ÿ”ด Hard: Softmax Attention\n", - "\n", - "Implement the core attention mechanism used in Transformers.\n", - "\n", - "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n", - "\n", - "### Signature\n", - "```python\n", - "def scaled_dot_product_attention(\n", - " Q: torch.Tensor, # (batch, seq_q, d_k)\n", - " K: torch.Tensor, # (batch, seq_k, d_k)\n", - " V: torch.Tensor, # (batch, seq_k, d_v)\n", - ") -> torch.Tensor: # (batch, seq_q, d_v)\n", - " ...\n", - "```\n", - "\n", - "### Rules\n", - "- Do **NOT** use `F.scaled_dot_product_attention`\n", - "- You **may** use `torch.softmax` and `torch.bmm`\n", - "- Must support autograd\n", - "- Must handle cross-attention (seq_q โ‰  seq_k)" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb)\n", + "\n", + "# ๐ŸŸ  Medium: Softmax Attention\n", + "\n", + "Implement the core attention mechanism used in Transformers.\n", + "\n", + "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n", + "\n", + "### Signature\n", + "```python\n", + "def scaled_dot_product_attention(\n", + " Q: torch.Tensor, # (batch, seq_q, d_k)\n", + " K: torch.Tensor, # (batch, seq_k, d_k)\n", + " V: torch.Tensor, # (batch, seq_k, d_v)\n", + ") -> torch.Tensor: # (batch, seq_q, d_v)\n", + " ...\n", + "```\n", + "\n", + "### Rules\n", + "- Do **NOT** use `F.scaled_dot_product_attention`\n", + "- You **may** use `torch.softmax` and `torch.bmm`\n", + "- Must support autograd\n", + "- Must handle cross-attention (seq_q โ‰  seq_k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", + "try:\n", + " import google.colab\n", + " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", + "except ImportError:\n", + " pass\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import math" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", + "\n", + "def scaled_dot_product_attention(Q, K, V):\n", + " pass # Replace this" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ๐Ÿงช Debug\n", + "torch.manual_seed(42)\n", + "Q = torch.randn(2, 4, 8)\n", + "K = torch.randn(2, 4, 8)\n", + "V = torch.randn(2, 4, 8)\n", + "\n", + "out = scaled_dot_product_attention(Q, K, V)\n", + "print(\"Output shape:\", out.shape) # should be (2, 4, 8)\n", + "print(\"Has NaN? \", torch.isnan(out).any().item()) # should be False\n", + "print(\"Has Inf? \", torch.isinf(out).any().item()) # should be False\n", + "\n", + "# Cross-attention: seq_q != seq_k\n", + "Q2 = torch.randn(1, 3, 16)\n", + "K2 = torch.randn(1, 5, 16)\n", + "V2 = torch.randn(1, 5, 32)\n", + "out2 = scaled_dot_product_attention(Q2, K2, V2)\n", + "print(\"Cross-attn shape:\", out2.shape) # should be (1, 3, 32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ… SUBMIT\n", + "from torch_judge import check\n", + "check(\"attention\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", - "try:\n", - " import google.colab\n", - " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", - "except ImportError:\n", - " pass\n" - ], - "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import math" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", - "\n", - "def scaled_dot_product_attention(Q, K, V):\n", - " pass # Replace this" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ๐Ÿงช Debug\n", - "torch.manual_seed(42)\n", - "Q = torch.randn(2, 4, 8)\n", - "K = torch.randn(2, 4, 8)\n", - "V = torch.randn(2, 4, 8)\n", - "\n", - "out = scaled_dot_product_attention(Q, K, V)\n", - "print(\"Output shape:\", out.shape) # should be (2, 4, 8)\n", - "print(\"Has NaN? \", torch.isnan(out).any().item()) # should be False\n", - "print(\"Has Inf? \", torch.isinf(out).any().item()) # should be False\n", - "\n", - "# Cross-attention: seq_q != seq_k\n", - "Q2 = torch.randn(1, 3, 16)\n", - "K2 = torch.randn(1, 5, 16)\n", - "V2 = torch.randn(1, 5, 32)\n", - "out2 = scaled_dot_product_attention(Q2, K2, V2)\n", - "print(\"Cross-attn shape:\", out2.shape) # should be (1, 3, 32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# โœ… SUBMIT\n", - "from torch_judge import check\n", - "check(\"attention\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/templates/23_cross_attention.ipynb b/templates/23_cross_attention.ipynb index 2467285..0319e67 100644 --- a/templates/23_cross_attention.ipynb +++ b/templates/23_cross_attention.ipynb @@ -1,108 +1,113 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb)\n", - "\n", - "# ๐ŸŸ  Medium: Multi-Head Cross-Attention\n", - "\n", - "Implement **multi-head cross-attention** (encoder-decoder attention).\n", - "\n", - "### Signature\n", - "```python\n", - "class MultiHeadCrossAttention(nn.Module):\n", - " def __init__(self, d_model: int, num_heads: int): ...\n", - " def forward(self, x_q: Tensor, x_kv: Tensor) -> Tensor:\n", - " # x_q: (B, S_q, D) โ€” decoder queries\n", - " # x_kv: (B, S_kv, D) โ€” encoder keys/values\n", - "```\n", - "\n", - "### Key Differences from Self-Attention\n", - "- Q comes from the decoder, K and V come from the encoder\n", - "- No causal mask (all encoder positions visible)" - ], - "outputs": [] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb)\n", + "\n", + "# ๐Ÿ”ด Hard: Multi-Head Cross-Attention\n", + "\n", + "Implement **multi-head cross-attention** (encoder-decoder attention).\n", + "\n", + "$$Q = x_q \\, W_Q, \\quad K = x_{kv} \\, W_K, \\quad V = x_{kv} \\, W_V$$\n", + "\n", + "$$\\text{head}_i = \\text{softmax}\\!\\left(\\frac{Q_i K_i^T}{\\sqrt{d_k}}\\right) V_i$$\n", + "\n", + "$$\\text{MultiHead}(x_q, x_{kv}) = \\text{Concat}(\\text{head}_1, \\dots, \\text{head}_h) \\, W_O$$\n", + "\n", + "### Signature\n", + "```python\n", + "class MultiHeadCrossAttention(nn.Module):\n", + " def __init__(self, d_model: int, num_heads: int): ...\n", + " def forward(self, x_q: Tensor, x_kv: Tensor) -> Tensor:\n", + " # x_q: (B, S_q, D) โ€” decoder queries\n", + " # x_kv: (B, S_kv, D) โ€” encoder keys/values\n", + "```\n", + "\n", + "### Key Differences from Self-Attention\n", + "- Q comes from the decoder, K and V come from the encoder\n", + "- No causal mask (all encoder positions visible)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", + "try:\n", + " import google.colab\n", + " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", + "except ImportError:\n", + " pass\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import math" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", + "\n", + "class MultiHeadCrossAttention(nn.Module):\n", + " def __init__(self, d_model, num_heads):\n", + " super().__init__()\n", + " pass # W_q, W_k, W_v, W_o\n", + "\n", + " def forward(self, x_q, x_kv):\n", + " pass # Q from x_q, K/V from x_kv, no causal mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ๐Ÿงช Debug\n", + "attn = MultiHeadCrossAttention(64, 4)\n", + "x_q = torch.randn(2, 6, 64)\n", + "x_kv = torch.randn(2, 10, 64)\n", + "print('Output:', attn(x_q, x_kv).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ… SUBMIT\n", + "from torch_judge import check\n", + "check('cross_attention')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", - "try:\n", - " import google.colab\n", - " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", - "except ImportError:\n", - " pass\n" - ], - "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import math" - ], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", - "\n", - "class MultiHeadCrossAttention(nn.Module):\n", - " def __init__(self, d_model, num_heads):\n", - " super().__init__()\n", - " pass # W_q, W_k, W_v, W_o\n", - "\n", - " def forward(self, x_q, x_kv):\n", - " pass # Q from x_q, K/V from x_kv, no causal mask" - ], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "# ๐Ÿงช Debug\n", - "attn = MultiHeadCrossAttention(64, 4)\n", - "x_q = torch.randn(2, 6, 64)\n", - "x_kv = torch.randn(2, 10, 64)\n", - "print('Output:', attn(x_q, x_kv).shape)" - ], - "execution_count": null - }, - { - "cell_type": "code", - "metadata": {}, - "outputs": [], - "source": [ - "# โœ… SUBMIT\n", - "from torch_judge import check\n", - "check('cross_attention')" - ], - "execution_count": null - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/torch_judge/tasks/attention.py b/torch_judge/tasks/attention.py index 0ef1dd5..d09d6d3 100644 --- a/torch_judge/tasks/attention.py +++ b/torch_judge/tasks/attention.py @@ -2,7 +2,7 @@ TASK = { "title": "Softmax Attention", - "difficulty": "Hard", + "difficulty": "Medium", "function_name": "scaled_dot_product_attention", "hint": "scores = Q @ K^T / sqrt(d_k), then softmax(scores, dim=-1) @ V. Use torch.bmm for batched matmul.", "tests": [ diff --git a/torch_judge/tasks/cross_attention.py b/torch_judge/tasks/cross_attention.py index c7d08cd..064e918 100644 --- a/torch_judge/tasks/cross_attention.py +++ b/torch_judge/tasks/cross_attention.py @@ -2,7 +2,7 @@ TASK = { "title": "Multi-Head Cross-Attention", - "difficulty": "Medium", + "difficulty": "Hard", "function_name": "MultiHeadCrossAttention", "hint": "Q from decoder (x_q), K/V from encoder (x_kv). Project, reshape to multi-head, compute scaled dot-product attention (no causal mask). Concat heads and project output.", "tests": [