-
Notifications
You must be signed in to change notification settings - Fork 26
/
score.py
87 lines (71 loc) · 4.01 KB
/
score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import time
import torch
import datetime
import numpy as np
import torch.nn as nn
from io import StringIO
import torch.nn.functional as F
class NeuralNework(nn.Module):
def __init__(self):
super(NeuralNework, self).__init__()
self.layer1 = nn.Linear(28*28, 512)
self.layer2 = nn.Linear(512, 512)
self.output = nn.Linear(512, 10)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = self.output(x)
return F.softmax(x, dim=1)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = x.view(-1, 1, 28, 28)
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.softmax(x, dim=1)
model, device = None, None
def init():
global model, device
model_path = 'outputs/model.pth'
device = torch.device('cpu')
model = NeuralNework()
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model.to(device)
model.eval()
def run(raw_data):
prev_time = time.time()
post = json.loads(raw_data)
# load and normalize image
image = np.loadtxt(StringIO(post['image']), delimiter=',') / 255.
# run model
with torch.no_grad():
x = torch.from_numpy(image).float().to(device)
pred = model(x).detach().numpy()[0]
# get timing
current_time = time.time()
inference_time = datetime.timedelta(seconds=current_time - prev_time)
payload = {
'time': inference_time.total_seconds(),
'prediction': int(np.argmax(pred)),
'scores': pred.tolist()
}
return payload
if __name__ == "__main__":
img = '0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,21,118,164,255,138,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,76,231,254,254,254,248,176,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,74,242,254,254,254,219,254,247,31,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,75,249,254,213,87,4,3,37,242,123,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,30,243,254,226,41,0,0,0,0,67,13,22,25,50,72,52,0,0,0,0,0,0,0,0,0,0,0,0,153,254,235,28,0,0,0,27,87,91,146,247,254,254,254,157,0,0,0,0,0,0,0,0,0,0,0,34,238,254,89,0,0,69,154,248,254,254,254,254,231,141,56,6,0,0,0,0,0,0,0,0,0,0,0,58,254,254,9,17,146,251,254,254,254,253,171,124,26,0,0,0,0,0,0,0,0,0,0,0,0,0,0,58,254,254,108,212,254,254,254,229,155,74,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,58,254,254,254,254,254,226,86,10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,58,254,254,254,247,141,21,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,87,254,254,254,73,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,51,245,254,254,254,43,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,53,195,254,253,247,254,83,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,152,254,254,138,63,251,165,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,34,238,254,201,4,0,245,183,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,101,254,254,49,1,75,248,101,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,158,254,254,140,157,254,254,43,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,146,254,254,254,254,254,144,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,192,254,254,244,126,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0'
data = {
'image': img
}
init()
out = run(json.dumps(data))
print(out)