Revision 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC, committed by rballester on 19 February 2023, 19:35:27 UTC
1 parent be80cb2
automata.py
``````import torch
import tntorch as tn

"""
Accepts a string iff its number of 1's equals (or is in) `weight`

:param N: number of dimensions
:param weight: an integer (or list thereof): recognized weight(s)
:param nsymbols: slices per core (default is 2)

"""

if not hasattr(weight, '__len__'):
weight = [weight]
weight = torch.tensor(weight).long()
assert weight[0] >= 0
t = tn.weight_one_hot(N, int(max(weight) + 1), nsymbols)
t.cores[-1] = torch.sum(t.cores[-1][:, :, weight], dim=2, keepdim=True)
return t

def weight_one_hot(N, r=None, nsymbols=2):
"""
Given a string with :math:`k` 1's, it produces a vector that represents :math:`k` in `one hot encoding <https://en.wikipedia.org/wiki/One-hot>`_

:param N: number of dimensions
:param r:
:param nsymbols:

:return: a vector of N zeros, except its :math:`k`-th element which is a 1
"""

if not hasattr(nsymbols, '__len__'):
nsymbols = [nsymbols]*N
assert len(nsymbols) == N
if r is None:
r = N + 1

cores = []
for n in range(N):
core = torch.zeros([r, nsymbols[n], r])
core[:, 0, :] = torch.eye(r)
for s in range(1, nsymbols[n]):
core[:, s, s:] = torch.eye(r)[:, :-s]
cores.append(core)
cores[0] = cores[0][0:1, :, :]
return tn.Tensor(cores)

def weight(N, nsymbols=2):
"""
For any string, counts how many 1's it has

:param N: number of dimensions
:param nsymbols: slices per core (default is 2)

"""

cores = []
for n in range(N):
core = torch.eye(2)[:, None, :].repeat(1, nsymbols, 1)
core[1, :, 0] = torch.arange(nsymbols)
cores.append(core)
cores[0] = cores[0][1:2, :, :]
cores[-1] = cores[-1][:, :, 0:1]
return tn.Tensor(cores)

def length(N):
"""
:todo:

:param N:
:return:
"""
raise NotImplementedError

def accepted_inputs(t):
"""
Returns all strings accepted by an automaton, in alphabetical order.

Note: each string s will appear as many times as the value t[s]

:param t: a :class:`Tensor`

:return Xs: a Torch matrix, each row is one string
"""
dtype = t.cores[0].dtype
def recursion(Xs, left, rights, bound, mu):
if t.batch:
raise ValueError('Batched tensors are not supproted.')
if mu == t.dim():
return

fiber = torch.einsum('ijk,k->ij', (t.cores[mu], rights[mu + 1]))

per_point = torch.matmul(left, fiber).double().round()

c = torch.cat((torch.tensor([0], dtype=per_point.dtype), per_point.cumsum(dim=0))).long()

for i, p in enumerate(per_point):
if c[i] == c[i + 1]:  # Improductive prefix, don't go further
continue
Xs[bound + c[i]:bound + c[i + 1], mu] = i
recursion(Xs, torch.matmul(left, t.cores[mu][..., i, :]), rights, bound + c[i], mu + 1)

Xs = torch.zeros([round(tn.sum(t).item()), t.dim()], dtype=torch.long)
rights = [torch.ones(1, dtype=dtype)]  # Precomputed right-product chains
for core in t.cores[::-1]:
rights.append(torch.matmul(torch.sum(core, dim=1), rights[-1]))
rights = rights[::-1]
recursion(Xs, torch.ones(1, dtype=dtype), rights, 0, 0)
return Xs
``````

Computing file changes ...