Revision ac62d292d7de9b49855b7ca53f2645a10c63246a authored by rballester on 16 September 2022, 15:17:15 UTC, committed by rballester on 16 September 2022, 15:17:15 UTC
1 parent 363d461
automata.py
import torch
import tntorch as tn
def weight_mask(N, weight, nsymbols=2):
"""
Accepts a string iff its number of 1's equals (or is in) `weight`
:param N: number of dimensions
:param weight: an integer (or list thereof): recognized weight(s)
:param nsymbols: slices per core (default is 2)
:return: a mask tensor
"""
if not hasattr(weight, '__len__'):
weight = [weight]
weight = torch.tensor(weight).long()
assert weight[0] >= 0
t = tn.weight_one_hot(N, int(max(weight) + 1), nsymbols)
t.cores[-1] = torch.sum(t.cores[-1][:, :, weight], dim=2, keepdim=True)
return t
def weight_one_hot(N, r=None, nsymbols=2):
"""
Given a string with :math:`k` 1's, it produces a vector that represents :math:`k` in `one hot encoding <https://en.wikipedia.org/wiki/One-hot>`_
:param N: number of dimensions
:param r:
:param nsymbols:
:return: a vector of N zeros, except its :math:`k`-th element which is a 1
"""
if not hasattr(nsymbols, '__len__'):
nsymbols = [nsymbols]*N
assert len(nsymbols) == N
if r is None:
r = N + 1
cores = []
for n in range(N):
core = torch.zeros([r, nsymbols[n], r])
core[:, 0, :] = torch.eye(r)
for s in range(1, nsymbols[n]):
core[:, s, s:] = torch.eye(r)[:, :-s]
cores.append(core)
cores[0] = cores[0][0:1, :, :]
return tn.Tensor(cores)
def weight(N, nsymbols=2):
"""
For any string, counts how many 1's it has
:param N: number of dimensions
:param nsymbols: slices per core (default is 2)
:return: a mask tensor
"""
cores = []
for n in range(N):
core = torch.eye(2)[:, None, :].repeat(1, nsymbols, 1)
core[1, :, 0] = torch.arange(nsymbols)
cores.append(core)
cores[0] = cores[0][1:2, :, :]
cores[-1] = cores[-1][:, :, 0:1]
return tn.Tensor(cores)
def length(N):
"""
:todo:
:param N:
:return:
"""
raise NotImplementedError
def accepted_inputs(t):
"""
Returns all strings accepted by an automaton, in alphabetical order.
Note: each string s will appear as many times as the value t[s]
:param t: a :class:`Tensor`
:return Xs: a Torch matrix, each row is one string
"""
dtype = t.cores[0].dtype
def recursion(Xs, left, rights, bound, mu):
if t.batch:
raise ValueError('Batched tensors are not supproted.')
if mu == t.dim():
return
fiber = torch.einsum('ijk,k->ij', (t.cores[mu], rights[mu + 1]))
per_point = torch.matmul(left, fiber).double().round()
c = torch.cat((torch.tensor([0], dtype=per_point.dtype), per_point.cumsum(dim=0))).long()
for i, p in enumerate(per_point):
if c[i] == c[i + 1]: # Improductive prefix, don't go further
continue
Xs[bound + c[i]:bound + c[i + 1], mu] = i
recursion(Xs, torch.matmul(left, t.cores[mu][..., i, :]), rights, bound + c[i], mu + 1)
Xs = torch.zeros([round(tn.sum(t).item()), t.dim()], dtype=torch.long)
rights = [torch.ones(1, dtype=dtype)] # Precomputed right-product chains
for core in t.cores[::-1]:
rights.append(torch.matmul(torch.sum(core, dim=1), rights[-1]))
rights = rights[::-1]
recursion(Xs, torch.ones(1, dtype=dtype), rights, 0, 0)
return Xs
Computing file changes ...