If you want to replace Numpy with a work-alike library, consider KxSystems (kx.com) for inspiration. They offer a well-designed Tensor library derived from APL, bundled with an in-memory columnar database.
I am a big jax fan and as others have pointed out you can solve your problem with vmap. But furthermore I feel like one of the problems with tensors is that they are just hard to get your head around. If you write out multi-headed attention on a piece of paper with diagrams (but so that it stays clear which dimension does what) it still doesn't look easy. The solution for me is just to describe what each dimension does in each step as a comment.
I was waiting for someone to mention it, but nobody did, so obligatory comment: why not julia? It has actual first-class linear algebra, you can actually write loops that are not slow (which is not only nice, but also means that the whole ecosystem is not built around avoiding loops), and it also has an einsum with sane notation (Tullio.jl). It can solve the linear problems on the GPU with a loop, but it might not be the most efficient way to do it, for that you probably have to hit a lower-level API.
Well, julia makes loops fast, but it doesn't (I think) solve the problem of creating an easy API to interface to the GPU. If you really want everything to run on the GPU, my impression was that julia was still way less mature than python?
Depends what you mean by everything, but this should work:
using CUDA
A = CUDA.randn(100, 5, 5)
x = CUDA.randn(100, 5)
y = CUDA.zeros(100, 5)
for i = 1:100
@views y[i, :] .= A[i, :, :] \ y[i, :]
end
It's just going to launch 100 GPU kernels, which is not very efficient in ML use cases (and also indexing is reversed in julia so this is not good). I don't know how to do it with one cuda call (but I don't know anything about GPUs)
Rather than "numpy but sane" I think a better approach is "APL but familiar". Q is an example of an APL-like language deliberately aimed at mainstream programmers. It keeps the fundamentals but syntactically uses words instead of single-character symbols.
One fundamental design choice where I feel numpy went wrong is broadcasting rules instead of the "leading axis model". Numpy applies operations to the innermost dimensions first by default (with replication and extending according to arcane broadcasting rules) whereas APL (and others like J and BQN) apply operations to the outermost axes first, but allow you to override this with a user-defined "function rank". So if you write a function for one shape of data it's trivial to reuse the same function for differently-shaped data.
I've spent a little time getting familiar with APL-type languages. I'm OK with the crazy symbols, but I still find it requires a lot of thinking. Maybe after I propose my "better numpy", you can tell me if you think APL/J/K/Q would do it equally well...
While I definitely failed to read your initial example right on the first pass, I think the particular example chosen isn't ambiguous.
The `Returns` part of the docstring strictly implies that all inputs b with ndim >= 2 will be interpreted as the M,K case.
From there, my (2nd attempt) instincts said to let K=1 and unsqueeze-squeeze, which was correct.
·
In general, I do agree that tensor methods can be ill-formed.
Contractually, it would be nice if all ops had
* a single unbatched implementation with fixed ndims for all inputs,
* clear descriptions of what broadcasting/batching semantics are supported for the op.
Or, if you could remake everything from scratch, you would gate all batching behavior behind a vmap() wrap which simply replaces all ops with their batched variant, making intended code behavior obvious... but this gets yucky when you need to flatten/split a batch dim for reasons.
·
I don't think broadcasting can be ditched. As you describe, it's good in "simple" cases, and it's quite easy to see how ugly code would become in the no-broadcast world. I think most cases of bad ambiguity go away if your broadcasting system is repeat-only rather than also ndim modifying.
·
I agree arrays-in-indices ("advanced" indexing) was a mistake. It is more confusing than a simple .take/.index_select. The nd cases are even worse.
·
I'm eagerly looking forward to the “better” NumPy :)
I tried learning Numpy a while back and found it completely incomprehensible, but I chalked it up to not having taken linear algebra. Happy to know it's not just me.
If you stay out of np.linalg, there's really not much linear algebra. I actually wonder if it might be easier to learn the basics with a less general-purpose language like octave and then switch over to numpy later. (The ecosystem of other packages with numpy/python is insanely better, so you'd probably want to learn that eventually.)
I've never used NumPy, but I've done some scripting in R (though nothing nearly as complex as your examples). How does R's array handling compare to NumPy? Just curious.
As an R afficionado myself, I was going to post this exact same question -- including the disclaimer about my problems not being as complex! But you beat me to it.
I'm not an expert with R, but from what I've seen it's broadly similar, but it doesn't try quite as hard to be general. So I think it's a bit less powerful for array stuff but also less complex. But again—not an expert!
If you want to replace Numpy with a work-alike library, consider KxSystems (kx.com) for inspiration. They offer a well-designed Tensor library derived from APL, bundled with an in-memory columnar database.
I am a big jax fan and as others have pointed out you can solve your problem with vmap. But furthermore I feel like one of the problems with tensors is that they are just hard to get your head around. If you write out multi-headed attention on a piece of paper with diagrams (but so that it stays clear which dimension does what) it still doesn't look easy. The solution for me is just to describe what each dimension does in each step as a comment.
jax.vmap definitely helps, but I still find it hard. (Too much axis= stuff). It's the basis for my new thing though!
I was waiting for someone to mention it, but nobody did, so obligatory comment: why not julia? It has actual first-class linear algebra, you can actually write loops that are not slow (which is not only nice, but also means that the whole ecosystem is not built around avoiding loops), and it also has an einsum with sane notation (Tullio.jl). It can solve the linear problems on the GPU with a loop, but it might not be the most efficient way to do it, for that you probably have to hit a lower-level API.
Well, julia makes loops fast, but it doesn't (I think) solve the problem of creating an easy API to interface to the GPU. If you really want everything to run on the GPU, my impression was that julia was still way less mature than python?
> einsum with sane notation (Tullio.jl)
This is very nice! https://github.com/mcabbott/Tullio.jl
Depends what you mean by everything, but this should work:
using CUDA
A = CUDA.randn(100, 5, 5)
x = CUDA.randn(100, 5)
y = CUDA.zeros(100, 5)
for i = 1:100
@views y[i, :] .= A[i, :, :] \ y[i, :]
end
It's just going to launch 100 GPU kernels, which is not very efficient in ML use cases (and also indexing is reversed in julia so this is not good). I don't know how to do it with one cuda call (but I don't know anything about GPUs)
I suggest you try `jax.vmap`, it greatly simplifies situations like this: https://docs.jax.dev/en/latest/automatic-vectorization.html
If you like jax.vmap, you'll love my new thing. Or maybe you won't... But it builds heavily on jax.vmap.
Rather than "numpy but sane" I think a better approach is "APL but familiar". Q is an example of an APL-like language deliberately aimed at mainstream programmers. It keeps the fundamentals but syntactically uses words instead of single-character symbols.
One fundamental design choice where I feel numpy went wrong is broadcasting rules instead of the "leading axis model". Numpy applies operations to the innermost dimensions first by default (with replication and extending according to arcane broadcasting rules) whereas APL (and others like J and BQN) apply operations to the outermost axes first, but allow you to override this with a user-defined "function rank". So if you write a function for one shape of data it's trivial to reuse the same function for differently-shaped data.
I've spent a little time getting familiar with APL-type languages. I'm OK with the crazy symbols, but I still find it requires a lot of thinking. Maybe after I propose my "better numpy", you can tell me if you think APL/J/K/Q would do it equally well...
To contextualize for those of us who are uninitiated: What are you doing with NumPy that requires this sort of complex matrix multiplication?
Believe it or not, things as simple as making plots like these: https://dynomight.net/theanine-2/ (Also: Multi-headed self-attention)
I've always thought this without verbalizing it. Waiting with baited breath for your "numpy but sane".
While I definitely failed to read your initial example right on the first pass, I think the particular example chosen isn't ambiguous.
The `Returns` part of the docstring strictly implies that all inputs b with ndim >= 2 will be interpreted as the M,K case.
From there, my (2nd attempt) instincts said to let K=1 and unsqueeze-squeeze, which was correct.
·
In general, I do agree that tensor methods can be ill-formed.
Contractually, it would be nice if all ops had
* a single unbatched implementation with fixed ndims for all inputs,
* clear descriptions of what broadcasting/batching semantics are supported for the op.
Or, if you could remake everything from scratch, you would gate all batching behavior behind a vmap() wrap which simply replaces all ops with their batched variant, making intended code behavior obvious... but this gets yucky when you need to flatten/split a batch dim for reasons.
·
I don't think broadcasting can be ditched. As you describe, it's good in "simple" cases, and it's quite easy to see how ugly code would become in the no-broadcast world. I think most cases of bad ambiguity go away if your broadcasting system is repeat-only rather than also ndim modifying.
·
I agree arrays-in-indices ("advanced" indexing) was a mistake. It is more confusing than a simple .take/.index_select. The nd cases are even worse.
·
I'm eagerly looking forward to the “better” NumPy :)
You mean the np.linalg.solve documentation? Yeah, I don't think it's actually ambiguous, just very hard to understand!
I tried learning Numpy a while back and found it completely incomprehensible, but I chalked it up to not having taken linear algebra. Happy to know it's not just me.
If you stay out of np.linalg, there's really not much linear algebra. I actually wonder if it might be easier to learn the basics with a less general-purpose language like octave and then switch over to numpy later. (The ecosystem of other packages with numpy/python is insanely better, so you'd probably want to learn that eventually.)
I've never used NumPy, but I've done some scripting in R (though nothing nearly as complex as your examples). How does R's array handling compare to NumPy? Just curious.
As an R afficionado myself, I was going to post this exact same question -- including the disclaimer about my problems not being as complex! But you beat me to it.
I eagerly await Dynomight's reply.
I'm not an expert with R, but from what I've seen it's broadly similar, but it doesn't try quite as hard to be general. So I think it's a bit less powerful for array stuff but also less complex. But again—not an expert!