Skip to content

Commit 815dfa6

Browse files
ospillingerdeliahu
authored andcommitted
Update examples
(cherry picked from commit 6aea035)
1 parent bc653d7 commit 815dfa6

File tree

6 files changed

+16
-14
lines changed

6 files changed

+16
-14
lines changed

examples/pytorch/image-classifier/predictor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import requests
2+
import torch
3+
import torchvision
4+
from torchvision import transforms
25
from PIL import Image
36
from io import BytesIO
4-
from torchvision import transforms
5-
import torchvision
6-
import torch
77

88
model = torchvision.models.alexnet(pretrained=True)
99
model.eval()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"url": "https://bowwowinsurance.com.au/wp-content/uploads/2018/10/akita-700x700.jpg"
2+
"url": "https://i.imgur.com/PzXprwl.jpg"
33
}

examples/pytorch/iris-classifier/src/predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
from model import IrisNet
44

5-
labels = ["iris-setosa", "iris-versicolor", "iris-virginica"]
65

76
model = IrisNet()
87

@@ -12,6 +11,9 @@ def init(model_path, metadata):
1211
model.eval()
1312

1413

14+
labels = ["iris-setosa", "iris-versicolor", "iris-virginica"]
15+
16+
1517
def predict(sample, metadata):
1618
input_tensor = torch.FloatTensor(
1719
[

examples/pytorch/text-generator/cortex.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
predictor:
77
path: predictor.py
88
metadata:
9-
num_words: 20
9+
num_words: 50
1010
device: cuda # use "cpu" to run on CPUs
1111
compute:
1212
gpu: 1

examples/pytorch/text-generator/predictor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1+
# This file includes code which was modified from https://github.com/huggingface/transformers/blob/master/examples/run_generation.py
2+
13
from __future__ import absolute_import, division, print_function, unicode_literals
24

3-
import numpy as np
4-
import argparse
5-
import logging
6-
from tqdm import trange
7-
import torch.nn.functional as F
85
import torch
6+
import torch.nn.functional as F
97
from transformers import GPT2Tokenizer, GPT2LMHeadModel
8+
from tqdm import trange
109

10+
11+
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
1112
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
1213
model.eval()
13-
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
1414

15-
# adapted from: https://github.com/huggingface/transformers/blob/master/examples/run_generation.py
15+
1616
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
1717
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
1818
Args:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"url": "https://bowwowinsurance.com.au/wp-content/uploads/2018/10/akita-700x700.jpg"
2+
"url": "https://i.imgur.com/PzXprwl.jpg"
33
}

0 commit comments

Comments
 (0)