Skip to content

Conversation

@meetdoshi90
Copy link

@meetdoshi90 meetdoshi90 commented Nov 28, 2025

This PR proposes a fix for #180 by replacing the existing torch.quantile implementation with one based on torch.kthvalue. Because torch.quantile limits tensors to 16 million elements, it causes problems when computing codebooks for more than 50,000 vectors of dimension greater than 335.

The new implementation does not have that limitation. For linear interpolation, it is 2x slower on GPU and 10x slower on CPU, but it has been verified to produce correct results. Any suggestions to speed up the CPU implementation are welcome.

A corresponding implementation in Rust would require more effort since it would mean replacing the underlying ATen kernel for quantile.

Runtime comparison

Params
Vector dims: (50000, 768) and (50000, 384)
Quantiles: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
Reduction dim: 0

Tensor dim = (50000, 768)

Interpolation Torch (CPU) This implementation (CPU) Torch (GPU) This implementation (GPU)
linear 0.4742 s 4.9824 s 0.0908 s 0.2105 s
lower 0.4488 s 2.4526 s 0.0090 s 0.0829 s
higher 0.4555 s 2.4452 s 0.0090 s 0.0835 s
midpoint 0.4431 s 4.8587 s 0.0092 s 0.1697 s
nearest 0.4648 s 2.4603 s 0.0090 s 0.0831 s

Tensor dim = (50000, 384)

Interpolation Torch (CPU) This implementation (CPU) Torch (GPU) This implementation (GPU)
linear 0.2071 s 3.0143 s 0.0887 s 0.1025 s
lower 0.2115 s 1.5261 s 0.0052 s 0.0273 s
higher 0.2098 s 1.5005 s 0.0051 s 0.0271 s
midpoint 0.2095 s 3.0220 s 0.0054 s 0.0600 s
nearest 0.2073 s 1.4914 s 0.0051 s 0.0272 s

@NohTow
Copy link
Collaborator

NohTow commented Dec 4, 2025

God sorry for the delay, totally forgot to answer
Thanks for digging this and making a thorough bench!!

To be honest, I find the slowdown quite massive so I do not know what to think about this...
Maybe we could have the two possibility, and only goes this route if we are above a threshold?
@raphaelsty wdyt?

@meetdoshi90
Copy link
Author

Yeah I agree, the slowdown is massive. I'm not sure if a similar performance degradation happens with the rust implementation since it also uses the same kth value function. Adding a threshold would be a pretty easy branch to implement, but it wouldn't be clean.

@raphaelsty
Copy link
Collaborator

raphaelsty commented Dec 4, 2025

Already fixed in fast-plaid, I went for this alternative and linear-interpolation solely, even if it's a bit slower it's not that impactful on the overall time to search as it is called only to build and load the index

pytorch/pytorch#157431 (comment)

/// Computes a single quantile for a 1D tensor using `kthvalue`.
///
/// This function calculates the value below which a given percentage of data falls.
/// It employs linear interpolation between the two nearest ranks if the target
/// index is not an integer.
///
/// # Arguments
///
/// * `tensor` - The input 1D tensor.
/// * `q` - The quantile to compute (between 0.0 and 1.0).
///
/// # Returns
///
/// A scalar `Tensor` containing the computed quantile value.
fn scalar_quantile_kthvalue(tensor: &Tensor, q: f64) -> Tensor {
    let n = tensor.size()[0];

    // 1. Calculate target float index
    let idx_float = q * (n - 1) as f64;
    let lower_idx = idx_float.floor() as i64;
    let upper_idx = idx_float.ceil() as i64;

    // Optimization: If the index is exactly an integer, we only need one lookup.
    if lower_idx == upper_idx {
        return tensor.kthvalue(lower_idx + 1, 0, true).0;
    }

    // 2. Retrieve bounds
    let (lower_val, _) = tensor.kthvalue(lower_idx + 1, 0, true);
    let (upper_val, _) = tensor.kthvalue(upper_idx + 1, 0, true);

    // 3. Linear Interpolation (Lerp)
    let weight = idx_float - lower_idx as f64;
    lower_val.lerp(&upper_val, weight)
}

@NohTow
Copy link
Collaborator

NohTow commented Dec 5, 2025

So given that:

  1. The function is not called that often so I guess it is ok
  2. the OG plaid index is meant to be deprecated at some point

I do not think it is that big of a deal. That being said, I believe @raphaelsty might have some suggestion to speed things up a bit (?).
Also, I wonder although it bloat a bit the code, maybe we can still check the threshold and only goes this way if the dim * sampled elem is bigger than x no? I agree that this is ugly, but PLAID code is bloated already anyways and this keeps old perf the same. Your call @raphaelsty

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants