Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions src/python/tblis_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def _contract(subscripts, *tensors, **kwargs):
indices = ''.join(sub_idx)
c_dtype = kwargs.get('dtype', numpy.result_type(*tensors))
if ('...' in subscripts or
not (numpy.issubdtype(c_dtype, numpy.float) or
numpy.issubdtype(c_dtype, numpy.complex))):
not (numpy.issubdtype(c_dtype, numpy.floating) or
numpy.issubdtype(c_dtype, numpy.complexfloating))):
return numpy_einsum(subscripts, *tensors)

if '->' not in subscripts:
Expand Down Expand Up @@ -118,6 +118,7 @@ def _contract(subscripts, *tensors, **kwargs):
tblis_dtype[c_dtype], alpha, beta)
return c


def einsum(subscripts, *tensors, **kwargs):
subscripts = subscripts.replace(' ','')
if len(tensors) <= 1:
Expand All @@ -134,3 +135,47 @@ def einsum(subscripts, *tensors, **kwargs):
out = einsum(subscripts, *tensors, **kwargs)
return out


einsum_symbols = 'abcdefghijklmnopqrstuvwxyz'


def tensordot(x, y, axes=2, **kwargs):
"""Simple translation of tensordot syntax to einsum.
"""
# convert int argument to (list[int], list[int])
if isinstance(axes, int):
axes = range(x.ndim - axes, x.ndim), range(axes)

# convert (int, int) to (list[int], list[int])
if isinstance(axes[0], int):
axes = (axes[0],), axes[1]
if isinstance(axes[1], int):
axes = axes[0], (axes[1],)

# initialize empty indices
x_ix = [None] * x.ndim
y_ix = [None] * y.ndim
out_ix = []

# fill in repeated indices
available_ix = iter(einsum_symbols)
for ax1, ax2 in zip(*axes):
repeat = next(available_ix)
x_ix[ax1] = repeat
y_ix[ax2] = repeat

# fill in the rest, and maintain output order
for i in range(x.ndim):
if x_ix[i] is None:
leave = next(available_ix)
x_ix[i] = leave
out_ix.append(leave)
for i in range(y.ndim):
if y_ix[i] is None:
leave = next(available_ix)
y_ix[i] = leave
out_ix.append(leave)

# form full string and contract!
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
return _contract(einsum_str, x, y, **kwargs)