Skip to main content

miri/shims/x86/
avx512.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use super::{
7    packssdw, packsswb, packusdw, packuswb, permute, permute2, pmaddbw, pmaddwd, psadbw, pshufb,
8};
9use crate::*;
10
11impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
12pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
13    fn emulate_x86_avx512_intrinsic(
14        &mut self,
15        link_name: Symbol,
16        abi: &FnAbi<'tcx, Ty<'tcx>>,
17        args: &[OpTy<'tcx>],
18        dest: &MPlaceTy<'tcx>,
19    ) -> InterpResult<'tcx, EmulateItemResult> {
20        let this = self.eval_context_mut();
21        // Prefix should have already been checked.
22        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx512.").unwrap();
23
24        match unprefixed_name {
25            // Used by the ternarylogic functions.
26            "pternlog.d.128" | "pternlog.d.256" | "pternlog.d.512" => {
27                this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
28                if matches!(unprefixed_name, "pternlog.d.128" | "pternlog.d.256") {
29                    this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
30                }
31
32                let [a, b, c, imm8] =
33                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
34
35                assert_eq!(dest.layout, a.layout);
36                assert_eq!(dest.layout, b.layout);
37                assert_eq!(dest.layout, c.layout);
38
39                // The signatures of these operations are:
40                //
41                // ```
42                // fn vpternlogd(a: i32x16, b: i32x16, c: i32x16, imm8: i32) -> i32x16;
43                // fn vpternlogd256(a: i32x8, b: i32x8, c: i32x8, imm8: i32) -> i32x8;
44                // fn vpternlogd128(a: i32x4, b: i32x4, c: i32x4, imm8: i32) -> i32x4;
45                // ```
46                //
47                // The element type is always a 32-bit integer, the width varies.
48
49                let (a, _a_len) = this.project_to_simd(a)?;
50                let (b, _b_len) = this.project_to_simd(b)?;
51                let (c, _c_len) = this.project_to_simd(c)?;
52                let (dest, dest_len) = this.project_to_simd(dest)?;
53
54                // Compute one lane with ternary table.
55                let tern = |xa: u32, xb: u32, xc: u32, imm: u32| -> u32 {
56                    let mut out = 0u32;
57                    // At each bit position, select bit from imm8 at index = (a << 2) | (b << 1) | c
58                    for bit in 0..32 {
59                        let ia = (xa >> bit) & 1;
60                        let ib = (xb >> bit) & 1;
61                        let ic = (xc >> bit) & 1;
62                        let idx = (ia << 2) | (ib << 1) | ic;
63                        let v = (imm >> idx) & 1;
64                        out |= v << bit;
65                    }
66                    out
67                };
68
69                let imm8 = this.read_scalar(imm8)?.to_u32()? & 0xFF;
70                for i in 0..dest_len {
71                    let a_lane = this.project_index(&a, i)?;
72                    let b_lane = this.project_index(&b, i)?;
73                    let c_lane = this.project_index(&c, i)?;
74                    let d_lane = this.project_index(&dest, i)?;
75
76                    let va = this.read_scalar(&a_lane)?.to_u32()?;
77                    let vb = this.read_scalar(&b_lane)?.to_u32()?;
78                    let vc = this.read_scalar(&c_lane)?.to_u32()?;
79
80                    let r = tern(va, vb, vc, imm8);
81                    this.write_scalar(Scalar::from_u32(r), &d_lane)?;
82                }
83            }
84            // Used to implement the _mm512_sad_epu8 function.
85            "psad.bw.512" => {
86                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
87
88                let [left, right] =
89                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
90
91                psadbw(this, left, right, dest)?
92            }
93            // Used to implement the _mm512_madd_epi16 function.
94            "pmaddw.d.512" => {
95                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
96
97                let [left, right] =
98                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
99
100                pmaddwd(this, left, right, dest)?;
101            }
102            // Used to implement the _mm512_maddubs_epi16 function.
103            "pmaddubs.w.512" => {
104                let [left, right] =
105                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
106
107                pmaddbw(this, left, right, dest)?;
108            }
109            // Used to implement the _mm512_permutexvar_epi32/_mm512_permutexvar_epi64 functions.
110            "permvar.si.512" | "permvar.di.512" => {
111                let [left, right] =
112                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
113
114                permute(this, left, right, dest)?;
115            }
116            // Used to implement the _mm512_permutex2var_epi64 intrinsic.
117            "vpermi2var.q.512" => {
118                let [left, indices, right] =
119                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
120
121                permute2(this, left, indices, right, dest)?;
122            }
123            // Used to implement the _mm512_shuffle_epi8 intrinsic.
124            "pshuf.b.512" => {
125                let [left, right] =
126                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
127
128                pshufb(this, left, right, dest)?;
129            }
130
131            // Used to implement the _mm512_dpbusd_epi32 function.
132            "vpdpbusd.512" | "vpdpbusd.256" | "vpdpbusd.128" => {
133                this.expect_target_feature_for_intrinsic(link_name, "avx512vnni")?;
134                if matches!(unprefixed_name, "vpdpbusd.128" | "vpdpbusd.256") {
135                    this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
136                }
137
138                let [src, a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
139
140                vpdpbusd(this, src, a, b, dest)?;
141            }
142            // Used to implement the _mm512_packs_epi16 function
143            "packsswb.512" => {
144                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
145
146                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
147
148                packsswb(this, a, b, dest)?;
149            }
150            // Used to implement the _mm512_packus_epi16 function
151            "packuswb.512" => {
152                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
153
154                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
155
156                packuswb(this, a, b, dest)?;
157            }
158            // Used to implement the _mm512_packs_epi32 function
159            "packssdw.512" => {
160                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
161
162                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
163
164                packssdw(this, a, b, dest)?;
165            }
166            // Used to implement the _mm512_packus_epi32 function
167            "packusdw.512" => {
168                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
169
170                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
171
172                packusdw(this, a, b, dest)?;
173            }
174            _ => return interp_ok(EmulateItemResult::NotSupported),
175        }
176        interp_ok(EmulateItemResult::NeedsReturn)
177    }
178}
179
180/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in `a` with corresponding signed
181/// 8-bit integers in `b`, producing 4 intermediate signed 16-bit results. Sum these 4 results with
182/// the corresponding 32-bit integer in `src` (using wrapping arighmetic), and store the packed
183/// 32-bit results in `dst`.
184///
185/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_dpbusd_epi32>
186/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_dpbusd_epi32>
187/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_dpbusd_epi32>
188fn vpdpbusd<'tcx>(
189    ecx: &mut crate::MiriInterpCx<'tcx>,
190    src: &OpTy<'tcx>,
191    a: &OpTy<'tcx>,
192    b: &OpTy<'tcx>,
193    dest: &MPlaceTy<'tcx>,
194) -> InterpResult<'tcx, ()> {
195    let (src, src_len) = ecx.project_to_simd(src)?;
196    let (a, a_len) = ecx.project_to_simd(a)?;
197    let (b, b_len) = ecx.project_to_simd(b)?;
198    let (dest, dest_len) = ecx.project_to_simd(dest)?;
199
200    // fn vpdpbusd(src: i32x16, a: u8x64, b: i8x64) -> i32x16;
201    // fn vpdpbusd256(src: i32x8, a: u8x32, b: i8x32) -> i32x8;
202    // fn vpdpbusd128(src: i32x4, a: u8x16, b: i8x16) -> i32x4;
203    assert_eq!(src_len, dest_len);
204    assert_eq!(a_len, dest_len.strict_mul(4));
205    assert_eq!(b_len, a_len);
206
207    for i in 0..dest_len {
208        let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?;
209        let dest = ecx.project_index(&dest, i)?;
210
211        let mut intermediate_sum: i32 = 0;
212        for j in 0..4 {
213            let idx = i.strict_mul(4).strict_add(j);
214            let a = ecx.read_scalar(&ecx.project_index(&a, idx)?)?.to_u8()?;
215            let b = ecx.read_scalar(&ecx.project_index(&b, idx)?)?.to_i8()?;
216
217            let product = i32::from(a).strict_mul(i32::from(b));
218            intermediate_sum = intermediate_sum.strict_add(product);
219        }
220
221        // Use `wrapping_add` because `src` is an arbitrary i32 and the addition can overflow.
222        let res = Scalar::from_i32(intermediate_sum.wrapping_add(src));
223        ecx.write_scalar(res, &dest)?;
224    }
225
226    interp_ok(())
227}