Skip to content

Commit

Permalink
teachable machine - loading model works, image recognition could be b…
Browse files Browse the repository at this point in the history
…etter
  • Loading branch information
Brandon Lei authored and Brandon Lei committed Dec 6, 2024
1 parent 5f44627 commit 04e380a
Show file tree
Hide file tree
Showing 5 changed files with 882 additions and 55 deletions.
338 changes: 283 additions & 55 deletions extensions/src/doodlebot/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import { categoryByGesture, classes, emojiByGesture, gestureDetection, gestureMe
import { line0, line1, line2, line3, line4, line5, line6, line7, line8 } from './Points';
import { followLine } from "./LineFollowing";
import { createLineDetector } from "./LineDetection";
import tmPose from '@teachablemachine/pose';
import { calculateArcTime } from "./TimeHelper";
import tmImage from '@teachablemachine/image';
import * as speechCommands from '@tensorflow-models/speech-commands';

const details: ExtensionMenuDisplayDetails = {
name: "Doodlebot",
Expand Down Expand Up @@ -61,13 +64,26 @@ export default class DoodlebotBlocks extends extension(details, "ui", "indicator

imageStream: HTMLImageElement;
videoDrawable: ReturnType<typeof this.createDrawable>;
predictionState = {};
latestAudioResults: any;
ModelType = {
POSE: 'pose',
IMAGE: 'image',
AUDIO: 'audio',
};
teachableImageModel;

lastUpdate: number = null;
maxConfidence: number = null;
modelConfidences = {};
isPredicting: number = 0;
INTERVAL = 16;
DIMENSIONS = [480, 360];

init(env: Environment) {
this.openUI("Connect");
this.setIndicator("disconnected");

// idea: set up polling mechanism to try and disable unused sensors
// idea: set up polling mechanism to destroy gesture recognition loop
this._loop();
}

setDoodlebot(doodlebot: Doodlebot) {
Expand Down Expand Up @@ -110,57 +126,6 @@ export default class DoodlebotBlocks extends extension(details, "ui", "indicator
return drawable;
}

@block({
type: "command",
text: "displayLine"
})
async displayLine() {
console.log("displayLine");
if (!this.lineDetector) {
const ipAddress = await this.doodlebot?.getIPAddress();
console.log("DEBUG IP Address:", ipAddress);
if (!ipAddress) {
console.error("Unable to get IP address for line detection");
return;
}
const imageStream = await this.getImageStream();
this.lineDetector = createLineDetector(ipAddress, imageStream);
}

const lineCoordinates = await this.lineDetector();
console.log("Raw line coordinates:", lineCoordinates);
if (lineCoordinates.length === 0) {
console.log("No line detected");
return;
}

console.log("Line coordinates:", JSON.stringify(lineCoordinates));

if (!this.videoDrawable) {
this.videoDrawable = await this.createVideoStreamDrawable();
}

const canvas = document.createElement('canvas');
canvas.width = this.imageStream.width; // Assume these properties exist
canvas.height = this.imageStream.height;
const ctx = canvas.getContext('2d');

if (ctx) {
ctx.drawImage(this.imageStream, 0, 0, canvas.width, canvas.height);

ctx.beginPath();
ctx.moveTo(lineCoordinates[0][0], lineCoordinates[0][1]);
for (let i = 1; i < lineCoordinates.length; i++) {
ctx.lineTo(lineCoordinates[i][0], lineCoordinates[i][1]);
}
ctx.strokeStyle = 'red';
ctx.lineWidth = 2;
ctx.stroke();

this.videoDrawable.update(canvas);
}
}

@buttonBlock("Connect Robot")
connect() {
this.openUI("Connect");
Expand Down Expand Up @@ -603,4 +568,267 @@ export default class DoodlebotBlocks extends extension(details, "ui", "indicator
: await this.doodlebot?.sendWebsocketCommand(candidates[0], ...splitArgsString(args));
}

}
@block({
type: "command",
text: (url) => `import model ${url}`,
arg: {
type: "string",
defaultValue: "URL HERE"
}
})
async importModel(url: string) {
await this.useModel(url);
}


@block({
type: "hat",
text: (className) => `when model detects ${className}`,
arg: {
type: "string",
options: function() {
if (!this) {
throw new Error('Context is undefined');
}
return this.getModelClasses() || ["Select a class"];
},
defaultValue: "Select a class"
}
})
whenModelDetects(className: string) {
return this.model_match(className);
}

@block({
type: "reporter",
text: "model prediction",
})
modelPrediction() {
return this.getModelPrediction();
}


async useModel(url: string) {
try {
const modelUrl = this.modelArgumentToURL(url);
console.log('Loading model from URL:', modelUrl);
this.getPredictionStateOrStartPredicting(modelUrl, true);
console.log('Model state:', this.predictionState[modelUrl]);
this.updateStageModel(modelUrl);
} catch (e) {
console.error('Error loading model:', e);
this.teachableImageModel = null;
}
}

modelArgumentToURL(modelArg: string) {
const endpointProvidedFromInterface = "https://teachablemachine.withgoogle.com/models/";
// NOTE: It's possible Google will change this endpoint in the future, and that will break this extension.
// TODO: https://github.com/mitmedialab/prg-extension-boilerplate/issues/343
const redirectEndpoint = "https://storage.googleapis.com/tm-model/";
return modelArg.startsWith(endpointProvidedFromInterface)
? modelArg.replace(endpointProvidedFromInterface, redirectEndpoint)
: redirectEndpoint + modelArg + "/";
}

getPredictionStateOrStartPredicting(modelUrl, override = false) {
const hasPredictionState = this.predictionState.hasOwnProperty(modelUrl);
if (!hasPredictionState || override) {
this.startPredicting(modelUrl);
return null;
}
return this.predictionState[modelUrl];
}

async startPredicting(modelDataUrl) {
const alreadyLoaded = Boolean(this.predictionState[modelDataUrl]);
try {
const indicator = await this.indicate({
type: "warning",
msg: alreadyLoaded ? "Updating model" : "Loading model"
});
this.predictionState[modelDataUrl] = {};
// https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/image
const { model, type } = await this.initModel(modelDataUrl);
this.predictionState[modelDataUrl].modelType = type;
this.predictionState[modelDataUrl].model = model;
this.runtime.requestToolboxExtensionsUpdate();
indicator.close();
this.indicateFor({ type: "success", msg: "Model loaded" }, 1);
} catch (e) {
this.predictionState[modelDataUrl] = {};
console.log("Model initialization failure!", e);
this.indicateFor({ type: "error", msg: "Unable to load model." }, 1);
}
}

async initModel(modelUrl) {
const avoidCache = `?x=${Date.now()}`;
const modelURL = modelUrl + "model.json" + avoidCache;
const metadataURL = modelUrl + "metadata.json" + avoidCache;
const customMobileNet = await tmImage.load(modelURL, metadataURL);
if ((customMobileNet as any)._metadata.hasOwnProperty('tfjsSpeechCommandsVersion')) {
const recognizer = await speechCommands.create("BROWSER_FFT", undefined, modelURL, metadataURL);
await recognizer.ensureModelLoaded();
await recognizer.listen(async result => {
this.latestAudioResults = result;
}, {
includeSpectrogram: true, // in case listen should return result.spectrogram
probabilityThreshold: 0.75,
invokeCallbackOnNoiseAndUnknown: true,
overlapFactor: 0.50 // probably want between 0.5 and 0.75. More info in README
});
return { model: recognizer, type: this.ModelType.AUDIO };
} else if ((customMobileNet as any)._metadata.packageName === "@teachablemachine/pose") {
const customPoseNet = await tmPose.load(modelURL, metadataURL);
return { model: customPoseNet, type: this.ModelType.POSE };
} else {
console.log(customMobileNet.getMetadata(), customMobileNet.getTotalClasses(), customMobileNet.getClassLabels());
return { model: customMobileNet, type: this.ModelType.IMAGE };
}
}

updateStageModel(modelUrl) {
const stage = this.runtime.getTargetForStage();
this.teachableImageModel = modelUrl;
if (stage) {
(stage as any).teachableImageModel = modelUrl;
}
}

model_match(state) {
const modelUrl = this.teachableImageModel;
const className = state;

const predictionState = this.getPredictionStateOrStartPredicting(modelUrl);
if (!predictionState) {
return false;
}

const currentMaxClass = predictionState.topClass;
return (currentMaxClass === String(className));
}

getModelClasses(): string[] {
if (
!this.teachableImageModel ||
!this.predictionState ||
!this.predictionState[this.teachableImageModel] ||
!this.predictionState[this.teachableImageModel].hasOwnProperty('model')
) {
return ["Select a class"];
}

if (this.predictionState[this.teachableImageModel].modelType === this.ModelType.AUDIO) {
return this.predictionState[this.teachableImageModel].model.wordLabels();
}

return this.predictionState[this.teachableImageModel].model.getClassLabels();
}

getModelPrediction() {
const modelUrl = this.teachableImageModel;
const predictionState: { topClass: string } = this.getPredictionStateOrStartPredicting(modelUrl);
if (!predictionState) {
console.error("No prediction state found");
return '';
}
return predictionState.topClass;
}

private _loop() {
setTimeout(this._loop.bind(this), Math.max(this.runtime.currentStepTime, this.INTERVAL));
console.log('Running loop');
const time = Date.now();
if (this.lastUpdate === null) {
this.lastUpdate = time;
}
if (!this.isPredicting) {
this.isPredicting = 0;
}
const offset = time - this.lastUpdate;

if (offset > this.INTERVAL && this.isPredicting === 0) {
this.lastUpdate = time;
this.isPredicting = 0;
this.getImageStreamAndPredict();
}
}

private async getImageStreamAndPredict() {
try {
const imageStream = await this.getImageStream();
if (!imageStream) {
console.error("Failed to get image stream");
return;
}
console.log("received new image stream");
const imageBitmap = await createImageBitmap(imageStream);
this.predictAllBlocks(imageBitmap);
} catch (error) {
console.error("Error in getting image stream and predicting:", error);
}
}

private async predictAllBlocks(frame: ImageBitmap) {
console.log('Starting prediction with frame:', frame);
for (let modelUrl in this.predictionState) {
if (!this.predictionState[modelUrl].model) {
console.log('No model found for:', modelUrl);
continue;
}
if (this.teachableImageModel !== modelUrl) {
console.log('Model URL mismatch:', modelUrl);
continue;
}
++this.isPredicting;
console.log('Starting prediction, isPredicting:', this.isPredicting);
const prediction = await this.predictModel(modelUrl, frame);
console.log('Prediction:', prediction);
this.predictionState[modelUrl].topClass = prediction;
--this.isPredicting;
}
}

private async predictModel(modelUrl: string, frame: ImageBitmap) {
const predictions = await this.getPredictionFromModel(modelUrl, frame);
if (!predictions) {
return;
}
let maxProbability = 0;
let maxClassName = "";
for (let i = 0; i < predictions.length; i++) {
const probability = predictions[i].probability.toFixed(2);
const className = predictions[i].className;
this.modelConfidences[className] = probability;
if (probability > maxProbability) {
maxClassName = className;
maxProbability = probability;
}
}
this.maxConfidence = maxProbability;
return maxClassName;
}

private async getPredictionFromModel(modelUrl: string, frame: ImageBitmap) {
const { model, modelType } = this.predictionState[modelUrl];
switch (modelType) {
case this.ModelType.IMAGE:
if (!frame) return null;
return await model.predict(frame);
case this.ModelType.POSE:
if (!frame) return null;
const { pose, posenetOutput } = await model.estimatePose(frame);
return await model.predict(posenetOutput);
case this.ModelType.AUDIO:
if (this.latestAudioResults) {
return model.wordLabels().map((label, i) => ({
className: label,
probability: this.latestAudioResults.scores[i]
}));
}
return null;
}
}
}

4 changes: 4 additions & 0 deletions extensions/src/doodlebot/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
"license": "ISC",
"dependencies": {
"@mediapipe/tasks-vision": "^0.10.12",
"@teachablemachine/image": "^0.8.5",
"@teachablemachine/pose": "^0.8.6",
"@tensorflow-models/speech-commands": "^0.5.4",
"@tensorflow/tfjs": "^4.17.0",
"@types/web-bluetooth": "^0.0.20",
"axios": "^1.7.7",
"bezier-js": "^6.1.4",
Expand Down
Loading

0 comments on commit 04e380a

Please sign in to comment.