diff --git a/BuddyGemmini/.gitignore b/BuddyGemmini/.gitignore new file mode 100644 index 0000000..8876939 --- /dev/null +++ b/BuddyGemmini/.gitignore @@ -0,0 +1,7 @@ +*.data +__pycache__ +*.pyc +/build +forward.mlir +subgraph0.mlir +*.o \ No newline at end of file diff --git a/BuddyGemmini/CMakeLists.txt b/BuddyGemmini/CMakeLists.txt new file mode 100644 index 0000000..8ddd313 --- /dev/null +++ b/BuddyGemmini/CMakeLists.txt @@ -0,0 +1,111 @@ +set(BUDDY_EXAMPLES_DIR ${BUDDY_MLIR_DIR}/examples/) +set(BUDDY_BINARY_DIR ${BUDDY_MLIR_DIR}/build/bin/) +set(RISCV_GNU_TOOLCHAIN ${BUDDY_MLIR_DIR}/build/thirdparty/riscv-gnu-toolchain) +set(CMAKE_CXX_COMPILER ${RISCV_GNU_TOOLCHAIN}/bin/riscv64-unknown-linux-gnu-g++) + +set(BUDDY_GEMMINI_DIR ${BUDDY_EXAMPLES_DIR}/BuddyGemmini) +set(INTERFACE_DIR ${BUDDY_MLIR_DIR}/frontend/Interfaces/) +set(INCLUDE_DIR ${BUDDY_GEMMINI_DIR}/include/) + + +add_custom_command( + OUTPUT forward.mlir subgraph0.mlir arg0.data + COMMAND export BUDDYGEMMINI_EXAMPLE_PATH=${BUDDY_GEMMINI_DIR} && + python3 ${BUDDY_GEMMINI_DIR}/buddy-lenet-import.py + DEPENDS buddy-lenet-import.py + COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files" +) + +add_custom_command( + OUTPUT forward.o + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_GEMMINI_DIR}/forward.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${BUDDY_BINARY_DIR}/buddy-opt + -buffer-deallocation-simplification + -convert-linalg-to-loops + -eliminate-empty-tensors + -llvm-request-c-wrappers + -convert-math-to-llvm + -convert-math-to-libm + -convert-scf-to-cf + -convert-arith-to-llvm + -expand-strided-metadata + -finalize-memref-to-llvm + -convert-func-to-llvm + -reconcile-unrealized-casts | + ${BUDDY_BINARY_DIR}/buddy-translate --buddy-to-llvmir | + ${BUDDY_BINARY_DIR}/buddy-llc -filetype=obj -mtriple=riscv64 -O0 -mattr=+buddyext,+D -float-abi=hard -o forward.o + DEPENDS forward.mlir + COMMENT "Building forward.o" + VERBATIM) + +add_custom_command( + OUTPUT subgraph0.o + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_GEMMINI_DIR}/subgraph0.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" + > subgraph0_linalg.mlir + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_GEMMINI_DIR}/subgraph0.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${BUDDY_BINARY_DIR}/buddy-opt + -eliminate-empty-tensors + -convert-tensor-to-linalg + -linalg-bufferize + -batchmatmul-optimize + -convert-linalg-to-gemmini + > subgraph0_loops.mlir + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_GEMMINI_DIR}/subgraph0.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${BUDDY_BINARY_DIR}/buddy-opt + -eliminate-empty-tensors + -linalg-bufferize + -tensor-bufferize + -func-bufferize + -convert-linalg-to-gemmini + -expand-strided-metadata + -convert-linalg-to-loops + -convert-scf-to-cf + -llvm-request-c-wrappers + -lower-gemmini + -arith-bufferize + -buffer-deallocation + -finalizing-bufferize + -convert-arith-to-llvm + -convert-func-to-llvm + -finalize-memref-to-llvm + -reconcile-unrealized-casts | + ${BUDDY_BINARY_DIR}/buddy-translate --buddy-to-llvmir + > subgraph0.ll + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_GEMMINI_DIR}/subgraph0.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${BUDDY_BINARY_DIR}/buddy-opt + -eliminate-empty-tensors + -linalg-bufferize + -tensor-bufferize + -func-bufferize + -convert-linalg-to-gemmini + -expand-strided-metadata + -convert-linalg-to-loops + -convert-scf-to-cf + -llvm-request-c-wrappers + -lower-gemmini + -arith-bufferize + -buffer-deallocation + -finalizing-bufferize + -convert-arith-to-llvm + -convert-func-to-llvm + -finalize-memref-to-llvm + -reconcile-unrealized-casts | + ${BUDDY_BINARY_DIR}/buddy-translate --buddy-to-llvmir | + ${BUDDY_BINARY_DIR}/buddy-llc -filetype=obj -mtriple=riscv64 -O0 -mattr=+buddyext,+D -float-abi=hard -o subgraph0.o + DEPENDS subgraph0.mlir + COMMENT "Building subgraph0.o" + VERBATIM) + + +add_library(GemminiLENET STATIC subgraph0.o forward.o) +set_target_properties(GemminiLENET PROPERTIES LINKER_LANGUAGE CXX) + +add_executable(buddy-gemmini-lenet-run buddy-lenet-main.cpp) +add_dependencies(buddy-gemmini-lenet-run GemminiLENET) +target_include_directories(buddy-gemmini-lenet-run PRIVATE ${INTERFACE_DIR} ${INCLUDE_DIR}) +target_link_libraries(buddy-gemmini-lenet-run -static GemminiLENET) diff --git a/BuddyGemmini/README.md b/BuddyGemmini/README.md new file mode 100644 index 0000000..e96e4d7 --- /dev/null +++ b/BuddyGemmini/README.md @@ -0,0 +1,85 @@ +# BuddyGemmini LeNet E2E deployment on Firesim + +## Overview +This guide provides an example of end-to-end deployment of a DNN (LeNet) inference to a DSA backend (Gemmini) using the Buddy Compiler. + +We use FireSim, a platform for FPGA-accelerated cycle-accurate simulation, to run end-to-end DNN workloads that would take too long to run on Verilator/VCS. FireSim also allows users to check that their Gemmini hardware/software will work when running in a Linux environment. The FireSim used in this guide is installed locally on a Xilinx VCU118. + +## Preparation +Before proceed any further make sure that you installed dependencies below +1. Installation of [Buddy-mlir basic environment and cross-compilation toolchain](https://github.com/buddy-compiler/buddy-mlir/blob/main/docs/RVVEnvironment.md) + +2. Environment installation for [Chipyard](https://chipyard.readthedocs.io/en/1.11.0/) and [Firesim](https://docs.fires.im/en/1.18.0/). The environment for this guide is based on a local acceleration card, the VCU118, with configuration versions Chipyard 1.11.0 and FireSim 1.18.0. We recommend installing these versions (install firesim as a submodule of chipyard) and completing all the content in the FireSim documentation's [Getting Started Guide](https://docs.fires.im/en/1.18.0/Getting-Started-Guides/On-Premises-FPGA-Getting-Started/Repo-Setup/Xilinx-Alveo-U280.html). + +3. Complete the build of [gemmini](https://github.com/ucb-bar/gemmini), and building a complete bitstream file based on the default Gemmini configuration using the firesim buildbitstream command. + +## Cross-compilation +1. Activate your python environment. + +2. Build buddy-gemmini-lenet-run + +``` +$ mkdir build && cd build +$ cmake .. -DBUDDY_MLIR_DIR=/path/to/buddy-mlir/ # replace with your buddy-mlir directory path +$ make buddy-gemmini-lenet-run +``` + +## Deployment to FireSim +1. Copy the executable files (located in the `BuddyGemmini/build/`) and the required data files to Gemmini's software path +``` +$ cd chipyard # go to your chipyard root directory +$ mkdir ./generators/gemmini/software/overlay/root/BuddyGemmini/ +$ cp ${BUDDYGEMMINI_EXAMPLE_PATH}/build/buddy-gemmini-lenet-run ./generators/gemmini/software/overlay/root/BuddyGemmini/ +$ cp ${BUDDYGEMMINI_EXAMPLE_PATH}/arg0.data ./generators/gemmini/software/overlay/root/BuddyGemmini/ +$ cp -r ${BUDDYGEMMINI_EXAMPLE_PATH}/images/ ./generators/gemmini/software/overlay/root/BuddyGemmini/ +``` +2. Build software for the target platform +``` +$ cd chipyard +$ ./sims/firesim/sw/firesim-software/marshal -v build ./generators/gemmini/software/gemmini-tests-interactive.json && ./sims/firesim/sw/firesim-software/marshal -v install ./generators/gemmini/software/gemmini-tests-interactive.json +``` + +3. Activate your Firesim environment. +``` +$ cd chipyard/sim/firesim +$ source ./sourceme-manager.sh --skip-ssh-setup +``` + +4. In the `firesim/deploy/` path, there are four files that configure key information for FireSim's build workload, bitstream, runtime, etc. Please check the following configurations: + +- `config_build_recipes.yaml`: Configures the Gemmini configuration, such as `firesim_custom_gemmini_rocket_singlecore_no_nic` +- `config_build.yaml`: Under `builds_to_run`, select the Gemmini configuration, such as `firesim_custom_gemmini_rocket_singlecore_no_nic` +- `config_hwdb.yaml`: For `bitstream_tar`, configure the absolute path where your generated Gemmini bitstream is stored +- `config_runtime.yaml`: This file is for building the runtime. Please modify the `workload_name` to `gemmini-tests-interactive.json`. We will execute this interactive configuration after starting the simulation later. + + +5. Build and deploy simulation infrastructure to the Run Farm Machines. Each time you change your workload content, please re-execute `step 2` to `step 5`. + +``` +$ firesim infrasetup +``` + +6. Start simulation on Run Farm Machines. After executing the command below, the terminal will display a background monitor of the simulation running. + +``` +$ firesim runworkload +``` + +7. SSH connect to `BUILD_FARM_IP`, open a new terminal connection to the screen created by Run Farm Machines (please refer to the FireSim documentation to confirm you can correctly connect to Run Farm Machines). + +``` +$ ssh BUILD_FARM_IP +$ screen -r fsim0 +``` + +## Final step! +Now, you can login to the system! The username is root and there is no password. The steps described here are for manual execution. You can also refer to the writing of `gemmini-tests.json` and `overlay/root/run-tests.sh` to write your own automated execution script. This will change the manual operations after firesim runworkload to automatic execution. The corresponding log files will be recorded in the `/firesim/deploy/results-workload` folder. + +``` +$ cd ./BuddyGemmini +$ export BUDDYGEMMINI_EXAMPLE_PATH=$PWD +$ ./buddy-gemmini-lenet-run +``` + +If all steps go well, you will see the output below. Good luck. +![demo](./doc/demo.png) diff --git a/BuddyGemmini/buddy-lenet-import.py b/BuddyGemmini/buddy-lenet-import.py new file mode 100644 index 0000000..4be05e3 --- /dev/null +++ b/BuddyGemmini/buddy-lenet-import.py @@ -0,0 +1,76 @@ +# ===- buddy-lenet-import.py --------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the LeNet model AOT importer. +# +# ===--------------------------------------------------------------------------- + +import os +from pathlib import Path + +import numpy as np +import torch +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.graph import GraphDriver +from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.ops import tosa +from model import LeNet + +# Retrieve the LeNet model path from environment variables. +model_path = os.environ.get("BUDDYGEMMINI_EXAMPLE_PATH") +if model_path is None: + raise EnvironmentError( + "The environment variable 'LENET_MODEL_PATH' is not set or is invalid." + ) + +model = LeNet() +model = torch.load(model_path + "/lenet-model.pth") +model = model.eval() + +# Initialize Dynamo Compiler with specific configurations as an importer. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +data = torch.randn([1, 1, 28, 28]) +# Import the model into MLIR module and parameters. +with torch.no_grad(): + graphs = dynamo_compiler.importer(model, data) + +assert len(graphs) == 1 +graph = graphs[0] +params = dynamo_compiler.imported_params[graph] +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) +driver = GraphDriver(graphs[0]) +driver.subgraphs[0].lower_to_top_level_ir() +path_prefix = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file: + print(driver.subgraphs[0]._imported_module, file=module_file) +with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file: + print(driver.construct_main_graph(True), file=module_file) + +params = dynamo_compiler.imported_params[graph] +current_path = os.path.dirname(os.path.abspath(__file__)) + +float32_param = np.concatenate( + [param.detach().numpy().reshape([-1]) for param in params] +) + +float32_param.tofile(Path(current_path) / "arg0.data") diff --git a/BuddyGemmini/buddy-lenet-main.cpp b/BuddyGemmini/buddy-lenet-main.cpp new file mode 100644 index 0000000..8af5b7e --- /dev/null +++ b/BuddyGemmini/buddy-lenet-main.cpp @@ -0,0 +1,140 @@ +//===- buddy-lenet-main.cpp -----------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include "include/gemmini_testutils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +constexpr size_t ParamsSize = 44426; +const std::string ImgName = "8.bmp"; + +/// Declare LeNet forward function. +extern "C" void _mlir_ciface_forward(MemRef *output, + MemRef *arg0, + dip::Image *input); + +/// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +/// Load parameters into data container. +void loadParameters(const std::string ¶mFilePath, + MemRef ¶ms) { + const auto loadStart = std::chrono::high_resolution_clock::now(); + // Open the parameter file in binary mode. + std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary); + if (!paramFile.is_open()) { + throw std::runtime_error("[Error] Failed to open params file!"); + } + printLogLabel(); + std::cout << "Loading params..." << std::endl; + printLogLabel(); + // Print the canonical path of the parameter file. + std::cout << "Params file: " << std::filesystem::canonical(paramFilePath) + << std::endl; + // Read the parameter data into the provided memory reference. + paramFile.read(reinterpret_cast(params.getData()), + sizeof(float) * (params.getSize())); + if (paramFile.fail()) { + throw std::runtime_error("Error occurred while reading params file!"); + } + paramFile.close(); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Params load time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; +} + +/// Softmax function to convert logits to probabilities. +void softmax(float *input, size_t size) { + size_t i; + float max_value = -INFINITY; + double sum = 0.0; + // Find the maximum value in the input array for numerical stability. + for (i = 0; i < size; ++i) { + if (max_value < input[i]) { + max_value = input[i]; + } + } + // Calculate the sum of the exponentials of the input elements, normalized by + // the max value. + for (i = 0; i < size; ++i) { + sum += exp(input[i] - max_value); + } + // Normalize the input array with the softmax calculation. + for (i = 0; i < size; ++i) { + input[i] = exp(input[i] - max_value) / sum; + } +} + +int main() { + // Print the title of this example. + const std::string title = "LeNet Inference Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + // Define the sizes of the output tensors. + intptr_t sizesOutput[2] = {1, 10}; + + // Create input and output containers for the image and model output. + std::string lenetDir = getenv("BUDDYGEMMINI_EXAMPLE_PATH"); + std::string imgPath = lenetDir + "/images/" + ImgName; + dip::Image input(imgPath, dip::DIP_GRAYSCALE, true /* norm */); + MemRef output(sizesOutput); + + // Load model parameters from the specified file. + std::string paramsDir = lenetDir + "/arg0.data"; + MemRef paramsContainer({ParamsSize}); + loadParameters(paramsDir, paramsContainer); + + unsigned long start = read_cycles(); + // Call the forward function of the model. + _mlir_ciface_forward(&output, ¶msContainer, &input); + unsigned long end = read_cycles(); + + // Apply softmax to the output logits to get probabilities. + auto out = output.getData(); + softmax(out, 10); + // gemmini profiling + std::cout << "Inference Cycles taken: " << end-start << std::endl; + + // Find the classification and print the result. + float maxVal = 0; + float maxIdx = 0; + for (int i = 0; i < 10; ++i) { + if (out[i] > maxVal) { + maxVal = out[i]; + maxIdx = i; + } + } + + std::cout << "Results: " << std::endl; + std::cout << "Classification: " << maxIdx << std::endl; + std::cout << "Probability: " << maxVal << std::endl; + + return 0; +} diff --git a/BuddyGemmini/doc/demo.png b/BuddyGemmini/doc/demo.png new file mode 100644 index 0000000..db87b55 Binary files /dev/null and b/BuddyGemmini/doc/demo.png differ diff --git a/BuddyGemmini/images/3.png b/BuddyGemmini/images/3.png new file mode 100644 index 0000000..0402de2 Binary files /dev/null and b/BuddyGemmini/images/3.png differ diff --git a/BuddyGemmini/images/8.bmp b/BuddyGemmini/images/8.bmp new file mode 100644 index 0000000..7a9e02a Binary files /dev/null and b/BuddyGemmini/images/8.bmp differ diff --git a/BuddyGemmini/include/gemmini_testutils.h b/BuddyGemmini/include/gemmini_testutils.h new file mode 100644 index 0000000..e7e4e3f --- /dev/null +++ b/BuddyGemmini/include/gemmini_testutils.h @@ -0,0 +1,13 @@ +#ifndef GEMMINI_TESTUTILS_H +#define GEMMINI_TESTUTILS_H + +#include + +// gemmini profile tool +static inline uint64_t read_cycles() { + uint64_t cycles; + asm volatile ("rdcycle %0" : "=r" (cycles)); + return cycles; +} + +#endif // GEMMINI_TESTUTILS_H diff --git a/BuddyGemmini/lenet-model.pth b/BuddyGemmini/lenet-model.pth new file mode 100644 index 0000000..bf91d6e Binary files /dev/null and b/BuddyGemmini/lenet-model.pth differ diff --git a/BuddyGemmini/model.py b/BuddyGemmini/model.py new file mode 100644 index 0000000..c017366 --- /dev/null +++ b/BuddyGemmini/model.py @@ -0,0 +1,41 @@ +# ===- model.py ---------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# LeNet model definition. +# +# ===--------------------------------------------------------------------------- + +import torch +import torch.nn as nn + +class LeNet(nn.Module): + def __init__(self): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(torch.relu(self.conv1(x))) + x = self.pool(torch.relu(self.conv2(x))) + x = x.view(-1, 16 * 4 * 4) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/BuddyGemmini/pytorch-lenet-inference.py b/BuddyGemmini/pytorch-lenet-inference.py new file mode 100644 index 0000000..772262b --- /dev/null +++ b/BuddyGemmini/pytorch-lenet-inference.py @@ -0,0 +1,66 @@ +# ===- pytorch-lenet-inference.py ---------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# LeNet inference with PyTorch runtime. +# +# ===--------------------------------------------------------------------------- + +import torch +from torchvision import transforms +from PIL import Image + +from model import LeNet + +# Load model +model = LeNet() +torch.load("./lenet-model.pth") +# Set the model to evaluation mode +model.eval() + +# Prepare image and convert to grayscale +image_path = "./images/3.png" +image = Image.open(image_path).convert("L") + +# Resize image to match the model's expected input dimensions +# Convert to tensor +# - This conversion is achieved by dividing the original pixel values by 255. +# - Before: An image with pixel values typically in the range [0, 255]. +# - After: A PyTorch tensor with the shape (C, H, W) and pixel values +# normalized to [0.0, 1.0]. +# Normalize +# - This step normalizes each channel of the tensor to have a mean of 0.5 and +# a standard deviation of 0.5. +# - Before: A tensor with pixel values in the range [0.0, 1.0]. +# - After: A tensor with pixel values normalized to the range [-1.0, 1.0], +# making the network training process more stable and faster by +# standardizing the range of input values. +transform = transforms.Compose( + [ + transforms.Resize((28, 28)), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] +) +# Add batch dimension: [CHW] -> [NCHW] +image = transform(image).unsqueeze(0) + +# Perform inference +# No gradient tracking in this block +with torch.no_grad(): + output = model(image) + prediction = output.argmax() + +print(f"Classification: {prediction.item()}") diff --git a/README.md b/README.md deleted file mode 100644 index 9d17ee6..0000000 --- a/README.md +++ /dev/null @@ -1 +0,0 @@ -# buddy-examples \ No newline at end of file