def endpoint(s:str="london - england + japan") -> str:
"""Does vector arithmetic on word vectors.
:param s: statement, which consist of terms separated by '+' or '-'
operations. Can also just be a single word"""
tokens = s.split(" ")
if tokens[0] != "+" and tokens[0] != "-": tokens = ["+", *tokens]
if len(tokens) % 2 == 1: return "Malformed expression. Make sure it's in the form 'a + b - c + d'"
invalidOps = tokens | batched(2) | cut(0) | ~inSet(["+", "-"]) | deref()
if len(invalidOps) > 0: return f"Don't understand operation '{invalidOps[0]}'. Valid operations include '+' and '-'."
terms = tokens | batched(2) | cut(1) | deref()
for term in terms:
if term not in tok2Id(): return f"Term '{term}' not available in the dictionary"
v = terms | lookup(tok2Id()) | aS(list) | toTensor(torch.long) | aS(lambda x: vectors[x])# | shape()
signs = tokens | batched(2) | cut(0) | apply(lambda x: 1 if x == "+" else -1) | deref() | aS(torch.tensor) | op().to(vectors.device)
u = signs[:,None] * v | op().sum(0)
return vectors * u[None] | op().mean(1).topk(100) | op().values & op().indices | transpose() | lookup(id2Tok(), 1) | apply(lambda x: f"{x:.2f}", 0) | batched(7) | join(" ").all(2) | pretty() | join("\n")