@@ -192,110 +192,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192
193193 commit! (cmdbuf)
194194
195+ wait_completed (cmdbuf)
196+
195197 return B
196198end
197199
198200
201+ function LinearAlgebra.:(\ )(A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
202+ C = deepcopy (B)
203+ LinearAlgebra. ldiv! (A, C)
204+ return C
205+ end
206+
207+
199208function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
200- orig = size (B)
201- M,N = size (B)[1 ], ndims (B) > 1 ? size (B)[2 ] : 1
209+ M,N = size (B,1 ), size (B,2 )
202210 dev = current_device ()
203211 queue = global_queue (dev)
204212
205- B = reshape (B, (N,M))
213+ At = similar (A. factors)
214+ Bt = similar (B, (N,M))
206215 P = reshape ((A. ipiv .- UInt32 (1 )), (1 ,M))
207- X = similar (B)
216+ X = similar (B, (N,M) )
208217
209- mps_a = MPSMatrix (A. factors)
210- mps_b = MPSMatrix (B)
218+ transpose! (At, A. factors)
219+ transpose! (Bt, B)
220+
221+ mps_a = MPSMatrix (At)
222+ mps_b = MPSMatrix (Bt)
211223 mps_p = MPSMatrix (P)
212224 mps_x = MPSMatrix (X)
213225
214226 MTLCommandBuffer (queue) do cmdbuf
215- kernel = MPSMatrixSolveLU (dev, true , M, N)
227+ kernel = MPSMatrixSolveLU (dev, false , M, N)
216228 encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
217229 end
218230
219- B . = X
220- B = reshape (B, orig)
231+ transpose! (B, X)
232+ return B
221233end
222234
223- function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
224- M,N = size (B)
235+
236+ function LinearAlgebra. ldiv! (A:: UpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
237+ M,N = size (B,1 ), size (B,2 )
225238 dev = current_device ()
226239 queue = global_queue (dev)
227- cmdbuf = MTLCommandBuffer (queue)
228- enqueue! (cmdbuf)
229240
230- Bh = reshape (B, )
231- X = MtlMatrix {T} (undef, size (B))
241+ Ad = MtlMatrix (A' )
242+ Br = similar (B, (M,M))
243+ X = similar (Br)
232244
233- mps_a = MPSMatrix (A)
234- mps_b = MPSMatrix (Bh) # TODO reshape to matrix if B is a vector
245+ transpose! (Br, B)
246+
247+ mps_a = MPSMatrix (Ad)
248+ mps_b = MPSMatrix (Br)
235249 mps_x = MPSMatrix (X)
236250
237- solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
238- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
239- commit! (cmdbuf)
251+ buf = MTLCommandBuffer (queue) do cmdbuf
252+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , N, M, 1.0 )
253+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
254+ end
240255
241- return X
256+ wait_completed (buf)
257+
258+ copy! (B, X)
259+ return B
242260end
243261
244- function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
245- M,N = size (B)
262+
263+ function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
264+ M,N = size (B,1 ), size (B,2 )
246265 dev = current_device ()
247266 queue = global_queue (dev)
248- cmdbuf = MTLCommandBuffer (queue)
249- enqueue! (cmdbuf)
250267
251- X = MtlMatrix {T} (undef, size (B))
268+ Ad = MtlMatrix (A)
269+ Br = reshape (B, (M,N))
270+ X = similar (Br)
252271
253- mps_a = MPSMatrix (A )
254- mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
272+ mps_a = MPSMatrix (Ad )
273+ mps_b = MPSMatrix (Br)
255274 mps_x = MPSMatrix (X)
256275
257- solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
258- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
259- commit! (cmdbuf)
276+
277+ buf = MTLCommandBuffer (queue) do cmdbuf
278+ kernel = MPSMatrixSolveTriangular (dev, true , false , false , true , M, N, 1.0 )
279+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
280+ end
260281
261- return X
282+ wait_completed (buf)
283+
284+ copy! (Br, X)
285+ return B
262286end
263287
264- function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
265- M,N = size (B)
288+
289+ function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
290+ M,N = size (B,1 ), size (B,2 )
266291 dev = current_device ()
267292 queue = global_queue (dev)
268- cmdbuf = MTLCommandBuffer (queue)
269- enqueue! (cmdbuf)
270293
271- X = MtlMatrix {T} (undef, size (B))
294+ Ad = MtlMatrix (A)
295+ Br = reshape (B, (M,N))
296+ X = similar (Br)
272297
273- mps_a = MPSMatrix (A )
274- mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
298+ mps_a = MPSMatrix (Ad )
299+ mps_b = MPSMatrix (Br)
275300 mps_x = MPSMatrix (X)
276301
277- solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
278- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
279- commit! (cmdbuf)
302+
303+ buf = MTLCommandBuffer (queue) do cmdbuf
304+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , false , M, N, 1.0 )
305+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
306+ end
307+
308+ wait_completed (buf)
280309
281- return X
310+ copy! (Br, X)
311+ return B
282312end
283313
284- # function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
285- # require_one_based_indexing(A, B)
286- # m, n = size(A)
287- # if m == n
288- # if istril(A)
289- # if istriu(A)
290- # return Diagonal(A) \ B
291- # else
292- # return LowerTriangular(A) \ B
293- # end
294- # end
295- # if istriu(A)
296- # return UpperTriangular(A) \ B
297- # end
298- # return lu(A) \ B
299- # end
300- # return qr(A, ColumnNorm()) \ B
301- # end
314+
315+ function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
316+ M,N = size (B,1 ), size (B,2 )
317+ dev = current_device ()
318+ queue = global_queue (dev)
319+
320+ Ad = MtlMatrix (A)
321+ Br = reshape (B, (M,N))
322+ X = similar (Br)
323+
324+ mps_a = MPSMatrix (Ad)
325+ mps_b = MPSMatrix (Br)
326+ mps_x = MPSMatrix (X)
327+
328+
329+ buf = MTLCommandBuffer (queue) do cmdbuf
330+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , true , M, N, 1.0 )
331+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
332+ end
333+
334+ wait_completed (buf)
335+
336+ copy! (Br, X)
337+ return B
338+ end
0 commit comments