My thoughts on JAX and Julia



Shared attributes

I consider both languages to be advanced tools in the realm of scientific computing. JAX offers a DSL with many restrictions on function purity and control flow, likewise writing idiomatic code is necessary to achieve maximum performance in Julia.

Writing code without a proper understanding of how the respective tools work will result in frustrating, error-prone development potentially leading to inefficient code. Compare this to NumPy, for example, which places far fewer restrictions on what is or isn’t a valid expression as it doesn’t have the same goals, namely performance and correctness, as the more constraining tools.

On the flipside, with an appreciation for the inner mechanics of their compilers one can write concise code to perform complex transformations. Moreover, this code will be free of the syntactic noise found in traditional performant approaches and the hacky tendencies of simpler tools.

Luckily, another strength of both tools is the community supporting them. If a user is struggling with a particular problem or wants some advice on their code then there are many sponges of knowledge lurking on the discussion platforms — primarily Discourse for Julia and GitHub for JAX — and so no question (following the guidelines) is left unanswered. Consequently, however, a lot of information is buried in threads or the brains of a clever few, something that is especially impactful as documentation of advanced concepts is lacklustre (more so for Julia). While these may be mostly edge cases, the curse of dimensionality tells us that if a project is complicated enough, one is very likely to encounter a problem that can only be solved by such information.

Julia is less restrictive

The approach taken by Julia’s core creators is that Julia should be a general-purpose language. Its allows a mix of procedural and functional code, offers strong meta-programming capabilities, and has libraries stretching to nearly every scientific domain.

As alluded to earlier, JAX is specifically built for high-performance machine learning research. While delivered as a Python extension (or two), it acts as a functional array programming language that leaves everything unrelated to this to its host language.

The restrictions JAX imposes are well covered in its documentation and this provides a refreshingly simple list of things that the programmer has to take into account while reasoning about the implementation of an algorithm. In summary, functions must be pure, arrays are immutable and the limited updates it allows are performed out-of-place, compiled function arguments must be JAX arrays or directly derived therefrom, if-statements cannot depend on run-time values, and random numbers require slightly more than zero effort.

Julia’s comparatively permissive attitude is fantastic for those who “just want things to work” until the user inevitably notices how painfully slow their program is running and either complains to the community or goes back to Python. It is precisely because of its general-purpose goal that the language hands out footguns so freely. Unwilling to impose what is or isn’t possible, it won’t make so much as a peep when encountering untyped global variables, a type instability in the output of a function, abstract typed containers or an instruction to heap-allocate the same vector a million times even if it never leaves scope.

On the other hand, if you are in a field that JAX is designed for then everything works out of the box. In my personal experience, JAX’s limitations are a blessing more often than a curse. The fact that it forces the functional paradigm is a core part of why JAX feels so good as it is a natural way of solving problems in maths-heavy domains like deep learning. In comparison, when I program in Julia I often want to write code functionally, but can’t afford to for performance reasons because a for-loop would be so much faster. I applaud the fact that I can write tacit functions with the unicode character meaning composition (e.g. gcd ⚬ collect ⚬ extrema from this fun series), but I weep when I have to go from thinking in terms of high-level array transformations to thinking in terms of for-loops, indices, and procedural updates.

Julia’s utility tools are unmatched by Python and JAX

Rust’s Cargo is the gold standard. For such a complex language, Cargo is a beacon of user-friendly design (and btw it’s ⚡ blazingly fast ⚡). Leaning heavily on Cargo’s shoulders, Pkg makes it easy to do environments, dependency management and packaging in a reproducible way, an asset amidst science’s replication crisis.

The best efforts in the Python ecosystem to replicate Cargo ’s successes are Poetry (v1.0 in December 2019) and Hatch (v1.0 in April 2022). However, their adoption has been less-than-instant in the scientific community, a people well known for unenthusiastically learning badly-designed software and developing Stockholm syndrome once something new comes along.

Even so, Python’s tools are not as integrated into their target language as Julia’s Pkg. The REPL could be a section all on its own, but its Pkg mode, activated with a single input (]), makes it trivial to do useful but odd things like test out a new package, only to uninstall it 30s later. Similarly, ? mode in the REPL improves productivity by letting users fuzzy-search documentation of functions, macros, and types.

Finally, once you’ve used BenchmarkTools you never want to go back to timeit, whose bizarre syntax requires looking up every time I use it. Julia’s inbuilt @time and BenchmarkTools’ @btime and @benchmark allow you to annotate a line of code and immediately know its distribution of execution times, how many allocations occur, and the total space allocated. Furthermore, Julia’s profiler shows where allocations occur as of v1.8 (previously just execution time), and can be invoked and viewed by the @profview macro that comes with Julia’s VSCode extension.

The popularity of Jupyter notebooks for scientific Python development is no surprise given the productivity boost magic commands like %timeit and %prun provide.

Unfortunately, JAX itself is difficult to profile. The opaque names of operations in the graphical trace means it can be difficult to determine what’s going on or what has happened to your code.

Julia is always compiled, JAX has Python as a fallback

The compiled nature of both tools means that the time between running a script and seeing results is far larger than an interpreted language like Python. It’s no secret that JAX’s compilation takes far longer than Julia’s in many cases, but the fact that JAX users can selectively compile bits of their code means that one of Julia’s biggest pain points, time-to-first-X, can be completely sidestepped.

Python is a limiting surrogate language for JAX

Julia was designed around multiple dispatch. Its system of abstract, concrete, and parametric types give rise to an incredibly useful toolbox for writing generic, performant code. It also provides simple, extensible ways for handling many problems that occur commonly in science, such as missing values handled through the Missing type. As a side effect, the type system also provides another one of Julia’s more frustrating points: time-to-first-Method Error (as anyone who has used autodiff in Julia will know).

Python’s type system is a cobbled together afterthought that looks especially weak in the current landscape of programming languages. JAX unfortunately suffers for this, even in its limited scope of array programming and especially when compounded with its meagre capacity for metaprogramming. While a powerful tool like jaxtyping can provide the capability to perform compile-time shape/type checking on JAX arrays (at no performance cost!), it essentially builds a new parametric type system out of strings to do so.

Which tool should I use?

JAX and Julia are powerful tools that require a non-trivial amount of investment to wield. JAX’s restrictive approach means that it is the perfect tool for its limited scope. Julia main goal was to solve the two-language problem and it has done so fairly well*. Its limitations to more widespread adoption in science come mostly from its frustrating error messages, slow compilation speed, incredibly high knowledge-ceiling of mastery, and lack of teaching tools (and a JetBrains designed IDE).

Once mastered, both JAX and Julia can be used in almost all scientific settings. That said, using JAX in a domain where reasoning in a functional way is difficult then inexperienced JAX users may struggle. Furthermore, JAX’s fundamental transformations is built for SIMD-heavy operations that run on GPUs/TPUs so code not taking advantage of vmap may run faster in Julia.

For Julia, hell is other people’s code. The ecosystem is highly developed in certain areas (see SciML for example) but stability and fragmentation are issues in others. If one has a lot of domain knowledge and is planning to write a lot of custom code then Julia is a great language to choose.

*Although ultra-performant Julia code can sometimes read like a different language