20 Comments
User's avatar
Roy Lowrance's avatar

If you want to replace Numpy with a work-alike library, consider KxSystems (kx.com) for inspiration. They offer a well-designed Tensor library derived from APL, bundled with an in-memory columnar database.

Expand full comment
Niclas's avatar

I am a big jax fan and as others have pointed out you can solve your problem with vmap. But furthermore I feel like one of the problems with tensors is that they are just hard to get your head around. If you write out multi-headed attention on a piece of paper with diagrams (but so that it stays clear which dimension does what) it still doesn't look easy. The solution for me is just to describe what each dimension does in each step as a comment.

Expand full comment
dynomight's avatar

jax.vmap definitely helps, but I still find it hard. (Too much axis= stuff). It's the basis for my new thing though!

Expand full comment
Antoine Levitt's avatar

I was waiting for someone to mention it, but nobody did, so obligatory comment: why not julia? It has actual first-class linear algebra, you can actually write loops that are not slow (which is not only nice, but also means that the whole ecosystem is not built around avoiding loops), and it also has an einsum with sane notation (Tullio.jl). It can solve the linear problems on the GPU with a loop, but it might not be the most efficient way to do it, for that you probably have to hit a lower-level API.

Expand full comment
dynomight's avatar

Well, julia makes loops fast, but it doesn't (I think) solve the problem of creating an easy API to interface to the GPU. If you really want everything to run on the GPU, my impression was that julia was still way less mature than python?

> einsum with sane notation (Tullio.jl)

This is very nice! https://github.com/mcabbott/Tullio.jl

Expand full comment
Antoine Levitt's avatar

Depends what you mean by everything, but this should work:

using CUDA

A = CUDA.randn(100, 5, 5)

x = CUDA.randn(100, 5)

y = CUDA.zeros(100, 5)

for i = 1:100

@views y[i, :] .= A[i, :, :] \ y[i, :]

end

It's just going to launch 100 GPU kernels, which is not very efficient in ML use cases (and also indexing is reversed in julia so this is not good). I don't know how to do it with one cuda call (but I don't know anything about GPUs)

Expand full comment
Julian's avatar

I suggest you try `jax.vmap`, it greatly simplifies situations like this: https://docs.jax.dev/en/latest/automatic-vectorization.html

Expand full comment
dynomight's avatar

If you like jax.vmap, you'll love my new thing. Or maybe you won't... But it builds heavily on jax.vmap.

Expand full comment
Alex Shroyer's avatar

Rather than "numpy but sane" I think a better approach is "APL but familiar". Q is an example of an APL-like language deliberately aimed at mainstream programmers. It keeps the fundamentals but syntactically uses words instead of single-character symbols.

One fundamental design choice where I feel numpy went wrong is broadcasting rules instead of the "leading axis model". Numpy applies operations to the innermost dimensions first by default (with replication and extending according to arcane broadcasting rules) whereas APL (and others like J and BQN) apply operations to the outermost axes first, but allow you to override this with a user-defined "function rank". So if you write a function for one shape of data it's trivial to reuse the same function for differently-shaped data.

Expand full comment
dynomight's avatar

I've spent a little time getting familiar with APL-type languages. I'm OK with the crazy symbols, but I still find it requires a lot of thinking. Maybe after I propose my "better numpy", you can tell me if you think APL/J/K/Q would do it equally well...

Expand full comment
Sol Hando's avatar

To contextualize for those of us who are uninitiated: What are you doing with NumPy that requires this sort of complex matrix multiplication?

Expand full comment
dynomight's avatar

Believe it or not, things as simple as making plots like these: https://dynomight.net/theanine-2/ (Also: Multi-headed self-attention)

Expand full comment
Brzozowski's avatar

I've always thought this without verbalizing it. Waiting with baited breath for your "numpy but sane".

Expand full comment
Sherman's avatar

While I definitely failed to read your initial example right on the first pass, I think the particular example chosen isn't ambiguous.

The `Returns` part of the docstring strictly implies that all inputs b with ndim >= 2 will be interpreted as the M,K case.

From there, my (2nd attempt) instincts said to let K=1 and unsqueeze-squeeze, which was correct.

·

In general, I do agree that tensor methods can be ill-formed.

Contractually, it would be nice if all ops had

* a single unbatched implementation with fixed ndims for all inputs,

* clear descriptions of what broadcasting/batching semantics are supported for the op.

Or, if you could remake everything from scratch, you would gate all batching behavior behind a vmap() wrap which simply replaces all ops with their batched variant, making intended code behavior obvious... but this gets yucky when you need to flatten/split a batch dim for reasons.

·

I don't think broadcasting can be ditched. As you describe, it's good in "simple" cases, and it's quite easy to see how ugly code would become in the no-broadcast world. I think most cases of bad ambiguity go away if your broadcasting system is repeat-only rather than also ndim modifying.

·

I agree arrays-in-indices ("advanced" indexing) was a mistake. It is more confusing than a simple .take/.index_select. The nd cases are even worse.

·

I'm eagerly looking forward to the “better” NumPy :)

Expand full comment
dynomight's avatar

You mean the np.linalg.solve documentation? Yeah, I don't think it's actually ambiguous, just very hard to understand!

Expand full comment
mithrandir15's avatar

I tried learning Numpy a while back and found it completely incomprehensible, but I chalked it up to not having taken linear algebra. Happy to know it's not just me.

Expand full comment
dynomight's avatar

If you stay out of np.linalg, there's really not much linear algebra. I actually wonder if it might be easier to learn the basics with a less general-purpose language like octave and then switch over to numpy later. (The ecosystem of other packages with numpy/python is insanely better, so you'd probably want to learn that eventually.)

Expand full comment
Alex C.'s avatar

I've never used NumPy, but I've done some scripting in R (though nothing nearly as complex as your examples). How does R's array handling compare to NumPy? Just curious.

Expand full comment
DH's avatar

As an R afficionado myself, I was going to post this exact same question -- including the disclaimer about my problems not being as complex! But you beat me to it.

I eagerly await Dynomight's reply.

Expand full comment
dynomight's avatar

I'm not an expert with R, but from what I've seen it's broadly similar, but it doesn't try quite as hard to be general. So I think it's a bit less powerful for array stuff but also less complex. But again—not an expert!

Expand full comment