1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/examples/models/llama3_2_vision/cross_attention/cross_attention_mask.h>
10
11 #include <gtest/gtest.h>
12
13 using namespace ::testing;
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using exec_aten::TensorImpl;
17
TEST(CrossAttentxnMaskTest,TestCrossAttentionMask)18 TEST(CrossAttentxnMaskTest, TestCrossAttentionMask) {
19 std::vector<int> tokens = {
20 1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415};
21
22 // Initialize image tensors.
23 TensorImpl::SizesType sizes[2] = {2, 2};
24 TensorImpl::DimOrderType dim_order[2] = {0, 1};
25 TensorImpl::StridesType strides[2] = {2, 1};
26
27 int32_t a_data[4] = {1, 2, 3, 4};
28 auto a_impl =
29 TensorImpl(ScalarType::Int, 2, sizes, a_data, dim_order, strides);
30 Tensor a(&a_impl);
31
32 int32_t b_data[4] = {5, 6, 7, 8};
33 auto b_impl =
34 TensorImpl(ScalarType::Int, 2, sizes, b_data, dim_order, strides);
35 Tensor b(&b_impl);
36
37 int32_t c_data[4] = {9, 10, 11, 12};
38 auto c_impl =
39 TensorImpl(ScalarType::Int, 2, sizes, c_data, dim_order, strides);
40 Tensor c(&c_impl);
41
42 std::vector<Tensor> images = {a, b, c};
43 std::vector<std::vector<int>> mask_data;
44 auto output_masks = example::cross_attention_mask(
45 tokens,
46 images,
47 /*tile_size=*/1,
48 /*patch_size=*/1,
49 /*image_token_id=*/1,
50 /*out=*/mask_data);
51
52 // Check contents of the mask.
53 std::vector<std::vector<size_t>> expected_intervals = {
54 {0, 7}, {1, 7}, {7, 12}};
55 for (size_t mask_idx = 0; mask_idx < output_masks.size(); ++mask_idx) {
56 auto& output_tensor = output_masks[mask_idx];
57 for (size_t i = 0; i < output_tensor->size(0); ++i) {
58 for (size_t j = 0; j < output_tensor->strides()[0]; ++j) {
59 size_t unrolled_index = i * output_tensor->strides()[0] + j;
60 if (i >= expected_intervals[mask_idx][0] &&
61 i < expected_intervals[mask_idx][1]) {
62 EXPECT_EQ(output_tensor->const_data_ptr<int>()[unrolled_index], 1);
63 } else {
64 EXPECT_EQ(output_tensor->const_data_ptr<int>()[unrolled_index], 0);
65 }
66 }
67 }
68 }
69 }
70