01P — Tensor Basics Practice

Practice companion for Chapter 01: Tensor Basics.

This notebook focuses on the mechanics that show up everywhere in ML code:

The problems start basic and gradually become more like the shape puzzles you see inside real model implementations.

How to use this notebook

For each problem:

  1. Read the prompt.
  2. Try solving it before reading the solution.
  3. Run the tests.
  4. Compare with the provided solution.

The goal is not to memorize PyTorch APIs.
The goal is to build shape intuition.

import torch

torch.manual_seed(42)
print(torch.__version__)
2.5.1

Level 1 — Shape Literacy

These are warmups. Do not skip them. Most tensor bugs are shape bugs.

Problem 1 — Predict the shape

Given:

x = torch.randn(2, 3, 4)
a = x[0]
b = x[:, 1]
c = x[:, :, 2]
d = x[None]
e = x.unsqueeze(1)

Predict the shape of a, b, c, d, and e.

Then verify using code.

# Problem 1 — your attempt

x = torch.randn(2, 3, 4)

a = x[0]
b = x[:, 1]
c = x[:, :, 2]
d = x[None]
e = x.unsqueeze(1)

# Fill these in before printing:
predicted_shapes = {
    "a": None,
    "b": None,
    "c": None,
    "d": None,
    "e": None,
}

actual_shapes = {
    "a": tuple(a.shape),
    "b": tuple(b.shape),
    "c": tuple(c.shape),
    "d": tuple(d.shape),
    "e": tuple(e.shape),
}

actual_shapes

Indexing with an integer removes that dimension.
Slicing with : keeps the dimension.
None / unsqueeze adds a new dimension.

Code
# Solution 1

x = torch.randn(2, 3, 4)

a = x[0]          # choose first item from dim 0 -> (3, 4)
b = x[:, 1]       # keep dim 0, choose index from dim 1 -> (2, 4)
c = x[:, :, 2]    # keep dims 0 and 1, choose index from dim 2 -> (2, 3)
d = x[None]       # add new leading dim -> (1, 2, 3, 4)
e = x.unsqueeze(1)# add dim at position 1 -> (2, 1, 3, 4)

assert a.shape == (3, 4)
assert b.shape == (2, 4)
assert c.shape == (2, 3)
assert d.shape == (1, 2, 3, 4)
assert e.shape == (2, 1, 3, 4)

print("All tests passed.")

Problem 2 — Batch, sequence, hidden

You have an LLM hidden-state tensor:

x.shape == (batch, seq_len, hidden)

Create:

  1. first_token: the first token embedding for every batch item.
  2. last_token: the last token embedding for every batch item.
  3. first_batch: all tokens for the first batch item.
  4. single_value: the scalar at batch 0, token 1, hidden dim 2.

Use indexing only.

# Problem 2 — your attempt

batch, seq_len, hidden = 4, 6, 8
x = torch.randn(batch, seq_len, hidden)

first_token = None
last_token = None
first_batch = None
single_value = None

# Tests
# assert first_token.shape == (batch, hidden)
# assert last_token.shape == (batch, hidden)
# assert first_batch.shape == (seq_len, hidden)
# assert single_value.ndim == 0
Code
# Solution 2

batch, seq_len, hidden = 4, 6, 8
x = torch.randn(batch, seq_len, hidden)

first_token = x[:, 0, :]
last_token = x[:, -1, :]
first_batch = x[0, :, :]
single_value = x[0, 1, 2]

assert first_token.shape == (batch, hidden)
assert last_token.shape == (batch, hidden)
assert first_batch.shape == (seq_len, hidden)
assert single_value.ndim == 0

print("All tests passed.")

Problem 3 — Keep the dimension

Same tensor:

x.shape == (batch, seq_len, hidden)

Extract the first token for every batch item, but keep the sequence dimension.

Expected shape:

(batch, 1, hidden)

This is common when you want the selected token to still broadcast against sequence-shaped tensors.

# Problem 3 — your attempt

batch, seq_len, hidden = 4, 6, 8
x = torch.randn(batch, seq_len, hidden)

first_token_keepdim = None

# assert first_token_keepdim.shape == (batch, 1, hidden)

Using 0 drops the dimension.
Using 0:1 keeps the dimension.

Code
# Solution 3

batch, seq_len, hidden = 4, 6, 8
x = torch.randn(batch, seq_len, hidden)

first_token_keepdim = x[:, 0:1, :]

assert first_token_keepdim.shape == (batch, 1, hidden)

print("All tests passed.")

Level 2 — Indexing and Slicing

These problems mimic common data and model preprocessing patterns.

Problem 4 — Take every second token

Given:

x.shape == (batch, seq_len, hidden)

Return only tokens at even positions:

0, 2, 4, ...

Expected shape:

(batch, ceil(seq_len / 2), hidden)
# Problem 4 — your attempt

batch, seq_len, hidden = 3, 7, 5
x = torch.randn(batch, seq_len, hidden)

even_tokens = None

# assert even_tokens.shape == (batch, 4, hidden)
Code
# Solution 4

batch, seq_len, hidden = 3, 7, 5
x = torch.randn(batch, seq_len, hidden)

even_tokens = x[:, ::2, :]

assert even_tokens.shape == (batch, 4, hidden)
assert torch.equal(even_tokens, x[:, [0, 2, 4, 6], :])

print("All tests passed.")

Problem 5 — Remove the class token

Vision Transformer-style tensors often prepend a class token:

tokens.shape == (batch, 1 + num_patches, hidden)

Remove the first token and keep only patch tokens.

Expected shape:

(batch, num_patches, hidden)
# Problem 5 — your attempt

batch, num_patches, hidden = 2, 16, 12
tokens = torch.randn(batch, 1 + num_patches, hidden)

patch_tokens = None

# assert patch_tokens.shape == (batch, num_patches, hidden)

Solution 5

Code
# Solution 5

batch, num_patches, hidden = 2, 16, 12
tokens = torch.randn(batch, 1 + num_patches, hidden)

patch_tokens = tokens[:, 1:, :]

assert patch_tokens.shape == (batch, num_patches, hidden)

print("All tests passed.")

Problem 6 — Gather specific token positions

Given:

x.shape == (batch, seq_len, hidden)
positions.shape == (k,)

Select the same token positions from every batch item.

Example:

positions = torch.tensor([0, 3, 5])

Expected shape:

(batch, k, hidden)
# Problem 6 — your attempt

batch, seq_len, hidden = 4, 8, 6
x = torch.randn(batch, seq_len, hidden)
positions = torch.tensor([0, 3, 5])

selected = None

# assert selected.shape == (batch, len(positions), hidden)

Solution 6

Code
# Solution 6

batch, seq_len, hidden = 4, 8, 6
x = torch.randn(batch, seq_len, hidden)
positions = torch.tensor([0, 3, 5])

selected = x[:, positions, :]

assert selected.shape == (batch, len(positions), hidden)
assert torch.equal(selected[:, 0, :], x[:, 0, :])
assert torch.equal(selected[:, 1, :], x[:, 3, :])
assert torch.equal(selected[:, 2, :], x[:, 5, :])

print("All tests passed.")

Level 3 — Broadcasting

Broadcasting is one of the most important tensor ideas. It lets small tensors act like larger tensors without explicit copying.

Problem 7 — Add per-feature bias

Given:

x.shape == (batch, hidden)
bias.shape == (hidden,)

Add the bias to every batch row.

Expected shape:

(batch, hidden)
# Problem 7 — your attempt

batch, hidden = 5, 7
x = torch.randn(batch, hidden)
bias = torch.randn(hidden)

y = None

# assert y.shape == (batch, hidden)
# assert torch.allclose(y[0], x[0] + bias)

Solution 7

Code
# Solution 7

batch, hidden = 5, 7
x = torch.randn(batch, hidden)
bias = torch.randn(hidden)

y = x + bias

assert y.shape == (batch, hidden)
assert torch.allclose(y[0], x[0] + bias)

print("All tests passed.")

Problem 8 — Add per-channel bias to images

Given image tensor:

x.shape == (batch, channels, height, width)
bias.shape == (channels,)

Add one bias value per channel.

Expected shape:

(batch, channels, height, width)

This is a classic broadcasting trap.

# Problem 8 — your attempt

batch, channels, height, width = 2, 3, 4, 5
x = torch.randn(batch, channels, height, width)
bias = torch.randn(channels)

y = None

# assert y.shape == x.shape

bias has shape (channels,), but the channel dimension in x is dimension 1.
So reshape bias to (1, channels, 1, 1).

Code
# Solution 8

batch, channels, height, width = 2, 3, 4, 5
x = torch.randn(batch, channels, height, width)
bias = torch.randn(channels)

y = x + bias.view(1, channels, 1, 1)

assert y.shape == x.shape
assert torch.allclose(y[:, 0, :, :], x[:, 0, :, :] + bias[0])
assert torch.allclose(y[:, 1, :, :], x[:, 1, :, :] + bias[1])

print("All tests passed.")

Problem 9 — Add positional embeddings

LLM-style hidden states:

x.shape == (batch, seq_len, hidden)
pos_emb.shape == (seq_len, hidden)

Add pos_emb to every batch item.

Expected shape:

(batch, seq_len, hidden)
# Problem 9 — your attempt

batch, seq_len, hidden = 3, 6, 8
x = torch.randn(batch, seq_len, hidden)
pos_emb = torch.randn(seq_len, hidden)

y = None

# assert y.shape == x.shape

pos_emb broadcasts from (seq_len, hidden) to (batch, seq_len, hidden).

Code
# Solution 9

batch, seq_len, hidden = 3, 6, 8
x = torch.randn(batch, seq_len, hidden)
pos_emb = torch.randn(seq_len, hidden)

y = x + pos_emb

assert y.shape == x.shape
assert torch.allclose(y[0], x[0] + pos_emb)
assert torch.allclose(y[1], x[1] + pos_emb)

print("All tests passed.")

Problem 10 — Normalize per sample

Given:

x.shape == (batch, hidden)

For each row independently, subtract that row’s mean.

Expected output:

y.shape == (batch, hidden)

Each row of y should have mean approximately zero.

# Problem 10 — your attempt

batch, hidden = 5, 10
x = torch.randn(batch, hidden)

y = None

# assert y.shape == x.shape
# assert torch.allclose(y.mean(dim=1), torch.zeros(batch), atol=1e-6)

Use keepdim=True so the row mean remains shape (batch, 1) and can broadcast back across hidden dimension.

Code
# Solution 10

batch, hidden = 5, 10
x = torch.randn(batch, hidden)

row_mean = x.mean(dim=1, keepdim=True)
y = x - row_mean

assert y.shape == x.shape
assert torch.allclose(y.mean(dim=1), torch.zeros(batch), atol=1e-6)

print("All tests passed.")

Level 4 — Reductions

Reductions collapse dimensions. The core skill is knowing which dimension disappears and when to keep it.

Problem 11 — Mean pool over tokens

Given:

x.shape == (batch, seq_len, hidden)

Compute the mean over the sequence dimension.

Expected shape:

(batch, hidden)
# Problem 11 — your attempt

batch, seq_len, hidden = 4, 7, 9
x = torch.randn(batch, seq_len, hidden)

pooled = None

# assert pooled.shape == (batch, hidden)

Solution 11

Code
# Solution 11

batch, seq_len, hidden = 4, 7, 9
x = torch.randn(batch, seq_len, hidden)

pooled = x.mean(dim=1)

assert pooled.shape == (batch, hidden)

print("All tests passed.")

Problem 12 — Global average pool image features

Given:

x.shape == (batch, channels, height, width)

Average over height and width.

Expected shape:

(batch, channels)
# Problem 12 — your attempt

batch, channels, height, width = 2, 3, 4, 5
x = torch.randn(batch, channels, height, width)

pooled = None

# assert pooled.shape == (batch, channels)

Solution 12

Code
# Solution 12

batch, channels, height, width = 2, 3, 4, 5
x = torch.randn(batch, channels, height, width)

pooled = x.mean(dim=(2, 3))

assert pooled.shape == (batch, channels)

print("All tests passed.")

Problem 13 — Find most likely token

Given logits:

logits.shape == (batch, seq_len, vocab_size)

Find the most likely token id at every position.

Expected shape:

(batch, seq_len)
# Problem 13 — your attempt

batch, seq_len, vocab_size = 2, 5, 11
logits = torch.randn(batch, seq_len, vocab_size)

token_ids = None

# assert token_ids.shape == (batch, seq_len)

argmax(dim=-1) collapses the vocabulary dimension.

Code
# Solution 13

batch, seq_len, vocab_size = 2, 5, 11
logits = torch.randn(batch, seq_len, vocab_size)

token_ids = logits.argmax(dim=-1)

assert token_ids.shape == (batch, seq_len)
assert token_ids.max() < vocab_size
assert token_ids.min() >= 0

print("All tests passed.")

Level 5 — Reshape, View, Permute

This is where shape thinking becomes real model implementation.

Problem 14 — Split attention heads

Given:

x.shape == (batch, seq_len, hidden)
num_heads = 4
hidden = num_heads * head_dim

Reshape into:

(batch, num_heads, seq_len, head_dim)

This is the standard LLM attention-head layout.

# Problem 14 — your attempt

batch, seq_len, hidden = 2, 6, 12
num_heads = 4
x = torch.randn(batch, seq_len, hidden)

head_dim = hidden // num_heads
y = None

# assert y.shape == (batch, num_heads, seq_len, head_dim)

First reshape hidden into (num_heads, head_dim), then move num_heads before seq_len.

Code
# Solution 14

batch, seq_len, hidden = 2, 6, 12
num_heads = 4
x = torch.randn(batch, seq_len, hidden)

head_dim = hidden // num_heads

y = x.view(batch, seq_len, num_heads, head_dim)
y = y.permute(0, 2, 1, 3)

assert y.shape == (batch, num_heads, seq_len, head_dim)

print("All tests passed.")

Problem 15 — Merge attention heads

Invert the previous operation.

Given:

y.shape == (batch, num_heads, seq_len, head_dim)

Return:

x.shape == (batch, seq_len, hidden)

where:

hidden = num_heads * head_dim
# Problem 15 — your attempt

batch, num_heads, seq_len, head_dim = 2, 4, 6, 3
y = torch.randn(batch, num_heads, seq_len, head_dim)

x = None

# assert x.shape == (batch, seq_len, num_heads * head_dim)

After permute, call .contiguous() before .view() because memory layout may not be contiguous.

Code
# Solution 15

batch, num_heads, seq_len, head_dim = 2, 4, 6, 3
y = torch.randn(batch, num_heads, seq_len, head_dim)

x = y.permute(0, 2, 1, 3).contiguous()
x = x.view(batch, seq_len, num_heads * head_dim)

assert x.shape == (batch, seq_len, num_heads * head_dim)

print("All tests passed.")

Problem 16 — Flatten image into tokens

Given image tensor:

x.shape == (batch, channels, height, width)

Convert it into tokens:

(batch, height * width, channels)

This treats each spatial location as a token.

# Problem 16 — your attempt

batch, channels, height, width = 2, 3, 4, 5
x = torch.randn(batch, channels, height, width)

tokens = None

# assert tokens.shape == (batch, height * width, channels)

Move channels to the end, then flatten height and width.

Code
# Solution 16

batch, channels, height, width = 2, 3, 4, 5
x = torch.randn(batch, channels, height, width)

tokens = x.permute(0, 2, 3, 1).contiguous()
tokens = tokens.view(batch, height * width, channels)

assert tokens.shape == (batch, height * width, channels)

print("All tests passed.")

Problem 17 — Unflatten tokens into image

Invert the previous operation.

Given:

tokens.shape == (batch, height * width, channels)

Return:

x.shape == (batch, channels, height, width)
# Problem 17 — your attempt

batch, channels, height, width = 2, 3, 4, 5
tokens = torch.randn(batch, height * width, channels)

x = None

# assert x.shape == (batch, channels, height, width)

Solution 17

Code
# Solution 17

batch, channels, height, width = 2, 3, 4, 5
tokens = torch.randn(batch, height * width, channels)

x = tokens.view(batch, height, width, channels)
x = x.permute(0, 3, 1, 2).contiguous()

assert x.shape == (batch, channels, height, width)

print("All tests passed.")

Level 6 — LLM-Style Tensor Problems

These are simplified versions of patterns that appear in transformer implementations.

Problem 18 — Causal mask

Create a causal attention mask for sequence length seq_len.

The mask should be shape:

(seq_len, seq_len)

And should contain:

  • True where a token is allowed to attend
  • False where a token is not allowed to attend

For causal attention, position i can attend to positions j <= i.

# Problem 18 — your attempt

seq_len = 6

mask = None

# assert mask.shape == (seq_len, seq_len)
# assert mask.dtype == torch.bool

torch.tril gives the lower triangle.

Code
# Solution 18

seq_len = 6

mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))

assert mask.shape == (seq_len, seq_len)
assert mask.dtype == torch.bool
assert mask[0, 0] == True
assert mask[0, 1] == False
assert mask[5, 0] == True
assert mask[5, 5] == True

print(mask.int())
print("All tests passed.")

Problem 19 — Apply causal mask to attention scores

Given attention scores:

scores.shape == (batch, num_heads, seq_len, seq_len)

Apply a causal mask so future positions get a very negative value, e.g. -1e9.

Expected output shape:

(batch, num_heads, seq_len, seq_len)
# Problem 19 — your attempt

batch, num_heads, seq_len = 2, 3, 5
scores = torch.randn(batch, num_heads, seq_len, seq_len)

masked_scores = None

# assert masked_scores.shape == scores.shape

The mask has shape (seq_len, seq_len).
It broadcasts over batch and heads.

Code
# Solution 19

batch, num_heads, seq_len = 2, 3, 5
scores = torch.randn(batch, num_heads, seq_len, seq_len)

mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
masked_scores = scores.masked_fill(~mask, -1e9)

assert masked_scores.shape == scores.shape
assert masked_scores[0, 0, 0, 1] == -1e9
assert masked_scores[0, 0, 4, 0] != -1e9

print("All tests passed.")

Problem 20 — Last-token logits

Autoregressive models often only need logits for the final position during generation.

Given:

logits.shape == (batch, seq_len, vocab_size)

Extract final-token logits.

Expected shape:

(batch, vocab_size)
# Problem 20 — your attempt

batch, seq_len, vocab_size = 3, 7, 20
logits = torch.randn(batch, seq_len, vocab_size)

last_logits = None

# assert last_logits.shape == (batch, vocab_size)

Solution 20

Code
# Solution 20

batch, seq_len, vocab_size = 3, 7, 20
logits = torch.randn(batch, seq_len, vocab_size)

last_logits = logits[:, -1, :]

assert last_logits.shape == (batch, vocab_size)
assert torch.equal(last_logits, logits[:, seq_len - 1, :])

print("All tests passed.")

Problem 21 — KV cache append

During autoregressive generation, key tensors may be stored as:

past_k.shape == (batch, num_heads, past_len, head_dim)
new_k.shape == (batch, num_heads, 1, head_dim)

Append new_k along the sequence dimension.

Expected shape:

(batch, num_heads, past_len + 1, head_dim)
# Problem 21 — your attempt

batch, num_heads, past_len, head_dim = 2, 4, 5, 8
past_k = torch.randn(batch, num_heads, past_len, head_dim)
new_k = torch.randn(batch, num_heads, 1, head_dim)

full_k = None

# assert full_k.shape == (batch, num_heads, past_len + 1, head_dim)

The sequence dimension is dimension 2.

Code
# Solution 21

batch, num_heads, past_len, head_dim = 2, 4, 5, 8
past_k = torch.randn(batch, num_heads, past_len, head_dim)
new_k = torch.randn(batch, num_heads, 1, head_dim)

full_k = torch.cat([past_k, new_k], dim=2)

assert full_k.shape == (batch, num_heads, past_len + 1, head_dim)
assert torch.equal(full_k[:, :, :-1, :], past_k)
assert torch.equal(full_k[:, :, -1:, :], new_k)

print("All tests passed.")

Level 7 — Diffusion and Image-Model Style Problems

These are simplified patterns from image models, ViTs, U-Nets, and diffusion code.

Problem 22 — Patchify an image

Given:

x.shape == (batch, channels, height, width)
patch_size = 2

Assume height and width are divisible by patch_size.

Convert image into patch tokens:

(batch, num_patches, patch_dim)

where:

num_patches = (height // patch_size) * (width // patch_size)
patch_dim = channels * patch_size * patch_size
# Problem 22 — your attempt

batch, channels, height, width = 2, 3, 8, 8
patch_size = 2
x = torch.randn(batch, channels, height, width)

patches = None

# num_patches = (height // patch_size) * (width // patch_size)
# patch_dim = channels * patch_size * patch_size
# assert patches.shape == (batch, num_patches, patch_dim)

Use reshape to split height and width into grid dimensions and patch dimensions, then permute so patch-grid dimensions come before patch contents.

Code
# Solution 22

batch, channels, height, width = 2, 3, 8, 8
patch_size = 2
x = torch.randn(batch, channels, height, width)

h_grid = height // patch_size
w_grid = width // patch_size

patches = x.view(batch, channels, h_grid, patch_size, w_grid, patch_size)
patches = patches.permute(0, 2, 4, 1, 3, 5).contiguous()
patches = patches.view(batch, h_grid * w_grid, channels * patch_size * patch_size)

num_patches = (height // patch_size) * (width // patch_size)
patch_dim = channels * patch_size * patch_size

assert patches.shape == (batch, num_patches, patch_dim)

print("All tests passed.")

Problem 23 — Unpatchify an image

Invert the previous operation.

Given:

patches.shape == (batch, num_patches, patch_dim)

Return:

x.shape == (batch, channels, height, width)

Use:

height = width = 8
patch_size = 2
channels = 3
# Problem 23 — your attempt

batch, channels, height, width = 2, 3, 8, 8
patch_size = 2
h_grid = height // patch_size
w_grid = width // patch_size
num_patches = h_grid * w_grid
patch_dim = channels * patch_size * patch_size

patches = torch.randn(batch, num_patches, patch_dim)

x = None

# assert x.shape == (batch, channels, height, width)

Solution 23

Code
# Solution 23

batch, channels, height, width = 2, 3, 8, 8
patch_size = 2
h_grid = height // patch_size
w_grid = width // patch_size
num_patches = h_grid * w_grid
patch_dim = channels * patch_size * patch_size

patches = torch.randn(batch, num_patches, patch_dim)

x = patches.view(batch, h_grid, w_grid, channels, patch_size, patch_size)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
x = x.view(batch, channels, height, width)

assert x.shape == (batch, channels, height, width)

print("All tests passed.")

Problem 24 — Add timestep embedding to image features

In diffusion-style models, a timestep embedding may be added to every spatial location.

Given:

x.shape == (batch, channels, height, width)
t_emb.shape == (batch, channels)

Add t_emb to x.

Expected shape:

(batch, channels, height, width)
# Problem 24 — your attempt

batch, channels, height, width = 4, 6, 8, 8
x = torch.randn(batch, channels, height, width)
t_emb = torch.randn(batch, channels)

y = None

# assert y.shape == x.shape

t_emb must become (batch, channels, 1, 1) to broadcast over height and width.

Code
# Solution 24

batch, channels, height, width = 4, 6, 8, 8
x = torch.randn(batch, channels, height, width)
t_emb = torch.randn(batch, channels)

y = x + t_emb[:, :, None, None]

assert y.shape == x.shape
assert torch.allclose(y[0, :, 0, 0], x[0, :, 0, 0] + t_emb[0])

print("All tests passed.")

Level 8 — Tricky / Informatics-Olympiad-Style Tensor Reasoning

These problems are harder. They reward translating loops into tensor operations.

Problem 25 — Pairwise squared distances without loops

Given:

x.shape == (n, d)

Compute:

dist[i, j] = sum_k (x[i, k] - x[j, k]) ** 2

Expected shape:

(n, n)

Do not use Python loops.

# Problem 25 — your attempt

n, d = 5, 3
x = torch.randn(n, d)

dist = None

# assert dist.shape == (n, n)
# assert torch.allclose(torch.diag(dist), torch.zeros(n), atol=1e-6)

Broadcasting approach:

x[:, None, :] - x[None, :, :]

creates shape (n, n, d).

Code
# Solution 25

n, d = 5, 3
x = torch.randn(n, d)

diff = x[:, None, :] - x[None, :, :]
dist = (diff ** 2).sum(dim=-1)

assert dist.shape == (n, n)
assert torch.allclose(torch.diag(dist), torch.zeros(n), atol=1e-6)
assert torch.allclose(dist, dist.T, atol=1e-6)

print("All tests passed.")

Problem 26 — Pairwise dot products

Given:

x.shape == (n, d)

Compute the matrix:

dots[i, j] = dot(x[i], x[j])

Expected shape:

(n, n)

Do not use loops.

# Problem 26 — your attempt

n, d = 6, 4
x = torch.randn(n, d)

dots = None

# assert dots.shape == (n, n)

This is a matrix multiplication between x and x.T.

# Solution 26

n, d = 6, 4
x = torch.randn(n, d)

dots = x @ x.T

assert dots.shape == (n, n)
assert torch.allclose(dots, dots.T, atol=1e-6)

print("All tests passed.")

Problem 27 — Top-k nearest neighbors

Given:

x.shape == (n, d)

Find the indices of the k nearest neighbors of every row, excluding itself.

Expected shape:

(n, k)

Use squared Euclidean distance.

Hint: after computing distances, set diagonal to inf.

# Problem 27 — your attempt

n, d, k = 8, 3, 2
x = torch.randn(n, d)

nearest_idx = None

# assert nearest_idx.shape == (n, k)

Solution 27

Code
# Solution 27

n, d, k = 8, 3, 2
x = torch.randn(n, d)

dist = ((x[:, None, :] - x[None, :, :]) ** 2).sum(dim=-1)
dist.fill_diagonal_(float("inf"))

nearest_idx = dist.topk(k, largest=False).indices

assert nearest_idx.shape == (n, k)

# Make sure no row selects itself.
rows = torch.arange(n)[:, None]
assert not torch.any(nearest_idx == rows)

print("All tests passed.")

Problem 28 — One-hot encode token ids

Given:

ids.shape == (batch, seq_len)

Create one-hot vectors:

one_hot.shape == (batch, seq_len, vocab_size)

Avoid loops.

# Problem 28 — your attempt

batch, seq_len, vocab_size = 3, 5, 10
ids = torch.randint(0, vocab_size, (batch, seq_len))

one_hot = None

# assert one_hot.shape == (batch, seq_len, vocab_size)

Use torch.nn.functional.one_hot.

Code
# Solution 28

import torch.nn.functional as F

batch, seq_len, vocab_size = 3, 5, 10
ids = torch.randint(0, vocab_size, (batch, seq_len))

one_hot = F.one_hot(ids, num_classes=vocab_size).float()

assert one_hot.shape == (batch, seq_len, vocab_size)
assert torch.all(one_hot.sum(dim=-1) == 1)

print("All tests passed.")

Problem 29 — Masked mean over tokens

Given:

x.shape == (batch, seq_len, hidden)
mask.shape == (batch, seq_len)

where mask == 1 means valid token and mask == 0 means padding.

Compute the mean hidden vector over valid tokens only.

Expected shape:

(batch, hidden)

Avoid counting padded tokens in the denominator.

# Problem 29 — your attempt

batch, seq_len, hidden = 3, 6, 4
x = torch.randn(batch, seq_len, hidden)
mask = torch.tensor([
    [1, 1, 1, 0, 0, 0],
    [1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 0],
], dtype=torch.float32)

pooled = None

# assert pooled.shape == (batch, hidden)

The mask must become (batch, seq_len, 1) so it can multiply hidden vectors.

Code
# Solution 29

batch, seq_len, hidden = 3, 6, 4
x = torch.randn(batch, seq_len, hidden)
mask = torch.tensor([
    [1, 1, 1, 0, 0, 0],
    [1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 0],
], dtype=torch.float32)

mask_expanded = mask[:, :, None]
summed = (x * mask_expanded).sum(dim=1)
counts = mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = summed / counts

assert pooled.shape == (batch, hidden)

# Manual check for first row:
assert torch.allclose(pooled[0], x[0, :3].mean(dim=0))

print("All tests passed.")

Problem 30 — Segment sum using index_add

You have token features:

x.shape == (num_tokens, hidden)
segment_ids.shape == (num_tokens,)

Each token belongs to one segment. Compute sum of token features per segment.

Example:

segment_ids = [0, 0, 1, 2, 2]

Expected output:

segment_sums.shape == (num_segments, hidden)

This pattern appears in batching, graph neural networks, ragged tensors, and pooling.

# Problem 30 — your attempt

num_tokens, hidden, num_segments = 7, 4, 3
x = torch.randn(num_tokens, hidden)
segment_ids = torch.tensor([0, 0, 1, 2, 2, 1, 0])

segment_sums = None

# assert segment_sums.shape == (num_segments, hidden)

index_add_ accumulates rows of x into rows of segment_sums according to segment_ids.

Code
# Solution 30

num_tokens, hidden, num_segments = 7, 4, 3
x = torch.randn(num_tokens, hidden)
segment_ids = torch.tensor([0, 0, 1, 2, 2, 1, 0])

segment_sums = torch.zeros(num_segments, hidden)
segment_sums.index_add_(0, segment_ids, x)

assert segment_sums.shape == (num_segments, hidden)
assert torch.allclose(segment_sums[0], x[[0, 1, 6]].sum(dim=0))
assert torch.allclose(segment_sums[1], x[[2, 5]].sum(dim=0))
assert torch.allclose(segment_sums[2], x[[3, 4]].sum(dim=0))

print("All tests passed.")

Bonus Challenge — Build a Tiny Attention Shape Pipeline

This combines several ideas from the notebook.

Problem 31 — Mini attention pipeline

Given:

x.shape == (batch, seq_len, hidden)

and projection matrices:

Wq.shape == (hidden, hidden)
Wk.shape == (hidden, hidden)
Wv.shape == (hidden, hidden)

Do the following:

  1. Compute q, k, v.
  2. Split each into heads with shape (batch, num_heads, seq_len, head_dim).
  3. Compute attention scores:
scores.shape == (batch, num_heads, seq_len, seq_len)
  1. Apply causal mask.
  2. Compute attention weights with softmax.
  3. Compute output:
out.shape == (batch, seq_len, hidden)

This is not meant to be a production attention implementation.
It is a shape exercise.

# Problem 31 — your attempt

batch, seq_len, hidden = 2, 5, 12
num_heads = 3
head_dim = hidden // num_heads

x = torch.randn(batch, seq_len, hidden)
Wq = torch.randn(hidden, hidden)
Wk = torch.randn(hidden, hidden)
Wv = torch.randn(hidden, hidden)

out = None

# assert out.shape == (batch, seq_len, hidden)

Solution 31

Code
# Solution 31

import math
import torch.nn.functional as F

batch, seq_len, hidden = 2, 5, 12
num_heads = 3
head_dim = hidden // num_heads

x = torch.randn(batch, seq_len, hidden)
Wq = torch.randn(hidden, hidden)
Wk = torch.randn(hidden, hidden)
Wv = torch.randn(hidden, hidden)

# 1. Linear projections
q = x @ Wq
k = x @ Wk
v = x @ Wv

assert q.shape == (batch, seq_len, hidden)
assert k.shape == (batch, seq_len, hidden)
assert v.shape == (batch, seq_len, hidden)

# 2. Split heads
def split_heads(t):
    t = t.view(batch, seq_len, num_heads, head_dim)
    t = t.permute(0, 2, 1, 3).contiguous()
    return t

q = split_heads(q)
k = split_heads(k)
v = split_heads(v)

assert q.shape == (batch, num_heads, seq_len, head_dim)
assert k.shape == (batch, num_heads, seq_len, head_dim)
assert v.shape == (batch, num_heads, seq_len, head_dim)

# 3. Attention scores
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(head_dim)

assert scores.shape == (batch, num_heads, seq_len, seq_len)

# 4. Causal mask
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
scores = scores.masked_fill(~mask, -1e9)

# 5. Softmax
weights = F.softmax(scores, dim=-1)

assert weights.shape == (batch, num_heads, seq_len, seq_len)

# 6. Weighted sum over values
out_heads = weights @ v

assert out_heads.shape == (batch, num_heads, seq_len, head_dim)

# Merge heads
out = out_heads.permute(0, 2, 1, 3).contiguous()
out = out.view(batch, seq_len, hidden)

assert out.shape == (batch, seq_len, hidden)

print("All tests passed.")

Reflection Questions

Use these to test whether you really understood the notebook.

  1. Why does x[:, 0, :] produce a different rank from x[:, 0:1, :]?
  2. Why does image channel bias need shape (1, channels, 1, 1)?
  3. Why do transformer implementations often permute from (batch, seq, heads, head_dim) to (batch, heads, seq, head_dim)?
  4. Why do we often need .contiguous() after permute() before calling .view()?
  5. Why is keepdim=True useful in reductions?
  6. In pairwise distance computation, why does x[:, None, :] - x[None, :, :] create all pairs?
  7. What is the difference between flattening spatial dimensions and flattening channel dimensions?
  8. Why is masking done before softmax in attention?
  9. What shape should a padding mask have before multiplying hidden states?
  10. Which problems in this notebook felt like shape memorization, and which felt like real reasoning?

Suggested next notebook

02P — Matrix Multiplication Practice

Possible topics:

  • matrix-vector products
  • batched matrix multiplication
  • dot products
  • attention score computation
  • image patch projection
  • linear layer equivalence
  • einsum basics
  • no-loop dynamic programming-style tensor tricks