Hacker News new | past | comments | ask | show | jobs | submit login
Show HN: Want something better than k-means? Try BanditPAM (github.com/motiwari)
281 points by motiwari on April 5, 2023 | hide | past | favorite | 41 comments
Want something better than k-means? I'm happy to announce our SOTA k-medoids algorithm from NeurIPS 2020, BanditPAM, is now publicly available! `pip install banditpam` or `install.packages("banditpam")` and you're good to go!

k-means is one of the most widely-used algorithms to cluster data. However, it has several limitations: a) it requires the use of L2 distance for efficient clustering, which also b) restricts the data you're clustering to be vectors, and c) doesn't require the means to be datapoints in the dataset.

Unlike in k-means, the k-medoids problem requires cluster centers to be actual datapoints, which permits greater interpretability of your cluster centers. k-medoids also works better with arbitrary distance metrics, so your clustering can be more robust to outliers if you're using metrics like L1. Despite these advantages, most people don't use k-medoids because prior algorithms were too slow.

In our NeurIPS 2020 paper, BanditPAM, we sped up the best known algorithm from O(n^2) to O(nlogn) by using techniques from multi-armed bandits. We were inspired by prior research that demonstrated many algorithms can be sped up by sampling the data intelligently, instead of performing exhaustive computations.

We've released our implementation, which is pip- and CRAN-installable. It's written in C++ for speed, but callable from Python and R. It also supports parallelization and intelligent caching at no extra complexity to end users. Its interface also matches the sklearn.cluster.KMeans interface, so minimal changes are necessary to existing code.

PyPI: https://pypi.org/project/banditpam

CRAN: https://cran.r-project.org/web/packages/banditpam/index.html

Repo: https://github.com/motiwari/BanditPAM

Paper: https://arxiv.org/abs/2006.06856

If you find our work valuable, please consider starring the repo or citing our work. These help us continue development on this project.

I'm Mo Tiwari (motiwari.com), a PhD student in Computer Science at Stanford University. A special thanks to my collaborators on this project, Martin Jinye Zhang, James Mayclin, Sebastian Thrun, Chris Piech, and Ilan Shomorony, as well as the author of the R package, Balasubramanian Narasimhan.

(This is my first time posting on HN; I've read the FAQ before posting, but please let me know if I broke any rules)




Awesome work and it's a great HN contribution! I loved your blog post, it gave me a gist of everything I wanted to know without reading the paper (I have back ground in bandits).

I have a couple of non-technical questions:

1. Can you share some background on how this work developed? I am guessing there were many attempts to improve PAM over the last three decades, right? And in hindsight, bandit-based approach seems like a natural approach to try, right? Did you start with trying to improve PAM and realize no one else thought of a probabilistic/random approach?

2. Once you realized multi-armed bandit approach is the way to go, did implementation of the idea and empirical evaluation take a lot of time? I am guessing most of the effort went to providing complexity guarantees, right?

3. The paper has an interesting set of authors from diverse areas - areas in which the k-medoid problems seems highly relevant. This was partly the reason why I asked question 1. - was the project motivated by the need of such an algorithm in application areas or what is by looking for an area to apply the insight that bandit based approaches can actually perform better.

Overall, I really like the life-cycle of the entire paper. It started with a highly relevant and practical problem, gave an intuitive algorithm that comes with complexity bounds, has an accessible blog post to support the paper, and has what seems to be a very efficient implementation that can directly be used in production at scale. A lot of researchers miss the last part and move on to the next project (I am guilty of that) - kudos to you for spending time on the implementation! If you ever end up at UIUC, I'd love to buy you a coffee (:

PS: I am a grad student at UIUC and was scrolling by and stopped as I saw two familiar names: Ilan (took Random processes with him and loved it) and of course who in robotics wouldn't know Prof. Thrun (for those who don't, his Probabilistic Robotics is a mandatory reference in every robotics class).


Thank you for the positive feedback! Will definitely take you up on that coffee next time I visit Ilan :)

To answer your questions:

1. When looking at k-medoids algorithms, we realized that PAM hadn't been improved since the 80s! Other, faster algorithms had been developed but sacrificed solution quality (clustering loss) for speed. Simultaneously, prior work from my coauthors had recognized that the 1-medoid problem could be solved via randomized algorithms/multi-armed bandit -- much faster but returning the same solution. Our key insight was that every stage of PAM could be recast as a multi-armed bandit problem, and that reusing information across different stages would result in further speedups.

2. Actually, the complexity guarantees/theory were pretty easy because we were able to use proof techniques that are common in the multi-armed bandit literature. The hardest part was implementing it in C++ and making it available to both Python and R via bindings. For the original paper we did everything in Python and measured sample complexity, but to make the algorithm valuable for users we had to implement it in a more performant language.

3. To be honest, this project came about from a chance meeting at a conference (ICML 2019). I was randomly introduced to Martin Zhang (ironically I met him first at the conference even though we were both at Stanford). Martin and Ilan had deep expertise in multi-armed bandits/randomized algorithms (and solved the 1-medoid problem), and I got really interested in their work when talking with them, primarily because it seemed like a straightforward win and useful for a lot of people.


Simple explaination of the general idea using nearest neighbor search as example:

1. You want to calculate the l2-nearest neighbor to y from a set of x1, x2, ... xn. Where each point is a d-dimensional vector.

2. Naive approach will take O(nd) computations. For very large d, say 1e4, this is very expensive

Multi-arm bandits idea:

3. Estimate the distance of y to each point (xi) by sampling only along log(d) dimensions.

4. Keep the nearest n/2 points (throw away the rest)

5. Repeat step 3 and 4 until you have one point left

6. The total run time = nlog(d) + nlog(d)/2 + nlog(d)/4.... = O (nlog(d)log(n))

This is a huge improvement over O(nd) for large d (which is usually the case today)

This paper (https://ar5iv.org/abs/1805.08321) from Stanford pioneered this idea for nearest neighbors and many such household problems. It also theoretically proves that the above approach gives the right answer and empirically gets huge (100x) speed boost on large datasets.

The k-medoid paper is based off of this work :)


And the biggest caveats IMO:

- Like any other probabilistic clustering algorithm that gets is speed boost from just ignoring large chunks of the data (like MiniBatch KMeans), this will not take outliers into account very well or very often.

- They're advertising this as not just a better KMedoids implementation, but a better drop-in replacement for KMeans. For generic clustering tasks, sure, fine, whatever, but the metric the new paper is trying to optimize is different from the one KMeans uses, so if KMeans has the right metric for a given task then you'll be switching to a (maybe) faster algorithm that just computes the wrong result. The easiest example that comes to mind is a dataset of non-intersecting hollow spheres. KMeans will spit out (for some choice of NClusters) the sphere centers, and KMedoids will spit out sphere boundary points, decreasing performance on the far side of the sphere and potentially allowing classification jumping from one sphere to another.

Both of those things are just qualities you may or may not want, so KMedoids may be better purely because it has those biases and KMeans doesn't, but it's not totally uncommon to just want cluster centers minimizing some error and not care how you get there or how explainable they are (the bolt vector quantization algorithm comes to mind), where KMedoids would just be the wrong choice.


I understand what you say, and I don't disagree with your analysis for the collection of datapoints on nonintersecting spherical surfaces. (Although I do frown at this type of dataset, since distribution densities tend to peak at their peaks, the "in high enough dimensions everything lies on a sphere" is a misconception: it is only the radial histogram that peaks at a nonzero radius...).

Just thought that it might be worth pointing out that there are valid use cases of clustering with real data reference for each class available: consider pictures, to be inspectable it would be helpful to see a real picture instead of some blurry interpolated mess.

Would you mind providing a reference for the bolt vector quantization algorithm? It sounds interesting.


> frown on that sort of dataset

That example was definitely contrived and designed to strongly illustrate the point. I'll counter slightly that non-peaky topologies aren't uncommon, but they're unlikely to look anything that would push KMedoids to a pathological state rather than just a slightly worse state ("worse" assuming that KMeans is the right choice for a given problem).

> worth pointing out .. data reference

Totally agreed. I hope my answer didn't come across as too negative. It's good work, and everyone else was talking about the positives, so I just didn't want to waste too much time echoing again that while getting the other points across.

> bolt reference

https://github.com/dblalock/bolt

They say as much in their paper, but they aren't the first vector quantization library by any stretch. Their contributions are, roughly:

1. If you're careful selecting the right binning strategy then you can cancel out a meaningful amount of discretization error.

2. If you do that, you can afford to choose parameters that fit everything nicely into AVX2 machine words, turning 100s of branching instructions into 1-4 instructions.

3. Doing some real-world tests to show that (1-2) matter.

Last I checked their code wasn't very effective for the places I wanted to apply it, but the paper is pretty solid. I'd replace it with a faster KMeans approximation less likely to crash on big data (maybe even initializing with KMedoids :) ), and if the thing you're quantizing is trainable with some sort of gradient update step then you should do a few optimization passes in the discretized form as well.


We talk exactly about clustering pictures in our blog post! https://ai.stanford.edu/blog/banditpam/


Right! Though to clarify a few nits: we use successive elimination instead of successive halving. And we talk about the Maximum Inner Product Search problem (very similar to NN problem) in our followup work: https://ar5iv.org/abs/2212.07551


Hi Mo, thanks for this work. It seems interesting.

I had the chance to play a little bit and wanted to compare that with KMeans. I relied on sklearn KMeans implementation.

Furthermore, I did some examples (mostly what is available). But One interesting thing I did is I generated some isotropic Gaussian blobs for clustering (using `make_blobs`) and then tried a comparison between the two methods. Bandit PAM was a little bit better for a couple of metrics I used, but also much faster. I was generating `n_samples=1000` but then I increased it to `n_samples=10000` and I found that it is much slower than KMeans, see [1] and code is in [2]. Is there a particular reason for that?

[1] https://imgur.com/a/VibpgNz

[2] https://paste.elashri.xyz/aXCE


Thanks for bug report and repro steps! I've filed this issue at https://github.com/motiwari/BanditPAM/issues/244 on our repo.

I suspect that this is because the scikit-learn implementation of KMeans subsamples the data and uses some highly-optimized data structures for larger datasets. I've asked the team to see how we can use some of those techniques in BanditPAM and will update the Github repo as we learn more and improve our implementation.


Thanks for doing this, I will be sure to try it in ML competitions. I really like that you used Armadillo, which is something I personally want to do for my own projects.

Just out of technical curiosity, have you come across any particular developments or empirical evidence on the use of (invertible) data transformations to enhance clustering results? I am currently researching a particular problem within signal processing related to signal distribution transforms and I am particularly interested in reading about potential applications. As an example, and since you mention JPM partial funding, how would copula transformation affect the results of clustering (assuming an inverse exists etc. and we apply the inverse transformation afterwards)?


One thing that's important to note is that k-medoids supports arbitrary distance metrics -- in fact, your dissimilarity measure need not even be a metric (it can be negative, asymmetric, not satisfy the triangle inequality, etc.)

An implication of this is that if you were to do some invertible data transformation and then perform clustering, that's equivalent to doing clustering with a different dissimilarity measure (without the data transformation in the first place). It should be possible to avoid doing the invertible data transformation in the first place if you're willing to engineer your dissimilarity measure.

Without more details, it's hard to say exactly what would happen to the clustering results under custom dissimilarity measures or data transformations -- but our package supports both use cases!


This is one of the best-written Show HNs in many months. If you're planning to launch a project in future, use this as your template.


It’s a skill hammered down in your head if you do a good PhD for sure!


Hey, thanks! Shout out to Daniel @ HN who gave me a lot of great feedback on how to make this post better.


Really interesting and have not heard of k-medoids before!

Have you tried BanditPAM as an index creation technique for approximate-nearest-neighbor search?


Really funny that you mention that! Some of our more recent work focuses on using adaptive sampling techniques in approximate-nearest-neighbor search (actually, the related problem of maximum inner product search: https://ar5iv.org/abs/2212.07551).

We definitely think that our approach could be used to make an index structure for ANN search directly, for example in conjunction with Hierarchical Navigable Small World approaches.


Thanks for sharing your great work!

[Try this tool from arxiv labs (just replace the x in original url with a 5)]

https://ar5iv.labs.arxiv.org/html/2006.06856

https://ar5iv.labs.arxiv.org/html/2212.07551


Oh cool, neat trick, thanks!


Is there a good heuristic for picking a reasonable number of clusters automatically for an arbitrary set of vectors?


I usually go with Davies-Bouldin index but there are a few methods:

Python/Sklearn: https://scikit-learn.org/stable/modules/clustering.html#clus...

R: https://cran.r-hub.io/web/packages/clusterCrit/clusterCrit.p...


You could always try to use a density based clusterer like DBSCAN, HDBSCAN or OPTICS to determine a likely number of clusters.


The elbow method is pretty common! https://en.wikipedia.org/wiki/Elbow_method_(clustering)

You can also use some regularization criterion (AIC, BIC, or other)


Your algorithm looks surprisingly similar to "Affinity Propagation" which uses message passing techniques to (approximately) optimize the binary integer programming problem. Message passing algorithms have always fascinated me as they seem to be related to deep structure in the original linear programming problem.

For example, there are results for other binary problems that show a relationship between fixed points of message passing and optimal dual points to the relaxed linear programming problem (see below for an example with maximum weighted independent sets).

Back in the day I spent a long time trying to directly relate the affinity propagation messages to a coordinate-descent type of algorithm on the dual for k-medoids but despite the similarity in structure I could never make it work.

I'm curious if you're familiar with this class of algorithms and how they compare (both practically and theoretically) to the work you've presented here? Thanks for sharing!

References: - https://www.science.org/doi/10.1126/science.1136800 - https://arxiv.org/abs/0807.5091


A 2 author science paper in 2007 for an algorithm. The authors ought to be proud indeed, this is close to the holy grail wet dream of mine (a single author original nature science paper; the only one I’ve seen in recent decades is the Nature paper on the purported active ingredient of royal jelly: royalactin https://www.nature.com/articles/nature10093 )


Interesting, thanks for the references! I'm not too familiar with this line of work; let me read up on it and get back to you


If you end up spending any time on this and finding anything interesting I'd love to know! My email address is in my profile, feel free to hit me up anytime; also happy to share any of my original notes, I'm pretty sure I have them TeXed up somewhere accessible.


Hey, great work. Do you think this algorithm would be amenable to be done online? I'm the author of River (https://riverml.xyz) where we're looking for good online clustering algorithms.


Definitely possible, but it would require some extensions to the algorithm. More specifically, as new datapoints enter the stream, they could be compared with the existing medoids to see if swapping them would lower the clustering loss.

This would be a nontrivial engineering effort and I likely won't be able to do it myself (I'm a PhD student about to graduate), but if you or your team is interested in adapting BanditPAM to the streaming setting, please feel free to reach out! My email's motiwari@stanford.edu


>3) doesn't require the means to be datapoints in the dataset.

I actually thought this was a k means strength :)


I think the real strength of this method is that it doesn't require the data live in a vector space. Once you give up that structure, you're pretty much locked in to using points from the dataset as the cluster representatives, unless you assume some other structure.


A strength if you're strictly looking to minimize squared L2 loss from each point to its closest mean -- but for a lot of other applications, it's a weakness! As the other poster mentioned, with KMedoids you can use arbitrary loss functions and cluster exotic objects (not restricted to metrics on a vector space)


How does it fare on this [0] sort of benchmark?

[0] https://cdn-images-1.medium.com/v2/resize:fit:1600/1*yMQItRO...


Where is this benchmark from? We'd be happy to run BanditPAM on these datasets and report the results



> This is my first time posting on HN

Except for this time, I guess: https://news.ycombinator.com/item?id=35362384


Oh yes. But we invited them to repost it and put it in the second-chance pool (https://news.ycombinator.com/pool, explained at https://news.ycombinator.com/item?id=26998308), so it got a random placement on HN's front page.


You got me! As @dang mentioned, once I got feedback from him I was allowed to repost and enter the second-chance pool

(My real first post was submitted too hastily without receiving @dang's feedback)


Cool. Could this/similar sampling be used to initialize a Gaussian Mixture Model and speed that up too?


Hmmmm.... not sure exactly what you mean. I believe the setting you're describing is: we have a dataset we're trying to fit with a GMM, we know the number of components k, and we're trying to determine the parameters of the GMM, correct?

I suppose that you could adaptively sample points from the dataset to update your parameters of the GMM, and sample more points for parameters of the GMM that you're less certain about.

(To understand how the parameter estimates would converge to their true values, you'd likely need to use the delta method; see Appendix 3 in https://ar5iv.org/pdf/2212.07473.pdf for an example)

Is that what you had in mind?


>c) doesn't require the means to be datapoints in the dataset.

I actually thought this was a pro :)




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: