But... does numpy parallelize? I had always understood that it just does loops 'under the hood', which are faster because they are written more optimally, in C. For example, the docs about broadcasting say: "Broadcasting provides a means of vectorizing array operations so that looping occurs in C instead of Python." https://numpy.org/doc/stable/user/basics.broadcasting.html
I can't find any clear documentation of the extent to which it uses the GPU or parallelizes vector operations. In 2020, a top user demand was to improve performance, and many of the comments request access to the GPU to parallelise. https://numpy.org/user-survey-2020-details/content/2021/priorities.html.
* if you have hardware available, don't need high precision arithmetic, your problem isn't inevitably highly serial, you're not bottlenecked by data movement bandwidth, etc.!
(I like the post and also am super frustrated by numpy, but surely the world is much bigger than matrix operations on GPUs)
This is very interesting to me as I always really loved the fancy indexing and especially the broadcasting in NumPy. There is certainly some friction sometimes, and it often feels like I am chasing some platonic ideal of "if I am able to get this right, the code will be simple, look correct, be correct and DTRT in every situation imaginable".
However, to me, Numpy being polymorphic in the number of dimensions is exactly where the magic is. Being able to handle different problem shapes in the same exact line of code is _the thing_. (Although, admittedly, the range of shapes is probably limited to "one of" or "many of").
> Being able to handle different problem shapes in the same exact line of code is _the thing_.
I think this is the idea behind A[n] working even when A has three dimensions. I like the idea, but in practice it's never worked for me. I'm quite doubtful that there's any implementation of (say) attention that will be both simple and work for multi-head attention.
My perspective is that the solution for having the same line of code with with different shapes is: Write a function for one set of shapes and then use loops+indices or vmap.
I do certainly agree there's some trade-off in terms of generality. But personally, I'd rather give up generality for the simpler goal of being more confident it will work with just one set of shapes!
I'm open to the idea, though, I'm hesitant to do that because I feel like it might imply that it's "real software" as opposed to "think I made to support a long rant"? (Isn't downloading a single file quite easy?)
You are nicely rewinding the clock to around 2009 when "just write the loops and let the system work out how to distribute the work across available processors" last seems to have been in fashion.
> Honestly, I’m a little confused why it isn’t already standard.
I think the reasons are explainable:
---
1. DumPy is not sufficient
The "evil meta-programming AST macro bytecode interception" (i.e. torchdynamo) exists to implement graph transformations, not per-op batching. As a simple example, the two following examples should dispatch to different GPU kernels:
- Z['i','j'] = A['i',:] * B[:,'j']
- Z['j','i'] = A['i',:] * B[:,'j']
If you only needed batched ops, only a registry of batched & unbatched op impls would be needed. If you instead plan to vectorize an entire compute graph, you end up with vmap().
---
2. DumPy is verbose (in certain cases)
DumPy syntax is expression-level. That makes it convenient if I need to express a single operation at higher batch dims.
It is very inconvenient to manually edit an entire NN implementation to be batched with DumPy syntax. It may be more convenient today, with code agents, but up until ~1yr ago it would not have been.
If you believe in explicit shape documentation, jaxtyping is opt-in. I'd be happy with opt-in DumPy too, whereas fully removing broadcasting from torch would be a good way to fracture the ecosystem && force many users downstream to simply stop updating torch. I assume the same is true of jax/numpy.
---
3. Without removals, DumPy implementation would only increase code coverage / workload for compiler maintainers
This issue goes away if you successfully sell DumPy to library maintainers.
I don't think I follow your first point. All DumPy really does is generate jax.vmap statements. You can call jax.jit on DumPy code and JAX will still do all its evil meta-programming and kernel selection (which I agree is necessary, just no *extra* meta-programming for this feature).
it's a multipart argument, because "this should exist" can be argued via:
- "it is sufficient to replace what exists" (deny with [1])
- "it is universally better than what exists" (deny with [2])
- "it would be a reasonably easy add-on" (deny with [3])
I am aware the underlying implementation of MappedFunction as given in dumpy does vmap+jit inside. What I doubt by default is that it's free; maybe I am wrong in this case but the assumption on the torch side of things is that you can't implement the same business logic in a different fashion && expect torch compile to reach the same optimizations as it would in "conventional code"
I believe that is correct for torch. I complained about torch.vmap in the post because I think torch.vmap would limit how well this idea would work in pytorch. But I think jax is good enough. (Not debating your other points for now, I just think this one in particular seems off to me.)
Oh man do i support this. Thank you from every person who learned math with index notation before learning to code and then had to unlearn index notation.
I love index notation. By the way, have you ever used the np.einsum functionality? I found that much, much easier to understand.
Ah you are sadly correct that it is not full functionality, but I use it when I can! Einstein notation is the ONLY way (in my opinion) to avoid horrible tensor confusion headaches
I *think* that should be doable, though I'm not 100% sure how to go about it. Python types are a bit of a mystery to me. Especially here, since I think we'd need some kind of fancy parameterized types?
Yeah, I think theoretically possible but whether it's actually implementable in any real programming language without fixing the size of everything I'm not sure
My impression is that there's a certain (rapidly increasing!) level of complexity where they can nail problems the first time, but then there's a pretty sharp phase change and they become borderline useless? There might be some better workflow, but I didn't get much out of them for this. (I guess it was somewhat helpful to look at what they attempted.)
They were surprisingly OK at making the spec! It was far from perfect, but I edited it and made it (I think) flawless. My big problem was that, even given an extremely detailed spec, they just couldn't implement it, even when repeatedly prompted with errors.
I guess someone could test if AI could take this whole blogpost and generate working code. I doubt it, but I haven't checked!
and I don't understand why the second one continues all the way to 9, rather than stopping at 4. (this happens twice so it seems like it's not a mistake, and it's been like two years since I touched numpy math, so I have no idea if it's intentional)
But... does numpy parallelize? I had always understood that it just does loops 'under the hood', which are faster because they are written more optimally, in C. For example, the docs about broadcasting say: "Broadcasting provides a means of vectorizing array operations so that looping occurs in C instead of Python." https://numpy.org/doc/stable/user/basics.broadcasting.html
I can't find any clear documentation of the extent to which it uses the GPU or parallelizes vector operations. In 2020, a top user demand was to improve performance, and many of the comments request access to the GPU to parallelise. https://numpy.org/user-survey-2020-details/content/2021/priorities.html.
Am I missing something?
Yes, JAX!
Everything is GPU*
* if you have hardware available, don't need high precision arithmetic, your problem isn't inevitably highly serial, you're not bottlenecked by data movement bandwidth, etc.!
(I like the post and also am super frustrated by numpy, but surely the world is much bigger than matrix operations on GPUs)
This is very interesting to me as I always really loved the fancy indexing and especially the broadcasting in NumPy. There is certainly some friction sometimes, and it often feels like I am chasing some platonic ideal of "if I am able to get this right, the code will be simple, look correct, be correct and DTRT in every situation imaginable".
However, to me, Numpy being polymorphic in the number of dimensions is exactly where the magic is. Being able to handle different problem shapes in the same exact line of code is _the thing_. (Although, admittedly, the range of shapes is probably limited to "one of" or "many of").
> Being able to handle different problem shapes in the same exact line of code is _the thing_.
I think this is the idea behind A[n] working even when A has three dimensions. I like the idea, but in practice it's never worked for me. I'm quite doubtful that there's any implementation of (say) attention that will be both simple and work for multi-head attention.
My perspective is that the solution for having the same line of code with with different shapes is: Write a function for one set of shapes and then use loops+indices or vmap.
I do certainly agree there's some trade-off in terms of generality. But personally, I'd rather give up generality for the simpler goal of being more confident it will work with just one set of shapes!
It'd be great to have this available on GitHub and PyPI.
I'm open to the idea, though, I'm hesitant to do that because I feel like it might imply that it's "real software" as opposed to "think I made to support a long rant"? (Isn't downloading a single file quite easy?)
I just barely understand this stuff, but wasn’t Mojo https://en.m.wikipedia.org/wiki/Mojo_(programming_language) trying to do something like this? Though the project doesn’t look very healthy right now
I'm not aware of Mojo having any loop->GPU translation happening, though I imagine it would probably be possible to build DuMojo if you wanted to!
@Dynomight, how long did this take you to build? You seemed to have whipped this up quite quickly!
I'm honestly not sure. Maybe 15-ish hours in total? It was over the course of a long time.
(The excludes the time I wasted trying to get AI to write it. In the end that amounted to almost nothing.)
You are nicely rewinding the clock to around 2009 when "just write the loops and let the system work out how to distribute the work across available processors" last seems to have been in fashion.
> Honestly, I’m a little confused why it isn’t already standard.
I think the reasons are explainable:
---
1. DumPy is not sufficient
The "evil meta-programming AST macro bytecode interception" (i.e. torchdynamo) exists to implement graph transformations, not per-op batching. As a simple example, the two following examples should dispatch to different GPU kernels:
- Z['i','j'] = A['i',:] * B[:,'j']
- Z['j','i'] = A['i',:] * B[:,'j']
If you only needed batched ops, only a registry of batched & unbatched op impls would be needed. If you instead plan to vectorize an entire compute graph, you end up with vmap().
---
2. DumPy is verbose (in certain cases)
DumPy syntax is expression-level. That makes it convenient if I need to express a single operation at higher batch dims.
It is very inconvenient to manually edit an entire NN implementation to be batched with DumPy syntax. It may be more convenient today, with code agents, but up until ~1yr ago it would not have been.
If you believe in explicit shape documentation, jaxtyping is opt-in. I'd be happy with opt-in DumPy too, whereas fully removing broadcasting from torch would be a good way to fracture the ecosystem && force many users downstream to simply stop updating torch. I assume the same is true of jax/numpy.
---
3. Without removals, DumPy implementation would only increase code coverage / workload for compiler maintainers
This issue goes away if you successfully sell DumPy to library maintainers.
---
I don't think I follow your first point. All DumPy really does is generate jax.vmap statements. You can call jax.jit on DumPy code and JAX will still do all its evil meta-programming and kernel selection (which I agree is necessary, just no *extra* meta-programming for this feature).
it's a multipart argument, because "this should exist" can be argued via:
- "it is sufficient to replace what exists" (deny with [1])
- "it is universally better than what exists" (deny with [2])
- "it would be a reasonably easy add-on" (deny with [3])
I am aware the underlying implementation of MappedFunction as given in dumpy does vmap+jit inside. What I doubt by default is that it's free; maybe I am wrong in this case but the assumption on the torch side of things is that you can't implement the same business logic in a different fashion && expect torch compile to reach the same optimizations as it would in "conventional code"
I believe that is correct for torch. I complained about torch.vmap in the post because I think torch.vmap would limit how well this idea would work in pytorch. But I think jax is good enough. (Not debating your other points for now, I just think this one in particular seems off to me.)
Oh man do i support this. Thank you from every person who learned math with index notation before learning to code and then had to unlearn index notation.
I love index notation. By the way, have you ever used the np.einsum functionality? I found that much, much easier to understand.
Yes, I am pro-einsum! https://dynomight.net/numpy/#ok-i-lied
Ah you are sadly correct that it is not full functionality, but I use it when I can! Einstein notation is the ONLY way (in my opinion) to avoid horrible tensor confusion headaches
Another typo spot
# legal in both numpy and dumpy
A[1, 1:6, C, 2:10]
I think the C should be B to match the text below.
You're right, thanks! (Actually I think I should change the text to C.)
This is great though. Just need static type checking to guarantee that all your stuff is correctly sized when you run the code.
I *think* that should be doable, though I'm not 100% sure how to go about it. Python types are a bit of a mystery to me. Especially here, since I think we'd need some kind of fancy parameterized types?
Yeah, I think theoretically possible but whether it's actually implementable in any real programming language without fixing the size of everything I'm not sure
I'm curious if you used or attempted to use any LLM coding tools in this effort?
My feeling is that they tend to increase confusion in confusing situations.
Good question! The answer is I tried but failed: https://mastodon.social/@dynomight/114438979721967821
My impression is that there's a certain (rapidly increasing!) level of complexity where they can nail problems the first time, but then there's a pretty sharp phase change and they become borderline useless? There might be some better workflow, but I didn't get much out of them for this. (I guess it was somewhat helpful to look at what they attempted.)
Thanks!
My feeling is that they are good at:
a) this thing is like that thing
and
b) this thing is like that thing but different
but
c) creating the spec for 'this thing is like that thing but different' is difficult for various reasons, including but limited to "dum"
They were surprisingly OK at making the spec! It was far from perfect, but I edited it and made it (I think) flawless. My big problem was that, even given an extremely detailed spec, they just couldn't implement it, even when repeatedly prompted with errors.
I guess someone could test if AI could take this whole blogpost and generate working code. I doubt it, but I haven't checked!
Hmm, now wondering how to make DyTorch with the same idea.
My impression is that torch.vmap would be the weak link here, though I could be wrong about that!
small code typo:
> Z[i,j] = Y[k,:] @ dp.linalg.solve(A[i,j,:,:], X[i,:])
this should be `j` instead of `k`.
Yes, thank you!
oh, another typo:
> Hij = 1/(1+j+1)
should probably be i+j+1 inside?
Keep them coming!
well there's this bit...
> I = dp.Array([0,1,2,3,4])
> J = dp.Array([0,1,2,3,4,5,6,7,8,9])
> X['i','j'] = 1 / (1 + I['i'] + J['j'])
and I don't understand why the second one continues all the way to 9, rather than stopping at 4. (this happens twice so it seems like it's not a mistake, and it's been like two years since I touched numpy math, so I have no idea if it's intentional)
Don't overthink! It's just because I'm dum