| /* |
| Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| package tensorflow_test |
| |
| import ( |
| "archive/zip" |
| "bufio" |
| "flag" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "log" |
| "net/http" |
| "os" |
| "path/filepath" |
| |
| "github.com/tensorflow/tensorflow/tensorflow/go/op" |
| tf "github.com/tensorflow/tensorflow/tensorflow/go" |
| ) |
| |
| func Example() { |
| // An example for using the TensorFlow Go API for image recognition |
| // using a pre-trained inception model (http://arxiv.org/abs/1512.00567). |
| // |
| // Sample usage: <program> -dir=/tmp/modeldir -image=/path/to/some/jpeg |
| // |
| // The pre-trained model takes input in the form of a 4-dimensional |
| // tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ], |
| // where: |
| // - BATCH_SIZE allows for inference of multiple images in one pass through the graph |
| // - IMAGE_HEIGHT is the height of the images on which the model was trained |
| // - IMAGE_WIDTH is the width of the images on which the model was trained |
| // - 3 is the (R, G, B) values of the pixel colors represented as a float. |
| // |
| // And produces as output a vector with shape [ NUM_LABELS ]. |
| // output[i] is the model-implied probability of the input image having |
| // the i-th label. |
| // |
| // A separate file contains a list of string labels corresponding to the |
| // integer indices of the output. |
| // |
| // This example: |
| // - Loads the serialized representation of the pre-trained model into a Graph |
| // - Creates a Session to execute operations on the Graph |
| // - Converts an image file to a Tensor to provide as input to a Session run |
| // - Executes the Session and prints out the label with the highest probability |
| // |
| // To convert an image file to a Tensor suitable for input to the Inception model, |
| // this example: |
| // - Constructs another TensorFlow graph to normalize the image into a |
| // form suitable for the model (for example, resizing the image) |
| // - Creates and executes a Session to obtain a Tensor in this normalized form. |
| modeldir := flag.String( |
| "dir", |
| "testdata/saved_model/inception5h", |
| "Directory containing the trained model files. The directory will be"+ |
| "created and the model downloaded into it if necessary", |
| ) |
| imagefile := flag.String( |
| "image", |
| "testdata/label_image/grace_hopper.jpg", |
| "Path of a JPEG-image to extract labels for", |
| ) |
| flag.Parse() |
| |
| // Load the serialized GraphDef from a file. |
| modelfile, labelsfile, err := modelFiles(*modeldir) |
| if err != nil { |
| log.Fatal(err) |
| } |
| |
| labels, err := readLabelsFile(labelsfile) |
| if err != nil { |
| log.Fatal(err) |
| } |
| |
| model, err := ioutil.ReadFile(modelfile) |
| if err != nil { |
| log.Fatal(err) |
| } |
| |
| // Construct an in-memory graph from the serialized form. |
| graph := tf.NewGraph() |
| if err := graph.Import(model, ""); err != nil { |
| log.Fatal(err) |
| } |
| |
| // Create a session for inference over graph. |
| session, err := tf.NewSession(graph, nil) |
| if err != nil { |
| log.Fatal(err) |
| } |
| defer session.Close() |
| |
| // Run inference on *imageFile. |
| // For multiple images, session.Run() can be called in a loop (and |
| // concurrently). Alternatively, images can be batched since the model |
| // accepts batches of image data as input. |
| tensor, err := makeTensorFromImage(*imagefile) |
| if err != nil { |
| log.Fatal(err) |
| } |
| output, err := session.Run( |
| map[tf.Output]*tf.Tensor{ |
| graph.Operation("input").Output(0): tensor, |
| }, |
| []tf.Output{ |
| graph.Operation("output").Output(0), |
| }, |
| nil) |
| if err != nil { |
| log.Fatal(err) |
| } |
| // output[0].Value() is a vector containing probabilities of |
| // labels for each image in the "batch". The batch size was 1. |
| // Find the most probable label index. |
| probabilities := output[0].Value().([][]float32)[0] |
| printBestLabel(probabilities, labels) |
| // // Output: |
| // // BEST MATCH: (29% likely) military uniform |
| } |
| |
| func printBestLabel(probabilities []float32, labels []string) { |
| idx := argmax(probabilities) |
| fmt.Printf( |
| "BEST MATCH: (%2.0f%% likely) %s", |
| probabilities[idx]*100.0, |
| labels[idx], |
| ) |
| } |
| |
| // Convert the image in filename to a Tensor suitable as input to the Inception model. |
| func makeTensorFromImage(filename string) (*tf.Tensor, error) { |
| bytes, err := ioutil.ReadFile(filename) |
| if err != nil { |
| return nil, err |
| } |
| // DecodeJpeg uses a scalar String-valued tensor as input. |
| tensor, err := tf.NewTensor(string(bytes)) |
| if err != nil { |
| return nil, err |
| } |
| // Construct a graph to normalize the image |
| graph, input, output, err := constructGraphToNormalizeImage() |
| if err != nil { |
| return nil, err |
| } |
| // Execute that graph to normalize this one image |
| session, err := tf.NewSession(graph, nil) |
| if err != nil { |
| return nil, err |
| } |
| defer session.Close() |
| normalized, err := session.Run( |
| map[tf.Output]*tf.Tensor{input: tensor}, |
| []tf.Output{output}, |
| nil) |
| if err != nil { |
| return nil, err |
| } |
| return normalized[0], nil |
| } |
| |
| // The inception model takes as input the image described by a Tensor in a very |
| // specific normalized format (a particular image size, shape of the input tensor, |
| // normalized pixel values etc.). |
| // |
| // This function constructs a graph of TensorFlow operations which takes as |
| // input a JPEG-encoded string and returns a tensor suitable as input to the |
| // inception model. |
| func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output, err error) { |
| // Some constants specific to the pre-trained model at: |
| // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip |
| // |
| // - The model was trained after with images scaled to 224x224 pixels. |
| // - The colors, represented as R, G, B in 1-byte each were converted to |
| // float using (value - Mean)/Scale. |
| const ( |
| H, W = 224, 224 |
| Mean = float32(117) |
| Scale = float32(1) |
| ) |
| // - input is a String-Tensor, where the string the JPEG-encoded image. |
| // - The inception model takes a 4D tensor of shape |
| // [BatchSize, Height, Width, Colors=3], where each pixel is |
| // represented as a triplet of floats |
| // - Apply normalization on each pixel and use ExpandDims to make |
| // this single image be a "batch" of size 1 for ResizeBilinear. |
| s := op.NewScope() |
| input = op.Placeholder(s, tf.String) |
| output = op.Div(s, |
| op.Sub(s, |
| op.ResizeBilinear(s, |
| op.ExpandDims(s, |
| op.Cast(s, |
| op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)), tf.Float), |
| op.Const(s.SubScope("make_batch"), int32(0))), |
| op.Const(s.SubScope("size"), []int32{H, W})), |
| op.Const(s.SubScope("mean"), Mean)), |
| op.Const(s.SubScope("scale"), Scale)) |
| graph, err = s.Finalize() |
| return graph, input, output, err |
| } |
| |
| func modelFiles(dir string) (modelfile, labelsfile string, err error) { |
| const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" |
| var ( |
| model = filepath.Join(dir, "tensorflow_inception_graph.pb") |
| labels = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt") |
| zipfile = filepath.Join(dir, "inception5h.zip") |
| ) |
| if filesExist(model, labels) == nil { |
| return model, labels, nil |
| } |
| log.Println("Did not find model in", dir, "downloading from", URL) |
| if err := os.MkdirAll(dir, 0755); err != nil { |
| return "", "", err |
| } |
| if err := download(URL, zipfile); err != nil { |
| return "", "", fmt.Errorf("failed to download %v - %v", URL, err) |
| } |
| if err := unzip(dir, zipfile); err != nil { |
| return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err) |
| } |
| os.Remove(zipfile) |
| return model, labels, filesExist(model, labels) |
| } |
| |
| func filesExist(files ...string) error { |
| for _, f := range files { |
| if _, err := os.Stat(f); err != nil { |
| return fmt.Errorf("unable to stat %s: %v", f, err) |
| } |
| } |
| return nil |
| } |
| |
| func download(URL, filename string) error { |
| resp, err := http.Get(URL) |
| if err != nil { |
| return err |
| } |
| defer resp.Body.Close() |
| file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644) |
| if err != nil { |
| return err |
| } |
| defer file.Close() |
| _, err = io.Copy(file, resp.Body) |
| return err |
| } |
| |
| func unzip(dir, zipfile string) error { |
| r, err := zip.OpenReader(zipfile) |
| if err != nil { |
| return err |
| } |
| defer r.Close() |
| for _, f := range r.File { |
| src, err := f.Open() |
| if err != nil { |
| return err |
| } |
| log.Println("Extracting", f.Name) |
| dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644) |
| if err != nil { |
| return err |
| } |
| if _, err := io.Copy(dst, src); err != nil { |
| return err |
| } |
| dst.Close() |
| } |
| return nil |
| } |
| |
| func readLabelsFile(f string) ([]string, error) { |
| var labels []string |
| |
| file, err := os.Open(f) |
| if err != nil { |
| return nil, err |
| } |
| defer file.Close() |
| |
| scanner := bufio.NewScanner(file) |
| for i := 0; scanner.Scan(); i++ { |
| labels = append(labels, scanner.Text()) |
| } |
| if err := scanner.Err(); err != nil { |
| return nil, err |
| } |
| return labels, nil |
| } |
| |
| // argmax returns the index the maximal element in slice a. |
| func argmax(a []float32) int { |
| var idx int |
| for i := 0; i < len(a); i++ { |
| if a[i] > a[idx] { |
| idx = i |
| } |
| } |
| return idx |
| } |