NOTE: This post is going to be a compiler post, not a machine learning tutorial, so please treat it as such. Less ML than the last post, even.
Hello everyone! I’m back. I wasn’t satisfied with adding a Dot
operator to
micrograd and manually using it in the MLP implementation from the last
post (read that first if you want more context, but
it’s not super necessary). I kept wondering if it was possible to turn all the
+
/*
nodes in the graph to Dot
nodes automatically using an optimizer. So
I did just that.
Forget about all my changes to micrograd: we’re going to start from a clean micrograd and talk about autovectorization. The idea isn’t new, but this seems to be a simple enough and small enough language that it is easy instead of very difficult.
Remember the Value
class? That was the main node type in the graph
intermediate representation. It works on scalars (numbers) instead of bigger
structures like vectors. But Andrej is doing tensor math—dot products, matrix
multiplications. He’s just encoding it at the lowest level—the scalar level.
We’re going to try and find that high-level structure and lift it out.
To do that we’re going to use the union-find/disjoint-set data structure, which, if you don’t think about it too hard, is really simple (proving it fast is a different matter entirely). If you are a programming languages person and haven’t used it much, I recommend taking a look at CF’s Implementing a Toy Optimizer. That was what a) made it stick, b) made it simple, and c) got me hooked. I also recommend this visualization tool by David Galles at USF. His other visualizations are excellent too.
This is, regrettably, an invasive change we will have to make to the Value
class. I think we could do it with a data structure on the side, but maybe I
will leave that as an exercise for the reader.
diff --git a/micrograd/engine.py b/micrograd/engine.py
index afd82cc..1c1863a 100644
--- a/micrograd/engine.py
+++ b/micrograd/engine.py
@@ -7,8 +7,33 @@ class Value:
self.grad = 0
# internal variables used for autograd graph construction
self._backward = lambda: None
- self._prev = set(_children)
+ self._prev = tuple(_children)
self._op = _op # the op that produced this node
+ self.forwarded = None
+
+ def find(self):
+ op = self
+ while isinstance(op, Value):
+ next = op.forwarded
+ if next is None:
+ return op
+ op = next
+ return op
+
+ def args(self):
+ return [v.find() for v in self._prev]
+
+ def arg(self, idx):
+ return self._prev[idx].find()
+
+ def make_equal_to(self, other):
+ self.find().set_forwarded(other)
+
+ def set_forwarded(self, other):
+ if self._op == '':
+ assert self.data == other.data
+ else:
+ self.forwarded = other
The important parts are adding a forwarded
field to the object, the find
function, and the make_equal_to
function. Forcing an order for the _prev
isn’t really necessary but it makes debugging easier.
Union-find is a data structure with two operations:
The funky part is that it’s not exactly objects and sets. When you talk about an object, you also talk about the entire set it’s a part of and you are implicitly speaking about the “representative value” for that set. So really, the two operations are:
And when you think of the sets as equivalence classes for operations—ops that, when run, would produce the same value—you can start to see its value in optimizing an SSA/SSI IR. You can optimize a node in the graph without rewriting all its uses. So for an optimizer, the operations become:
The main cognitive difference when using union-find in an optimizer is that
when you are looking at an object, you need to make sure you are looking at the
representative of the equivalence class—you need to call .find()
.
(And, by the way, if you want to hold multiple optimization versions of an object at once, you start getting into e-graphs. Fun stuff awaits.)
To get a feel for how union-find works, we can take a look at a small example.
Let’s say we have three nodes: a +
and its two children:
l = Value(1)
r = Value(2)
root = l + r
Right now, we have a three-node graph with the +
being the root. If we find
out that the root
node is actually equivalent to the constant 3—say, via a
constant folding pass—we can write root.make_equal_to(Value(3))
. This
doesn’t delete anything from our graph. It actually adds to it. But when we
later call root.find()
, the representative is a constant instead of a plus
node. So in some sense, the graph has shrunk, if you look at the graphs of
only representatives.
Remember that dot(a, b) = a[0]*b[0] + ... + a[n]*b[n]
. Unfortunately, we
don’t have that kind of many-argument +
explicitly encoded in the scalar IR.
What we have looks more like:
v0 = input
v3 = input
v6 = * v0 v3
v10 = + v6 0
v1 = input
v4 = input
v7 = * v1 v4
v11 = + v10 v7
v2 = input
v5 = input
v8 = * v2 v5
v12 = + v11 v8
It’s this deeply-nested tree structure instead of a shallow one. Where’s the dot product? I don’t see anything that looks like a long series of adds… Ideally instead we would have:
v0 = input
v3 = input
v6 = * v0 v3
v1 = input
v4 = input
v7 = * v1 v4
v2 = input
v5 = input
v8 = * v2 v5
v15 = + v6 v7 v8
Where all the *
nodes have been brought together as siblings in the +
node.
Now finding a dot product looks a lot easier. You can see there’s a wide +
node where all the children are multiplications. Great.
Let’s think about how we would do this at the individual node level. In order
to make one +
out of multiple, the tree must look like:
v2 = v0 + v1
v4 = v2 + v3
So, a plus made up of other nested plus nodes. If we take the children of v2
(v0
and v1
) and bring them up to be children of v4
, we get v4 = v0 + v1 + v3
.
Neat. And if v4
’s children aren’t all +
, that’s fine; we just leave the
other operations as they are. Except that we have to make a new node because
we’re not modifying the graph, so we get v5
:
v5 = v0 + v1 + v3
In this graph diagram, I have kept around the old v2
and v4
because we
never really delete them in our optimizer. The garbage collector might get to
it eventually if nothing else uses them. It also illustrates (using a dotted
line) that v4
is forwarded to v5
. That v5
is now the representative for
that equivalence class.
To do this, we make a function to optimize one Value
at a time:
optimize_one
. What we’re looking for is a +
made out of other +
nodes—a
many-argument +
. If we find such a situation, we make a new, wider +
node
with the grandchildren added to it.
def optimize_one(v):
if v._op == "+":
args = v.args()
if any(arg._op == "+" for arg in args):
new_args = []
for arg in args:
if arg._op == "+":
new_args.extend(arg.args())
else:
new_args.append(arg)
v.make_equal_to(Value(0, tuple(new_args), "+"))
Remember that v0
and v1
might be entire graphs on their own, so this is
only one step of a bigger loop. If v0
and v1
are also deeply nested plus
trees, we want to flatten them as well. The “normal” way to do to this
optimization in a functional style is to do a depth-first transformation and
return a new copy of the graph. With union-find, we can avoid doing a bunch of
those copies.
We also already have an operation to traverse leaf-first: topological sort. A topological sort of a graph orders node dependencies before the node itself—children before parents. Bottom up.
So here is one loop over each node in the graph, optimizing from the leaves on
up, doing cascading squishing of +
nodes:
def run_optimize_one(v):
topo = v.topo()
for op in topo:
optimize_one(op.find())
def optimize(v):
run_optimize_one(v)
Let’s check to see if this optimization works. To do that, I use a little
collections.Counter
to check the distribution of Value
node types in the
graph.
# Fake MNIST
dim_in = 28 * 28
net = MLP(dim_in, [50, 10])
model = net([Value(i, (), "input") for i in range(dim_in)])
loss = sum(model) # fake "loss" function to turn the array into a scalar
print(" ", count(loss.find()))
optimize(loss.find())
print(" ", count(loss.find()))
And now if we run it and cross our fingers…
$ time /usr/bin/pypy3.8 test.py
Counter({'': 39761, '+': 39710, '*': 39700, 'input': 784, 'ReLU': 50})
Counter({'': 39761, '*': 39700, 'input': 784, '+': 51, 'ReLU': 50})
$
Alright, that worked extremely well. It looks like the number of +
nodes went
from 39,000 (39 thousand!) to just 51. Fifty-one! And it left all the other
operations in the graph unchanged. Super.
Now that we have all these big +
nodes, we can turn all +
-of-*
patterns
into Dot
. Well, kinda. We have one more step first: Array
.
Since we’re now making vector operations, it makes sense to add a first-class type for them.
class Array(Value):
def __init__(self, data):
super().__init__(0, data, 'array')
Unlike the other node types, there is no transformation happening on the data. It’s collecting other nodes together. So there is no forward or backward happening here.
Alright, let’s take a look at the pass to make Dot
s:
def optimize(v):
run_optimize_one(v):
topo = v.find().topo()
for op in topo:
args = op.args()
if op._op == "+" and any(arg._op == "*" for arg in args):
mul_args = tuple(arg for arg in args if arg._op == "*")
assert all(len(arg._prev) == 2 for arg in mul_args)
mul_left = Array(tuple(arg.arg(0) for arg in mul_args))
mul_right = Array(tuple(arg.arg(1) for arg in mul_args))
other_args = tuple(arg for arg in args if arg._op != "*")
op.make_equal_to(Value(0, (Dot(mul_left, mul_right), *other_args), "+"))
For each +
, if it has *
children, we can optimize. So if it has *
children, we partition those out from the rest. Remember that since the MLP
also adds the bias to the long scalarized dot products, we will have some
data Value
nodes in there too.
We split all the *
nodes—which are still all binary—into two Array
s and
make a Dot
out of those. Then we add all the other non-*
arguments to it in
a new +
node.
Now, what we have right now will make a new array everytime someone does a dot
product with [v0, v1, v2]
. We probably don’t want to make a new array for the
same collection of objects; we can have some reusable storage. To do that, we
memoize the array creation; if the tuple of arguments has been seen before,
return the old Array
.
@functools.lru_cache(maxsize=None)
def hashcons_array(vs):
return Array(vs)
def optimize(v):
run_optimize_one(v):
topo = v.find().topo()
for op in topo:
args = op.args()
if op._op == "+" and any(arg._op == "*" for arg in args):
# ...
mul_left = hashcons_array(tuple(arg.arg(0) for arg in mul_args))
mul_right = hashcons_array(tuple(arg.arg(1) for arg in mul_args))
# ...
We should probably know what we’re building, so let’s take a look at the fancy
new Dot
operator:
class Dot(Value):
def __init__(self, left, right):
assert len(left._prev) == len(right._prev)
super().__init__(0, (left, right), 'dot')
# TODO(max): Figure out a way to compute this automatically using chain
# rule.
def _backward():
left = self._prev[0]
right = self._prev[1]
for i in range(left._prev):
left._prev[i].grad += right._prev[i].data*self.grad
right._prev[i].grad += left._prev[i].data*self.grad
self._backward = _backward
The only really notable thing is the hand-derived (hopefully correct) backpropagation function. You can see my little note about that, too. I think it’s probably possible to use this same technique to build more complex backpropagation functions automatically as we optimize the graph. But I haven’t figured that one out yet.
Again, I ask: does it work? Great question.
$ time /usr/bin/pypy3.8 test.py
Counter({'': 39761, '+': 39710, '*': 39700, 'input': 784, 'ReLU': 50})
Counter({'': 39761, 'input': 784, 'array': 53, 'dot': 51, '+': 51, 'ReLU': 50})
$
All the +
and *
went away! 39,000 of each turned to just 51 dot products
and handful of adds (for the bias).
Remember to think very hard if your machine learning is actually going to help people instead of being a solution search of a problem, or worse, hurt people.
Can we turn these Dot
into Matmul
s? What about automatically deriving
_backward
functions?
What about scheduling? Sure, we have all these vector operations now, but no
CPU actually supports that natively. We have to encode it as x86_64 dppd
instructions or something. Maybe e-graphs would be fun here to optimally
schedule them.
I would like to see how much we can optimize, say, femtoGPT with this.
For my version of the code in this post, check out the mb-vectorize branch of my fork of micrograd.