@@ -192,3 +192,108 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192
193193 return B
194194end
195+
196+
197+ function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
198+ # TODO
199+ end
200+
201+ function LinearAlgebra. ldiv! (A:: UpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
202+ M,N = size (B)
203+ dev = current_device ()
204+ queue = global_queue (dev)
205+ cmdbuf = MTLCommandBuffer (queue)
206+ enqueue! (cmdbuf)
207+
208+ X = MtlMatrix {T} (undef, size (B))
209+
210+ mps_a = MPSMatrix (A)
211+ mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
212+ mps_x = MPSMatrix (X)
213+
214+ solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , false , M, N, 1.0 ) # TODO : likely N, M is the correct order
215+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
216+ commit! (cmdbuf)
217+
218+ return X
219+ end
220+
221+ function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
222+ M,N = size (B)
223+ dev = current_device ()
224+ queue = global_queue (dev)
225+ cmdbuf = MTLCommandBuffer (queue)
226+ enqueue! (cmdbuf)
227+
228+ Bh = reshape (B, )
229+ X = MtlMatrix {T} (undef, size (B))
230+
231+ mps_a = MPSMatrix (A)
232+ mps_b = MPSMatrix (Bh) # TODO reshape to matrix if B is a vector
233+ mps_x = MPSMatrix (X)
234+
235+ solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
236+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
237+ commit! (cmdbuf)
238+
239+ return X
240+ end
241+
242+ function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
243+ M,N = size (B)
244+ dev = current_device ()
245+ queue = global_queue (dev)
246+ cmdbuf = MTLCommandBuffer (queue)
247+ enqueue! (cmdbuf)
248+
249+ X = MtlMatrix {T} (undef, size (B))
250+
251+ mps_a = MPSMatrix (A)
252+ mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
253+ mps_x = MPSMatrix (X)
254+
255+ solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
256+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
257+ commit! (cmdbuf)
258+
259+ return X
260+ end
261+
262+ function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
263+ M,N = size (B)
264+ dev = current_device ()
265+ queue = global_queue (dev)
266+ cmdbuf = MTLCommandBuffer (queue)
267+ enqueue! (cmdbuf)
268+
269+ X = MtlMatrix {T} (undef, size (B))
270+
271+ mps_a = MPSMatrix (A)
272+ mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
273+ mps_x = MPSMatrix (X)
274+
275+ solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
276+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
277+ commit! (cmdbuf)
278+
279+ return X
280+ end
281+
282+ # function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
283+ # require_one_based_indexing(A, B)
284+ # m, n = size(A)
285+ # if m == n
286+ # if istril(A)
287+ # if istriu(A)
288+ # return Diagonal(A) \ B
289+ # else
290+ # return LowerTriangular(A) \ B
291+ # end
292+ # end
293+ # if istriu(A)
294+ # return UpperTriangular(A) \ B
295+ # end
296+ # return lu(A) \ B
297+ # end
298+ # return qr(A, ColumnNorm()) \ B
299+ # end
0 commit comments