Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bf16 complex dot product for NEON #163

Open
ashvardanian opened this issue Sep 6, 2024 · 9 comments
Open

Add bf16 complex dot product for NEON #163

ashvardanian opened this issue Sep 6, 2024 · 9 comments
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed

Comments

@ashvardanian
Copy link
Owner

The vbfmlaltq_f32 and vbfmlalbq_f32 already have the benefit of skipping odd/even entries.

@ashvardanian ashvardanian added enhancement New feature or request help wanted Extra attention is needed good first issue Good for newcomers labels Sep 7, 2024
@ashvardanian ashvardanian changed the title bf16 complex dot product for NEON Add bf16 complex dot product for NEON Sep 7, 2024
@MarkReedZ
Copy link
Contributor

MarkReedZ commented Sep 9, 2024

The complex dot product exists for neon, but we're converting to f32 and want to operate on the bf16 inputs. The complex vector is real, imag, real, imag,,,

Original

//ab_real += ar * br - ai * bi;                                                                              \
//ab_imag += ar * bi + ai * br;   
ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec);
ab_real_vec = vfmsq_f32(ab_real_vec, a_imag_vec, b_imag_vec);
ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec);
ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec);    

New looks like this perhaps (altq is an fma of odd entries while albq is even)

    ab_real_vec = vbfmlaltq_f32(ab_real_vec, a_vec, b_vec);
    ab_real_vec = vbfmlalbq_f32(ab_real_vec, vnegq_f16(a_vec), b_vec);  // ar*br + (-ai*bi)
    
    ab_imag_vec = vbfmlaltq_f32(ab_imag_vec, a_vec, vrev32q_bf16(b_vec));   // vrev32q swaps imag and real
    ab_imag_vec = vbfmlalbq_f32(ab_imag_vec, a_vec, vrev32q_bf16(b_vec));  // ar * bi + ai * br;   

@ashvardanian
Copy link
Owner Author

Indeed, you are right! I suppose the new version must be a lot faster, right?

@MarkReedZ
Copy link
Contributor

10% faster. There are not bf16 versions of the neg and rev32 so we still have to jump through hoops. I confirmed that the new function's output matches the old and the tests pass. Will take a look to see if we can do this better before making a PR.

Assembly code: https://godbolt.org/z/4hzr9f943

        // ar*br + (-ai*bi)
        ab_real_vec = vbfmlaltq_f32(ab_real_vec, a_vec, b_vec);
        //ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_f16(vnegq_f16(vreinterpretq_f16_bf16(a_vec))), b_vec);  
        ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_u16(veorq_u16(vreinterpretq_u16_bf16(a_vec), vdupq_n_u16(0x8000))), b_vec);
    
        // vrev32q swaps imag and real
        // ar * bi + ai * br;   
        ab_imag_vec = vbfmlaltq_f32(ab_imag_vec, a_vec, vreinterpretq_bf16_u16(vrev32q_u16(vreinterpretq_u16_bf16(b_vec))));
        ab_imag_vec = vbfmlalbq_f32(ab_imag_vec, a_vec, vreinterpretq_bf16_u16(vrev32q_u16(vreinterpretq_u16_bf16(b_vec))));

@ashvardanian
Copy link
Owner Author

Interestingly, the godbolt.org snippet you've provided breaks Clang 18.1 if you add -O3. Without it the assembly contains a lot of noise, a bit hard to read. Still, @MarkReedZ, the source looks really good! Was hoping it would be at least 20% 😢

@MarkReedZ
Copy link
Contributor

MarkReedZ commented Sep 9, 2024

Good catch I was playing around with another compiler on there.

Ubuntu 24.04's clang 18.1 sees the same bug when building this code. Issue opened: llvm/llvm-project#107810

I'll try to move code around to avoid this later.

Code: MarkReedZ@09e89bb

@MarkReedZ
Copy link
Contributor

Clang is choking on the flipping of the sign bit. I haven't come up with an alternative to these two. No amount of moving code around fixes the clang bug if veorq and vnegq are used to flip the bit.

       //ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_f16(vnegq_f16(vreinterpretq_f16_bf16(a_vec))), b_vec);  
       ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_u16(veorq_u16(vreinterpretq_u16_bf16(a_vec), vdupq_n_u16(0x8000))), b_vec);
 

@ashvardanian
Copy link
Owner Author

Hi @MarkReedZ! Any chance you have an update in this?

@MarkReedZ
Copy link
Contributor

This is fixed in clang 19.1. I'm not sure what our approach to handling this should be as 18 will remain the default for some time. We could check the clang version number defines though apparently in some cases those may by overridden.

@ashvardanian
Copy link
Owner Author

ashvardanian commented Nov 8, 2024

@MarkReedZ, can you please submit a PR that works with 19, and I'll try a few more ideas around your prototype?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants