Hacker News new | past | comments | ask | show | jobs | submit login
Exact binary vector search for RAG in 100 lines of Julia (domluna.com)
180 points by lunaticd 22 days ago | hide | past | favorite | 23 comments



Dom! Fellow Julian here!

I loved this post <3 The hamming distance is one of my favorite demos of the conciseness of Julia:

hamming_distance(s1, s2) = mapreduce(!=, +, s1, s2)

I'm a bit swamped at the moment but I'll a response article later - they're still some juicy perf on the table here.

Thanks for the post, such a good showcase.


please make it even faster!


And this right here is why the Julia crowd is awesome...

A great article inspires everyone to participate in the fun.

Cheers =3


I wanted to get too fancy and I tried * LoopVectorization.jl - @turbo choked on the loop * a direct llvmcall to use AVX512 pop count - I malformed the types for the instruction * Defining the `db` as

db = [rand(Int8) for _ in 1:64, j in 1:(10^6)];

to avoid the vec of vecs structure,

and then

function my_cluster!(db, query, k) db .= query .⊻ db popcounts = mapreduce(count_ones, +, db, dims = 1) results = reshape(popcounts, last(size(db))) partialsortperm!(results, results, k) @views results[begin:k] end

...which I couldn't get to be faster than your version. If you use the `partialsortperm!` and reuse the same cache array, I suspect you'll get good speedups, as you won't be sorting the array every time. This is a classic `nth_element` algorithm.

The above is not the most amazing code, but I suspect the lack of indexing will make it ridiculously friendly for a GPU (Edit: Nope, it chokes on `partialsortperm!`).

I'm guessing the manual loopy approach should be just as good but I battled hard to get it somewhat competitive here in 6 lines of code

#@be my_cluster!(X2, q1, 5) Benchmark: 3 samples with 1 evaluation 42.883 ms (17 allocs: 15.259 MiB) 45.711 ms (17 allocs: 15.259 MiB) 46.670 ms (17 allocs: 15.259 MiB)

#@be k_closest(X1, q1, 5) Benchmark: 4 samples with 1 evaluation 27.994 ms (2 allocs: 176 bytes) 28.733 ms (2 allocs: 176 bytes) 29.000 ms (2 allocs: 176 bytes) 30.709 ms (2 allocs: 176 bytes)

I also didn't try using `FixedSizedArrays.jl` as Mose Giordano recommended in my livestream chat.


Why not use the built in BitVector type that has specialized code for things like xor?

https://docs.julialang.org/en/v1/base/arrays/#Base.BitArray


it doesn't seem to have better support for things like xor and count_ones. I believe the main use case is comparisons.


Combined with broadcasting its fast.

Under the hood it’s doing the same thing with a vector of ints (64 bits for bitvectors) and all the bulk manipulation is handled that way so SIMD in inherent as well. Worth a shot.


    for i in 0:7
        c += (r >> i) & 1
    end
This is just popcnt, surely Julia has a built in for that.


There is, it's called count_ones. Though I wouldn't be surprised if LLVM could maybe optimize some of these loops into a popcnt, but I'm sure it would be brittle


author here. I thought there might be a machine instruction for this but wasn't sure, I also didn't know Julia had a count_ones that counted the 1s.

Thanks! With this the timings are even faster. I'll update the post.


julia> @code_typed hamming_distance(Int8(33), Int8(125)) CodeInfo( 1 ─ %1 = Base.xor_int(x1, x2)::Int8 │ %2 = Base.ctpop_int(%1)::Int8 │ %3 = Base.sext_int(Int64, %2)::Int64 │ nothing::Nothing └── return %3 ) => Int64

julia> @code_llvm hamming_distance(Int8(33), Int8(125)) ; Function Signature: hamming_distance(Int8, Int8) ; @ /Users/lunaticd/code/tiny-binary-rag/rag.jl:13 within `hamming_distance` define i64 @julia_hamming_distance_16366(i8 signext %"x1::Int8", i8 signext %"x2::Int8") #0 { top: ; @ /Users/lunaticd/code/tiny-binary-rag/rag.jl:14 within `hamming_distance` ; ┌ @ int.jl:373 within `xor` %0 = xor i8 %"x2::Int8", %"x1::Int8" ; └ ; ┌ @ int.jl:415 within `count_ones` %1 = call i8 @llvm.ctpop.i8(i8 %0) ; │┌ @ int.jl:549 within `rem` %2 = zext i8 %1 to i64 ; └└ ret i64 %2 }

it lowers to the machine instruction now.

I also tried 8 Int64s vs 64 Int8s and it doesn't seem to make a difference when doing the search.

EDIT: apologize for the formatting


I think you may need to update the figures in the rest of the article. At some point you mention it should take around 128ns but with the new benchmark that's probably closer to 64*1.25=80ns.


I had Opus translate your code to Rust

    fn hamming_distance_u8(x1: u8, x2: u8) -> usize {
        (x1 ^ x2).count_ones() as usize
    }


From what I've heard it's actually faster to create a 256 byte lookup table than to use popcnt.


It used to be pretty bad on old intel processors but nowadays it should be faster than an L1 fetch.


For those like me who are not familiar with the field... The article assumes you know the entire context - as far as I could see there is no explanation of any part except the technical details.

RAG = Retrieval-Augmented Generation

The field is machine learning. Retrieval = get relevant documents. Generation = create answer for user (based on the docs).


Nice speed up!

Have you tried to benchmark also this heap implementation?

https://juliacollections.github.io/DataStructures.jl/latest/...


i did it does a ton of allocations, which is why i made a simple maxheap implementation and then sort the final result


I'm not sure what is meant by "exact" here - do they describe their binarisation process at all? This seems more like an XOR benchmark than a rag benchmark, no mention of recall or other relevant performance metrics


Some (not all) of your questions may be answered by the linked article near the top of the submitted article, which goes into more detail about how much is lost quantizing to 1 bit (and 1 byte): https://huggingface.co/blog/embedding-quantization


exact in this case means that all the vectors are compared against the query vector. Where as other search methods such as HNSW are approximate searches.


I used ```hamming_bitwise(x::Union{UInt, UInt128, BigInt}, y::Union{UInt, UInt128, BigInt}) = count_ones(x ⊻ y)``` to get a fast hamming distance with "binary vectors" encoded as ints.


cool to binary vector search -- I made something similar with numpy in python: https://github.com/sdan/vlite




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: