This series is not written for the “happy path” learner.
It is for the person who: - knows some math but still freezes at tensor shapes, - understands operations in isolation but gets lost in actual code, - does not want hand-wavy comfort, - wants a shape-first, debug-first, hacker-style understanding.
The thesis of this notebook is simple:
Most tensor confusion is not about algebra. It is about shape tracking.
If you build the habit of reading code as:
\[
[\text{shape of left}] \;\to\; \text{operation} \;\to\; [\text{shape of result}]
\]
then a huge part of PyTorch becomes less magical and more mechanical.
We will stay practical, but we will not stay shallow.
Setup
We will use PyTorch, print shapes aggressively, and keep the examples small enough to inspect.
A recurring pattern in this notebook:
state the shape,
state the operation,
state the resulting shape,
then verify in code.
That habit matters more than memorizing functions.
4. The most useful distinction in practice: * vs @
A lot of confusion comes from mixing up these two.
Element-wise multiplication
\[
[a_{ij}] * [b_{ij}] = [a_{ij} b_{ij}]
\]
Same position multiplied with same position.
Matrix multiplication
\[
C_{ij} = \sum_k A_{ik} B_{kj}
\]
This is not entry-by-entry. It is row-by-column with a sum.
A = torch.tensor([[1., 2.], [3., 4.]])B = torch.tensor([[10., 20.], [30., 40.]])printbold("Element-wise A * B:")print(A * B)print()printbold("Matrix multiply A @ B:")print(A @ B)
Element-wise A * B:
tensor([[ 10., 40.],
[ 90., 160.]])
Matrix multiply A @ B:
tensor([[ 70., 100.],
[150., 220.]])
5. Shape mechanics for 1D tensors in PyTorch
This is where PyTorch takes some convenience liberties.
Case A: vector on the left
\[
[n] @ [n, p] \to [p]
\]
Internally, PyTorch behaves roughly like:
\[
[1, n] @ [n, p] \to [1, p] \to [p]
\]
Case B: vector on the right
\[
[m, n] @ [n] \to [m]
\]
Internally:
\[
[m, n] @ [n, 1] \to [m, 1] \to [m]
\]
This is convenient in code, but less explicit than pure linear algebra notation.
a = torch.randn(4)B = torch.randn(4, 6)A = torch.randn(5, 4)b = torch.randn(4)left_result = a @ Bright_result = A @ bprintbold("a.shape:", a.shape)printbold("B.shape:", B.shape)printbold("a @ B shape:", left_result.shape)print("", "")printbold("A.shape:", A.shape)printbold("b.shape:", b.shape)printbold("A @ b shape:", right_result.shape)
a.shape: torch.Size([4])
B.shape: torch.Size([4, 6])
a @ B shape: torch.Size([6])
A.shape: torch.Size([5, 4])
b.shape: torch.Size([4])
A @ b shape: torch.Size([5])
6. Broadcasting: the rule that removes loops
Broadcasting is one of the most important concepts in PyTorch.
The rule is simple:
compare shapes from the right
dimensions are compatible if they are equal, or one of them is 1
missing dimensions are treated like leading 1s
Example
\([2,3] + [3]\)
Right-align the shapes:
\([2,3]\)
\([1,3]\)
Now apply the rule:
last dimension: \((3)\) vs \((3)\) → OK
second dimension: \((2)\) vs \((1)\) → stretch \((1 \rightarrow 2)\)
So the second tensor behaves like:
\([2,3]\)
Final result:
\([2,3]\)
Broadcasting is just implicit expansion along size-1 dimensions.
A = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) # [2, 3]b = torch.tensor([10., 20., 30.]) # [3]C = A + bprintbold("A.shape:", A.shape)printbold("b.shape:", b.shape)printbold("C.shape:", C.shape)print()printbold("Value of C is:")print(C)
A.shape: torch.Size([2, 3])
b.shape: torch.Size([3])
C.shape: torch.Size([2, 3])
Value of C is:
tensor([[11., 22., 33.],
[14., 25., 36.]])
7. Broadcasting is not copying in spirit. It is virtual expansion
Conceptually, PyTorch behaves as if a smaller tensor were expanded.
But you should think:
not “copy data everywhere”
but “treat it as repeatable along size-1 axes”
This matters because broadcasting is how tensor code stays compact and fast.
Two-way axis broadcasting
This is one of the most important non-happy-path patterns.
Take:
\([3,5,1] * [3,1,7] \to [3,5,7]\)
Why?
first dimension: \((3)\) matches \((3)\)
second: \((5)\) with \((1)\) stretches to \((5)\)
third: \((1)\) with \((7)\) stretches to \((7)\)
You are not “removing” information.
You are letting each tensor contribute structure on different axes.
C = torch.randn(3, 5, 1)D = torch.randn(3, 1, 7)E = C * Dprintbold("C.shape:", C.shape)printbold("D.shape:", D.shape)printbold("E.shape:", E.shape)
17. Tensor contraction mindset without jargon overload
A deep but useful idea:
Many tensor operations are of the form:
keep some axes,
multiply along some axes,
sum over those axes.
That is what matrix multiplication already does.
For:
\(C_{ij} = \sum_k A_{ik}B_{kj}\)
the axis \((k)\) is the contracted axis.
You do not “throw away” that axis.
You collapse it into a scalar contribution for each remaining index pair \(((i,j))\).
18. Why shape literacy matters more than memorizing APIs
You can forget function names and still survive if you know: - which axes should match, - which axes should remain, - which axes should broadcast, - and what each axis means.
But if you memorize functions without semantics, you’ll keep fighting the framework.
19. Anti-textbook checklist
Before writing a tensor operation, ask:
What does each axis mean?
What shape do I want at the end?
Which dimensions should match?
Which dimensions should broadcast?
Am I changing storage layout or just interpretation?
If I printed all shapes now, would the code still make sense?
That is how you stop “hoping PyTorch understands your intention”.
20. Mini summary
The core rules
A tensor is inseparable from its shape.
[3] is not [1,1].
* and @ are fundamentally different.
Broadcasting compares from the right.
Size-1 dimensions are stretchable.
unsqueeze makes broadcasting intentional.
view changes interpretation, not semantics automatically.
transpose / permute change axis order.
GPU changes device, not math.
Most PyTorch confusion is really shape confusion.
The notebook thesis
If you track shapes aggressively, tensors become less scary and more programmable.
# Final compact cheat blockprint("[m, n] @ [n, p] -> [m, p]")print("[..., m, n] @ [..., n, p] -> [..., m, p]")print("[2, 3] + [3] -> [2, 3] (broadcast)")print("[3, 1] * [4] -> [3, 4] (outer-product style broadcast)")print("view/reshape: same number of elements, new shape")print("transpose/permute: same values, different axis order")
[m, n] @ [n, p] -> [m, p]
[..., m, n] @ [..., n, p] -> [..., m, p]
[2, 3] + [3] -> [2, 3] (broadcast)
[3, 1] * [4] -> [3, 4] (outer-product style broadcast)
view/reshape: same number of elements, new shape
transpose/permute: same values, different axis order