Source code for objax.jaxboard

# Copyright 2020 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import os
from time import time
from typing import Union, Callable, Tuple, ByteString

import numpy as np
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.summary.writer.event_file_writer import EventFileWriter
from tensorboard.util.tensor_util import make_tensor_proto

from objax import util

class Reducer(enum.Enum):
    """Reduces tensor batch into a single tensor."""
    FIRST = lambda x: x[0]
    LAST = lambda x: x[-1]
    MEAN = lambda x: np.mean(x)

class DelayedScalar:
    def __init__(self, reduce: Union[Callable, Reducer]):
        self.values = []
        self.reduce = reduce

    def __call__(self):
        return self.reduce(self.values)

class Image:
    def __init__(self, shape: Tuple[int, int, int], png: ByteString):
        self.shape = shape
        self.png = png

class Text:
    def __init__(self, text: str):
        self.text = text

[docs]class Summary(dict): """Writes entries to `Summary` protocol buffer."""
[docs] def image(self, tag: str, image: np.ndarray): """Adds image to the summary. Float image in [-1, 1] in CHW format expected.""" self[tag] = Image(image.shape, util.image.to_png(image))
[docs] def scalar(self, tag: str, value: float, reduce: Union[Callable, Reducer] = Reducer.MEAN): """Adds scalar to the summary.""" if tag not in self: self[tag] = DelayedScalar(reduce) self[tag].values.append(value)
[docs] def text(self, tag: str, text: str): """Adds text to the summary.""" self[tag] = Text(text)
[docs] def __call__(self): entries = [] for tag, value in self.items(): if isinstance(value, DelayedScalar): entries.append(summary_pb2.Summary.Value(tag=tag, simple_value=value())) elif isinstance(value, Image): image_summary = summary_pb2.Summary.Image(encoded_image_string=value.png, colorspace=value.shape[0], height=value.shape[1], width=value.shape[2]) entries.append(summary_pb2.Summary.Value(tag=tag, image=image_summary)) elif isinstance(value, Text): metadata = summary_pb2.SummaryMetadata( plugin_data=summary_pb2.SummaryMetadata.PluginData(plugin_name='text')) entries.append(summary_pb2.Summary.Value(tag=tag, metadata=metadata, tensor=make_tensor_proto(values=value.text.encode('utf-8'), shape=(1,)))) else: raise NotImplementedError(tag, value) return summary_pb2.Summary(value=entries)
[docs]class SummaryWriter: """Writes entries to event files in the logdir to be consumed by Tensorboard."""
[docs] def __init__(self, logdir: str, queue_size: int = 5, write_interval: int = 5): """Creates SummaryWriter instance. Args: logdir: directory where event file will be written. queue_size: size of the queue for pending events and summaries before one of the 'add' calls forces a flush to disk. write_interval: how often, in seconds, to write the pending events and summaries to disk. """ if not os.path.isdir(logdir): os.makedirs(logdir, exist_ok=True) self.writer = EventFileWriter(logdir, queue_size, write_interval)
[docs] def write(self, summary: Summary, step: int): """Adds on event to the event file.""" self.writer.add_event(event_pb2.Event(step=step, summary=summary(), wall_time=time()))
[docs] def close(self): """Flushes the event file to disk and close the file.""" self.writer.close()
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()