@@ -56,62 +56,52 @@ def test_one_block_mask(self):
5656 bidirectional_mask = np .asarray ([[0 , 1 , 1 , 1 , 0 , 0 ]])
5757 # pylint: disable=protected-access
5858 block_mask = _make_bidirectional_block_mask (bidirectional_mask )
59- expected_mask = np .asarray (
60- [
61- [
62- [False , False , False , False , False , False ],
63- [False , True , True , True , False , False ],
64- [False , True , True , True , False , False ],
65- [False , True , True , True , False , False ],
66- [False , False , False , False , False , False ],
67- [False , False , False , False , False , False ],
68- ]
69- ]
70- )
59+ expected_mask = np .asarray ([[
60+ [False , False , False , False , False , False ],
61+ [False , True , True , True , False , False ],
62+ [False , True , True , True , False , False ],
63+ [False , True , True , True , False , False ],
64+ [False , False , False , False , False , False ],
65+ [False , False , False , False , False , False ],
66+ ]])
7167 np .testing .assert_array_equal (block_mask , expected_mask )
7268
7369 def test_two_blocks_mask (self ):
7470 bidirectional_mask = np .asarray ([[0 , 1 , 1 , 0 , 1 , 1 ]])
7571 # pylint: disable=protected-access
7672 block_mask = _make_bidirectional_block_mask (bidirectional_mask )
77- expected_mask = np .asarray (
78- [
79- [
80- [False , False , False , False , False , False ],
81- [False , True , True , False , False , False ],
82- [False , True , True , False , False , False ],
83- [False , False , False , False , False , False ],
84- [False , False , False , False , True , True ],
85- [False , False , False , False , True , True ],
86- ]
87- ]
88- )
73+ expected_mask = np .asarray ([[
74+ [False , False , False , False , False , False ],
75+ [False , True , True , False , False , False ],
76+ [False , True , True , False , False , False ],
77+ [False , False , False , False , False , False ],
78+ [False , False , False , False , True , True ],
79+ [False , False , False , False , True , True ],
80+ ]])
8981 np .testing .assert_array_equal (block_mask , expected_mask )
9082
9183 def test_batch_block_masks (self ):
9284 bidirectional_mask = np .asarray ([[0 , 1 , 1 , 1 , 0 , 0 ], [0 , 1 , 1 , 0 , 1 , 1 ]])
9385 # pylint: disable=protected-access
9486 block_mask = _make_bidirectional_block_mask (bidirectional_mask )
95- expected_mask = np .asarray (
87+ expected_mask = np .asarray ([
9688 [
97- [
98- [False , False , False , False , False , False ],
99- [False , True , True , True , False , False ],
100- [False , True , True , True , False , False ],
101- [False , True , True , True , False , False ],
102- [False , False , False , False , False , False ],
103- [False , False , False , False , False , False ],
104- ],
105- [
106- [False , False , False , False , False , False ],
107- [False , True , True , False , False , False ],
108- [False , True , True , False , False , False ],
109- [False , False , False , False , False , False ],
110- [False , False , False , False , True , True ],
111- [False , False , False , False , True , True ],
112- ],
113- ]
114- )
89+ [False , False , False , False , False , False ],
90+ [False , True , True , True , False , False ],
91+ [False , True , True , True , False , False ],
92+ [False , True , True , True , False , False ],
93+ [False , False , False , False , False , False ],
94+ [False , False , False , False , False , False ],
95+ ],
96+ [
97+ [False , False , False , False , False , False ],
98+ [False , True , True , False , False , False ],
99+ [False , True , True , False , False , False ],
100+ [False , False , False , False , False , False ],
101+ [False , False , False , False , True , True ],
102+ [False , False , False , False , True , True ],
103+ ],
104+ ])
115105 np .testing .assert_array_equal (block_mask , expected_mask )
116106
117107 def test_empty_block_mask (self ):
@@ -141,34 +131,24 @@ def test_combine_with_causal_mask(self):
141131 # pylint: disable=protected-access
142132 image_mask = _make_bidirectional_block_mask (bidirectional_mask )
143133 combined_mask = causal_mask | image_mask [:, None , None , ...]
144- expected_mask = np .asarray (
145- [
146- [
147- [
148- [
149- [True , False , False , False , False , False ],
150- [True , True , True , True , False , False ],
151- [True , True , True , True , False , False ],
152- [True , True , True , True , False , False ],
153- [True , True , True , True , True , False ],
154- [True , True , True , True , True , True ],
155- ]
156- ]
157- ],
158- [
159- [
160- [
161- [True , False , False , False , False , False ],
162- [True , True , True , False , False , False ],
163- [True , True , True , False , False , False ],
164- [True , True , True , True , False , False ],
165- [True , True , True , True , True , True ],
166- [True , True , True , True , True , True ],
167- ]
168- ]
169- ],
170- ]
171- )
134+ expected_mask = np .asarray ([
135+ [[[
136+ [True , False , False , False , False , False ],
137+ [True , True , True , True , False , False ],
138+ [True , True , True , True , False , False ],
139+ [True , True , True , True , False , False ],
140+ [True , True , True , True , True , False ],
141+ [True , True , True , True , True , True ],
142+ ]]],
143+ [[[
144+ [True , False , False , False , False , False ],
145+ [True , True , True , False , False , False ],
146+ [True , True , True , False , False , False ],
147+ [True , True , True , True , False , False ],
148+ [True , True , True , True , True , True ],
149+ [True , True , True , True , True , True ],
150+ ]]],
151+ ])
172152 np .testing .assert_array_equal (combined_mask , expected_mask )
173153
174154
0 commit comments