Large drop in accuracy in QKeras => hls4ml conversion regardless of quant/overflow params #614
Replies: 4 comments 6 replies
-
To show the output we get, the results of this sweep (modified to run different configurations in parallel with a multiprocessing pool and presented as "configuration: match %") are below The only configurations that give a 100% match between the qkeras and hls4ml models are all using 32 bit words, but for our use case that is quite excessive and will yield resource usage that is way too high.
|
Beta Was this translation helpful? Give feedback.
-
Hi,
At the step |
Beta Was this translation helpful? Give feedback.
-
As a follow up here, AutoQKeras should be able to perform per-layer quantization for us, right? Even if I adapt the AutoQKeras segment from the CNN tutorial notebook, I still end up seeing large differences in behavior between the "best model" from autoqkeras's search and the resulting performance after conversion via hls4ml. I'm assuming I must still be doing something wrong somewhere, but I can't figure out what it is. With the following script, I end up with an output of
i.e. the hls4ml model is getting a training accuracy basically equivalent to random guessing despite autoqkeras giving a model with reasonable enough accuracy. Is there a chance that this is related to our usage of Xilinx 2019.2? We still get good performance when we run that CNN notebook using the same virtualenv, so it doesn't seem to be a package/environment issue #!/usr/bin/env python
# coding: utf-8
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
import numpy as np
import hls4ml
from qkeras.autoqkeras import AutoQKeras
from qkeras.autoqkeras.utils import print_qmodel_summary
from sklearn.metrics import accuracy_score
import os
os.environ['PATH'] = '/media/jmack2545/data_drive/Xilinx/Vivado/2019.2/bin:' + os.environ['PATH']
def print_dict(d, indent=0):
align=20
for key, value in d.items():
print(' ' * indent + str(key), end='')
if isinstance(value, dict):
print()
print_dict(value, indent+1)
else:
print(':' + ' ' * (20 - len(key) - 2 * indent) + str(value))
def generate_mnist_dataset(one_hot = True):
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
if one_hot:
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
return (x_train,y_train),(x_test,y_test)
def generate_mnist_network(filt_size, n_filt, n_dense):
# build the baseline architecture
in_layer = Input((28,28,1))
x = in_layer
x = Conv2D(filters=n_filt,
kernel_size=filt_size,
padding='valid',
kernel_initializer='glorot_uniform',
)(x)
x = BatchNormalization(name='bn_conv_1')(x)
x = Activation(activation='relu')(x)
x = MaxPooling2D()(x)
x = Flatten()(x)
x = Dense(n_dense)(x)
x = BatchNormalization(name='bn_dense_1')(x)
x = Activation(activation='relu')(x)
# avoid touching these with the autoqkeras optimization process
x = Dense(10, name='output_dense')(x)
out_layer = Activation(activation='softmax', name='output_softmax')(x)
model = Model(inputs=[in_layer],outputs=[out_layer])
# check hls4ml synthesizability with reuse factor = 1
for layer in model.layers:
if layer.__class__.__name__ in ['Conv2D', 'Dense']:
w = layer.get_weights()[0]
layersize = np.prod(w.shape)
print("{}: {}".format(layer.name,layersize)) # 0 = weights, 1 = biases
if (layersize > 4096): # assuming that shape[0] is batch, i.e., 'None'
print("Layer {} is too large ({}), are you sure you want to train?".format(layer.name,layersize))
# finishing touches...
LOSS = tf.keras.losses.CategoricalCrossentropy()
OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=3E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True)
model.compile(loss=LOSS, optimizer=OPTIMIZER, metrics=["accuracy"])
return model
def get_autoq_model(nn, x_train, y_train, x_test, y_test):
quantization_config = {
"kernel": {
"quantized_bits(2,0,1,alpha=1.0)": 2,
"quantized_bits(4,0,1,alpha=1.0)": 4,
"quantized_bits(6,0,1,alpha=1.0)": 6,
"quantized_bits(8,0,1,alpha=1.0)": 8,
},
"bias": {
"quantized_bits(2,0,1,alpha=1.0)": 2,
"quantized_bits(4,0,1,alpha=1.0)": 4,
"quantized_bits(6,0,1,alpha=1.0)": 6,
"quantized_bits(8,0,1,alpha=1.0)": 8,
},
"activation": {
"quantized_relu(3,1)": 3,
"quantized_relu(4,2)": 4,
"quantized_relu(8,2)": 8,
"quantized_relu(8,4)": 8,
"quantized_relu(16,6)": 16
},
"linear": {
"quantized_bits(2,0,1,alpha=1.0)": 2,
"quantized_bits(4,0,1,alpha=1.0)": 4,
"quantized_bits(6,0,1,alpha=1.0)": 6,
"quantized_bits(8,0,1,alpha=1.0)": 8,
}
}
# Layer-type limitations on which configurations are acceptable
limit = {
"Dense": [8, 8, 16],
"Conv2D": [8, 8, 16],
"Activation": [16],
}
goal_energy = {
"type": "energy",
"params": {
"delta_p": 8.0,
"delta_n": 8.0,
"rate": 1.5,
"stress": 1.0,
"process": "horowitz",
"parameters_on_memory": ["sram", "sram"],
"activations_on_memory": ["sram", "sram"],
"rd_wr_on_io": [False, False],
"min_sram_size": [0, 0],
"source_quantizers": ["fp32"],
"reference_internal": "int8",
"reference_accumulator": "int32"
}
}
run_config = {
# Note: goal_bits seems to just not work well in general even though it's more of what we want
"goal": goal_energy,
"quantization_config": quantization_config,
"learning_rate_optimizer": False,
"transfer_weights": False, # Don't randomly initialize weights
"mode": "bayesian", # This can be bayesian,random,hyperband
"seed": 1000,
"limit": limit,
"tune_filters": "layer",
"tune_filters_exceptions": "^output",
"distribution_strategy": None,
"max_trials": 2 # Let's just do 2 trials for this demonstrator, ideally you should do as many as possible
}
autoqk = AutoQKeras(nn, output_dir=f'autoq_mnist_sweep', metrics=["acc"], custom_objects={}, **run_config)
# yeah yeah, test data != validation data
autoqk.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5)
aqmodel = autoqk.get_best_model()
print_qmodel_summary(aqmodel)
return aqmodel
def strip_autoq_model(aqmodel):
aqmodel.save_weights(f"autoq_mnist_finalparams.h5")
layers = [l for l in aqmodel.layers]
x = layers[0].output
for i in range(1, len(layers)):
x = layers[i](x)
new_model = Model(inputs=[layers[0].input], outputs=[x])
LOSS = tf.keras.losses.CategoricalCrossentropy()
OPTIMIZER = tf.keras.optimizers.Adam(learning_rate=3E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True)
new_model.compile(loss=LOSS, optimizer=OPTIMIZER, metrics=["accuracy"])
new_model.summary()
new_model.load_weights(f"autoq_mnist_finalparams.h5")
print_qmodel_summary(new_model)
return new_model
def convert_via_hls4ml(new_model):
hls4ml.model.optimizer.OutputRoundingSaturationMode.layers = ['Activation']
hls4ml.model.optimizer.OutputRoundingSaturationMode.rounding_mode = 'AP_RND'
hls4ml.model.optimizer.OutputRoundingSaturationMode.saturation_mode = 'AP_SAT'
hls_config_aq = hls4ml.utils.config_from_keras_model(new_model, granularity='name')
hls_config_aq['Model']['ReuseFactor'] = 1
hls_config_aq['Model']['Precision'] = 'ap_fixed<16,6>'
hls_config_aq['Model']['Strategy'] = 'Latency'
hls_config_aq['LayerName']['output_softmax']['Strategy'] = 'Stable'
print_dict(hls_config_aq)
cfg_aq = hls4ml.converters.create_config(backend='Vivado')
cfg_aq['IOType'] = 'io_stream'
cfg_aq['HLSConfig'] = hls_config_aq
cfg_aq['KerasModel'] = new_model
cfg_aq['OutputDir'] = f'autoq_mnist_hls4ml/'
cfg_aq['XilinxPart'] = 'xczu28dr-ffvg1517-2-e'
#pprint.pprint(cfg_aq)
hls_model_aq = hls4ml.converters.keras_to_hls(cfg_aq)
# hls_model_aq.build(reset=True, synth=False, vsynth=False, csim=False, cosim=False)
hls_model_aq.compile()
return hls_model_aq
if __name__ == "__main__":
(x_train,y_train), (x_test,y_test) = generate_mnist_dataset()
x_train = x_train.astype(float)
x_test = x_test.astype(float)
# instantiate a baseline architecture
nn = generate_mnist_network(filt_size = 3, n_filt = 5, n_dense = 4)
# execute the autoqkeras search
aqmodel = get_autoq_model(nn, x_train, y_train, x_test, y_test)
# train the best model found by autoqkeras a bit more
aqmodel.fit(x_train,
y_train,
epochs = 5,
validation_data = (x_test, y_test),
callbacks = [],
verbose=1)
# This model (apparently) has some remnants from the optimization procedure attached to it
new_model = strip_autoq_model(aqmodel)
# Create the hls4ml project directory and compile our C++ representation
hls_model_aq = convert_via_hls4ml(new_model)
# compare the two networks
y_predict_aq = new_model.predict(x_train)
y_predict_hls4ml_aq = hls_model_aq.predict(x_train)
num_matches = sum(np.argmax(y_predict_aq, axis=1) == np.argmax(y_predict_hls4ml_aq, axis=1))
num_images = y_predict_aq.shape[0]
print(f"Match % (AutoQ Keras vs AutoQ hls4ml): {num_matches / num_images * 100.0}")
accuracy_keras = 100 * float(accuracy_score (np.argmax(y_train, axis=1), np.argmax(y_predict_aq, axis=1)))
accuracy_hls4ml = 100 * float(accuracy_score (np.argmax(y_train, axis=1), np.argmax(y_predict_hls4ml_aq, axis=1)))
print("Accuracy AutoQ Keras: {}".format(accuracy_keras))
print("Accuracy AutoQ hls4ml: {}".format(accuracy_hls4ml)) |
Beta Was this translation helpful? Give feedback.
-
follow up here, I didn't have much luck modifying hls4ml.model.optimizer.OutputRoundingSaturationMode.layers = ['Activation']
hls4ml.model.optimizer.OutputRoundingSaturationMode.rounding_mode = 'AP_RND_CONV'
hls4ml.model.optimizer.OutputRoundingSaturationMode.saturation_mode = 'AP_SAT'
hls_config = hls4ml.utils.config_from_keras_model(autoq_network, granularity='name')
hls_config['Model']['ReuseFactor'] = 1
hls_config['Model']['Strategy'] = 'Latency'
hls_config['LayerName']['output_softmax']['Strategy'] = 'Stable'
quantizers_to_modify = ['result']
big_quantizer = 'ap_fixed<64,48,AP_RND_CONV,AP_SAT>'
for layer in hls_config['LayerName'].keys():
if isinstance(hls_config['LayerName'][layer]['Precision'], dict):
for quantizer in hls_config['LayerName'][layer]['Precision'].keys():
if quantizer in quantizers_to_modify:
hls_config['LayerName'][layer]['Precision'][quantizer] = big_quantizer
cfg = hls4ml.converters.create_config(backend='Vivado')
cfg['IOType'] = 'io_stream'
cfg['HLSConfig'] = hls_config
cfg['KerasModel'] = autoq_network
cfg['OutputDir'] = f'hls4ml_proj/'
cfg['XilinxPart'] = 'xczu28dr-ffvg1517-2-e'
hls_model = hls4ml.converters.keras_to_hls(cfg)
hls_model.compile() which then gives
thanks for the suggestion there @thesps |
Beta Was this translation helpful? Give feedback.
-
We're in the process of training QKeras models with fixed precision and then converting them to hls4ml (tried
0.6.0
and db943b7) models. However, we're running into issues where we can't seem to get hls4ml's model outputs to match the baseline behavior/accuracy of our QKeras model. This is quite different from what we expected as we're under the assumption that QKeras' quantization model should match hls4ml's for some selection ofap_fixed<*, *, Q, O>
quantization and overflow modesAt the bottom of this post is a proof-of-concept python script that shows what we're doing via an mnist sample network (sweeping through various word/int widths and quantization/overflow modes and retraining + testing). We seem to never find a configuration where hls4ml matches QKeras over ~80% of the time (with the exception of, say, some extremely large 32 bit models). We've seen this behavior across different image classification tasks too, though, so we believe whatever issue we're having is fairly data independent.
If anyone has thoughts as to what we might be doing wrong, we'd appreciate any feedback.
Beta Was this translation helpful? Give feedback.
All reactions