from k1lib.imports import *
import gc
Tries to implement vision transformers. Works pretty well actually, and performance is lower than CNNs for complex datasets. Severely overfit if no augmentations are used, or if the augmentations are static. Performance summary:
| Arch | Dataset | Augmentations? | Loss (train) | Accuracy | Top 5 accuracy |
| -------- | ------- | -------------- | ------------ | -------- | -------------- |
| ViT | 3 | no | 0.2 | 12.6% | 33% |
| CNN | 3 | no | 0.8 | 24% | 52% |
| ViT | 3 | yes | 1.1 | 30% | 58% |
| ViT wide | 3 | yes | 2.2 | 27% | 56% |
| ViT deep | 3 | yes | 2.2 | 24% | 52% |
| CNN | 3 | yes | 1.3 | 47% | 77% |
| ResNet50 | 3 | yes | 0.8 | 46% | 72% |
| ViT | 1 | yes | 0.3 | 66% | 95% |
| CNN | 1 | yes | 0.05 | 82% | 98% |
| CNN | 1 | no | 0.4 | 70% | 97% |
Here, dataset 3 is more complicated and has more images and categories (80 total) than dataset 1 (only 9). On previous notebooks, we often use dataset 1 without augmentation.
%%time
base = "~/ssd/data/imagenet/set1/192px"
idxToCat = base | ls() | head(80) | op().split("/")[-1].all() | insertIdColumn() | toDict()
catToIdx = idxToCat.items() | permute(1, 0) | toDict()
# stage 1, (train/valid, classes, samples (url of img))
st1 = base | ls() | head(80) | apply(ls() | splitW()) | transpose() | deref() | aS(k1.Wrapper)
# stage 2, (train/valid, classes, samples, [img, class])
st2 = st1() | (apply(lambda x: [x | toImg() | toTensor(torch.uint8), catToIdx[x.split("/")[-2]]]) | repeatFrom(4) | apply(aS(tf.Resize(192)) | aS(tf.AutoAugment()) | op()/255, 0)).all(2) | deref() | aS(k1.Wrapper)
def dataF(bs): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | aS(torch.stack) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
xb, yb = dataF(64) | item(2)
st1() | shape(), st2() | shape(), [xb, yb] | shape().all() | deref()
CPU times: user 2min 45s, sys: 3.14 s, total: 2min 48s Wall time: 21.3 s
((2, 9, 531, 59), (2, 9, 2124, 2, 3, 192, 192), [[64, 3, 192, 192], [64]])
32x32 patch, just to demonstrate that it works:
# thumbnail
einops.rearrange(xb[1], "c (h s1) (w s2) -> (c h w) s1 s2", s1=32, s2=32) | head(40) | plotImgs(10)
8x8 patches, so 24x24=576 total number of patches:
xb[1].reshape(3, 24, 8, 24, 8).transpose(2, 3).reshape(-1, 8, 8) | head(20) | plotImgs(10)
Should look something like this:
bs = xb.shape[0]
a1 = xb.reshape(bs, 3, 24, 8, 24, 8).transpose(3, 4).transpose(1, 3).reshape(bs, -1, 8*8*3)
N, L, F = a1 | shape(); N, L, F
(64, 576, 192)
Fp = 20; a2 = torch.randn(L, Fp).positionalEncode()[None].expand(64, -1, -1)
a3 = torch.concat([a1, a2], dim=2); F = a3.shape[2]
a2[0].T | aS(plt.imshow); a2 | shape(), a3 | shape()
((64, 576, 20), (64, 576, 212))
clsTok = torch.zeros(N, 1, F); clsTok[:,:,range(0,F,2)] = 1
x = torch.concat([clsTok, a3], dim=1); x | shape(), clsTok | shape()
((64, 577, 212), (64, 1, 212))
Packaging everything:
def bells(xb, patchSize=8):
a1 = einops.rearrange(xb, "b c (h s1) (w s2) -> b (w h) (c s1 s2)", s1=patchSize, s2=patchSize)
N, L, F = a1 | shape(); Fp = 20; a2 = torch.randn(L, Fp).positionalEncode()[None].expand(N, -1, -1)
a3 = torch.concat([a1, a2], dim=2); F = a3.shape[2]
clsTok = torch.zeros(N, 1, F); clsTok[:,:,range(0,F,2)] = 1
return torch.concat([clsTok, a3], dim=1)
#clsTok = torch.randn(F); return torch.concat([clsTok[None, None].expand(N, -1, -1), a3], dim=1) # clsTok should be put outside if you were to use this version btw
bells(xb) | shape()
(64, 577, 212)
class MultiheadAttention(nn.Module):
def __init__(self, qdim, kdim, vdim, embed, head=4, outdim=None):
"""Kinda like :class:`torch.nn.MultiheadAttention`, just simpler, shorter, and clearer.
Probably not as fast as the official version, and doesn't have masks and whatnot, but easy to read!
Example::
xb = torch.randn(32, 14, 35) # (N, S, ), or batch size 32, sequence size 14, feature size 35
# returns torch.Size([32, 14, 50])
MultiheadAttention(35, 35, 35, 9, 4, 50)(xb).shape
Although you can use this right away with no mods, I really encourage you to copy and paste the
source code of this and modify it to your needs.
:param qdim: Basic query, key and value dimensions
:param embed: a little different from :class:`torch.nn.MultiheadAttention`, as this is after splitting into heads
:param outdim: if not specified, then equals to ``embed * head``"""
super().__init__()
self.embed = embed; self.head = head; outdim = outdim or embed*head
self.qdim = qdim; self.wq = nn.Linear(qdim, head*embed)
self.kdim = kdim; self.wk = nn.Linear(kdim, head*embed)
self.vdim = vdim; self.wv = nn.Linear(vdim, head*embed)
self.outLin = nn.Linear(head*embed, outdim)
self.softmax = nn.Softmax(-1)
def forward(self, query, key=None, value=None):
"""If ``key`` or ``value`` is not specified, just default to ``query``."""
if key is None: key = query
if value is None: value = query
N, S, *_ = key.shape; F = self.embed; head = self.head
q = self.wq(query); k = self.wk(key); v = self.wv(value)
q = einops.rearrange(q, "N S1 (head F) -> (N head) S1 F", head=self.head) / math.sqrt(F)
k = einops.rearrange(k, "N S (head F) -> (N head) S F", head=self.head)
v = einops.rearrange(v, "N S (head F) -> (N head) S F", head=self.head)
mat = self.softmax(einops.einsum(q, k, "Nh S1 F, Nh S F -> Nh S1 S"))
return einops.rearrange(einops.einsum(mat, v, "Nh S1 S, Nh S F -> Nh S1 F"), "(N head) S1 F -> N S1 (head F)", head=self.head) | self.outLin
x | shape(), x | MultiheadAttention(212, 212, 212, 32) | shape()
((64, 577, 212), (64, 577, 128))
xx = x | nn.Linear(212, 128); nn.MultiheadAttention(128, 4)(xx, xx, xx)[0] | shape()
(64, 577, 128)
class Block(nn.Module):
def __init__(self, embed=32, skips=1, nhead=4, patchSize=8):
super().__init__(); dim = embed*nhead; S = (192//patchSize)**2+1
self.n1 = nn.LayerNorm([S, dim])
self.ma = MultiheadAttention(dim, dim, dim, embed, nhead)
#self.ma = nn.MultiheadAttention(dim, 4)
self.n2 = nn.LayerNorm([S, dim])
self.act = nn.Sequential(*repeatF(lambda: knn.LinBlock(dim, dim), skips))
def forward(self, xb):
a = (xb | self.ma | self.n1) + xb
#a = self.ma(xb, xb, xb)[0] + xb
return (a | self.act | self.n2) + a
class Net(nn.Module):
def __init__(self, nclasses=9, embed=32, skips=3, nhead=4, patchSize=8, finalDim=30, **kwargs):
super().__init__(); dim=embed*nhead; F = patchSize**2*3+20; self.l1 = knn.LinBlock(F, dim)
self.seq = nn.Sequential(*repeatF(lambda: Block(embed, skips, nhead, patchSize, **kwargs), skips))
self.l2 = knn.LinBlock(dim, finalDim); self.l3 = nn.Linear(finalDim, nclasses);
def forward(self, xb): return xb | self.l1 | self.seq | op()[:,0] | self.l2 | self.l3
n = Net(skips=1, nhead=8); x | n | shape()
(64, 9)
Nice! Let's rewrite the dataloader (just adding bells
really):
def dataF(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | (aS(torch.stack) | aS(bells, patchSize)) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
def newL(bs=64, timeLimit=None, patchSize=8, lr=1e-2, cuda=True, model=None, **kwargs):
l = k1.Learner(); l.data = dataF(bs, patchSize); l.model = model or Net(patchSize=patchSize, **kwargs)
l.css = "skipBlock:HookModule; #distill Linear, #gen Skip #0:HookParam, HookModule"
l.opt = optim.AdamW(l.model.parameters(), lr=lr); l.cbs.add(Cbs.TimeLimit(timeLimit));
l.cbs.add(Cbs.LossCrossEntropy())
l.css = "#act > #1: HookModule; #act > #1 *: HookParam"
if cuda: l.cbs.add(Cbs.Cuda());
return l
l = newL(bs=32, lr=1e-3); l.run(1);
Progress: 100%, epoch: 0/1 (5.32 epochs/minute), batch: 311/312 (27.67 batches/s), elapsed: 11.24s, remaining: 0.04s, loss: 2.2055299282073975
#notest
l = newL(bs=64, lr=1e-3); l.run(1);
Progress: 99%, epoch: 0/1 (7.02 epochs/minute), batch: 155/156, elapsed: 8.49s, remaining: 0.05s, loss: 2.1818768978118896
#notest
l = newL(bs=128, lr=1e-3); l.run(1);
Progress: 99%, epoch: 0/1 (7.77 epochs/minute), batch: 77/78, elapsed: 7.62s, remaining: 0.1s, loss: 2.1729390621185303
#notest
l = newL(bs=64, lr=1e-2, skips=5); l.run(100);
Progress: 100%, epoch: 99/100 (4.11 epochs/minute), batch: 155/156, elapsed: 1460.74s, remaining: 0.09s, loss: 1.4815202951431274
l.Loss.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Seems to work well. But like, loss decreases and accuracy increases so slow. Well, it's a transformer afterall. Let's to a grid scan of parameters:
from k1lib.imports import *
def gen():
for skip in [1, 3, 5]: # 45 hours total
for patchSize in [4, 8, 16, 32, 64]: # 15 hours total
for embed in [16, 32, 64]: # 3 hours total
for nhead in [4, 8, 16]: # 1 hour total
yield {"skips": skip, "patchSize": patchSize, "embed": embed, "nhead": nhead}
#notest
def scan():
"""This will grab lines in this notebook, dump it into a file, then run that file. That
file will then load up "scan.pth", add new data, then dump back saved data. This is required
because I run out of CUDA memory all the time, and freeing memory inside the notebook doesn't
seem to work."""
[] | aS(dill.dumps) | file("scan.pth")
"" | file("perif.txt")
for hyperparams in tqdm(gen() | aS(list)):
nb.cells("8-vit.ipynb") | filt(op()["cell_type"] == "code") | op()["source"].all() | joinStreams() | ~filt(op().startswith("%%")) | breakIf(op().startswith("l = newL")) | op().strip("\n").all() | file("run.py")
f"""l = newL(bs=32, lr=1e-3, timeLimit=20*60, **({hyperparams})); e = None; startMem = torch.cuda.memory_allocated(0)
try:
json.dumps({hyperparams}) >> file("perif.txt")
with k1.timer() as t: l.run(1000);
except Exception as _e: e = _e
finally:
data = cat("scan.pth", False) | aS(dill.loads)
data.append([{hyperparams}, e, torch.cuda.memory_allocated(0)-startMem, t(), l.progress, list(l.Loss.train[-30:]), list(l.Loss.valid[-30:]), list(l.Accuracy.train[-30:]), list(l.Accuracy.valid[-30:])])
data | aS(dill.dumps) | file("scan.pth")
""" >> file("run.py"); time.sleep(1)
with k1.captureStdout(): None | cmd("python run.py") | stdout()
scan()
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 135/135 [37:37:59<00:00, 1003.55s/it]
data = cat("scan.pth", False) | aS(dill.loads) | aS(k1.Wrapper); data() | shape()
(135, 9, 4, 5)
List of errors:
data() | cut(1) | filt(op()) | apply(str) | op().split(".")[0].all() | batched(4, True) | display()
CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory CUDA out of memory
All out of memory errors. What are the typical configurations?
data() | filt(op(), 1) | cut(0) | batched(3, True) | display(None)
{'skips': 1, 'patchSize': 4, 'embed': 16, 'nhead': 8} {'skips': 1, 'patchSize': 4, 'embed': 16, 'nhead': 16} {'skips': 1, 'patchSize': 4, 'embed': 32, 'nhead': 8} {'skips': 1, 'patchSize': 4, 'embed': 32, 'nhead': 16} {'skips': 1, 'patchSize': 4, 'embed': 64, 'nhead': 8} {'skips': 1, 'patchSize': 4, 'embed': 64, 'nhead': 16} {'skips': 3, 'patchSize': 4, 'embed': 16, 'nhead': 4} {'skips': 3, 'patchSize': 4, 'embed': 16, 'nhead': 8} {'skips': 3, 'patchSize': 4, 'embed': 16, 'nhead': 16} {'skips': 3, 'patchSize': 4, 'embed': 32, 'nhead': 4} {'skips': 3, 'patchSize': 4, 'embed': 32, 'nhead': 8} {'skips': 3, 'patchSize': 4, 'embed': 32, 'nhead': 16} {'skips': 3, 'patchSize': 4, 'embed': 64, 'nhead': 4} {'skips': 3, 'patchSize': 4, 'embed': 64, 'nhead': 8} {'skips': 3, 'patchSize': 4, 'embed': 64, 'nhead': 16} {'skips': 5, 'patchSize': 4, 'embed': 16, 'nhead': 4} {'skips': 5, 'patchSize': 4, 'embed': 16, 'nhead': 8} {'skips': 5, 'patchSize': 4, 'embed': 16, 'nhead': 16} {'skips': 5, 'patchSize': 4, 'embed': 32, 'nhead': 4} {'skips': 5, 'patchSize': 4, 'embed': 32, 'nhead': 8} {'skips': 5, 'patchSize': 4, 'embed': 32, 'nhead': 16} {'skips': 5, 'patchSize': 4, 'embed': 64, 'nhead': 4} {'skips': 5, 'patchSize': 4, 'embed': 64, 'nhead': 8} {'skips': 5, 'patchSize': 4, 'embed': 64, 'nhead': 16} {'skips': 5, 'patchSize': 8, 'embed': 64, 'nhead': 16}
Common denominator here seems to be patch size. Makes sense, as memory used should be $\text{patchSize}^4$. Let's plot it:
data() | ~filt(op(), 1) | cut(0, 2) | ~apply(lambda x, y: [*x.values(), y]) | groupBy(1) | apply(cut(1, 4)) | joinStreams() | apply(op()/1e9, 1) | transpose() | ~aS(plt.plot, ".")
plt.xlabel("Patch size"); plt.ylabel("Memory used (GB)"); plt.grid(True); plt.xscale("log");
Best architecture (max valid accuracy)?
data() | ~filt(op(), 1) | transpose() | mtmS(*[iden()]*5, *[toMean().all()]*4) | transpose() | ~sort(-1) | cut(0, 8) | apply(round, 1, ndigits=3) | deref() | headOut()
[{'skips': 1, 'patchSize': 4, 'embed': 16, 'nhead': 4}, 0.646] [{'skips': 1, 'patchSize': 4, 'embed': 32, 'nhead': 4}, 0.625] [{'skips': 1, 'patchSize': 8, 'embed': 16, 'nhead': 16}, 0.614] [{'skips': 1, 'patchSize': 8, 'embed': 32, 'nhead': 8}, 0.61] [{'skips': 1, 'patchSize': 8, 'embed': 32, 'nhead': 4}, 0.602] [{'skips': 1, 'patchSize': 8, 'embed': 16, 'nhead': 8}, 0.598] [{'skips': 3, 'patchSize': 8, 'embed': 32, 'nhead': 4}, 0.598] [{'skips': 3, 'patchSize': 8, 'embed': 16, 'nhead': 8}, 0.597] [{'skips': 1, 'patchSize': 8, 'embed': 16, 'nhead': 4}, 0.571] [{'skips': 3, 'patchSize': 16, 'embed': 16, 'nhead': 8}, 0.562]
Best ones are with networks with the smallest patch size? That's very surprising. Worse networks:
collateStats = ~filt(op(), 1) | transpose.wrap(mtmS(*[iden()]*5, *[toMean().all()]*4))
data() | collateStats | sort(-1) | cut(0, 8) | apply(round, 1, ndigits=3) | deref() | headOut()
[{'skips': 3, 'patchSize': 8, 'embed': 64, 'nhead': 16}, 0.185] [{'skips': 5, 'patchSize': 8, 'embed': 32, 'nhead': 16}, 0.191] [{'skips': 5, 'patchSize': 8, 'embed': 64, 'nhead': 8}, 0.234] [{'skips': 3, 'patchSize': 8, 'embed': 32, 'nhead': 16}, 0.24] [{'skips': 1, 'patchSize': 64, 'embed': 64, 'nhead': 8}, 0.26] [{'skips': 5, 'patchSize': 64, 'embed': 16, 'nhead': 16}, 0.265] [{'skips': 1, 'patchSize': 64, 'embed': 16, 'nhead': 8}, 0.282] [{'skips': 5, 'patchSize': 64, 'embed': 32, 'nhead': 16}, 0.293] [{'skips': 5, 'patchSize': 8, 'embed': 64, 'nhead': 4}, 0.296] [{'skips': 5, 'patchSize': 64, 'embed': 64, 'nhead': 8}, 0.301]
legends = data() | collateStats | ~apply(lambda x, *y: [*x.values(), *y]) | cut(0, 1, 11) | groupBy(1) | transpose().all() | ~apply(lambda x, y, z: [y | item(), [x, z]]) | cut(0) & (cut(1) | apply(~aS(plt.plot, "."))) | deref() | item()
plt.xlabel("Skips"); plt.ylabel("Accuracy"); plt.legend(legends | apply(lambda x: f"Patch size: {x}") | deref()); plt.grid(True)
legends = data() | collateStats | ~apply(lambda x, *y: [*x.values(), *y]) | cut(1, 2, 11) | permute(1, 0, 2) | groupBy(1) | transpose().all() | ~apply(lambda x, y, z: [y | item(), [x, z]]) | cut(0) & (cut(1) | apply(~aS(plt.plot, "."))) | deref() | item()
plt.xlabel("Embed size"); plt.ylabel("Accuracy"); plt.legend(legends | apply(lambda x: f"Patch size: {x}") | deref()); plt.grid(True);
legends = data() | collateStats | ~apply(lambda x, *y: [*x.values(), *y]) | cut(1, 3, 11) | permute(1, 0, 2) | groupBy(1) | transpose().all() | ~apply(lambda x, y, z: [y | item(), [x, z]]) | cut(0) & (cut(1) | apply(~aS(plt.plot, "."))) | deref() | item()
plt.xlabel("Number of heads"); plt.ylabel("Accuracy"); plt.legend(legends | apply(lambda x: f"Patch size: {x}") | deref()); plt.grid(True);
These are long runs so that we can use them on other notebooks. Let's use set3 instead of the usual set1, because it has more categories. We're only gonna use the first 80, because I can't fit more into my RAM!
%%time
base = "~/ssd/data/imagenet/set3/192px"
idxToCat = base | ls() | head(80) | op().split("/")[-1].all() | insertIdColumn() | toDict()
catToIdx = idxToCat.items() | permute(1, 0) | toDict()
# stage 1, (train/valid, classes, samples (url of img))
st1 = base | ls() | head(80) | apply(ls() | splitW()) | transpose() | deref() | aS(k1.Wrapper)
# stage 2, (train/valid, classes, samples, [img, class])
st2 = st1() | (apply(lambda x: [x | toImg() | toTensor(torch.uint8), catToIdx[x.split("/")[-2]]]) | repeatFrom(1) | apply(aS(tf.Resize(192)) | aS(tf.AutoAugment()) | op()/255, 0)).all(2) | deref() | aS(k1.Wrapper)
def dataF(bs): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | aS(torch.stack) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
xb, yb = dataF(64) | item(2)
st1() | shape(), st2() | shape(), [xb, yb] | shape().all() | deref()
CPU times: user 12min 26s, sys: 13.8 s, total: 12min 40s Wall time: 1min 38s
((2, 80, 490, 59), (2, 80, 490, 2, 3, 192, 192), [[64, 3, 192, 192], [64]])
def dataF(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | (aS(torch.stack) | aS(bells, patchSize)) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
class Block(nn.Module):
def __init__(self, embed=32, skips=1, nhead=4, patchSize=8):
super().__init__(); dim = embed*nhead; S = (192//patchSize)**2+1
self.n1 = nn.LayerNorm([S, dim])
self.ma = MultiheadAttention(dim, dim, dim, embed, nhead)
self.n2 = nn.LayerNorm([S, dim])
self.act = nn.Sequential(*repeatF(lambda: knn.LinBlock(dim, dim), skips))
def forward(self, xb):
a = (xb | self.n1 | self.ma) + xb
return (a | self.n2 | self.act) + a
class ViT(nn.Module):
def __init__(self, nclasses=9, embed=32, skips=3, nhead=4, patchSize=8, finalDim=30, **kwargs):
super().__init__(); dim=embed*nhead; F = patchSize**2*3+20; self.l1 = knn.LinBlock(F, dim)
self.seq = nn.Sequential(*repeatF(lambda: Block(embed, skips, nhead, patchSize, **kwargs), skips))
self.l2 = knn.LinBlock(dim, finalDim); self.l3 = nn.Linear(finalDim, nclasses);
def forward(self, xb): return xb | self.l1 | self.seq | op()[:,0] | self.l2 | self.l3
def newLViT(bs=64, timeLimit=None, patchSize=8, lr=1e-2, cuda=True, model=None, **kwargs):
l = k1.Learner(); l.data = dataF(bs, patchSize); l.model = model or ViT(patchSize=patchSize, **kwargs)
l.css = "skipBlock:HookModule; #distill Linear, #gen Skip #0:HookParam, HookModule"
l.opt = optim.AdamW(l.model.parameters(), lr=lr); l.cbs.add(Cbs.TimeLimit(timeLimit));
l.cbs.add(Cbs.LossCrossEntropy())
l.css = "#act > #1: HookModule; #act > #1 *: HookParam"
if cuda: l.cbs.add(Cbs.Cuda());
return l
n = ViT(skips=1, nhead=8); x | n | shape()
(64, 9)
l = newLViT(bs=32, lr=1e-3, nclasses=80, finalDim=100); l.run(1);
Progress: 100%, epoch: 0/1 (3.23 epochs/minute), batch: 311/312 (16.77 batches/s), elapsed: 18.54s, remaining: 0.06s, loss: 4.39176082611084
From our experiment above, let's pick this config: [{'skips': 3, 'patchSize': 8, 'embed': 32, 'nhead': 4}, 0.598]
. Not strictly the best arch, but we can't afford to have patch size of 4 cause we're going to run out of memory, and I want the network to have lots of params, so that it can generalize better. Let's check the speed first:
#notest
l = newLViT(bs=32, lr=1e-3, skips=3, patchSize=8, embed=32, nhead=4, nclasses=80, finalDim=100); l.run(10);
Progress: 100%, epoch: 9/10 (3.31 epochs/minute), batch: 311/312 (17.19 batches/s), elapsed: 181.39s, remaining: 0.06s, loss: 4.27975606918335
l.Loss.plot(~head(0.3) | smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Let's now do the 8 hour run:
#notest
l = newLViT(bs=32, lr=1e-3, skips=3, patchSize=8, embed=32, nhead=4, nclasses=80, finalDim=100); l.run(3000); l.save("set3-vit")
Progress: 100%, epoch: 2999/3000 (3.99 epochs/minute), batch: 311/312 (20.72 batches/s), elapsed: 45165.21s, remaining: 0.05s, loss: 8.544167518615723 Saved to set3-vit
l.Loss.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Loss.plot(~head(0.1) | smooth(1000))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(1000))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.Accuracy.plot(head(0.1) | smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth(1000))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(head(0.1) | smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Under random policy, the accuracy would be 1.25%, so at least this is better than that, but not great perf. Compared to a regular CNN?
seq = nn.Sequential
class Skip(nn.Module):
def __init__(self, channel=3):
super().__init__(); self.seq = nn.Sequential(
nn.Conv2d(channel, channel, 1), nn.ReLU(), nn.BatchNorm2d(channel),
nn.Conv2d(channel, channel, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(channel),
nn.Conv2d(channel, channel, 1), nn.ReLU(), nn.BatchNorm2d(channel))
def forward(self, x): return self.seq(x) + x
class Distill(nn.Module):
def __init__(self, skips=5, **kwargs):
super().__init__();
self.c1 = seq(nn.Conv2d(3, 8, 7, 2, padding=3), nn.ReLU(), nn.BatchNorm2d(8)); self.s2 = seq(*[Skip(8) for e in range(skips)])
self.c3 = seq(nn.Conv2d(8, 16, 3, 2, padding=1), nn.ReLU(), nn.BatchNorm2d(16)); self.s4 = seq(*[Skip(16) for e in range(skips)])
self.c5 = seq(nn.Conv2d(16, 32, 3, 2, padding=1), nn.ReLU(), nn.BatchNorm2d(32)); self.s6 = seq(*[Skip(32) for e in range(skips)])
self.c7 = seq(nn.Conv2d(32, 64, 3, 2, padding=1), nn.ReLU(), nn.BatchNorm2d(64)); self.s8 = seq(*[Skip(64) for e in range(skips)])
def forward(self, x): return x | self.c1 | self.s2 | self.c3 | self.s4 | self.c5 | self.s6 | self.c7 | self.s8
class CNN(nn.Module):
def __init__(self, **kwargs):
super().__init__(); self.distill = Distill(**kwargs); self.pool = nn.AdaptiveAvgPool2d(1)
self.l1 = knn.LinBlock(64, 64); self.l2 = nn.Linear(64, 80)
def forward(self, xb): return xb | self.distill | self.pool | aS(einops.rearrange, "b c 1 1 -> b c") | self.l1 | self.l2
def dataFCNN(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | aS(torch.stack) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
def newLCNN(bs=64, timeLimit=None, patchSize=8, lr=1e-2, cuda=True, model=None, **kwargs):
l = k1.Learner(); l.data = dataFCNN(bs, patchSize); l.model = model or CNN(patchSize=patchSize, **kwargs)
l.css = "skipBlock:HookModule; #distill Linear, #gen Skip #0:HookParam, HookModule"
l.opt = optim.AdamW(l.model.parameters(), lr=lr); l.cbs.add(Cbs.TimeLimit(timeLimit));
l.cbs.add(Cbs.LossCrossEntropy()); l.css = "#act > #1: HookModule; #act > #1 *: HookParam"
if cuda: l.cbs.add(Cbs.Cuda());
return l
l = newLCNN(bs=128, lr=1e-3, cuda=False); l.run(1);
Progress: 99%, epoch: 0/1 (0.42 epochs/minute), batch: 77/78 (0.55 batches/s), elapsed: 139.68s, remaining: 1.81s, loss: 4.367627143859863
#notest
l = newLCNN(bs=128, lr=1e-3, timeLimit=3600); l.run(10000); l.save("set3-cnn");
Progress: 3%, epoch: 258/10000 (4.32 epochs/minute), batch: 77/78 (5.61 batches/s), elapsed: 3600.03s, remaining: 135403.76s, loss: 5.877784252166748 Run cancelled: Takes more than 3600 seconds!. Saved to set3-cnn
l.Loss.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Okay CNNs are much better, and I'd say it has acceptable perf, but still not great. Let's test my performance instead, to see how accurate am I. All the categories and their index:
synset = cat("../2017_synsets.txt") | op().split(": ").all() | toDict()
idxToSynset = catToIdx.items() | lookup(synset, 0) | apply(op()[:20], 0) | permute(1, 0) | toDict()
idxToSynset.items() | sort(0) | join(" ").all() | batched(6, True) | display(None)
0 lorikeet 1 Band Aid 2 croquet ball 3 German short-haired 4 Lhasa, Lhasa apso 5 Chesapeake Bay retri 6 espresso maker 7 lycaenid, lycaenid b 8 Granny Smith 9 schooner 10 black stork, Ciconia 11 tennis ball 12 soup bowl 13 submarine, pigboat, 14 oxcart 15 Sealyham terrier, Se 16 banana 17 redshank, Tringa tot 18 water snake 19 German shepherd, Ger 20 bustard 21 sombrero 22 vine snake 23 starfish, sea star 24 howler monkey, howle 25 chow, chow chow 26 hair spray 27 French bulldog 28 swimming trunks, bat 29 hippopotamus, hippo, 30 miniskirt, mini 31 Persian cat 32 quail 33 wreck 34 king snake, kingsnak 35 affenpinscher, monke 36 lion, king of beasts 37 hummingbird 38 boxer 39 tape player 40 electric ray, crampf 41 isopod 42 toy terrier 43 sunglasses, dark gla 44 black-and-tan coonho 45 Pomeranian 46 Norwegian elkhound, 47 folding chair 48 lemon 49 beaker 50 hare 51 langur 52 otterhound, otter ho 53 macaque 54 Bouvier des Flandres 55 European gallinule, 56 Brittany spaniel 57 dalmatian, coach dog 58 toucan 59 ptarmigan 60 schipperke 61 pool table, billiard 62 power drill 63 thunder snake, worm 64 prairie chicken, pra 65 guacamole 66 mountain bike, all-t 67 soft-coated wheaten 68 sidewinder, horned r 69 leatherback turtle, 70 snowmobile 71 Windsor tie 72 miniature schnauzer 73 mixing bowl 74 wire-haired fox terr 75 patas, hussar monkey 76 Ibizan hound, Ibizan 77 frilled lizard, Chla 78 diaper, nappy, napki 79 rubber eraser, rubbe
a = dataFCNN(64) | item() | deref() | aS(k1.Wrapper); a() | shape()
(125, 2, 64, 3, 192, 192)
a() | cut(0) | joinStreams() | op().permute(1, 2, 0).all() | head(80) | plotImgs(10)
Actually, this is too hard for me. Let's just agree that that accuracy is pretty good as far as set3 goes. If you want to try it out, here're the answers:
a() | transpose().all() | joinStreams() | toInt(1) | lookup(idxToSynset, 1) | apply(op()[:15], 1) | apply(op().permute(1, 2, 0), 0) | head(80) | plotImgs(10)
Finally, let's do ViT again, but this time with generated augmentations, instead of a fixed number of augmentations. I didn't autogenerate the augmentations before because it's quite slow, even with multithreading.
%%time
base = "~/ssd/data/imagenet/set3/192px"
idxToCat = base | ls() | head(80) | op().split("/")[-1].all() | insertIdColumn() | toDict()
catToIdx = idxToCat.items() | permute(1, 0) | toDict()
# stage 1, (train/valid, classes, samples (url of img))
st1 = base | ls() | head(80) | apply(ls() | splitW()) | transpose() | deref() | aS(k1.Wrapper)
# stage 2, (train/valid, classes, samples, [img, class])
st2 = st1() | (apply(lambda x: [x | toImg() | toTensor(torch.uint8), catToIdx[x.split("/")[-2]]]) | repeatFrom(1) | apply(aS(tf.Resize(192)), 0)).all(2) | deref() | aS(k1.Wrapper)
def dataF(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | (apply(tf.AutoAugment()) | aS(list) | aS(torch.stack) | op()/255 | aS(bells, patchSize)) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
xb, yb = dataF(64) | item(2)
st1() | shape(), st2() | shape(), [xb, yb] | shape().all() | deref()
CPU times: user 7min 20s, sys: 6.64 s, total: 7min 27s Wall time: 56 s
((2, 80, 490, 59), (2, 80, 490, 2, 3, 192, 192), [[64, 577, 212], [64]])
l = newLViT(bs=32, lr=1e-3, skips=3, patchSize=8, embed=32, nhead=4, nclasses=80, finalDim=100); l.run(1, 10);
Progress: 3%, epoch: 0/1 (1.01 epochs/minute), batch: 10/312 (5.24 batches/s), elapsed: 1.91s, remaining: 57.68s, loss: 4.419412136077881 Epoch cancelled: Batch 10 reached.
#notest
l = newLViT(bs=32, lr=1e-3, skips=3, patchSize=8, embed=32, nhead=4, nclasses=80, finalDim=100); l.run(3000); l.save("set3-vit-aug")
Progress: 100%, epoch: 2999/3000 (2.37 epochs/minute), batch: 311/312 (12.34 batches/s), elapsed: 75844.2s, remaining: 0.08s, loss: 4.192038536071777 Saved to set3-vit-aug
l.Loss.plot(~head(0.01) | smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Much better than before, and better than CNN with no augmentations a little bit. Let's make the patch sizes bigger:
#notest
l = newLViT(bs=32, lr=1e-3, skips=3, patchSize=16, embed=32, nhead=4, nclasses=80, finalDim=100, timeLimit=3600); l.run(3000); l.save("set3-vit-aug-wide")
Progress: 9%, epoch: 267/3000 (4.46 epochs/minute), batch: 202/312 (23.2 batches/s), elapsed: 3600.04s, remaining: 36751.97s, loss: 1.5382466316223145 Run cancelled: Takes more than 3600 seconds!. Saved to set3-vit-aug-wide
l.Loss.plot(~head(0.01) | smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth(100))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Little worse than the previous ViT, but not by much. Considering the reduced memory requirements, this is probably for the better. Let's make it deeper:
#notest
l = newLViT(bs=32, lr=3e-4, skips=15, patchSize=16, embed=32, nhead=8, nclasses=80, finalDim=100, timeLimit=3600); l.run(3000); l.save("set3-vit-aug-deep")
Progress: 2%, epoch: 55/3000 (0.92 epochs/minute), batch: 57/312 (4.78 batches/s), elapsed: 3600.2s, remaining: 192123.33s, loss: 3.6798274517059326 Run cancelled: Takes more than 3600 seconds!. Saved to set3-vit-aug-deep
#notest
l.run(3000); l.run(3000); l.run(3000); l.save("set3-vit-aug-deep")
Progress: 2%, epoch: 56/3000 (0.95 epochs/minute), batch: 240/312 (4.92 batches/s), elapsed: 3600.1s, remaining: 186647.91s, loss: 2.558469533920288 Run cancelled: Takes more than 3600 seconds!. Progress: 2%, epoch: 57/3000 (0.96 epochs/minute), batch: 123/312 (4.97 batches/s), elapsed: 3599.98s, remaining: 184570.45s, loss: 1.8335938453674316 Run cancelled: Takes more than 3600 seconds!. Saved to set3-vit-aug-deep
l.Loss.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
def dataFCNN(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | (apply(tf.AutoAugment()) | aS(list) | aS(torch.stack) | op()/255) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
#notest
l = newLCNN(bs=128, lr=1e-3, timeLimit=3600); l.run(10000); l.save("set3-cnn-aug");
Progress: 2%, epoch: 242/10000 (4.04 epochs/minute), batch: 42/78 (5.25 batches/s), elapsed: 3600.16s, remaining: 144835.72s, loss: 1.1506781578063965 Run cancelled: Takes more than 3600 seconds!. Saved to set3-cnn-aug
l.Loss.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
This is like, much better than all of our ViTs before. Kinda disappointing tbh. I guess ViT is only good if the domain is really complex or there're lots more data. Let's just do resnet to provide a baseline:
Just to compare the baseline network that everyone knows and loves to both our own ViT and CNN.
def newLResNet(bs=64, timeLimit=None, patchSize=8, lr=1e-2, cuda=True, model=None, **kwargs):
l = k1.Learner(); l.data = dataFCNN(bs, patchSize)
l.model = vision.models.resnet50(); l.model.fc = nn.Linear(2048, 80)
l.css = "skipBlock:HookModule; #distill Linear, #gen Skip #0:HookParam, HookModule"
l.opt = optim.AdamW(l.model.parameters(), lr=lr); l.cbs.add(Cbs.TimeLimit(timeLimit));
l.cbs.add(Cbs.LossCrossEntropy())
l.css = "#act > #1: HookModule; #act > #1 *: HookParam"
if cuda: l.cbs.add(Cbs.Cuda());
return l
l = newLResNet(bs=32, lr=1e-3); l.run(1);
Progress: 100%, epoch: 0/1 (2.11 epochs/minute), batch: 311/312 (10.99 batches/s), elapsed: 28.3s, remaining: 0.09s, loss: 4.325710773468018
#notest
l = newLResNet(bs=32, lr=1e-3, timeLimit=3600); l.run(500); l.save("set3-resnet-aug")
Progress: 25%, epoch: 123/500 (2.06 epochs/minute), batch: 169/312 (10.71 batches/s), elapsed: 3600.1s, remaining: 10970.3s, loss: 0.826676607131958 Run cancelled: Takes more than 3600 seconds!. Saved to set3-resnet-aug
l.Loss.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
A little worse than our CNN, surprisingly. Also this resnet seems to be overfitting harder than our implementation.
Last but not least, let's do it with set1, because lots of other experiments in other notebooks are on set1.
%%time
base = "~/ssd/data/imagenet/set1/192px"
idxToCat = base | ls() | op().split("/")[-1].all() | insertIdColumn() | toDict()
catToIdx = idxToCat.items() | permute(1, 0) | toDict()
# stage 1, (train/valid, classes, samples (url of img))
st1 = base | ls() | head(80) | apply(ls() | splitW()) | transpose() | deref() | aS(k1.Wrapper)
# stage 2, (train/valid, classes, samples, [img, class])
st2 = st1() | (apply(lambda x: [x | toImg() | toTensor(torch.uint8), catToIdx[x.split("/")[-2]]]) | repeatFrom(1) | apply(aS(tf.Resize(192)), 0)).all(2) | deref() | aS(k1.Wrapper)
def dataF(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | (apply(tf.AutoAugment()) | aS(list) | aS(torch.stack) | op()/255 | aS(bells, patchSize)) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
xb, yb = dataF(64) | item(2)
st1() | shape(), st2() | shape(), [xb, yb] | shape().all() | deref()
CPU times: user 47.6 s, sys: 752 ms, total: 48.3 s Wall time: 6.05 s
((2, 9, 531, 59), (2, 9, 531, 2, 3, 192, 192), [[64, 577, 212], [64]])
#notest
l = newLViT(bs=32, lr=1e-3, skips=3, patchSize=8, embed=32, nhead=4, nclasses=9, finalDim=32, timeLimit=3600); l.run(3000); l.save("set1-vit-aug")
Progress: 8%, epoch: 243/3000 (4.06 epochs/minute), batch: 156/312 (21.1 batches/s), elapsed: 3600.03s, remaining: 40753.43s, loss: 0.26012831926345825 Run cancelled: Takes more than 3600 seconds!. Saved to set1-vit-aug
l.Loss.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth(30))
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
#notest
l = newLCNN(bs=128, lr=1e-3, timeLimit=3600); l.run(10000); l.save("set1-cnn-aug");
Progress: 2%, epoch: 246/10000 (4.11 epochs/minute), batch: 29/78 (5.34 batches/s), elapsed: 3600.19s, remaining: 142527.44s, loss: 0.049694836139678955 Run cancelled: Takes more than 3600 seconds!. Saved to set1-cnn-aug
l.Loss.plot(~head(0.1) | smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
def dataFCNN(bs, patchSize=8): return st2() | apply(repeatFrom().all() | joinStreamsRandom() | batched(bs) | apply(transpose() | (aS(torch.stack) | op()/255) + toTensor(torch.long))) | stagger.tv(10000/bs) | aS(list)
#notest
l = newL(bs=128, lr=1e-3, timeLimit=3600); l.run(10000); l.save("set1-cnn");
Progress: 3%, epoch: 277/10000 (4.62 epochs/minute), batch: 0/78 (6.0 batches/s), elapsed: 3600.01s, remaining: 126363.88s, loss: 1.2665777206420898 Run cancelled: Takes more than 3600 seconds!. Saved to set1-cnn
l.Loss.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
l.Accuracy.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.AccuracyTop5.plot(smooth())
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt