プリミティブは球体のみ、シェーディングはフラットシェーディング、カメラは正投影、解像度固定。
import math from itertools import product from functools import partial def range2(rx, ry): return product(range(rx), range(ry)) class Vector(object): def __init__(self, x, y, z): self.x = x self.y = y self.z = z def normalize(self): v = math.sqrt(self.x*self.x + self.y*self.y + self.z*self.z) self.x /= v self.y /= v self.z /= v def __mul__(self, v): return self.x*v.x + self.y*v.y + self.z*v.z def __sub__(self, v): return Vector(self.x - v.x, self.y - v.y, self.z - v.z) def printval(self): print self.x, self.y, self.z class Color(Vector): def __init__(self, r, g, b): super(Color, self).__init__(r, g, b) def printval(self): print '%d %d %d' % (self.x*255, self.y*255, self.z*255) class Point(Vector): def __init__(self, x, y, z): super(Point, self).__init__(x, y, z) class Ray(object): def __init__(self, p, v): self.p = p self.v = v self.v.normalize() class Primitive(object): def __init__(self): pass def intersect(self, ray): return (None, None) def color(self): pass class Back(Primitive): def __init__(self): super(Back, self).__init__() def intersect(self, ray): return (self, float('inf')) def color(self): return Color(0.5, 0.5, 0.5) class Sphere(Primitive): def __init__(self, p, rad, col): self.cp = p self.rad = rad self.col = col def _pmin(self, a, b): if b < 0.0: return None if a < 0.0: return b return a def intersect(self, ray): v = ray.p - self.cp n = v*ray.v det = n*n - v*v + self.rad*self.rad if det < 0.0: return (self, None) d = math.sqrt(det) t = self._pmin(-n - d, -n + d) return (self, t) def color(self): return self.col def create_ray(): rx = 200 ry = 200 f = lambda r, x: 1.0 - 2.0/r*x v = Vector(0.0, 0.0, -1.0) rs = [Ray(Point(f(rx, x), f(ry, y), 2.0), v) for x, y in range2(rx, ry)] return rs def get_objs(): return [Sphere(Point(0.0, 0.0, 0.0), 0.5, Color(1.0, 0.0, 0.0)), Sphere(Point(0.5, 0.5, -1.0), 0.5, Color(0.0, 1.0, 0.0)), Back()] def trace(ray): objs = get_objs() func = lambda ray, obj: obj.intersect(ray) ts = map(partial(func, ray), objs) ts = filter(lambda x: x[1], ts) ts.sort(key=lambda x: x[1]) return ts[0][0].color() def output(cs): print 'P3\n200 200\n255' map(lambda c: c.printval(), cs) def main(): rs = create_ray() cs = map(trace, rs) output(cs) main()